[Mlir-commits] [mlir] [mlir][tensor] Fold identity `reshape` of 0d-tensors (PR #146375)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 30 08:43:28 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Markus Böck (zero9178)
<details>
<summary>Changes</summary>
Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always no-folds as they only have one possible layout. This PR adds logic to the `fold` implementation to optimize these away as is currently implemented for 1d tensors.
---
Full diff: https://github.com/llvm/llvm-project/pull/146375.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+3-3)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+11)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 22a25fd1a5af8..0430e6fc6c63f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};
- // If the source and result are both 1D tensors and have the same type, the
- // reshape has no effect, even if the tensor is dynamically shaped.
- if (sourceTy.getRank() == 1)
+ // If the source and result are both 0D or 1D tensors and have the same type,
+ // the reshape has no effect, even if the tensor is dynamically shaped.
+ if (sourceTy.getRank() <= 1)
return source;
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3f9236095138b..95c5b8c91edf5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> te
// -----
+// CHECK-LABEL: func @fold_reshape_0d
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex>
+// CHECK: return %[[INPUT]]
+func.func @fold_reshape_0d(%input: tensor<f32>, %shape: tensor<0xindex>) -> tensor<f32> {
+ %0 = tensor.reshape %input(%shape) : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/146375
More information about the Mlir-commits
mailing list