Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,30 @@ package object config {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("0s")

private[spark] val MANAGED_CONSUMER_ENABLED =
ConfigBuilder("spark.memory.managedConsumer.enabled")
.doc("If true, UnifiedMemoryManager will consult registered ManagedConsumer " +
"instances (via their shrink() method) before falling back to evicting " +
"internal cached blocks or shrinking the storage pool for execution. This " +
"ordering protects user-explicit persist() blocks from being dropped before " +
"best-effort external caches (e.g. native columnar caches like Velox " +
"AsyncDataCache exposed through Gluten).")
.withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE)
.version("4.1.0")
.booleanConf
.createWithDefault(false)

private[spark] val MANAGED_CONSUMER_SHRINK_WARN_THRESHOLD_MS =
ConfigBuilder("spark.memory.managedConsumer.shrinkWarnThresholdMs")
.doc("If a ManagedConsumer.shrink() call takes longer than this many " +
"milliseconds, log a warning. Because shrink() runs inside the MemoryManager " +
"monitor, long shrink calls block other acquisition requests; this threshold " +
"helps surface slow implementations.")
.withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE)
.version("4.1.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("100ms")

private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD =
ConfigBuilder("spark.storage.unrollMemoryThreshold")
.doc("Initial memory to request before unrolling any block")
Expand Down
70 changes: 70 additions & 0 deletions core/src/main/scala/org/apache/spark/memory/ManagedConsumer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.memory

/**
* Storage-memory counterpart of [[org.apache.spark.memory.MemoryConsumer]]: holds bytes
* acquired via [[MemoryManager.acquireStorageMemory]] and synchronously releases them on
* Spark's request. Typical implementor: a native off-heap cache (e.g. Velox AsyncDataCache
* via Gluten) sharing `spark.memory.offHeap.size` with Spark's MemoryStore.
*
* == Contract ==
* - [[name]] is the registry key (JVM-unique; ON_HEAP and OFF_HEAP share one namespace).
* Use the SAME instance for register / acquire / unregister so identity-based
* self-exclusion works during shrink rounds.
* - A component that also implements [[UnmanagedMemoryConsumer]] MUST NOT report the same
* bytes through both APIs -- they would be subtracted twice from `effectiveMaxMemory`.
* - `MemoryManager.shrinkExternal` owns storage-pool accounting: it deducts exactly
* `shrink`'s return value from the pool. Implementations MUST NOT call
* [[MemoryManager.releaseStorageMemory]] from inside [[shrink]].
* - [[shrink]] runs inside the `MemoryManager` monitor; it MUST NOT cycle back into
* `MemoryStore.{putBytes, remove, evictBlocksToFreeSpace}` (lock-order cycle on
* `MemoryStore.entries`) and SHOULD return within
* `spark.memory.managedConsumer.shrinkWarnThresholdMs` (default 100ms) to avoid
* blocking other acquisitions.
* - [[shrink]] MUST be synchronous (claimed bytes reclaimable on return). Over-release
* and partial release are fine; negative return is a contract violation. Exceptions
* are caught and treated as 0-byte release.
*/
trait ManagedConsumer {
Copy link
Copy Markdown
Member

@zhztheplayer zhztheplayer May 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the trait name clear? Existing MemoryConsumer is already managed by Spark.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — "managed vs unmanaged" is a weak axis here (everything in Spark is managed in some sense). Better to name on the differentiating capability:

Trait Layer Core verb
MemoryConsumer (Java) task spill()
UnmanagedMemoryConsumer executor getMemBytesUsed
this PR executor shrink()

Proposal: rename to ShrinkableMemoryConsumer. Matches the SPI's defining verb (shrink()), and forms a clean three-way distinction from the two existing traits — no overlap with either. WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more piece of evidence: "shrinkable" is already a first-class concept in the API surface — getShrinkableMemoryBytes is the cheap snapshot the framework uses to skip consumers with nothing to give back. A consumer that always returns 0 is effectively un-shrinkable. So the trait name ShrinkableMemoryConsumer just makes explicit what the method names already imply.


/**
* Registry key and log identifier; MUST be non-empty and JVM-unique. Defaults to
* `getClass.getSimpleName`; override for anonymous classes (where the default is "")
* or when multiple instances of the same class may coexist.
*/
def name: String = getClass.getSimpleName

/** Memory mode this consumer manages; [[shrink]] is only invoked when modes match. */
def memoryMode: MemoryMode = MemoryMode.OFF_HEAP

/**
* Cheap snapshot of bytes currently releasable via [[shrink]]; used to skip empty
* consumers and order candidates largest-first. Non-negative; 0 means nothing to
* release right now.
*/
def getShrinkableMemoryBytes: Long

/**
* Synchronously release approximately `numBytes`. See class scaladoc for the full
* contract.
*
* @return actual bytes released, >= 0. Framework deducts this value from the storage pool.
*/
def shrink(numBytes: Long): Long
}
83 changes: 83 additions & 0 deletions core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package org.apache.spark.memory

import javax.annotation.concurrent.GuardedBy

import scala.util.control.NonFatal

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config._
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.memory.MemoryStore
Expand Down Expand Up @@ -65,6 +68,9 @@ private[spark] abstract class MemoryManager(
offHeapExecutionMemoryPool.incrementPoolSize(maxOffHeapMemory - offHeapStorageMemory)
offHeapStorageMemoryPool.incrementPoolSize(offHeapStorageMemory)

private val managedConsumerEnabled = conf.get(MANAGED_CONSUMER_ENABLED)
private val shrinkWarnThresholdMs = conf.get(MANAGED_CONSUMER_SHRINK_WARN_THRESHOLD_MS)

/**
* Total available on heap memory for storage, in bytes. This amount can vary over time,
* depending on the MemoryManager implementation.
Expand Down Expand Up @@ -105,6 +111,76 @@ private[spark] abstract class MemoryManager(
*/
def acquireUnrollMemory(blockId: BlockId, numBytes: Long, memoryMode: MemoryMode): Boolean

/**
* Acquire `numBytes` of storage memory on behalf of `self`. Bytes are added to the storage
* pool but never enter [[MemoryStore]]'s `entries` map. `self` is excluded by reference
* equality from its own shrink-candidate round.
*/
def acquireStorageMemory(
self: ManagedConsumer,
numBytes: Long,
memoryMode: MemoryMode): Boolean

/**
* Snapshot of [[ManagedConsumer]]s able to free `memoryMode` memory, filtered to those
* reporting positive shrinkable bytes and ordered largest-first. Caller MUST hold the
* [[MemoryManager]] monitor while invoking `shrink` on the result. Default: empty (so
* non-[[UnifiedMemoryManager]] backends disable the integration).
*/
private[spark] def getShrinkableConsumers(
memoryMode: MemoryMode): Iterable[ManagedConsumer] = Iterable.empty

/**
* Ask registered [[ManagedConsumer]]s to release up to `requested` bytes of `memoryMode`
* storage; returns the growth in `pool.memoryFree` over the call. The framework deducts
* each `shrink` return value from the storage pool, so consumers MUST NOT call
* [[releaseStorageMemory]] from inside `shrink`. Caller MUST hold `this` monitor.
* Returns 0 if the SPI is disabled or `requested <= 0`.
*
* @param exclude caller's own consumer, if any, to skip (compared by `eq`).
*/
private[memory] final def shrinkExternal(
requested: Long,
memoryMode: MemoryMode,
exclude: Option[ManagedConsumer] = None): Long = {
if (!managedConsumerEnabled || requested <= 0L) return 0L
val pool = memoryMode match {
case MemoryMode.ON_HEAP => onHeapStorageMemoryPool
case MemoryMode.OFF_HEAP => offHeapStorageMemoryPool
}
val freedAtStart = pool.memoryFree
val candidates = getShrinkableConsumers(memoryMode).iterator
.filterNot(c => exclude.exists(_ eq c))
var stillNeeded = requested
while (candidates.hasNext && stillNeeded > 0L) {
val c = candidates.next()
val (released, elapsedMs) = Utils.timeTakenMs {
try {
c.shrink(stillNeeded)
} catch {
case NonFatal(t) =>
logWarning(log"ManagedConsumer ${MDC(OBJECT_ID, MemoryManager.consumerLogName(c))}" +
log" threw from shrink(); treating as 0 release: ${MDC(ERROR, t.getMessage)}", t)
0L
}
}
require(released >= 0L,
s"ManagedConsumer ${MemoryManager.consumerLogName(c)} returned negative bytes from " +
s"shrink(): $released")
if (released > 0L) {
pool.releaseMemory(released)
}
if (elapsedMs > shrinkWarnThresholdMs) {
logWarning(log"ManagedConsumer ${MDC(OBJECT_ID, MemoryManager.consumerLogName(c))} took" +
log" ${MDC(TIME, elapsedMs)}ms to shrink (warn threshold " +
log"${MDC(THRESHOLD, shrinkWarnThresholdMs)}ms); MemoryManager monitor was " +
log"held throughout - consider smaller shrink requests or async preparation")
}
stillNeeded -= released
}
math.max(0L, pool.memoryFree - freedAtStart)
}

/**
* Try to acquire up to `numBytes` of execution memory for the current task and return the
* number of bytes obtained, or 0 if none can be allocated.
Expand Down Expand Up @@ -280,3 +356,10 @@ private[spark] abstract class MemoryManager(
}
}
}

private[memory] object MemoryManager {
private[memory] def consumerLogName(c: ManagedConsumer): String = {
val n = if (c.name != null) c.name else ""
if (n.nonEmpty) n else c.getClass.getName
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ private[memory] class StorageMemoryPool(
acquireMemory(blockId, numBytes, numBytesToFree)
}

/**
* Acquire `numBytes` for a [[ManagedConsumer]]: external bytes that never enter
* [[memoryStore]]'s `entries`, falling back to LRU eviction for any deficit. Caller is
* responsible for [[MemoryManager.shrinkExternal]] BEFORE this call; self-exclusion is
* handled in [[MemoryManager.acquireStorageMemory(self:ManagedConsumer,*]].
*/
def acquireMemoryForManagedConsumer(numBytes: Long): Boolean = lock.synchronized {
acquireMemoryInternal(None, numBytes, math.max(0L, numBytes - memoryFree))
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the API have to be bridged with storage memory pool?

IIUC Gluten demands a global memory area in the UMM that is not accounted to particular tasks. Would it be simpler to start from the unmanaged memory API added in #51778?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the look! Two points:

1. UnmanagedMemoryConsumer is not real-time — wrong shape for caches.

UMC is purely pull-mode: getMemBytesUsed: Long is polled every spark.memory.unmanagedMemoryPollingInterval (default 0s = disabled) and subtracted from effectiveMaxMemory. Readings lag real usage, and there is no callback to ask the consumer to release. External caches need to evict synchronously under pressure and grow synchronously when storage is idle — interval-bounded staleness can't deliver that.

2. The storage pool is the master ledger — must bridge into it.

MemoryStore, execution borrow-from-storage, and maybeGrowExecutionPool all arbitrate against storage pool accounting. To share spark.memory.offHeap.size with MemoryStore and participate in the same borrow/spill/evict ordering, an external cache has to be visible in the pool, not "outside it" via UMC subtraction.

On "global, not per-task" — agreed, that's what this SPI delivers: ManagedConsumer is registered on UnifiedMemoryManager (executor singleton), no per-task accounting.

WDYT?


/**
* Acquire N bytes of storage memory for the given block, evicting existing ones if necessary.
*
Expand All @@ -82,15 +92,21 @@ private[memory] class StorageMemoryPool(
blockId: BlockId,
numBytesToAcquire: Long,
numBytesToFree: Long): Boolean = lock.synchronized {
acquireMemoryInternal(Some(blockId), numBytesToAcquire, numBytesToFree)
}

private def acquireMemoryInternal(
blockId: Option[BlockId],
numBytesToAcquire: Long,
numBytesToFree: Long): Boolean = {
assert(numBytesToAcquire >= 0)
assert(numBytesToFree >= 0)
assert(memoryUsed <= poolSize)
if (numBytesToFree > 0) {
memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode)
memoryStore.evictBlocksToFreeSpace(blockId, numBytesToFree, memoryMode)
}
// NOTE: If the memory store evicts blocks, then those evictions will synchronously call
// back into this StorageMemoryPool in order to free memory. Therefore, these variables
// should have been updated.
// NOTE: If the memory store evicts blocks, those evictions synchronously call back
// into this StorageMemoryPool to free memory, so _memoryUsed is already updated.
val enoughMemory = numBytesToAcquire <= memoryFree
if (enoughMemory) {
_memoryUsed += numBytesToAcquire
Expand Down
Loading