[Mlir-commits] [mlir] [mlir][tensor] Fold identity `reshape` of 0d-tensors (PR #146375)
Markus Böck
llvmlistbot at llvm.org
Mon Jun 30 08:42:57 PDT 2025
https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/146375
Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always no-folds as they only have one possible layout. This PR adds logic to the `fold` implementation to optimize these away as is currently implemented for 1d tensors.
>From bcf4ef70d9c6ddef57983062db6228b40d7a0bb6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <mboeck at nvidia.com>
Date: Mon, 30 Jun 2025 17:41:14 +0200
Subject: [PATCH] [mlir][tensor] Fold identity `reshape` of 0d-tensors
Just like 1d-tensors, reshapes of 0d-tensors (aka scalars) are always no-folds as they only have one possible layout.
This PR adds logic to the `fold` implementation to optimize these away as is currently implemented for 1d tensors.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 6 +++---
mlir/test/Dialect/Tensor/canonicalize.mlir | 11 +++++++++++
2 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 22a25fd1a5af8..0430e6fc6c63f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1872,9 +1872,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};
- // If the source and result are both 1D tensors and have the same type, the
- // reshape has no effect, even if the tensor is dynamically shaped.
- if (sourceTy.getRank() == 1)
+ // If the source and result are both 0D or 1D tensors and have the same type,
+ // the reshape has no effect, even if the tensor is dynamically shaped.
+ if (sourceTy.getRank() <= 1)
return source;
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3f9236095138b..95c5b8c91edf5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -971,6 +971,17 @@ func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> te
// -----
+// CHECK-LABEL: func @fold_reshape_0d
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<0xindex>
+// CHECK: return %[[INPUT]]
+func.func @fold_reshape_0d(%input: tensor<f32>, %shape: tensor<0xindex>) -> tensor<f32> {
+ %0 = tensor.reshape %input(%shape) : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
More information about the Mlir-commits
mailing list