diff --git a/src/relax/transform/reorder_permute_dims_after_concat.cc b/src/relax/transform/reorder_permute_dims_after_concat.cc index bc542ccf91ef..01eadc37f376 100644 --- a/src/relax/transform/reorder_permute_dims_after_concat.cc +++ b/src/relax/transform/reorder_permute_dims_after_concat.cc @@ -151,8 +151,19 @@ std::tuple)>> auto concat_attrs = concat_call->attrs.as(); TVM_FFI_ICHECK(concat_attrs); - auto old_concat_axis = [&]() -> size_t { return concat_attrs->axis.value_or(0); }(); - Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis]; + auto permute_dims_axes = get_permute_dims_axes(all_permute_dims[0]); + + int64_t old_concat_axis = concat_attrs->axis.value_or(0); + int64_t ndim = static_cast(permute_dims_axes.size()); + if (old_concat_axis < 0) { + old_concat_axis += ndim; + } + TVM_FFI_ICHECK_GE(old_concat_axis, 0) + << "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input"; + TVM_FFI_ICHECK_LT(old_concat_axis, ndim) + << "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input"; + + Integer new_concat_axis = permute_dims_axes[static_cast(old_concat_axis)]; auto new_concat = concat(Tuple(args), new_concat_axis->value); auto new_permute_dims = permute_dims(new_concat, permute_axes); diff --git a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py index f93daa4c1e00..2da6cfcda99b 100644 --- a/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py +++ b/tests/python/relax/test_transform_reorder_permute_dims_after_concat.py @@ -261,5 +261,34 @@ def main( return out +class TestNegativeConcatAxis(Base): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([1, 4, 8, 8], "float32"), + y: R.Tensor([1, 4, 8, 8], "float32"), + ): + with R.dataflow(): + xt = R.permute_dims(x, axes=[0, 2, 3, 1]) + yt = R.permute_dims(y, axes=[0, 2, 3, 1]) + out = R.concat([xt, yt], axis=-1) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([1, 4, 8, 8], "float32"), + y: R.Tensor([1, 4, 8, 8], "float32"), + ): + with R.dataflow(): + merged = R.concat([x, y], axis=1) + out = R.permute_dims(merged, axes=[0, 2, 3, 1]) + R.output(out) + return out + + if __name__ == "__main__": tvm.testing.main()