[Mlir-commits] [mlir] [mlir][arith] Fix multiplication canonicalizations (PR #144787)

Tobias Gysi llvmlistbot at llvm.org
Wed Jun 18 12:57:17 PDT 2025


https://github.com/gysit created https://github.com/llvm/llvm-project/pull/144787

The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width.

For example, the following code:
```mlir
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```mlir
_, %hi = arith.mului_extended %a, %b : i32
````

>From dce14244c82938dd3281a0db44b915e0df1e87b5 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 18 Jun 2025 19:36:12 +0000
Subject: [PATCH] [mlir][arith] Fix multiplication canonicalizations

The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication,
which assumes that the shift amount is equal to the bit width of the
original operands. This check was missing, leading to incorrect
canonicalizations when the shift amount was less than the bit width.

For example, the following code:
```mlir
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```mlir
_, %hi = arith.mului_extended %a, %b : i32
````
---
 .../Dialect/Arith/IR/ArithCanonicalization.td | 14 ++++++--
 mlir/test/Dialect/Arith/canonicalize.mlir     | 32 ++++++++++++++++++-
 2 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 13eb97a910bd4..2f7beed549108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
     Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
         (SelectOp $pred, $a, $c)>;
 
-// select(pred, false, true) => not(pred) 
+// select(pred, false, true) => not(pred)
 def SelectI1ToNot :
     Pat<(SelectOp $pred,
                   (ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount :
       CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
               "*getIntOrSplatIntValue($2)">]>>;
 
+def ValueWidthMatchesShiftAmount :
+    Constraint<And<[
+      CPred<"succeeded(getIntOrSplatIntValue($1))">,
+      CPred<"getScalarOrElementWidth($0) == "
+              "*getIntOrSplatIntValue($1)">]>>;
+
 // trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
 def TruncIExtSIToExtSI :
     Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
@@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended :
         (Arith_MulSIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
        (ValueWiderThan $mul, $x),
-       (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+       (TruncationMatchesShiftAmount $mul, $x, $c0),
+       (ValueWidthMatchesShiftAmount $x, $c0)]>;
 
 // trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
 def TruncIShrUIMulIToMulUIExtended :
@@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended :
         (Arith_MulUIExtendedOp:$res__1 $x, $y),
       [(ValuesWithSameType $tr, $x, $y),
        (ValueWiderThan $mul, $x),
-       (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+       (TruncationMatchesShiftAmount $mul, $x, $c0),
+       (ValueWidthMatchesShiftAmount $x, $c0)]>;
 
 //===----------------------------------------------------------------------===//
 // TruncIOp
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b6188c81ff912..542603722ab8a 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
 
 
 // CHECK-LABEL: @foldSubXX_tensor
-//       CHECK:   %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> 
+//       CHECK:   %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
 //       CHECK:   %[[sub:.+]] = arith.subi
 //       CHECK:   return %[[c0]], %[[sub]]
 func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
   return %hi : i32
 }
 
+// Verify that the signed extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift
+//   CHECK-NOT: arith.mulsi_extended
+func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+  %x = arith.extsi %a: i32 to i33
+  %y = arith.extsi %b: i32 to i33
+  %m = arith.muli %x, %y: i33
+  %c1 = arith.constant 1: i33
+  %sh = arith.shrui %m, %c1 : i33
+  %hi = arith.trunci %sh: i33 to i32
+  return %hi : i32
+}
+
 // CHECK-LABEL: @wideMulToMulSIExtendedVector
 //  CHECK-SAME:   (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
 //  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
@@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
   return %hi : i32
 }
 
+// Verify that the unsigned extended multiplication pattern does not match
+// if the right shift does not match the bitwidth of the multipliers.
+
+// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift
+//   CHECK-NOT: arith.mului_extended
+func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
+  %x = arith.extui %a: i32 to i33
+  %y = arith.extui %b: i32 to i33
+  %m = arith.muli %x, %y: i33
+  %c1 = arith.constant 1: i33
+  %sh = arith.shrui %m, %c1 : i33
+  %hi = arith.trunci %sh: i33 to i32
+  return %hi : i32
+}
+
 // CHECK-LABEL: @wideMulToMulUIExtendedVector
 //  CHECK-SAME:   (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
 //  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>



More information about the Mlir-commits mailing list