[Mlir-commits] [mlir] [mlir][tensor] Fix tensor.reshape canonicalization (PR #90141)

Rob Suderman llvmlistbot at llvm.org
Thu Apr 25 15:47:10 PDT 2024


https://github.com/rsuderman created https://github.com/llvm/llvm-project/pull/90141

Canonicalization defaulted to replacement when the input dims were from unknown source. This is obviousl incorrect. Tweaked and included test to prevent future issue.

>From 561083ca59afded9618a7e15c1a3d1128a7343a6 Mon Sep 17 00:00:00 2001
From: Rob Suderman <rob.suderman at gmail.com>
Date: Thu, 25 Apr 2024 15:45:59 -0700
Subject: [PATCH] [mlir][tensor] Fix tensor.reshape canonicalization

Canonicalization defaulted to replacement when the input dims were from
unknown source. This is obviousl incorrect. Tweaked and included test to
prevent future issue.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  3 +++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 10 ++++++++++
 2 files changed, 13 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3ff41ab22fbc42..5029ed4aa0387a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1609,6 +1609,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
             cst.has_value() && cst.value() == static_cast<int64_t>(id);
         continue;
       }
+
+      dynamicNoop = false;
+      break;
     }
 
     if (dynamicNoop)
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 751c57eacd7ae5..cc1746a7a3de81 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2431,6 +2431,16 @@ func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
   return %reshape : tensor<?x?xi32>
 }
 
+// -----
+
+// CHECK-LABEL: @reshape_nofold_2d_ins
+func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
+  %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
+  // CHECK: tensor.reshape
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+  return %reshape : tensor<?x?xi32>
+}
+
 
 // -----
 



More information about the Mlir-commits mailing list