diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 86e5422a85515..5ab4d74b6dfb1 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -512,6 +512,30 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("0s") + private[spark] val MANAGED_CONSUMER_ENABLED = + ConfigBuilder("spark.memory.managedConsumer.enabled") + .doc("If true, UnifiedMemoryManager will consult registered ManagedConsumer " + + "instances (via their shrink() method) before falling back to evicting " + + "internal cached blocks or shrinking the storage pool for execution. This " + + "ordering protects user-explicit persist() blocks from being dropped before " + + "best-effort external caches (e.g. native columnar caches like Velox " + + "AsyncDataCache exposed through Gluten).") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .version("4.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val MANAGED_CONSUMER_SHRINK_WARN_THRESHOLD_MS = + ConfigBuilder("spark.memory.managedConsumer.shrinkWarnThresholdMs") + .doc("If a ManagedConsumer.shrink() call takes longer than this many " + + "milliseconds, log a warning. Because shrink() runs inside the MemoryManager " + + "monitor, long shrink calls block other acquisition requests; this threshold " + + "helps surface slow implementations.") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("100ms") + private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD = ConfigBuilder("spark.storage.unrollMemoryThreshold") .doc("Initial memory to request before unrolling any block") diff --git a/core/src/main/scala/org/apache/spark/memory/ManagedConsumer.scala b/core/src/main/scala/org/apache/spark/memory/ManagedConsumer.scala new file mode 100644 index 0000000000000..08b86b8688779 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/ManagedConsumer.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +/** + * Storage-memory counterpart of [[org.apache.spark.memory.MemoryConsumer]]: holds bytes + * acquired via [[MemoryManager.acquireStorageMemory]] and synchronously releases them on + * Spark's request. Typical implementor: a native off-heap cache (e.g. Velox AsyncDataCache + * via Gluten) sharing `spark.memory.offHeap.size` with Spark's MemoryStore. + * + * == Contract == + * - [[name]] is the registry key (JVM-unique; ON_HEAP and OFF_HEAP share one namespace). + * Use the SAME instance for register / acquire / unregister so identity-based + * self-exclusion works during shrink rounds. + * - A component that also implements [[UnmanagedMemoryConsumer]] MUST NOT report the same + * bytes through both APIs -- they would be subtracted twice from `effectiveMaxMemory`. + * - `MemoryManager.shrinkExternal` owns storage-pool accounting: it deducts exactly + * `shrink`'s return value from the pool. Implementations MUST NOT call + * [[MemoryManager.releaseStorageMemory]] from inside [[shrink]]. + * - [[shrink]] runs inside the `MemoryManager` monitor; it MUST NOT cycle back into + * `MemoryStore.{putBytes, remove, evictBlocksToFreeSpace}` (lock-order cycle on + * `MemoryStore.entries`) and SHOULD return within + * `spark.memory.managedConsumer.shrinkWarnThresholdMs` (default 100ms) to avoid + * blocking other acquisitions. + * - [[shrink]] MUST be synchronous (claimed bytes reclaimable on return). Over-release + * and partial release are fine; negative return is a contract violation. Exceptions + * are caught and treated as 0-byte release. + */ +trait ManagedConsumer { + + /** + * Registry key and log identifier; MUST be non-empty and JVM-unique. Defaults to + * `getClass.getSimpleName`; override for anonymous classes (where the default is "") + * or when multiple instances of the same class may coexist. + */ + def name: String = getClass.getSimpleName + + /** Memory mode this consumer manages; [[shrink]] is only invoked when modes match. */ + def memoryMode: MemoryMode = MemoryMode.OFF_HEAP + + /** + * Cheap snapshot of bytes currently releasable via [[shrink]]; used to skip empty + * consumers and order candidates largest-first. Non-negative; 0 means nothing to + * release right now. + */ + def getShrinkableMemoryBytes: Long + + /** + * Synchronously release approximately `numBytes`. See class scaladoc for the full + * contract. + * + * @return actual bytes released, >= 0. Framework deducts this value from the storage pool. + */ + def shrink(numBytes: Long): Long +} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 639b82b6080b3..0a85e3734d21f 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -19,8 +19,11 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy +import scala.util.control.NonFatal + import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys._ import org.apache.spark.internal.config._ import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore @@ -65,6 +68,9 @@ private[spark] abstract class MemoryManager( offHeapExecutionMemoryPool.incrementPoolSize(maxOffHeapMemory - offHeapStorageMemory) offHeapStorageMemoryPool.incrementPoolSize(offHeapStorageMemory) + private val managedConsumerEnabled = conf.get(MANAGED_CONSUMER_ENABLED) + private val shrinkWarnThresholdMs = conf.get(MANAGED_CONSUMER_SHRINK_WARN_THRESHOLD_MS) + /** * Total available on heap memory for storage, in bytes. This amount can vary over time, * depending on the MemoryManager implementation. @@ -105,6 +111,76 @@ private[spark] abstract class MemoryManager( */ def acquireUnrollMemory(blockId: BlockId, numBytes: Long, memoryMode: MemoryMode): Boolean + /** + * Acquire `numBytes` of storage memory on behalf of `self`. Bytes are added to the storage + * pool but never enter [[MemoryStore]]'s `entries` map. `self` is excluded by reference + * equality from its own shrink-candidate round. + */ + def acquireStorageMemory( + self: ManagedConsumer, + numBytes: Long, + memoryMode: MemoryMode): Boolean + + /** + * Snapshot of [[ManagedConsumer]]s able to free `memoryMode` memory, filtered to those + * reporting positive shrinkable bytes and ordered largest-first. Caller MUST hold the + * [[MemoryManager]] monitor while invoking `shrink` on the result. Default: empty (so + * non-[[UnifiedMemoryManager]] backends disable the integration). + */ + private[spark] def getShrinkableConsumers( + memoryMode: MemoryMode): Iterable[ManagedConsumer] = Iterable.empty + + /** + * Ask registered [[ManagedConsumer]]s to release up to `requested` bytes of `memoryMode` + * storage; returns the growth in `pool.memoryFree` over the call. The framework deducts + * each `shrink` return value from the storage pool, so consumers MUST NOT call + * [[releaseStorageMemory]] from inside `shrink`. Caller MUST hold `this` monitor. + * Returns 0 if the SPI is disabled or `requested <= 0`. + * + * @param exclude caller's own consumer, if any, to skip (compared by `eq`). + */ + private[memory] final def shrinkExternal( + requested: Long, + memoryMode: MemoryMode, + exclude: Option[ManagedConsumer] = None): Long = { + if (!managedConsumerEnabled || requested <= 0L) return 0L + val pool = memoryMode match { + case MemoryMode.ON_HEAP => onHeapStorageMemoryPool + case MemoryMode.OFF_HEAP => offHeapStorageMemoryPool + } + val freedAtStart = pool.memoryFree + val candidates = getShrinkableConsumers(memoryMode).iterator + .filterNot(c => exclude.exists(_ eq c)) + var stillNeeded = requested + while (candidates.hasNext && stillNeeded > 0L) { + val c = candidates.next() + val (released, elapsedMs) = Utils.timeTakenMs { + try { + c.shrink(stillNeeded) + } catch { + case NonFatal(t) => + logWarning(log"ManagedConsumer ${MDC(OBJECT_ID, MemoryManager.consumerLogName(c))}" + + log" threw from shrink(); treating as 0 release: ${MDC(ERROR, t.getMessage)}", t) + 0L + } + } + require(released >= 0L, + s"ManagedConsumer ${MemoryManager.consumerLogName(c)} returned negative bytes from " + + s"shrink(): $released") + if (released > 0L) { + pool.releaseMemory(released) + } + if (elapsedMs > shrinkWarnThresholdMs) { + logWarning(log"ManagedConsumer ${MDC(OBJECT_ID, MemoryManager.consumerLogName(c))} took" + + log" ${MDC(TIME, elapsedMs)}ms to shrink (warn threshold " + + log"${MDC(THRESHOLD, shrinkWarnThresholdMs)}ms); MemoryManager monitor was " + + log"held throughout - consider smaller shrink requests or async preparation") + } + stillNeeded -= released + } + math.max(0L, pool.memoryFree - freedAtStart) + } + /** * Try to acquire up to `numBytes` of execution memory for the current task and return the * number of bytes obtained, or 0 if none can be allocated. @@ -280,3 +356,10 @@ private[spark] abstract class MemoryManager( } } } + +private[memory] object MemoryManager { + private[memory] def consumerLogName(c: ManagedConsumer): String = { + val n = if (c.name != null) c.name else "" + if (n.nonEmpty) n else c.getClass.getName + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 0f15254f3a080..228abf46a20ad 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -70,6 +70,16 @@ private[memory] class StorageMemoryPool( acquireMemory(blockId, numBytes, numBytesToFree) } + /** + * Acquire `numBytes` for a [[ManagedConsumer]]: external bytes that never enter + * [[memoryStore]]'s `entries`, falling back to LRU eviction for any deficit. Caller is + * responsible for [[MemoryManager.shrinkExternal]] BEFORE this call; self-exclusion is + * handled in [[MemoryManager.acquireStorageMemory(self:ManagedConsumer,*]]. + */ + def acquireMemoryForManagedConsumer(numBytes: Long): Boolean = lock.synchronized { + acquireMemoryInternal(None, numBytes, math.max(0L, numBytes - memoryFree)) + } + /** * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. * @@ -82,15 +92,21 @@ private[memory] class StorageMemoryPool( blockId: BlockId, numBytesToAcquire: Long, numBytesToFree: Long): Boolean = lock.synchronized { + acquireMemoryInternal(Some(blockId), numBytesToAcquire, numBytesToFree) + } + + private def acquireMemoryInternal( + blockId: Option[BlockId], + numBytesToAcquire: Long, + numBytesToFree: Long): Boolean = { assert(numBytesToAcquire >= 0) assert(numBytesToFree >= 0) assert(memoryUsed <= poolSize) if (numBytesToFree > 0) { - memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode) + memoryStore.evictBlocksToFreeSpace(blockId, numBytesToFree, memoryMode) } - // NOTE: If the memory store evicts blocks, then those evictions will synchronously call - // back into this StorageMemoryPool in order to free memory. Therefore, these variables - // should have been updated. + // NOTE: If the memory store evicts blocks, those evictions synchronously call back + // into this StorageMemoryPool to free memory, so _memoryUsed is already updated. val enoughMemory = numBytesToAcquire <= memoryFree if (enoughMemory) { _memoryUsed += numBytesToAcquire diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 6b278c47f32f1..4ea168138615a 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -159,17 +159,18 @@ private[spark] class UnifiedMemoryManager( */ def maybeGrowExecutionPool(extraMemoryNeeded: Long): Unit = { if (extraMemoryNeeded > 0) { - // There is not enough free memory in the execution pool, so try to reclaim memory from - // storage. We can reclaim any free memory from the storage pool. If the storage pool - // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim - // the memory that storage has borrowed from execution. + // Compute the reclaim cap BEFORE asking externals to shrink, otherwise shrunk bytes + // would let execution claim into the protected storage region. val memoryReclaimableFromStorage = math.max( storagePool.memoryFree, storagePool.poolSize - storageRegionSize) + val targetReclaim = math.min(extraMemoryNeeded, memoryReclaimableFromStorage) + val shrinkNeeded = math.max(0L, targetReclaim - storagePool.memoryFree) + if (shrinkNeeded > 0L) { + shrinkExternal(shrinkNeeded, memoryMode) + } if (memoryReclaimableFromStorage > 0) { - // Only reclaim as much space as is necessary and available: - val spaceToReclaim = storagePool.freeSpaceToShrinkPool( - math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + val spaceToReclaim = storagePool.freeSpaceToShrinkPool(targetReclaim) storagePool.decrementPoolSize(spaceToReclaim) executionPool.incrementPoolSize(spaceToReclaim) } @@ -206,9 +207,63 @@ private[spark] class UnifiedMemoryManager( override def acquireStorageMemory( blockId: BlockId, numBytes: Long, - memoryMode: MemoryMode): Boolean = synchronized { + memoryMode: MemoryMode): Boolean = { + acquireStorageMemoryUnified( + numBytes, + memoryMode, + exclude = None, + (effective, unmanaged) => + logInfo(log"Will not store ${MDC(BLOCK_ID, blockId)} as the required space" + + log" (${MDC(NUM_BYTES, numBytes)} bytes) exceeds our" + + log" memory limit (${MDC(NUM_BYTES_MAX, effective)} bytes)" + + (if (unmanaged > 0) log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanaged)} bytes)" + else log"")), + (pool, n) => pool.acquireMemory(blockId, n)) + } + + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + memoryMode: MemoryMode): Boolean = { + acquireStorageMemory(blockId, numBytes, memoryMode) + } + + override def acquireStorageMemory( + self: ManagedConsumer, + numBytes: Long, + memoryMode: MemoryMode): Boolean = { + require(self != null, "self ManagedConsumer must not be null") + require(self.memoryMode == memoryMode, + s"requested memoryMode=$memoryMode does not match self.memoryMode=${self.memoryMode}; " + + "a ManagedConsumer may only acquire memory in the mode it manages") + acquireStorageMemoryUnified( + numBytes, + memoryMode, + exclude = Some(self), + (effective, unmanaged) => + logInfo(log"Will not grant external storage memory request of " + + log"${MDC(NUM_BYTES, numBytes)} bytes as it exceeds the " + + log"effective limit (${MDC(NUM_BYTES_MAX, effective)} bytes)" + + (if (unmanaged > 0) log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanaged)} bytes)" + else log"")), + (pool, n) => pool.acquireMemoryForManagedConsumer(n)) + } + + /** + * Shared body for the two [[acquireStorageMemory]] overloads. `logFailFast` and + * `acquireFromPool` capture the only per-overload differences. Order is: fail-fast on + * `effectiveMaxMemory`, borrow free execution memory, [[shrinkExternal]] for any remaining + * deficit, then delegate to the pool. Borrow runs before shrink because it is free + * (no eviction). + */ + private def acquireStorageMemoryUnified( + numBytes: Long, + memoryMode: MemoryMode, + exclude: Option[ManagedConsumer], + logFailFast: (Long, Long) => Unit, + acquireFromPool: (StorageMemoryPool, Long) => Boolean): Boolean = synchronized { assertInvariants() - assert(numBytes >= 0) + require(numBytes >= 0, s"numBytes must be >= 0, got $numBytes") val (executionPool, storagePool, maxMemory) = memoryMode match { case MemoryMode.ON_HEAP => ( onHeapExecutionMemoryPool, @@ -220,39 +275,29 @@ private[spark] class UnifiedMemoryManager( maxOffHeapStorageMemory) } - // Factor in unmanaged memory usage for the specific memory mode val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode) val effectiveMaxMemory = math.max(0L, maxMemory - unmanagedMemory) if (numBytes > effectiveMaxMemory) { - // Fail fast if the block simply won't fit - logInfo(log"Will not store ${MDC(BLOCK_ID, blockId)} as the required space" + - log" (${MDC(NUM_BYTES, numBytes)} bytes) exceeds our" + - log" memory limit (${MDC(NUM_BYTES_MAX, effectiveMaxMemory)} bytes)" + - (if (unmanagedMemory > 0) { - log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanagedMemory)} bytes)" - } else { - log"" - })) + logFailFast(effectiveMaxMemory, unmanagedMemory) return false } if (numBytes > storagePool.memoryFree) { - // There is not enough free memory in the storage pool, so try to borrow free memory from - // the execution pool. val memoryBorrowedFromExecution = Math.min(executionPool.memoryFree, numBytes - storagePool.memoryFree) executionPool.decrementPoolSize(memoryBorrowedFromExecution) storagePool.incrementPoolSize(memoryBorrowedFromExecution) } - storagePool.acquireMemory(blockId, numBytes) + val deficitAfterBorrow = math.max(0L, numBytes - storagePool.memoryFree) + if (deficitAfterBorrow > 0L) { + shrinkExternal(deficitAfterBorrow, memoryMode, exclude) + } + acquireFromPool(storagePool, numBytes) } - override def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - memoryMode: MemoryMode): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, memoryMode) - } + override private[spark] def getShrinkableConsumers( + memoryMode: MemoryMode): Iterable[ManagedConsumer] = + UnifiedMemoryManager.getShrinkableConsumers(memoryMode) } object UnifiedMemoryManager extends Logging { @@ -294,6 +339,26 @@ object UnifiedMemoryManager extends Logging { unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = { val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId unmanagedMemoryConsumers.put(id, unmanagedMemoryConsumer) + unmanagedMemoryConsumer match { + case mc: ManagedConsumer if isRegisteredManaged(mc) => warnCrossRegistered(mc) + case _ => + } + } + + private def isRegisteredManaged(mc: ManagedConsumer): Boolean = { + val n = mc.name + n != null && n.nonEmpty && (managedConsumers.get(n) eq mc) + } + + private def isRegisteredUnmanaged(umc: UnmanagedMemoryConsumer): Boolean = { + val id = umc.unmanagedMemoryConsumerId + unmanagedMemoryConsumers.get(id) eq umc + } + + private def warnCrossRegistered(mc: ManagedConsumer): Unit = { + logWarning(log"Object ${MDC(LogKeys.OBJECT_ID, MemoryManager.consumerLogName(mc))} " + + log"is registered as BOTH ManagedConsumer and UnmanagedMemoryConsumer; the same " + + log"bytes will be subtracted twice from effectiveMaxMemory. Pick exactly one SPI.") } /** @@ -341,6 +406,95 @@ object UnifiedMemoryManager extends Logging { unmanagedOffHeapUsed.set(0L) } + // -- Managed consumer registry -- + + private val managedConsumers = + new ConcurrentHashMap[String, ManagedConsumer]() + + /** + * Register a [[ManagedConsumer]] as a candidate for [[MemoryManager.shrinkExternal]] + * (requires `spark.memory.managedConsumer.enabled=true`). + * + * The registry is JVM-global and does NOT propagate across the cluster. Keyed by + * [[ManagedConsumer.name]] (ON_HEAP and OFF_HEAP share one namespace); re-registering + * the SAME instance is idempotent, a DIFFERENT instance under an already-taken name + * fails. Callers MUST invoke [[unregisterManagedConsumer]] on shutdown -- the registry + * holds strong references. + */ + def registerManagedConsumer(consumer: ManagedConsumer): Unit = { + require(consumer != null, "ManagedConsumer must not be null") + val n = consumer.name + require(n != null && n.nonEmpty, + "ManagedConsumer.name must be non-empty (used as the registry key)") + val prior = managedConsumers.putIfAbsent(n, consumer) + if (prior != null && (prior ne consumer)) { + throw new IllegalArgumentException( + s"A different ManagedConsumer is already registered under name '$n'. " + + s"Existing: ${prior.getClass.getName}, new: ${consumer.getClass.getName}. " + + "Names must be unique within this JVM (ON_HEAP and OFF_HEAP share one namespace).") + } + consumer match { + case umc: UnmanagedMemoryConsumer if isRegisteredUnmanaged(umc) => + warnCrossRegistered(consumer) + case _ => + } + } + + /** + * Unregister a [[ManagedConsumer]]. Removes only when the registered instance is the + * one passed here (a stale unregister cannot evict a later re-registration under the + * same name). No-op for null / empty-name / not-registered. + */ + private[spark] def unregisterManagedConsumer(consumer: ManagedConsumer): Unit = { + if (consumer == null) return + val n = consumer.name + if (n != null && n.nonEmpty) { + managedConsumers.remove(n, consumer) + } + } + + /** + * Snapshot of registered managed consumers for `memoryMode`, filtered to those reporting + * positive [[ManagedConsumer.getShrinkableMemoryBytes]], sorted DESC (tie-break unspecified). + * Iteration is weakly-consistent; consumers that throw or return negative are coerced to 0 + * and filtered out. + */ + private[spark] def getShrinkableConsumers( + memoryMode: MemoryMode): Iterable[ManagedConsumer] = { + if (managedConsumers.isEmpty) return Iterable.empty + def safeGetShrinkableBytes(c: ManagedConsumer): Long = { + try { + val b = c.getShrinkableMemoryBytes + if (b < 0L) { + logWarning(log"ManagedConsumer ${MDC(LogKeys.OBJECT_ID, + MemoryManager.consumerLogName(c))} returned negative " + + log"getShrinkableMemoryBytes=${MDC(LogKeys.NUM_BYTES, b)}; treating as 0") + 0L + } else { + b + } + } catch { + case NonFatal(t) => + logWarning(log"ManagedConsumer ${MDC(LogKeys.OBJECT_ID, + MemoryManager.consumerLogName(c))} threw from getShrinkableMemoryBytes; " + + log"treating as 0: ${MDC(LogKeys.ERROR, t.getMessage)}", t) + 0L + } + } + managedConsumers.values().asScala.iterator + .filter(_.memoryMode == memoryMode) + .map(c => (c, safeGetShrinkableBytes(c))) + .filter(_._2 > 0L) + .toSeq + .sortBy(-_._2) + .map(_._1) + } + + /** Test-only: clear all managed consumers. */ + private[spark] def clearManagedConsumers(): Unit = { + managedConsumers.clear() + } + // Shared polling infrastructure - only one polling thread per JVM @volatile private var unmanagedMemoryPoller: ScheduledExecutorService = _ diff --git a/core/src/test/scala/org/apache/spark/memory/ManagedConsumerSuite.scala b/core/src/test/scala/org/apache/spark/memory/ManagedConsumerSuite.scala new file mode 100644 index 0000000000000..03261fb52b11f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/ManagedConsumerSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import org.apache.spark.SparkFunSuite + +/** + * Trait-level compile and behavior contract tests for [[ManagedConsumer]]. End-to-end + * registry behavior (identity-based dedup, registration idempotency), self-exclusion via + * object identity, [[MemoryManager.shrinkExternal]] orchestration, and integration with + * [[UnifiedMemoryManager]] storage acquire / execution reclaim paths are covered in + * `UnifiedMemoryManagerSuite`. + */ +class ManagedConsumerSuite extends SparkFunSuite { + + test("a minimal ManagedConsumer compiles and exposes the SPI members") { + val consumer = new ManagedConsumer { + private var held = 1000L + override val memoryMode: MemoryMode = MemoryMode.OFF_HEAP + override def getShrinkableMemoryBytes: Long = held + override def shrink(numBytes: Long): Long = { + val released = math.min(numBytes, held) + held -= released + released + } + } + + // `name` has a default (getClass.getSimpleName); anonymous classes get "". + // MemoryManager.consumerLogName falls back to the FQ name in that case. + assert(consumer.name === "" || consumer.name.nonEmpty) + assert(MemoryManager.consumerLogName(consumer).nonEmpty) + assert(consumer.memoryMode === MemoryMode.OFF_HEAP) + assert(consumer.getShrinkableMemoryBytes === 1000L) + assert(consumer.shrink(300L) === 300L) + assert(consumer.getShrinkableMemoryBytes === 700L) + } + + test("named ManagedConsumer overrides `name` for logs") { + class NamedCache extends ManagedConsumer { + override val name: String = "MyVeloxCache:executor-7" + override val memoryMode: MemoryMode = MemoryMode.OFF_HEAP + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + } + val c = new NamedCache + assert(c.name === "MyVeloxCache:executor-7") + assert(MemoryManager.consumerLogName(c) === "MyVeloxCache:executor-7") + } + + test("ManagedConsumer is independent of UnmanagedMemoryConsumer") { + // The two traits are NOT in an inheritance relation. A consumer may implement either, + // both, or neither. Verifying the type system here documents that contract. + val pureManaged: ManagedConsumer = new ManagedConsumer { + override val memoryMode: MemoryMode = MemoryMode.OFF_HEAP + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + } + val pureUnmanaged: UnmanagedMemoryConsumer = new UnmanagedMemoryConsumer { + override val unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("PureUnmanaged", "k") + override val memoryMode: MemoryMode = MemoryMode.OFF_HEAP + override def getMemBytesUsed: Long = 0L + } + assert(!pureManaged.isInstanceOf[UnmanagedMemoryConsumer]) + assert(!pureUnmanaged.isInstanceOf[ManagedConsumer]) + + // A single component MAY implement both, with the mutual-exclusion contract from the + // scaladoc: getMemBytesUsed must NOT double-report bytes already accounted via + // acquireStorageMemory. + val both = new ManagedConsumer with UnmanagedMemoryConsumer { + override val unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("Both", "k") + override val memoryMode: MemoryMode = MemoryMode.OFF_HEAP + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + override def getMemBytesUsed: Long = 0L + } + assert(both.isInstanceOf[ManagedConsumer]) + assert(both.isInstanceOf[UnmanagedMemoryConsumer]) + } + + test("MemoryManager.getShrinkableConsumers default is empty for non-Unified backends") { + // Any MemoryManager that does not override getShrinkableConsumers (e.g., TestMemoryManager, + // alternative backends, future SPI implementations) must transparently disable the push-mode + // shrink integration. Without this default, MemoryManager.shrinkExternal would crash or + // silently miss a hard-coded UnifiedMemoryManager dependency. + val mm = new TestMemoryManager(new org.apache.spark.SparkConf(false)) + assert(mm.getShrinkableConsumers(MemoryMode.ON_HEAP).isEmpty) + assert(mm.getShrinkableConsumers(MemoryMode.OFF_HEAP).isEmpty) + } + + test("MemoryManager.consumerLogName falls back to FQ class name for empty `name`") { + val anon = new ManagedConsumer { + override val memoryMode: MemoryMode = MemoryMode.ON_HEAP + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + } + val resolved = MemoryManager.consumerLogName(anon) + assert(resolved.nonEmpty, "log name must never be empty (would lose context in WARN logs)") + // For anonymous classes getSimpleName is "" so the fallback is the FQ name. + if (anon.name.isEmpty) { + assert(resolved === anon.getClass.getName) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index fb41f1ab287f6..d1ee1daae8a32 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -95,6 +95,14 @@ class TestMemoryManager(conf: SparkConf) true } + override def acquireStorageMemory( + self: ManagedConsumer, + numBytes: Long, + memoryMode: MemoryMode): Boolean = { + require(numBytes >= 0) + true + } + override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = { require(numBytes >= 0) } diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 9f0e622b1d515..ad4518ad23245 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -627,4 +627,449 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes UnifiedMemoryManager.clearUnmanagedMemoryUsers() } } + + // -- ManagedConsumer (push-mode SPI, sibling of SPARK-53001 pull-mode UnmanagedMemoryConsumer) + + /** + * Mock that DOES NOT call [[MemoryManager.releaseStorageMemory]] in shrink. Useful only + * for registry-level tests where shrink() is never actually invoked. + */ + private class MockManagedConsumer( + bytesHeld: Long, + instanceKey: String = "instance-1", + mode: MemoryMode = MemoryMode.ON_HEAP, + bytesThrowsOnGet: Boolean = false, + bytesReturnsNegative: Boolean = false) extends ManagedConsumer { + override val name: String = s"MockManagedConsumer:$instanceKey" + override def memoryMode: MemoryMode = mode + override def getShrinkableMemoryBytes: Long = { + if (bytesThrowsOnGet) { + throw new RuntimeException("boom from getShrinkableMemoryBytes") + } else if (bytesReturnsNegative) { + -7L + } else { + bytesHeld + } + } + override def shrink(numBytes: Long): Long = 0L + } + + /** + * Mock that does NOT touch the storage pool itself: it merely tracks its own held-bytes + * counter and returns the would-be-released value from `shrink`. The framework + * ([[MemoryManager.shrinkExternal]]) is responsible for calling `pool.releaseMemory` on + * the returned value -- these tests validate that invariant end-to-end. Each instance + * acquires `initialBytes` via [[MemoryManager.acquireStorageMemory]] at construction so + * the pool has bytes to charge releases against. + */ + private class MockShrinker( + mm: MemoryManager, + initialBytes: Long, + mode: MemoryMode, + instanceKey: String, + shrinkBehavior: Long => Long = identity, + throwOnShrink: Option[() => Nothing] = None, + shrinkDelayMs: Long = 0L) extends ManagedConsumer { + @volatile private var heldBytes: Long = 0L + @volatile var shrinkCallCount: Int = 0 + + override val name: String = s"MockShrinker:$instanceKey" + override def memoryMode: MemoryMode = mode + override def getShrinkableMemoryBytes: Long = heldBytes + + if (initialBytes > 0L) { + require(mm.acquireStorageMemory(this, initialBytes, mode), + s"test setup failed: could not reserve $initialBytes for $instanceKey") + heldBytes = initialBytes + } + + def currentHeldBytes: Long = heldBytes + + override def shrink(numBytes: Long): Long = { + shrinkCallCount += 1 + if (shrinkDelayMs > 0) Thread.sleep(shrinkDelayMs) + throwOnShrink.foreach(t => t()) + val candidate = math.min(numBytes, heldBytes) + val toRelease = shrinkBehavior(candidate) + if (toRelease > 0) { + heldBytes -= toRelease + } + toRelease + } + } + + override def afterEach(): Unit = { + try super.afterEach() + finally UnifiedMemoryManager.clearManagedConsumers() + } + + // -- Registry-level tests + + test("registerManagedConsumer / unregisterManagedConsumer round-trip") { + val c1 = new MockManagedConsumer(100L, "k1") + val c2 = new MockManagedConsumer(200L, "k2") + UnifiedMemoryManager.registerManagedConsumer(c1) + UnifiedMemoryManager.registerManagedConsumer(c2) + assert(UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).size === 2) + + UnifiedMemoryManager.unregisterManagedConsumer(c1) + val remaining = UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).toSeq + assert(remaining.size === 1 && (remaining.head eq c2)) + + UnifiedMemoryManager.unregisterManagedConsumer(c2) + assert(UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).isEmpty) + } + + test("registerManagedConsumer is idempotent for the same instance (name dedup)") { + val c1 = new MockManagedConsumer(100L, "shared-key") + UnifiedMemoryManager.registerManagedConsumer(c1) + // Re-registering the *same* instance is idempotent. + UnifiedMemoryManager.registerManagedConsumer(c1) + assert(UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).size === 1) + } + + test("registerManagedConsumer tracks distinct instances under distinct names") { + val a = new MockManagedConsumer(100L, "a") + val b = new MockManagedConsumer(200L, "b") + UnifiedMemoryManager.registerManagedConsumer(a) + UnifiedMemoryManager.registerManagedConsumer(b) + val all = UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).toSeq + assert(all.size === 2) + assert(all.exists(_ eq a) && all.exists(_ eq b)) + } + + test("registerManagedConsumer rejects a different instance reusing the same name") { + val a = new MockManagedConsumer(100L, "dup") + val b = new MockManagedConsumer(200L, "dup") + UnifiedMemoryManager.registerManagedConsumer(a) + val ex = intercept[IllegalArgumentException] { + UnifiedMemoryManager.registerManagedConsumer(b) + } + assert(ex.getMessage.contains("'MockManagedConsumer:dup'")) + // Existing registration is preserved; the rejected instance is not in the registry. + val all = UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).toSeq + assert(all.size === 1 && (all.head eq a)) + } + + test("registerManagedConsumer rejects null and empty name") { + intercept[IllegalArgumentException] { + UnifiedMemoryManager.registerManagedConsumer(null) + } + val anonymous = new ManagedConsumer { + override val name: String = "" + override def memoryMode: MemoryMode = MemoryMode.ON_HEAP + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + } + intercept[IllegalArgumentException] { + UnifiedMemoryManager.registerManagedConsumer(anonymous) + } + } + + test("unregisterManagedConsumer with a stale name-collider is a safe no-op") { + val a = new MockManagedConsumer(100L, "name-collision") + val b = new MockManagedConsumer(200L, "name-collision") + UnifiedMemoryManager.registerManagedConsumer(a) + // Try to unregister 'b' (same name, different instance). Must NOT remove 'a'. + UnifiedMemoryManager.unregisterManagedConsumer(b) + val all = UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).toSeq + assert(all.size === 1 && (all.head eq a)) + } + + test("getShrinkableConsumers filters by memoryMode and sorts DESC, skipping zeros") { + val small = new MockManagedConsumer(100L, "small", MemoryMode.ON_HEAP) + val big = new MockManagedConsumer(500L, "big", MemoryMode.ON_HEAP) + val medium = new MockManagedConsumer(300L, "medium", MemoryMode.ON_HEAP) + val zero = new MockManagedConsumer(0L, "zero", MemoryMode.ON_HEAP) + val offHeap = new MockManagedConsumer(900L, "off-heap", MemoryMode.OFF_HEAP) + Seq(small, big, medium, zero, offHeap).foreach( + UnifiedMemoryManager.registerManagedConsumer) + + val onHeap = UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).toSeq + assert(onHeap === Seq(big, medium, small)) + val offHeapOnly = UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.OFF_HEAP).toSeq + assert(offHeapOnly === Seq(offHeap)) + } + + test("getShrinkableConsumers defensively coerces throwing or negative size to 0") { + val good = new MockManagedConsumer(100L, "good") + val throwing = new MockManagedConsumer(0L, "throws", bytesThrowsOnGet = true) + val negative = new MockManagedConsumer(0L, "negative", bytesReturnsNegative = true) + Seq(good, throwing, negative).foreach(UnifiedMemoryManager.registerManagedConsumer) + assert(UnifiedMemoryManager.getShrinkableConsumers(MemoryMode.ON_HEAP).toSeq === Seq(good)) + } + + // -- acquireStorageMemory(self: ManagedConsumer, ...) tests + + /** Probe consumer used only to provide a `self` reference for self-exclusion in tests + * that exercise the consumer-overload of `acquireStorageMemory` without needing a + * registered shrink candidate. */ + private def newProbeConsumer(key: String, mode: MemoryMode = MemoryMode.ON_HEAP) + : ManagedConsumer = new ManagedConsumer { + override val name: String = s"Probe:$key" + override val memoryMode: MemoryMode = mode + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + } + + test("acquireStorageMemory(self) grants and books bytes into the storage pool") { + val maxMemory = 1000L + val (mm, _) = makeThings(maxMemory) + val external = newProbeConsumer("external") + + assert(mm.storageMemoryUsed === 0L) + assert(mm.acquireStorageMemory(external, 600L, MemoryMode.ON_HEAP)) + assert(mm.storageMemoryUsed === 600L) + mm.releaseStorageMemory(600L, MemoryMode.ON_HEAP) + assert(mm.storageMemoryUsed === 0L) + } + + test("acquireStorageMemory(self) fails fast when request exceeds effective max") { + val maxMemory = 1000L + val (mm, _) = makeThings(maxMemory) + assert(!mm.acquireStorageMemory(newProbeConsumer("x"), maxMemory + 1, MemoryMode.ON_HEAP)) + assert(mm.storageMemoryUsed === 0L) + } + + test("acquireStorageMemory(self) can borrow free execution memory") { + val maxMemory = 1000L + val (mm, _) = makeThings(maxMemory) + // Storage region is 0.5 * 1000 = 500; a 700-byte request must borrow 200 from execution. + assert(mm.acquireStorageMemory(newProbeConsumer("x"), 700L, MemoryMode.ON_HEAP)) + assert(mm.storageMemoryUsed === 700L) + } + + // -- shrinkExternal orchestration via the storage acquire path + + test("default-off: registered consumers are not consulted on storage acquire") { + val (mm, _) = makeThings(1000L) + // Feature is OFF by default. + val c = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "ignored", + throwOnShrink = Some(() => throw new IllegalStateException( + "consumer must not be consulted when MANAGED_CONSUMER_ENABLED is false"))) + UnifiedMemoryManager.registerManagedConsumer(c) + // Request a tiny block that fits without eviction. + assert(mm.acquireStorageMemory(dummyBlock, 100L, MemoryMode.ON_HEAP)) + assert(c.shrinkCallCount === 0) + } + + private def makeMM(maxMemory: Long, enabled: Boolean): (UnifiedMemoryManager, MemoryStore) = { + val conf = new SparkConf() + .set(MEMORY_FRACTION, 1.0) + .set(TEST_MEMORY, maxMemory) + .set(MEMORY_OFFHEAP_ENABLED, false) + .set(MEMORY_STORAGE_FRACTION, storageFraction) + .set(MANAGED_CONSUMER_ENABLED, enabled) + val mm = UnifiedMemoryManager(conf, numCores = 1) + val ms = makeMemoryStore(mm) + (mm, ms) + } + + test("enabled: external consumer shrinks for storage acquire deficit") { + val (mm, _) = makeMM(1000L, enabled = true) + val c = new MockShrinker(mm, 600L, MemoryMode.ON_HEAP, "ext") + UnifiedMemoryManager.registerManagedConsumer(c) + // Storage used = 600 (c), free = 400. Request another 300 -> needs 0 shrink (fits in free). + assert(mm.acquireStorageMemory(dummyBlock, 300L, MemoryMode.ON_HEAP)) + assert(c.shrinkCallCount === 0, "no shrink: deficit is 0 after fits in free") + + // Request 200 more -> total request would be 200, free is 100 now, deficit 100 -> shrink. + assert(mm.acquireStorageMemory(dummyBlock, 200L, MemoryMode.ON_HEAP)) + assert(c.shrinkCallCount === 1) + assert(c.currentHeldBytes === 500L) + } + + test("enabled: largest consumer is consulted first, smaller skipped once deficit met") { + val (mm, _) = makeMM(2000L, enabled = true) + val small = new MockShrinker(mm, 100L, MemoryMode.ON_HEAP, "small") + val big = new MockShrinker(mm, 800L, MemoryMode.ON_HEAP, "big") + val medium = new MockShrinker(mm, 300L, MemoryMode.ON_HEAP, "medium") + Seq(small, big, medium).foreach(UnifiedMemoryManager.registerManagedConsumer) + // storageMemoryUsed = 1200, free in storage pool = max(2000*0.5 - used_in_region, 0). + // Acquire 1000 bytes (with borrow from execution): triggers shrink for any residual deficit. + assert(mm.acquireStorageMemory(dummyBlock, 1000L, MemoryMode.ON_HEAP)) + assert(big.shrinkCallCount === 1, "biggest consumer consulted first") + assert(medium.shrinkCallCount === 0, "medium not needed (big already covered the deficit)") + assert(small.shrinkCallCount === 0) + } + + test("enabled: shrink() RuntimeException is swallowed; treated as 0 release") { + val (mm, _) = makeMM(1000L, enabled = true) + val bomb = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "bomb", + throwOnShrink = Some(() => throw new RuntimeException("boom"))) + val backup = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "backup") + Seq(bomb, backup).foreach(UnifiedMemoryManager.registerManagedConsumer) + // Used 800; request 300 -> storage free = 200 (after borrow if any), deficit ~100. + // bomb is bigger and called first; throws, caught, treated as 0; backup picks up. + assert(mm.acquireStorageMemory(dummyBlock, 300L, MemoryMode.ON_HEAP)) + assert(bomb.shrinkCallCount === 1) + assert(backup.shrinkCallCount === 1, "next consumer must be consulted after a thrown shrink") + } + + test("enabled: shrink() returning negative triggers IllegalArgumentException") { + val (mm, _) = makeMM(1000L, enabled = true) + val c = new MockShrinker(mm, 800L, MemoryMode.ON_HEAP, "bad", + shrinkBehavior = _ => -1L) + UnifiedMemoryManager.registerManagedConsumer(c) + intercept[IllegalArgumentException] { + mm.acquireStorageMemory(dummyBlock, 400L, MemoryMode.ON_HEAP) + } + } + + test("enabled: framework owns pool accounting; shrink return value drives the loop") { + val (mm, _) = makeMM(1000L, enabled = true) + // The MockShrinker no longer touches the pool itself; only the framework calls + // pool.releaseMemory(). This test verifies the pool actually grows by the + // consumer-reported release. + val c = new MockShrinker(mm, 500L, MemoryMode.ON_HEAP, "framework-owned") + UnifiedMemoryManager.registerManagedConsumer(c) + // storageMemoryUsed = 500 after c.acquire. Request 200 more -> storage free = 500 + // (after borrow), already covers; no shrink needed. + val usedBefore = mm.storageMemoryUsed + assert(mm.acquireStorageMemory(dummyBlock, 200L, MemoryMode.ON_HEAP)) + assert(c.shrinkCallCount === 0, "no shrink expected when free covers the request") + assert(mm.storageMemoryUsed === usedBefore + 200L) + + // Now request 400 -> needs 100 from external. shrink returns 100; framework deducts 100. + assert(mm.acquireStorageMemory(dummyBlock, 400L, MemoryMode.ON_HEAP)) + assert(c.shrinkCallCount === 1) + assert(c.currentHeldBytes === 400L, + "MockShrinker.heldBytes must track its own view; framework released 100 from pool") + assert(mm.storageMemoryUsed === usedBefore + 200L + 400L - 100L, + "framework must deduct exactly shrink()'s return value from storageMemoryUsed") + } + + test("enabled: self-exclusion skips the caller's own consumer") { + val (mm, _) = makeMM(1000L, enabled = true) + val self = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "self") + val other = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "other") + Seq(self, other).foreach(UnifiedMemoryManager.registerManagedConsumer) + // self passes its own reference to acquire; self must NOT be asked to shrink. + assert(mm.acquireStorageMemory(self, 300L, MemoryMode.ON_HEAP)) + assert(self.shrinkCallCount === 0, "caller must be excluded from its own shrink candidates") + assert(other.shrinkCallCount === 1, "the other consumer must be consulted") + } + + test("enabled: consumers of a different MemoryMode are not consulted") { + val (mm, _) = makeMM(1000L, enabled = true) + val onHeap = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "on-heap") + val offHeap = new MockShrinker(mm, 0L, MemoryMode.OFF_HEAP, "off-heap", + throwOnShrink = Some(() => throw new IllegalStateException( + "OFF_HEAP consumer must not be called for an ON_HEAP acquire"))) + Seq(onHeap, offHeap).foreach(UnifiedMemoryManager.registerManagedConsumer) + assert(mm.acquireStorageMemory(dummyBlock, 700L, MemoryMode.ON_HEAP)) + assert(offHeap.shrinkCallCount === 0) + } + + // -- shrinkExternal orchestration via the execution reclaim path + + test("enabled: maybeGrowExecutionPool shrinks externals for the reclaim deficit") { + val (mm, _) = makeMM(1000L, enabled = true) + // Acquire 700 via external; storage region=500, so 200 is borrowed from execution. + val c = new MockShrinker(mm, 700L, MemoryMode.ON_HEAP, "ext") + UnifiedMemoryManager.registerManagedConsumer(c) + // executionPool.memoryFree = 300, storagePool.memoryFree = 0, storagePool.poolSize = 700. + // Acquire 800 execution -> extra = 800 - 300 = 500. + // memoryReclaimable = max(memoryFree=0, poolSize-region = 700-500 = 200) = 200. + // target = min(500, 200) = 200. shrinkNeeded = max(0, 200-0) = 200. + // c.shrink(200) releases 200 -> execution grants 300 + 200 = 500. + assert(mm.acquireExecutionMemory(800L, 0L, MemoryMode.ON_HEAP) === 500L) + assert(c.shrinkCallCount === 1) + assert(c.currentHeldBytes === 500L) + } + + test("enabled: storage region protection preserved with external shrink (pre-shrink cap)") { + val (mm, _) = makeMM(1000L, enabled = true) + // Externals together hold 700: region=500, borrowed=200. Only 200 should be reclaimable + // by execution even though externals collectively hold more (the rest is in the + // protected storage region). + val a = new MockShrinker(mm, 400L, MemoryMode.ON_HEAP, "a") + val b = new MockShrinker(mm, 300L, MemoryMode.ON_HEAP, "b") + Seq(a, b).foreach(UnifiedMemoryManager.registerManagedConsumer) + assert(mm.acquireExecutionMemory(1000L, 0L, MemoryMode.ON_HEAP) === 500L, + "execution-free (300) + reclaim cap from borrowed-portion (200) = 500") + val totalReleased = (400L - a.currentHeldBytes) + (300L - b.currentHeldBytes) + assert(totalReleased <= 200L, + s"storage region protection violated: externals released $totalReleased bytes > 200 cap") + } + + test("enabled: maybeGrowExecutionPool skips shrink when memoryFree alone covers target") { + val (mm, _) = makeMM(1000L, enabled = true) + val c = new MockShrinker(mm, 100L, MemoryMode.ON_HEAP, "ext-small", + throwOnShrink = Some(() => throw new IllegalStateException( + "must not be called when memoryFree already covers the reclaim target"))) + UnifiedMemoryManager.registerManagedConsumer(c) + // Used 100; pool size grew (storage borrowed by external acquire). memoryFree = 400. + // Execution acquire 200 -> all from execution-free; no reclaim needed. + assert(mm.acquireExecutionMemory(200L, 0L, MemoryMode.ON_HEAP) === 200L) + assert(c.shrinkCallCount === 0) + } + + // -- Argument validation on acquireStorageMemory(self, ...) + + test("acquireStorageMemory(self) rejects null self with IllegalArgumentException") { + val (mm, _) = makeThings(1000L) + intercept[IllegalArgumentException] { + mm.acquireStorageMemory(null.asInstanceOf[ManagedConsumer], 100L, MemoryMode.ON_HEAP) + } + } + + test("acquireStorageMemory(self) rejects mismatched memoryMode with IllegalArgumentException") { + val (mm, _) = makeThings(1000L) + val onHeap = newProbeConsumer("on", MemoryMode.ON_HEAP) + intercept[IllegalArgumentException] { + mm.acquireStorageMemory(onHeap, 100L, MemoryMode.OFF_HEAP) + } + } + + test("acquireStorageMemory(self) rejects negative numBytes with IllegalArgumentException") { + val (mm, _) = makeThings(1000L) + intercept[IllegalArgumentException] { + mm.acquireStorageMemory(newProbeConsumer("p"), -1L, MemoryMode.ON_HEAP) + } + } + + // -- Cross-SPI mutual-exclusion guard (warn, not enforce) + + test("registering the same object as both Managed and Unmanaged logs a WARN") { + class Both(uniqueName: String) extends ManagedConsumer with UnmanagedMemoryConsumer { + override val name: String = uniqueName + override def memoryMode: MemoryMode = MemoryMode.ON_HEAP + override def getShrinkableMemoryBytes: Long = 0L + override def shrink(numBytes: Long): Long = 0L + override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = + UnmanagedMemoryConsumerId("TEST", uniqueName) + override def getMemBytesUsed: Long = 0L + } + + val both = new Both("cross-spi-1") + val appender = new LogAppender("cross-SPI warn", maxEvents = 100) + try { + // Register order 1: managed first, then unmanaged -> warning from unmanaged register. + UnifiedMemoryManager.registerManagedConsumer(both) + withLogAppender(appender) { + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(both) + } + assert(appender.loggingEvents.exists(_.getMessage.getFormattedMessage.contains( + "registered as BOTH ManagedConsumer and UnmanagedMemoryConsumer"))) + } finally { + UnifiedMemoryManager.unregisterUnmanagedMemoryConsumer(both) + UnifiedMemoryManager.unregisterManagedConsumer(both) + } + + val both2 = new Both("cross-spi-2") + val appender2 = new LogAppender("cross-SPI warn reverse", maxEvents = 100) + try { + // Register order 2: unmanaged first, then managed -> warning from managed register. + UnifiedMemoryManager.registerUnmanagedMemoryConsumer(both2) + withLogAppender(appender2) { + UnifiedMemoryManager.registerManagedConsumer(both2) + } + assert(appender2.loggingEvents.exists(_.getMessage.getFormattedMessage.contains( + "registered as BOTH ManagedConsumer and UnmanagedMemoryConsumer"))) + } finally { + UnifiedMemoryManager.unregisterUnmanagedMemoryConsumer(both2) + UnifiedMemoryManager.unregisterManagedConsumer(both2) + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f5fca56e5ef77..72b993ef07b00 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -279,7 +279,8 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite // create 1 faulty block manager by injecting faulty memory manager val memManager = UnifiedMemoryManager(conf, numCores = 1) val mockedMemoryManager = spy[UnifiedMemoryManager](memManager) - doAnswer(_ => false).when(mockedMemoryManager).acquireStorageMemory(any(), any(), any()) + doAnswer(_ => false).when(mockedMemoryManager) + .acquireStorageMemory(any[BlockId](), any(), any()) val store2 = makeBlockManager(10000, "host-2", Some(mockedMemoryManager)) assert(master.getPeers(store1.blockManagerId).toSet ===