Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/relax/transform/reorder_permute_dims_after_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,19 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
auto concat_attrs = concat_call->attrs.as<ConcatAttrs>();
TVM_FFI_ICHECK(concat_attrs);

auto old_concat_axis = [&]() -> size_t { return concat_attrs->axis.value_or(0); }();
Integer new_concat_axis = get_permute_dims_axes(all_permute_dims[0])[old_concat_axis];
auto permute_dims_axes = get_permute_dims_axes(all_permute_dims[0]);

int64_t old_concat_axis = concat_attrs->axis.value_or(0);
int64_t ndim = static_cast<int64_t>(permute_dims_axes.size());
if (old_concat_axis < 0) {
old_concat_axis += ndim;
}
TVM_FFI_ICHECK_GE(old_concat_axis, 0)
<< "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input";
TVM_FFI_ICHECK_LT(old_concat_axis, ndim)
<< "concat axis " << old_concat_axis << " out of range for " << ndim << "-D input";

Integer new_concat_axis = permute_dims_axes[static_cast<size_t>(old_concat_axis)];
Comment thread
cchung100m marked this conversation as resolved.

auto new_concat = concat(Tuple(args), new_concat_axis->value);
auto new_permute_dims = permute_dims(new_concat, permute_axes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,34 @@ def main(
return out


class TestNegativeConcatAxis(Base):
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([1, 4, 8, 8], "float32"),
y: R.Tensor([1, 4, 8, 8], "float32"),
):
with R.dataflow():
xt = R.permute_dims(x, axes=[0, 2, 3, 1])
yt = R.permute_dims(y, axes=[0, 2, 3, 1])
out = R.concat([xt, yt], axis=-1)
R.output(out)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([1, 4, 8, 8], "float32"),
y: R.Tensor([1, 4, 8, 8], "float32"),
):
with R.dataflow():
merged = R.concat([x, y], axis=1)
out = R.permute_dims(merged, axes=[0, 2, 3, 1])
R.output(out)
return out


if __name__ == "__main__":
tvm.testing.main()
Loading