[Mlir-commits] [mlir] [MLIR] Fixes arith.sub folder crash on dynamically shaped tensors (PR #118908)
Mehdi Amini
llvmlistbot at llvm.org
Thu Dec 5 21:19:55 PST 2024
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/118908
>From 57bab0520ea754b44f4188c91ad23a23a3c43a37 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 5 Dec 2024 16:52:35 -0800
Subject: [PATCH] [MLIR] Fixes arith.sub folder crash on dynamically shaped
tensors
We can't create a constant for a value with dynamic shape.
Fixes #118772
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 ++++++--
mlir/test/Dialect/Arith/canonicalize.mlir | 21 +++++++++++++++++++++
2 files changed, 27 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f445231b80fdf..700258a1b6254a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -393,8 +393,12 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(
OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
// subi(x,x) -> 0
- if (getOperand(0) == getOperand(1))
- return Builder(getContext()).getZeroAttr(getType());
+ if (getOperand(0) == getOperand(1)) {
+ auto shapedType = dyn_cast<ShapedType>(getType());
+ // We can't generate a constant with a dynamic shaped tensor.
+ if (!shapedType || shapedType.hasStaticShape())
+ return Builder(getContext()).getZeroAttr(getType());
+ }
// subi(x,0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 69df83d42f543e..f1e36c2707a8f0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -869,6 +869,27 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
return %add2 : index
}
+
+// CHECK-LABEL: @foldSubXX_tensor
+// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[sub:.+]] = arith.subi
+// CHECK: return %[[c0]], %[[sub]]
+func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
+ %static_sub = arith.subi %static, %static : tensor<10xi32>
+ %dyn_sub = arith.subi %dyn, %dyn : tensor<?x?xi32>
+ return %static_sub, %dyn_sub : tensor<10xi32>, tensor<?x?xi32>
+}
+
+// CHECK-LABEL: @foldSubXX_vector
+// CHECK-DAG: %[[c0:.+]] = arith.constant dense<0> : vector<8xi32>
+// CHECK-DAG: %[[c0_scalable:.+]] = arith.constant dense<0> : vector<[4]xi32>
+// CHECK: return %[[c0]], %[[c0_scalable]]
+func.func @foldSubXX_vector(%static : vector<8xi32>, %dyn : vector<[4]xi32>) -> (vector<8xi32>, vector<[4]xi32>) {
+ %static_sub = arith.subi %static, %static : vector<8xi32>
+ %dyn_sub = arith.subi %dyn, %dyn : vector<[4]xi32>
+ return %static_sub, %dyn_sub : vector<8xi32>, vector<[4]xi32>
+}
+
// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
More information about the Mlir-commits
mailing list