[Mlir-commits] [mlir] [mlir][tensor] Fold `tensor.reshape` for dynamic reshape (PR #88961)

Rob Suderman llvmlistbot at llvm.org
Thu Apr 18 11:32:35 PDT 2024


https://github.com/rsuderman updated https://github.com/llvm/llvm-project/pull/88961

>From 825760b52502c31f09e20e0a875f4254d17bd7b8 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Tue, 16 Apr 2024 11:56:57 -0700
Subject: [PATCH 1/4] [mlir][tensor] Fold `tensor.reshape` for dynamic reshape

If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we
know that the reshape is a no-op. Checking for this case lets us fold
away the computation.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 42 +++++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 47 ++++++++++++++++++++++
 2 files changed, 89 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..50d3cd45a2dfe9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1580,6 +1580,48 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
           getResult().getType()))
     return reshapedSource;
+
+  auto source = getSource();
+  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
+  auto resultTy = dyn_cast<RankedTensorType>(getType());
+
+  if (!sourceTy || !resultTy || sourceTy != resultTy)
+    return {};
+
+  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
+    auto elements = fromElements.getElements();
+    bool dynamicNoop =
+        sourceTy.getRank() == static_cast<int64_t>(elements.size());
+    for (auto [id, element] : llvm::enumerate(elements)) {
+      APSInt cstElement;
+      if (matchPattern(element, m_ConstantInt(&cstElement))) {
+        if (cstElement.getExtValue() != sourceTy.getDimSize(id)) {
+          dynamicNoop = false;
+          break;
+        }
+        continue;
+      }
+
+      if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
+        if (dimOp.getSource() != source) {
+          dynamicNoop = false;
+          break;
+        }
+
+        APSInt dim;
+        if (!matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) ||
+            dim.getExtValue() != static_cast<int64_t>(id)) {
+          dynamicNoop = false;
+          break;
+        }
+        continue;
+      }
+    }
+
+    if (dynamicNoop)
+      return source;
+  }
+
   return {};
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ac365c9d297e88..751c57eacd7ae5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
 
 // -----
 
+// CHECK-LABEL: @reshape_fold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+  %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+  // CHECK: return %[[ARG0]]
+  return %reshape : tensor<?x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_nofold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+  %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
+  // CHECK: tensor.reshape
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+  return %reshape : tensor<?x?xi32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @reshape_fold_3d_cst
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
+func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %d0 = arith.constant 5 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
+  %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
+  %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
+  // CHECK: return %[[ARG0]]
+  return %reshape : tensor<5x?x?xi32>
+}
+
+// -----
+
 // Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
 // CHECK-LABEL: func @dim_out_of_bounds(
 //       CHECK: %[[IDX:.*]] = index.constant 28

>From d25829784b8ffd8025330f9519878817e5107632 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Wed, 17 Apr 2024 11:14:23 -0700
Subject: [PATCH 2/4] comments

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 20 +++++++-------------
 1 file changed, 7 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 50d3cd45a2dfe9..c41e8110fe8942 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1593,27 +1593,21 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     bool dynamicNoop =
         sourceTy.getRank() == static_cast<int64_t>(elements.size());
     for (auto [id, element] : llvm::enumerate(elements)) {
+      if (!dynamicNoop)
+        break;
+
       APSInt cstElement;
       if (matchPattern(element, m_ConstantInt(&cstElement))) {
-        if (cstElement.getExtValue() != sourceTy.getDimSize(id)) {
-          dynamicNoop = false;
-          break;
-        }
+        dynamicNoop &= cstElement.getExtValue() == sourceTy.getDimSize(id);
         continue;
       }
 
       if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
-        if (dimOp.getSource() != source) {
-          dynamicNoop = false;
-          break;
-        }
+        dynamicNoop &= dimOp.getSource() == source;
 
         APSInt dim;
-        if (!matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) ||
-            dim.getExtValue() != static_cast<int64_t>(id)) {
-          dynamicNoop = false;
-          break;
-        }
+        dynamicNoop &= matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) &&
+                       dim.getExtValue() == static_cast<int64_t>(id);
         continue;
       }
     }

>From f5a42236a99bea178d046c7597826478368b400f Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Wed, 17 Apr 2024 14:35:35 -0700
Subject: [PATCH 3/4] more fixuup

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index c41e8110fe8942..1adbe1452ed4cf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1592,9 +1592,8 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     auto elements = fromElements.getElements();
     bool dynamicNoop =
         sourceTy.getRank() == static_cast<int64_t>(elements.size());
-    for (auto [id, element] : llvm::enumerate(elements)) {
-      if (!dynamicNoop)
-        break;
+    for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
+      auto element = elements[id];
 
       APSInt cstElement;
       if (matchPattern(element, m_ConstantInt(&cstElement))) {

>From eed930fccd305fcc1960c6ee48b7a296018ff580 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Thu, 18 Apr 2024 11:32:13 -0700
Subject: [PATCH 4/4] rework for matthias comments

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 1adbe1452ed4cf..8d387284caba37 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1595,9 +1595,8 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
       auto element = elements[id];
 
-      APSInt cstElement;
-      if (matchPattern(element, m_ConstantInt(&cstElement))) {
-        dynamicNoop &= cstElement.getExtValue() == sourceTy.getDimSize(id);
+      if (auto cst = getConstantIntValue(element)) {
+        dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
         continue;
       }
 
@@ -1605,8 +1604,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
         dynamicNoop &= dimOp.getSource() == source;
 
         APSInt dim;
-        dynamicNoop &= matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) &&
-                       dim.getExtValue() == static_cast<int64_t>(id);
+        auto cst = getConstantIntValue(dimOp.getIndex());
+        dynamicNoop &=
+            cst.has_value() && cst.value() == static_cast<int64_t>(id);
         continue;
       }
     }



More information about the Mlir-commits mailing list