[Mlir-commits] [mlir] [mlir][arith] Fix multiplication canonicalizations (PR #144787)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 18 12:57:48 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tobias Gysi (gysit)
<details>
<summary>Changes</summary>
The Arith dialect includes patterns that canonicalize a sequence of:
- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width.
For example, the following code:
```mlir
%x = arith.extui %a: i32 to i33
%y = arith.extui %b: i32 to i33
%m = arith.muli %x, %y: i33
%c1 = arith.constant 1: i33
%sh = arith.shrui %m, %c1 : i33
%hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```mlir
_, %hi = arith.mului_extended %a, %b : i32
````
---
Full diff: https://github.com/llvm/llvm-project/pull/144787.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+11-3)
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+31-1)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a910bd4..2f7beed549108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;
-// select(pred, false, true) => not(pred)
+// select(pred, false, true) => not(pred)
def SelectI1ToNot :
Pat<(SelectOp $pred,
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount :
CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
"*getIntOrSplatIntValue($2)">]>>;
+def ValueWidthMatchesShiftAmount :
+ Constraint<And<[
+ CPred<"succeeded(getIntOrSplatIntValue($1))">,
+ CPred<"getScalarOrElementWidth($0) == "
+ "*getIntOrSplatIntValue($1)">]>>;
+
// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
def TruncIExtSIToExtSI :
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
@@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended :
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
- (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+ (TruncationMatchesShiftAmount $mul, $x, $c0),
+ (ValueWidthMatchesShiftAmount $x, $c0)]>;
// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
def TruncIShrUIMulIToMulUIExtended :
@@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended :
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
- (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+ (TruncationMatchesShiftAmount $mul, $x, $c0),
+ (ValueWidthMatchesShiftAmount $x, $c0)]>;
//===----------------------------------------------------------------------===//
// TruncIOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b6188c81ff912..542603722ab8a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
// CHECK-LABEL: @foldSubXX_tensor
-// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
+// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
// CHECK: %[[sub:.+]] = arith.subi
// CHECK: return %[[c0]], %[[sub]]
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
return %hi : i32
}
+// Verify that the signed extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift
+// CHECK-NOT: arith.mulsi_extended
+func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+ %x = arith.extsi %a: i32 to i33
+ %y = arith.extsi %b: i32 to i33
+ %m = arith.muli %x, %y: i33
+ %c1 = arith.constant 1: i33
+ %sh = arith.shrui %m, %c1 : i33
+ %hi = arith.trunci %sh: i33 to i32
+ return %hi : i32
+}
+
// CHECK-LABEL: @wideMulToMulSIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
@@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
return %hi : i32
}
+// Verify that the unsigned extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift
+// CHECK-NOT: arith.mului_extended
+func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+ %x = arith.extui %a: i32 to i33
+ %y = arith.extui %b: i32 to i33
+ %m = arith.muli %x, %y: i33
+ %c1 = arith.constant 1: i33
+ %sh = arith.shrui %m, %c1 : i33
+ %hi = arith.trunci %sh: i33 to i32
+ return %hi : i32
+}
+
// CHECK-LABEL: @wideMulToMulUIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/144787
More information about the Mlir-commits
mailing list