Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 83 additions & 61 deletions crates/core/src/task_center.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub use monitoring::*;
pub use runtime::*;
pub use task::*;
pub use task_kind::*;
use tokio::runtime::LocalOptions;

use std::collections::HashMap;
use std::future::Future;
Expand All @@ -37,7 +38,6 @@ use futures::future::BoxFuture;
use metrics::counter;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use tokio::task::LocalSet;
use tokio::task_local;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
Expand All @@ -47,6 +47,7 @@ use tracing::{debug, error, info, trace, warn};
use crate::metric_definitions::{STATUS_COMPLETED, STATUS_FAILED, TC_FINISHED, TC_SPAWN};
use crate::{Metadata, ShutdownError, ShutdownSourceErr};
use restate_memory::MemoryController;
use restate_platform::prelude::*;
use restate_types::SharedString;
use restate_types::cluster_state::ClusterState;
use restate_types::config::Configuration;
Expand Down Expand Up @@ -75,8 +76,12 @@ struct GlobalOverrides {

#[derive(Debug, thiserror::Error)]
pub enum RuntimeError {
#[error("Runtime with name {0} already exists")]
#[error("runtime with name {0} already exists")]
AlreadyExists(String),
#[error("[io] cannot start thread:{0}")]
Io(std::io::Error),
#[error("[os] cannot start thread:{0}")]
Os(GenericError),
#[error(transparent)]
Shutdown(#[from] ShutdownError),
}
Expand Down Expand Up @@ -716,8 +721,17 @@ impl TaskCenterInner {
if self.shutdown_requested.load(Ordering::Relaxed) {
return Err(ShutdownError.into());
}

let cancel = CancellationToken::new();
let tc = self.clone();
let runtime_name: SharedString = runtime_name.into();
// todo: configure the runtime according to a new runtime kind perhaps?
let thread_builder = std::thread::Builder::new().name(format!("rt:{runtime_name}"));
// capture the runtime handle if it started successfully. Uses std mpsc rather than
// tokio's oneshot so we can block on recv() from inside a tokio runtime without panicking.
let (start_tx, start_rx) = std::sync::mpsc::sync_channel::<tokio::runtime::Handle>(1);

let (result_tx, result_rx) = oneshot::channel();

// hold a lock while creating the runtime to avoid concurrent runtimes with the same name
let mut runtimes_guard = self.managed_runtimes.lock();
Expand All @@ -729,60 +743,77 @@ impl TaskCenterInner {
return Err(RuntimeError::AlreadyExists(runtime_name.into_owned()));
}

// todo: configure the runtime according to a new runtime kind perhaps?
let thread_builder = std::thread::Builder::new().name(format!("rt:{runtime_name}"));
let mut builder = tokio::runtime::Builder::new_current_thread();

#[cfg(any(test, feature = "test-util"))]
builder.start_paused(self.pause_time);

let rt = builder
.enable_all()
.build()
.expect("runtime builder succeeds");
let tc = self.clone();

let rt_handle = Arc::new(rt);

runtimes_guard.insert(
runtime_name.clone(),
OwnedRuntimeHandle::new(cancel.clone(), rt_handle.clone()),
);

// release the lock.
drop(runtimes_guard);

let id = TaskId::default();
let context = TaskContext {
id,
name: runtime_name.clone(),
kind: root_task_kind,
cancellation_token: cancel.clone(),
partition_id,
};

let (result_tx, result_rx) = oneshot::channel();

let start = Arc::new(OnceLock::new());
// start the work on the runtime
let _ = thread_builder
let thread_handle: std::thread::JoinHandle<()> = thread_builder
.spawn({
let start = start.clone();
#[cfg(any(test, feature = "test-util"))]
let pause_time = self.pause_time;
let cancel = cancel.clone();
let runtime_name = runtime_name.clone();
move || {
let local_set = LocalSet::new();
let result = rt_handle.block_on(local_set.run_until(unmanaged_wrapper(
tc.clone(),
context,
root_future(),
)));

drop(rt_handle);
// todo: setup the thread (core affinity, thread-locals, etc.)
let mut builder = tokio::runtime::Builder::new_current_thread();

#[cfg(any(test, feature = "test-util"))]
builder.start_paused(pause_time);

let id = TaskId::default();
let context = TaskContext {
id,
name: runtime_name.clone(),
kind: root_task_kind,
cancellation_token: cancel.clone(),
partition_id,
};

// try to insert the runtime into the list of managed runtimes
let rt = builder
.enable_all()
.build_local(LocalOptions::default())
.expect("runtime builder succeeds");
if start_tx.send(rt.handle().clone()).is_err() {
warn!(
"Will not start runtime {runtime_name} since task center is no longer interested in it"
);
return;
}

// Wait until the runtime is registered in the map before doing any real work.
start.wait();

let result = rt.block_on(unmanaged_wrapper(tc.clone(), context, root_future()));

tc.drop_runtime(runtime_name);

// need to use an oneshot here since we cannot await a thread::JoinHandle :-(
let _ = result_tx.send(result);
}
})
.unwrap();
}).map_err(RuntimeError::Io)?;

let rt_handle = match start_rx.recv() {
Ok(rt_handle) => rt_handle,
Err(_) => {
cancel.cancel();
if thread_handle.is_finished()
&& let Err(e) = thread_handle.join()
{
std::panic::resume_unwind(e);
} else {
error!("Runtime {runtime_name} was not started successfully");
}
return Err(RuntimeError::Shutdown(ShutdownError));
}
};

runtimes_guard.insert(
runtime_name.clone(),
OwnedRuntimeHandle::new(cancel.clone(), rt_handle, thread_handle),
);
drop(runtimes_guard);
// let the runtime go ahead.
let _ = start.set(true);

Ok(RuntimeTaskHandle::new(runtime_name, cancel, result_rx))
}
Expand All @@ -791,14 +822,9 @@ impl TaskCenterInner {
/// the runtime handle.
fn drop_runtime(self: &Arc<Self>, name: SharedString) {
let mut runtimes_guard = self.managed_runtimes.lock();
if let Some(runtime) = runtimes_guard.remove(&name) {
if runtimes_guard.remove(&name).is_some() {
// We must be the only owner of runtime at this point.
debug!("Runtime {} completed", name);
let owner = Arc::into_inner(runtime.into_inner());
if let Some(runtime) = owner {
runtime.shutdown_timeout(Duration::from_secs(2));
trace!("Runtime {} shutdown completed", name);
}
trace!("Runtime {name} shutdown completed");
}
}

Expand Down Expand Up @@ -1217,13 +1243,9 @@ impl TaskCenterInner {
fn shutdown_managed_runtimes(self: &Arc<Self>) {
let mut runtimes = self.managed_runtimes.lock();
for (_, runtime) in runtimes.drain() {
if let Some(runtime) = Arc::into_inner(runtime.into_inner()) {
// This really isn't doing much, but it's left here for completion.
// The reason is: If the runtime is still running, then it'll hold the Arc until it
// finishes gracefully, yielding None here. If the runtime completed, it'll
// self-shutdown prior to reaching this point.
runtime.shutdown_background();
}
// This won't do anything if the shutdown was already initiated via
// `initiate_managed_runtimes_shutdown`
runtime.cancel();
}
}
}
Expand Down
14 changes: 6 additions & 8 deletions crates/core/src/task_center/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// by the Apache License, Version 2.0.

use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};

use futures::FutureExt;
Expand Down Expand Up @@ -63,32 +62,31 @@ impl<T> std::future::Future for RuntimeTaskHandle<T> {

pub(super) struct OwnedRuntimeHandle {
cancellation_token: CancellationToken,
inner: Arc<tokio::runtime::Runtime>,
inner: tokio::runtime::Handle,
_thread_handle: std::thread::JoinHandle<()>,
}

impl OwnedRuntimeHandle {
pub fn new(
cancellation_token: CancellationToken,
runtime: Arc<tokio::runtime::Runtime>,
runtime: tokio::runtime::Handle,
thread_handle: std::thread::JoinHandle<()>,
) -> Self {
Self {
cancellation_token,
inner: runtime,
_thread_handle: thread_handle,
}
}

// The runtime name
pub fn runtime_handle(&self) -> &tokio::runtime::Handle {
self.inner.handle()
&self.inner
}

/// Trigger graceful shutdown of the runtime root task. Shutdown is not guaranteed, it depends
/// on whether the root task awaits the cancellation token or not.
pub fn cancel(&self) {
self.cancellation_token.cancel()
}

pub fn into_inner(self) -> Arc<tokio::runtime::Runtime> {
self.inner
}
}
Loading