rutkoor commented on code in PR #15679: URL: https://github.com/apache/tvm/pull/15679#discussion_r1338498012
########## src/relax/transform/alter_op_impl.cc: ########## @@ -176,16 +190,41 @@ class AlterOpImplMutator : public ExprMutator { Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map, const TensorStructInfo& old_tensor_sinfo, const Array<IntImm>& axis_separator) { + if (IsScalarConstant(expr) || index_map.get() == nullptr) { + return expr; + } Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo); Array<Range> initial_ranges = ConstructRangeFromShape(old_shape); arith::Analyzer analyzer; auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges, &analyzer); - ICHECK(tir::is_zero(padding_predicate)) - << "Only bijective transformations on input/output buffers are supported, but found " - "padding predicate " - << padding_predicate << " on initial range " << initial_ranges; - return TransformLayout(expr, inverse_index_map, axis_separator); + + if (tir::is_zero(padding_predicate)) { + return TransformLayout(expr, inverse_index_map, axis_separator); + } else { + auto padded_expr = + builder_->Normalize(TransformLayout(expr, inverse_index_map, axis_separator)); + const auto& tensor_sinfo = Downcast<TensorStructInfo>(padded_expr->struct_info_); + Array<PrimExpr> padded_shape = GetShapeFromTensorStructInfo(tensor_sinfo); + + te::Tensor placeholder_tensor = te::placeholder(padded_shape, tensor_sinfo->dtype, "input"); + te::Tensor output_tensor = te::compute( + old_shape, + [&placeholder_tensor](const Array<tir::Var>& indices) { + return placeholder_tensor(indices); + }, + "output", topi::kElementWise); + + String op_name = "remove_pad"; + PrimFunc remove_pad_with_frozen_layout = + WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}), kOperatorName, op_name); + + GlobalVar gv_remove_pad = builder_->AddFunction(remove_pad_with_frozen_layout, op_name); + builder_->UpdateFunction(gv_remove_pad, + WithoutAttr(remove_pad_with_frozen_layout, "global_symbol")); Review Comment: I have introduced `GetOrCreateRemovePadOp` function which returns the `GlobalVar` of the `remove_pad PrimFunc` if it is present in `remove_pad_map_` map variable. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org