Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions tensorflow/lite/micro/memory_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <limits>

#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
Expand Down Expand Up @@ -106,19 +107,36 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {

TfLiteStatus BytesRequiredForTensor(const tflite::Tensor& flatbuffer_tensor,
size_t* bytes, size_t* type_size) {
int element_count = 1;
// Use size_t for the running product so the byte computation is performed
// in the same unsigned domain as the result. Each multiplication is guarded
// against overflow because shape dimensions and the element type size all
// come from the (untrusted) flatbuffer model.
size_t element_count = 1;
// If flatbuffer_tensor.shape == nullptr, then flatbuffer_tensor is a scalar
// so has 1 element.
if (flatbuffer_tensor.shape() != nullptr) {
for (size_t n = 0; n < flatbuffer_tensor.shape()->size(); ++n) {
element_count *= flatbuffer_tensor.shape()->Get(n);
const int32_t dim = flatbuffer_tensor.shape()->Get(n);
if (dim < 0) {
return kTfLiteError;
}
const size_t udim = static_cast<size_t>(dim);
if (udim != 0 &&
element_count > std::numeric_limits<size_t>::max() / udim) {
return kTfLiteError;
}
element_count *= udim;
}
}

TfLiteType tf_lite_type;
TF_LITE_ENSURE_STATUS(
ConvertTensorType(flatbuffer_tensor.type(), &tf_lite_type));
TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(tf_lite_type, type_size));
if (*type_size != 0 &&
element_count > std::numeric_limits<size_t>::max() / *type_size) {
return kTfLiteError;
}
*bytes = element_count * (*type_size);
return kTfLiteOk;
}
Expand All @@ -127,15 +145,28 @@ TfLiteStatus TfLiteEvalTensorByteLength(const TfLiteEvalTensor* eval_tensor,
size_t* out_bytes) {
TFLITE_DCHECK(out_bytes != nullptr);

int element_count = 1;
size_t element_count = 1;
// If eval_tensor->dims == nullptr, then tensor is a scalar so has 1 element.
if (eval_tensor->dims != nullptr) {
for (int n = 0; n < eval_tensor->dims->size; ++n) {
element_count *= eval_tensor->dims->data[n];
const int dim = eval_tensor->dims->data[n];
if (dim < 0) {
return kTfLiteError;
}
const size_t udim = static_cast<size_t>(dim);
if (udim != 0 &&
element_count > std::numeric_limits<size_t>::max() / udim) {
return kTfLiteError;
}
element_count *= udim;
}
}
size_t type_size;
TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(eval_tensor->type, &type_size));
if (type_size != 0 &&
element_count > std::numeric_limits<size_t>::max() / type_size) {
return kTfLiteError;
}
*out_bytes = element_count * type_size;
return kTfLiteOk;
}
Expand All @@ -152,10 +183,16 @@ TfLiteStatus AllocateOutputDimensionsFromInput(TfLiteContext* context,
input = input1->dims->size > input2->dims->size ? input1 : input2;
TF_LITE_ENSURE(context, output->type == input->type);
size_t size = 0;
TfLiteTypeSizeOf(input->type, &size);
TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(input->type, &size));
const int dimensions_count = tflite::GetTensorShape(input).DimensionsCount();
for (int i = 0; i < dimensions_count; i++) {
size *= input->dims->data[i];
const int dim = input->dims->data[i];
TF_LITE_ENSURE(context, dim >= 0);
const size_t udim = static_cast<size_t>(dim);
TF_LITE_ENSURE(context,
udim == 0 ||
size <= std::numeric_limits<size_t>::max() / udim);
size *= udim;
}
output->bytes = size;

Expand Down
52 changes: 52 additions & 0 deletions tensorflow/lite/micro/memory_helpers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,58 @@ TEST(MemoryHelpersTest, TestBytesRequiredForTensor) {
EXPECT_EQ(static_cast<size_t>(4), type_size);
}

TEST(MemoryHelpersTest,
TfLiteEvalTensorByteLengthDoesNotTruncateAcrossInt32Boundary) {
// Shape [65536, 65536] with float32: 65536 * 65536 * 4 = 17179869184 bytes.
// The original implementation used a signed 32-bit running product, so the
// element count wrapped to 0 and *out_bytes was reported as 0 even though
// any subsequent allocation would address ~17 GiB. The hardened
// implementation performs the arithmetic in size_t and must report the
// mathematically correct value (where size_t is wide enough), or refuse the
// request via kTfLiteError when it would otherwise overflow size_t.
int dims[] = {2, 65536, 65536};
TfLiteEvalTensor eval_tensor = {};
eval_tensor.dims = tflite::testing::IntArrayFromInts(dims);
eval_tensor.type = kTfLiteFloat32;

size_t out_bytes = 0;
const TfLiteStatus status =
tflite::TfLiteEvalTensorByteLength(&eval_tensor, &out_bytes);

if (sizeof(size_t) >= 8) {
EXPECT_EQ(kTfLiteOk, status);
EXPECT_EQ(static_cast<size_t>(17179869184ULL), out_bytes);
} else {
// 32-bit size_t cannot represent the result; the hardened code must
// refuse rather than silently truncate.
EXPECT_EQ(kTfLiteError, status);
}
}

TEST(MemoryHelpersTest, TfLiteEvalTensorByteLengthRejectsSizeTOverflow) {
// A shape whose product overflows size_t even on 64-bit platforms must be
// rejected. INT32_MAX^4 * 4 vastly exceeds 2^64.
int dims[] = {4, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff};
TfLiteEvalTensor eval_tensor = {};
eval_tensor.dims = tflite::testing::IntArrayFromInts(dims);
eval_tensor.type = kTfLiteFloat32;

size_t out_bytes = 0;
EXPECT_EQ(kTfLiteError,
tflite::TfLiteEvalTensorByteLength(&eval_tensor, &out_bytes));
}

TEST(MemoryHelpersTest, TfLiteEvalTensorByteLengthRejectsNegativeDimension) {
int dims[] = {2, -1, 4};
TfLiteEvalTensor eval_tensor = {};
eval_tensor.dims = tflite::testing::IntArrayFromInts(dims);
eval_tensor.type = kTfLiteFloat32;

size_t out_bytes = 0;
EXPECT_EQ(kTfLiteError,
tflite::TfLiteEvalTensorByteLength(&eval_tensor, &out_bytes));
}

TEST(MemoryHelpersTest, TestAllocateOutputDimensionsFromInput) {
constexpr int kDimsLen = 4;
int input1_dims[] = {1, 1};
Expand Down
Loading