[Mlir-commits] [mlir] 3fa88f0 - [mlir][tensor] Fix empty tensor with cast encoding fold (#187963)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 02:06:42 PDT 2026


Author: Hocky Yudhiono
Date: 2026-03-23T09:06:36Z
New Revision: 3fa88f0846b21503f0a14689a2696b979dd25c5d

URL: https://github.com/llvm/llvm-project/commit/3fa88f0846b21503f0a14689a2696b979dd25c5d
DIFF: https://github.com/llvm/llvm-project/commit/3fa88f0846b21503f0a14689a2696b979dd25c5d.diff

LOG: [mlir][tensor] Fix empty tensor with cast encoding fold (#187963)

Fixed a todo where empty tensor with cast fold can't fold encoding or
attributes.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..4d266a635b7cf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1205,9 +1205,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
     newMixedSizes.reserve(currMixedSizes.size());
     assert(resultShape.size() == currMixedSizes.size() &&
            "mismatch in result shape and sizes of empty op");
-    for (auto it : llvm::zip(resultShape, currMixedSizes)) {
-      int64_t newDim = std::get<0>(it);
-      OpFoldResult currDim = std::get<1>(it);
+    for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
       // Case 1: The empty tensor dim is static. Check that the tensor cast
       // result dim matches.
       if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
@@ -1236,9 +1234,9 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
       newMixedSizes.push_back(currDim);
     }
 
-    // TODO: Do not drop tensor encoding.
     rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
-                                         resultType.getElementType());
+                                         resultType.getElementType(),
+                                         resultType.getEncoding());
     return success();
   }
 };

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index b251313b6580b..67b7ab99c5d18 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2279,6 +2279,19 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
 
 // -----
 
+func.func @fold_empty_tensor_with_cast_encoding(%arg0 : index)
+    -> tensor<1x12xf32, "foo"> {
+  %0 = tensor.empty(%arg0) : tensor<?x12xf32, "foo">
+  %1 = tensor.cast %0 : tensor<?x12xf32, "foo"> to tensor<1x12xf32, "foo">
+  return %1 : tensor<1x12xf32, "foo">
+}
+// CHECK-LABEL: func @fold_empty_tensor_with_cast_encoding
+//       CHECK:   %[[T0:.+]] = tensor.empty() : tensor<1x12xf32, "foo">
+//   CHECK-NOT:   tensor.cast
+//       CHECK:   return %[[T0]] : tensor<1x12xf32, "foo">
+
+// -----
+
 func.func private @some_use(%i : index, %j : index)
 
 // CHECK-LABEL: func @empty_tensor_canonicalize


        


More information about the Mlir-commits mailing list