[Mlir-commits] [mlir] [mlir][tensor] Fix empty tensor with cast encoding fold (PR #187963)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 22 20:07:52 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Hocky Yudhiono (hockyy)

<details>
<summary>Changes</summary>

Fixed a todo where empty tensor with cast fold can't fold encoding or attributes.

---
Full diff: https://github.com/llvm/llvm-project/pull/187963.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+3-5) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+13) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..4d266a635b7cf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1205,9 +1205,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
     newMixedSizes.reserve(currMixedSizes.size());
     assert(resultShape.size() == currMixedSizes.size() &&
            "mismatch in result shape and sizes of empty op");
-    for (auto it : llvm::zip(resultShape, currMixedSizes)) {
-      int64_t newDim = std::get<0>(it);
-      OpFoldResult currDim = std::get<1>(it);
+    for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
       // Case 1: The empty tensor dim is static. Check that the tensor cast
       // result dim matches.
       if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
@@ -1236,9 +1234,9 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
       newMixedSizes.push_back(currDim);
     }
 
-    // TODO: Do not drop tensor encoding.
     rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
-                                         resultType.getElementType());
+                                         resultType.getElementType(),
+                                         resultType.getEncoding());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index b251313b6580b..67b7ab99c5d18 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2279,6 +2279,19 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
 
 // -----
 
+func.func @fold_empty_tensor_with_cast_encoding(%arg0 : index)
+    -> tensor<1x12xf32, "foo"> {
+  %0 = tensor.empty(%arg0) : tensor<?x12xf32, "foo">
+  %1 = tensor.cast %0 : tensor<?x12xf32, "foo"> to tensor<1x12xf32, "foo">
+  return %1 : tensor<1x12xf32, "foo">
+}
+// CHECK-LABEL: func @fold_empty_tensor_with_cast_encoding
+//       CHECK:   %[[T0:.+]] = tensor.empty() : tensor<1x12xf32, "foo">
+//   CHECK-NOT:   tensor.cast
+//       CHECK:   return %[[T0]] : tensor<1x12xf32, "foo">
+
+// -----
+
 func.func private @some_use(%i : index, %j : index)
 
 // CHECK-LABEL: func @empty_tensor_canonicalize

``````````

</details>


https://github.com/llvm/llvm-project/pull/187963


More information about the Mlir-commits mailing list