diff --git a/tensorflow/lite/micro/kernels/circular_buffer_common.cc b/tensorflow/lite/micro/kernels/circular_buffer_common.cc index bf45c06f61c..9560dfec43c 100644 --- a/tensorflow/lite/micro/kernels/circular_buffer_common.cc +++ b/tensorflow/lite/micro/kernels/circular_buffer_common.cc @@ -53,6 +53,7 @@ TfLiteStatus CircularBufferPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, 1, input->dims->data[1]); TF_LITE_ENSURE_EQ(context, input->dims->data[2], output->dims->data[2]); TF_LITE_ENSURE_EQ(context, output->dims->data[3], input->dims->data[3]); + TF_LITE_ENSURE(context, output->dims->data[1] > 0); TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); diff --git a/tensorflow/lite/micro/kernels/circular_buffer_test.cc b/tensorflow/lite/micro/kernels/circular_buffer_test.cc index 9f07ba496c7..5ff50d0f993 100644 --- a/tensorflow/lite/micro/kernels/circular_buffer_test.cc +++ b/tensorflow/lite/micro/kernels/circular_buffer_test.cc @@ -234,4 +234,42 @@ TEST(CircularBufferTest, OutputTensorLength5) { } } +TEST(CircularBufferTest, RejectsZeroNumSlots) { + constexpr int depth = 4; + int8_t input_data[depth]; + int8_t output_data[1]; + + memset(output_data, 0, sizeof(output_data)); + int input_dims[] = {4, 1, 1, 1, depth}; + int output_dims[] = {4, 1, 0, 1, depth}; + TfLiteIntArray* input_tensor_dims = + tflite::testing::IntArrayFromInts(input_dims); + TfLiteIntArray* output_tensor_dims = + tflite::testing::IntArrayFromInts(output_dims); + + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + tflite::testing::CreateQuantizedTensor(input_data, input_tensor_dims, 1, + 0), + tflite::testing::CreateQuantizedTensor(output_data, output_tensor_dims, 1, + 0), + }; + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = + tflite::testing::IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = + tflite::testing::IntArrayFromInts(outputs_array_data); + + const TFLMRegistration* registration = tflite::Register_CIRCULAR_BUFFER(); + tflite::micro::KernelRunner runner = tflite::micro::KernelRunner( + *registration, tensors, tensors_size, inputs_array, outputs_array, + /*builtin_data=*/nullptr); + + EXPECT_EQ(kTfLiteError, runner.InitAndPrepare()); +} + TF_LITE_MICRO_TESTS_MAIN