Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ public int objectSize() {
public Variant getFieldByKey(String key) {
return handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> {
// Use linear search for a short list. Switch to binary search when the length reaches
// `BINARY_SEARCH_THRESHOLD`.
// `BINARY_SEARCH_THRESHOLD` and the object fields are sorted by key name (indicated by
// bit 5 of the type info in the header byte).
final int BINARY_SEARCH_THRESHOLD = 32;
if (size < BINARY_SEARCH_THRESHOLD) {
int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
boolean sorted = (typeInfo & 0x20) != 0;
if (size < BINARY_SEARCH_THRESHOLD || !sorted) {
for (int i = 0; i < size; ++i) {
int id = readUnsigned(value, idStart + idSize * i, idSize);
if (key.equals(getMetadataKey(metadata, id))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.types.variant.VariantBuilder
import org.apache.spark.types.variant.{Variant, VariantBuilder}
import org.apache.spark.types.variant.VariantUtil._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
import org.apache.spark.util.collection.Utils.createArray
Expand Down Expand Up @@ -1200,4 +1200,46 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
// Null input.
checkEvaluation(IsValidVariant(Literal.create(null, VariantType)), null)
}

test("SPARK-56637: getFieldByKey on unsorted object with >= 32 fields") {
// Build a variant object with 32 fields where field IDs are NOT in sorted order.
// The sort bit (bit 5 of type info) is NOT set, so binary search must not be used.
val numFields = 32
// Create metadata with 32 keys: "key00", "key01", ..., "key31" (sorted alphabetically by id).
val keys = (0 until numFields).map(i => f"key$i%02d")
// Build metadata: version byte, dict size (1 byte), (numFields+1) offsets (1 byte each), data.
val metadataBytes = new java.io.ByteArrayOutputStream()
metadataBytes.write(VERSION) // version=1, offsetSize=1 (upper 2 bits = 0)
metadataBytes.write(numFields) // dict size
// Offsets: each key is 5 bytes ("keyXX")
for (i <- 0 to numFields) metadataBytes.write(i * 5)
// String data
for (k <- keys) metadataBytes.write(k.getBytes("UTF-8"))
val metadata = metadataBytes.toByteArray

// Build value: an object with 32 fields, field IDs in REVERSE order (unsorted).
// Each field value is a small integer (INT1 primitive).
// Object header byte: basic_type=OBJECT(2), type_info with sort bit=0, largeSize=0,
// idSize=1(b3b2=00), offsetSize=1(b1b0=00). Full byte = 0b00_000000_10 = 0x02.
// (sort bit NOT set means fields may be unsorted)
val header: Byte = (OBJECT).toByte // type_info=0, basic_type=OBJECT
val size = numFields
// Field IDs in reverse order: 31, 30, ..., 0
val fieldIds = (0 until numFields).reverse.map(_.toByte).toArray
// Each field value is 2 bytes (header + 1 byte int), so offsets are 0, 2, 4, ...
val offsets = (0 to numFields).map(i => (i * 2).toByte).toArray
// Field data: INT1 values (header byte + 1 value byte)
val fieldData = (0 until numFields).flatMap(i => Seq(primitiveHeader(INT1), i.toByte)).toArray

val value = Array(header, size.toByte) ++ fieldIds ++ offsets ++ fieldData
val v = new Variant(value, metadata)

// Look up "key00" which has id=0. In the reversed id list, id=0 is at position 31.
// Binary search on unsorted data would fail to find it.
val result = v.getFieldByKey("key00")
assert(result != null, "getFieldByKey should find 'key00' in unsorted object")
// The field at position 31 (where id=0 is) has value = the INT1 at offset 31*2 = 62,
// which is the 31st field data entry (value byte = 31).
assert(result.getLong == 31)
}
}