[Mlir-commits] [mlir] 6c9be27 - [mlir][tensor] Fold identity `reshape` of 0d-tensors (#146375)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 2 00:09:06 PDT 2025
Author: Markus Böck
Date: 2025-07-02T09:09:03+02:00
New Revision: 6c9be27b526fe1742755778948d0129ace92d357
URL: https://github.com/llvm/llvm-project/commit/6c9be27b526fe1742755778948d0129ace92d357
DIFF: https://github.com/llvm/llvm-project/commit/6c9be27b526fe1742755778948d0129ace92d357.diff
LOG: [mlir][tensor] Fold identity `reshape` of 0d-tensors (#146375)
Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always
no-folds as they only have one possible layout. This PR adds logic to
the `fold` implementation to optimize these away as is currently
implemented for 1d tensors.
Added:
Modified:
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 22a25fd1a5af8..0430e6fc6c63f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};
- // If the source and result are both 1D tensors and have the same type, the
- // reshape has no effect, even if the tensor is dynamically shaped.
- if (sourceTy.getRank() == 1)
+ // If the source and result are both 0D or 1D tensors and have the same type,
+ // the reshape has no effect, even if the tensor is dynamically shaped.
+ if (sourceTy.getRank() <= 1)
return source;
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3f9236095138b..95c5b8c91edf5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> te
// -----
+// CHECK-LABEL: func @fold_reshape_0d
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex>
+// CHECK: return %[[INPUT]]
+func.func @fold_reshape_0d(%input: tensor<f32>, %shape: tensor<0xindex>) -> tensor<f32> {
+ %0 = tensor.reshape %input(%shape) : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
More information about the Mlir-commits
mailing list