[Mlir-commits] [mlir] [mlir][tensor] Fix empty tensor with cast encoding fold (PR #187963)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 22 20:07:52 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Hocky Yudhiono (hockyy)
<details>
<summary>Changes</summary>
Fixed a todo where empty tensor with cast fold can't fold encoding or attributes.
---
Full diff: https://github.com/llvm/llvm-project/pull/187963.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+3-5)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+13)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..4d266a635b7cf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1205,9 +1205,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
newMixedSizes.reserve(currMixedSizes.size());
assert(resultShape.size() == currMixedSizes.size() &&
"mismatch in result shape and sizes of empty op");
- for (auto it : llvm::zip(resultShape, currMixedSizes)) {
- int64_t newDim = std::get<0>(it);
- OpFoldResult currDim = std::get<1>(it);
+ for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
// Case 1: The empty tensor dim is static. Check that the tensor cast
// result dim matches.
if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
@@ -1236,9 +1234,9 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
newMixedSizes.push_back(currDim);
}
- // TODO: Do not drop tensor encoding.
rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
- resultType.getElementType());
+ resultType.getElementType(),
+ resultType.getEncoding());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index b251313b6580b..67b7ab99c5d18 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2279,6 +2279,19 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
// -----
+func.func @fold_empty_tensor_with_cast_encoding(%arg0 : index)
+ -> tensor<1x12xf32, "foo"> {
+ %0 = tensor.empty(%arg0) : tensor<?x12xf32, "foo">
+ %1 = tensor.cast %0 : tensor<?x12xf32, "foo"> to tensor<1x12xf32, "foo">
+ return %1 : tensor<1x12xf32, "foo">
+}
+// CHECK-LABEL: func @fold_empty_tensor_with_cast_encoding
+// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32, "foo">
+// CHECK-NOT: tensor.cast
+// CHECK: return %[[T0]] : tensor<1x12xf32, "foo">
+
+// -----
+
func.func private @some_use(%i : index, %j : index)
// CHECK-LABEL: func @empty_tensor_canonicalize
``````````
</details>
https://github.com/llvm/llvm-project/pull/187963
More information about the Mlir-commits
mailing list