[Mlir-commits] [mlir] 3fa88f0 - [mlir][tensor] Fix empty tensor with cast encoding fold (#187963)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 02:06:42 PDT 2026
Author: Hocky Yudhiono
Date: 2026-03-23T09:06:36Z
New Revision: 3fa88f0846b21503f0a14689a2696b979dd25c5d
URL: https://github.com/llvm/llvm-project/commit/3fa88f0846b21503f0a14689a2696b979dd25c5d
DIFF: https://github.com/llvm/llvm-project/commit/3fa88f0846b21503f0a14689a2696b979dd25c5d.diff
LOG: [mlir][tensor] Fix empty tensor with cast encoding fold (#187963)
Fixed a todo where empty tensor with cast fold can't fold encoding or
attributes.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ce0f8540d884a..4d266a635b7cf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1205,9 +1205,7 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
newMixedSizes.reserve(currMixedSizes.size());
assert(resultShape.size() == currMixedSizes.size() &&
"mismatch in result shape and sizes of empty op");
- for (auto it : llvm::zip(resultShape, currMixedSizes)) {
- int64_t newDim = std::get<0>(it);
- OpFoldResult currDim = std::get<1>(it);
+ for (auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
// Case 1: The empty tensor dim is static. Check that the tensor cast
// result dim matches.
if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
@@ -1236,9 +1234,9 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
newMixedSizes.push_back(currDim);
}
- // TODO: Do not drop tensor encoding.
rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
- resultType.getElementType());
+ resultType.getElementType(),
+ resultType.getEncoding());
return success();
}
};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index b251313b6580b..67b7ab99c5d18 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2279,6 +2279,19 @@ func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
// -----
+func.func @fold_empty_tensor_with_cast_encoding(%arg0 : index)
+ -> tensor<1x12xf32, "foo"> {
+ %0 = tensor.empty(%arg0) : tensor<?x12xf32, "foo">
+ %1 = tensor.cast %0 : tensor<?x12xf32, "foo"> to tensor<1x12xf32, "foo">
+ return %1 : tensor<1x12xf32, "foo">
+}
+// CHECK-LABEL: func @fold_empty_tensor_with_cast_encoding
+// CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32, "foo">
+// CHECK-NOT: tensor.cast
+// CHECK: return %[[T0]] : tensor<1x12xf32, "foo">
+
+// -----
+
func.func private @some_use(%i : index, %j : index)
// CHECK-LABEL: func @empty_tensor_canonicalize
More information about the Mlir-commits
mailing list