[Mlir-commits] [mlir] 593f6fd - [mlir][tensor] Fix tensor.reshape canonicalization (#90141)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 25 17:41:16 PDT 2024


Author: Rob Suderman
Date: 2024-04-25T17:41:12-07:00
New Revision: 593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf

URL: https://github.com/llvm/llvm-project/commit/593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf
DIFF: https://github.com/llvm/llvm-project/commit/593f6fdcb4bb3ff81ba4e6f89d7b16540c4b9eaf.diff

LOG: [mlir][tensor] Fix tensor.reshape canonicalization (#90141)

Canonicalization defaulted to replacement when the input dims were from
unknown source. This is obviously incorrect. Tweaked and included test
to prevent future issue.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3ff41ab22fbc42..5029ed4aa0387a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1609,6 +1609,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
             cst.has_value() && cst.value() == static_cast<int64_t>(id);
         continue;
       }
+
+      dynamicNoop = false;
+      break;
     }
 
     if (dynamicNoop)

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 751c57eacd7ae5..9a4dd2f3b5cc11 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2431,6 +2431,15 @@ func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
   return %reshape : tensor<?x?xi32>
 }
 
+// -----
+
+// CHECK-LABEL: @reshape_nofold_2d_ins
+func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
+  %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
+  // CHECK: tensor.reshape
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+  return %reshape : tensor<?x?xi32>
+}
 
 // -----
 


        


More information about the Mlir-commits mailing list