[Mlir-commits] [mlir] 1801fb4 - [MLIR] Fixes arith.sub folder crash on dynamically shaped tensors (#118908)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 6 06:24:32 PST 2024


Author: Mehdi Amini
Date: 2024-12-06T06:24:28-08:00
New Revision: 1801fb4bd358cd6be0d085f9b74aacbeea951a17

URL: https://github.com/llvm/llvm-project/commit/1801fb4bd358cd6be0d085f9b74aacbeea951a17
DIFF: https://github.com/llvm/llvm-project/commit/1801fb4bd358cd6be0d085f9b74aacbeea951a17.diff

LOG: [MLIR] Fixes arith.sub folder crash on dynamically shaped tensors (#118908)

We can't create a constant for a value with dynamic shape.

Fixes #118772

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f445231b80fdf..700258a1b6254a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -393,8 +393,12 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(
 
 OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
   // subi(x,x) -> 0
-  if (getOperand(0) == getOperand(1))
-    return Builder(getContext()).getZeroAttr(getType());
+  if (getOperand(0) == getOperand(1)) {
+    auto shapedType = dyn_cast<ShapedType>(getType());
+    // We can't generate a constant with a dynamic shaped tensor.
+    if (!shapedType || shapedType.hasStaticShape())
+      return Builder(getContext()).getZeroAttr(getType());
+  }
   // subi(x,0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 69df83d42f543e..f1e36c2707a8f0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -869,6 +869,27 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
   return %add2 : index
 }
 
+
+// CHECK-LABEL: @foldSubXX_tensor
+//       CHECK:   %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> 
+//       CHECK:   %[[sub:.+]] = arith.subi
+//       CHECK:   return %[[c0]], %[[sub]]
+func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
+  %static_sub = arith.subi %static, %static : tensor<10xi32>
+  %dyn_sub = arith.subi %dyn, %dyn : tensor<?x?xi32>
+  return %static_sub, %dyn_sub : tensor<10xi32>, tensor<?x?xi32>
+}
+
+// CHECK-LABEL: @foldSubXX_vector
+//       CHECK-DAG:  %[[c0:.+]] = arith.constant dense<0> : vector<8xi32>
+//       CHECK-DAG:  %[[c0_scalable:.+]] = arith.constant dense<0> : vector<[4]xi32>
+//       CHECK:   return %[[c0]], %[[c0_scalable]]
+func.func @foldSubXX_vector(%static : vector<8xi32>, %dyn : vector<[4]xi32>) -> (vector<8xi32>, vector<[4]xi32>) {
+  %static_sub = arith.subi %static, %static : vector<8xi32>
+  %dyn_sub = arith.subi %dyn, %dyn : vector<[4]xi32>
+  return %static_sub, %dyn_sub : vector<8xi32>, vector<[4]xi32>
+}
+
 // CHECK-LABEL: @tripleAddSub0
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index


        


More information about the Mlir-commits mailing list