diff --git a/crates/storage-query-datafusion/src/filter.rs b/crates/storage-query-datafusion/src/filter.rs index 374fe0f118..84d172576a 100644 --- a/crates/storage-query-datafusion/src/filter.rs +++ b/crates/storage-query-datafusion/src/filter.rs @@ -20,7 +20,7 @@ use datafusion::logical_expr::Operator; use datafusion::physical_expr::split_conjunction; use datafusion::physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion::physical_plan::PhysicalExpr; -use datafusion::physical_plan::expressions::{BinaryExpr, Column, InListExpr, Literal}; +use datafusion::physical_plan::expressions::{BinaryExpr, Column, InListExpr, IsNullExpr, Literal}; use strum::EnumCount; use restate_storage_api::vqueue_table::Stage; @@ -90,14 +90,44 @@ impl FirstMatchingPartitionKeyExtractor { } pub fn with_service_key(self, column_name: impl Into) -> Self { - let e = MatchingColumnExtractor::new(column_name, |value: &ScalarValue| { + self.append(Self::create_service_key_partition_key_extractor( + column_name, + )) + } + + fn create_service_key_partition_key_extractor( + column_name: impl Into, + ) -> MatchingColumnExtractor anyhow::Result> { + MatchingColumnExtractor::new(column_name, |value: &ScalarValue| { let value = value .try_as_str() .context("expected string service key")? .context("unexpected null service key")?; Ok(HashPartitioner::compute_partition_key(value)) + }) + } + + /// For tables sharded by `scope_column` when scoped and by `service_key_column` when + /// unscoped (i.e. `scope_column IS NULL`). Extracts a partition key from either: + /// - `scope = '...'` / `scope IN (...)` (sharded under `hash(scope)`), or + /// - `scope IS NULL AND service_key = '...'` / `IN (...)` (sharded under `hash(service_key)`). + pub fn with_scope_or_service_key( + self, + scope_column: impl Into, + service_key_column: impl Into, + ) -> Self { + let scope_column: String = scope_column.into(); + let by_scope = MatchingColumnExtractor::new(scope_column.clone(), |value: &ScalarValue| { + let value = value + .try_as_str() + .context("expected scope")? + .context("null scopes cannot be used for partition-key matching")?; + Ok(HashPartitioner::compute_partition_key(value)) }); - self.append(e) + self.append(by_scope).append(WhenNullExtractor::new( + scope_column, + Self::create_service_key_partition_key_extractor(service_key_column), + )) } pub fn with_invocation_id(self, column_name: impl Into) -> Self { @@ -209,6 +239,64 @@ where } } +/// Gates an inner [`PartitionKeyExtractor`] on the presence of a top-level +/// ` IS NULL` conjunct. +/// +/// Used for tables that are sharded differently depending on whether a column is null +/// (e.g. `state` and `sys_promise`: scoped rows live at `hash(scope)`, unscoped rows at +/// `hash(service_key)`). When the user writes `... AND IS NULL`, scoped +/// rows are filtered out by the predicate anyway, so it's safe to narrow the scan +/// using a key derived from another column. +pub(crate) struct WhenNullExtractor { + null_column_name: String, + inner: E, +} + +impl Debug for WhenNullExtractor { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "WhenNullExtractor({:?})", + self.null_column_name + )) + } +} + +impl WhenNullExtractor { + pub(crate) fn new(null_column_name: impl Into, inner: E) -> Self { + Self { + null_column_name: null_column_name.into(), + inner, + } + } +} + +impl PartitionKeyExtractor for WhenNullExtractor +where + E: PartitionKeyExtractor, +{ + fn try_extract( + &self, + filters: &[Arc], + ) -> anyhow::Result>> { + // Only accept a bare top-level `IsNullExpr` against a `Column`. An `IsNullExpr` + // nested in `Or`/`Not`/etc. does not count: e.g. `(scope IS NULL OR scope IS NOT NULL)` + // would otherwise spuriously gate the inner extractor open. + let has_null_check = filters.iter().any(|filter| { + filter + .as_any() + .downcast_ref::() + .and_then(|is_null| is_null.arg().as_any().downcast_ref::()) + .is_some_and(|column| column.name() == self.null_column_name) + }); + + if !has_null_check { + return Ok(None); + } + + self.inner.try_extract(filters) + } +} + /// A normalized representation of predicates that compare a column to literal values. /// Handles `col = lit`, `col IN (lit, ...)`, and `col = lit OR col = lit ...` patterns. struct InList<'a> { @@ -422,7 +510,9 @@ mod tests { use datafusion::common::ScalarValue; use datafusion::physical_plan::PhysicalExpr; - use datafusion::physical_plan::expressions::{BinaryExpr, Column, InListExpr, Literal}; + use datafusion::physical_plan::expressions::{ + BinaryExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, + }; use restate_storage_api::vqueue_table::Stage; use restate_types::identifiers::{InvocationId, ServiceId, StateMutationId, WithPartitionKey}; @@ -442,6 +532,14 @@ mod tests { Arc::new(Literal::new(ScalarValue::LargeUtf8(Some(value.into())))) } + fn is_null(name: &str) -> Arc { + Arc::new(IsNullExpr::new(col(name))) + } + + fn is_not_null(name: &str) -> Arc { + Arc::new(IsNotNullExpr::new(col(name))) + } + fn eq(left: Arc, right: Arc) -> Arc { Arc::new(BinaryExpr::new( left, @@ -708,6 +806,105 @@ mod tests { assert_eq!(None, got_keys); } + fn scope_or_service_key_extractor() -> FirstMatchingPartitionKeyExtractor { + FirstMatchingPartitionKeyExtractor::default() + .with_scope_or_service_key("scope", "service_key") + } + + #[test] + fn service_key_when_scope_is_null_extracts_partition_key() { + let expected = ServiceId::new(None, "svc", "k").partition_key(); + + let got = scope_or_service_key_extractor() + .try_extract(&[is_null("scope"), eq(col("service_key"), utf8_lit("k"))]) + .expect("extract") + .expect("partition key"); + + assert_eq!(1, got.len()); + assert_eq!(expected, got.into_iter().next().unwrap()); + } + + #[test] + fn service_key_in_list_when_scope_is_null() { + let expected_a = ServiceId::new(None, "svc", "a").partition_key(); + let expected_b = ServiceId::new(None, "svc", "b").partition_key(); + + let got = scope_or_service_key_extractor() + .try_extract(&[ + is_null("scope"), + in_list("service_key", vec![utf8_lit("a"), utf8_lit("b")]), + ]) + .expect("extract") + .expect("partition keys"); + + assert_eq!(2, got.len()); + assert!(got.contains(&expected_a)); + assert!(got.contains(&expected_b)); + } + + #[test] + fn service_key_without_scope_is_null_returns_none() { + // Without the explicit `scope IS NULL` guard, the extractor cannot narrow because + // scoped rows for the same service_key live at hash(scope), not hash(service_key). + let got = scope_or_service_key_extractor() + .try_extract(&[eq(col("service_key"), utf8_lit("k"))]) + .expect("extract"); + + assert_eq!(None, got); + } + + #[test] + fn scope_is_null_alone_returns_none() { + let got = scope_or_service_key_extractor() + .try_extract(&[is_null("scope")]) + .expect("extract"); + + assert_eq!(None, got); + } + + #[test] + fn scope_is_not_null_does_not_trigger() { + // IsNotNullExpr is a distinct type from IsNullExpr; the gate must stay closed. + let got = scope_or_service_key_extractor() + .try_extract(&[is_not_null("scope"), eq(col("service_key"), utf8_lit("k"))]) + .expect("extract"); + + assert_eq!(None, got); + } + + #[test] + fn scope_is_null_inside_or_does_not_trigger() { + // `(scope IS NULL OR scope IS NOT NULL)` is a top-level Or, not a bare IsNullExpr. + // The gate must stay closed so we don't narrow under a tautology. + let got = scope_or_service_key_extractor() + .try_extract(&[ + or(is_null("scope"), is_not_null("scope")), + eq(col("service_key"), utf8_lit("k")), + ]) + .expect("extract"); + + assert_eq!(None, got); + } + + #[test] + fn scope_is_null_or_service_key_does_not_trigger() { + // Single Or conjunct: neither side is a bare top-level IsNullExpr against scope. + let got = scope_or_service_key_extractor() + .try_extract(&[or(is_null("scope"), eq(col("service_key"), utf8_lit("k")))]) + .expect("extract"); + + assert_eq!(None, got); + } + + #[test] + fn scope_is_null_on_different_column_does_not_trigger() { + let got = scope_or_service_key_extractor() + .try_extract(&[is_null("other_col"), eq(col("service_key"), utf8_lit("k"))]) + .expect("extract"); + + assert_eq!(None, got); + } + #[test] fn invocation_id_filter_single_eq() { let id = make_invocation_id("key-1"); diff --git a/crates/storage-query-datafusion/src/promise/table.rs b/crates/storage-query-datafusion/src/promise/table.rs index 54ee603bc4..95290ad3e4 100644 --- a/crates/storage-query-datafusion/src/promise/table.rs +++ b/crates/storage-query-datafusion/src/promise/table.rs @@ -42,11 +42,8 @@ pub(crate) fn register_self( SysPromiseBuilder::schema(), sys_promise_sort_order(), remote_scanner_manager.create_distributed_scanner(NAME, local_scanner), - // We can no longer extract the partition key for unscoped promises based solely on - // the service_key as we might be missing scoped promises for this service key. We - // could only do this if scope is null for which we don't have a partition key extractor - // construct yet. - FirstMatchingPartitionKeyExtractor::default().with_scope("scope"), + FirstMatchingPartitionKeyExtractor::default() + .with_scope_or_service_key("scope", "service_key"), ); ctx.register_partitioned_table(NAME, Arc::new(table)) } diff --git a/crates/storage-query-datafusion/src/state/table.rs b/crates/storage-query-datafusion/src/state/table.rs index 03f2813ff7..38ad7cd66f 100644 --- a/crates/storage-query-datafusion/src/state/table.rs +++ b/crates/storage-query-datafusion/src/state/table.rs @@ -44,11 +44,8 @@ pub(crate) fn register_self( StateBuilder::schema(), state_sort_order(), remote_scanner_manager.create_distributed_scanner(NAME, local_scanner), - // We can no longer extract the partition key for unscoped state entries based solely on - // the service_key as we might be missing scoped state entries for this service key. We - // could only do this if scope is null for which we don't have a partition key extractor - // construct yet. - FirstMatchingPartitionKeyExtractor::default().with_scope("scope"), + FirstMatchingPartitionKeyExtractor::default() + .with_scope_or_service_key("scope", "service_key"), ); ctx.register_partitioned_table(NAME, Arc::new(table)) } diff --git a/crates/storage-query-datafusion/src/tests.rs b/crates/storage-query-datafusion/src/tests.rs index ffe5dfcc63..e77863cfc2 100644 --- a/crates/storage-query-datafusion/src/tests.rs +++ b/crates/storage-query-datafusion/src/tests.rs @@ -757,8 +757,7 @@ async fn query_state_with_service_key_filter() { ) ); - // Unfortunately, this query will no longer be a single partition key scan but instead a full - // partition key range scan :-( + // With `scope IS NULL` explicit, the conditional extractor narrows to hash(service_key). let null_scope_service_key = engine .execute( "SELECT service_name FROM state \