[Mlir-commits] [mlir] [mlir][arith] fix canonicalization of mulsi_extended (PR #90150)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 25 16:44:58 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Semyon Khechnev (s-khechnev)

<details>
<summary>Changes</summary>

There is the `MulSIExtendedRHSOne` canonicalization for arith.mulsi_extended that is defined as follows: `mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]`. In the implementation of this, there is a `IsScalarOrSplatOne` constraint for the second argument. However, this constraint does not correctly handle situation when multiplying i1 values. The 1:i1 in case of signed multiplication is actually -1. Therefore, an additional constraint has been added which checks the second argument for strict positivity.

fix #<!-- -->88732

---
Full diff: https://github.com/llvm/llvm-project/pull/90150.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+1) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+22) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index caca2ff81964f7..02d05780a7ac1d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -175,6 +175,7 @@ def MulSIExtendedToMulI :
 def IsScalarOrSplatOne :
     Constraint<And<[
       CPred<"succeeded(getIntOrSplatIntValue($0))">,
+      CPred<"getIntOrSplatIntValue($0)->isStrictlyPositive()">,
       CPred<"*getIntOrSplatIntValue($0) == 1">]>>;
 
 // mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 79a318565e98f9..6c4193bc06ca2d 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1223,6 +1223,28 @@ func.func @mulsiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vec
   return %low, %high : vector<3xi32>, vector<3xi32>
 }
 
+// CHECK-LABEL: @mulsiExtendedOneRhsI1
+//  CHECK-SAME:   (%[[ARG:.+]]: i1) -> (i1, i1)
+//  CHECK-NEXT:   %[[T:.+]]  = arith.constant true
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[ARG]], %[[T]] : i1
+//  CHECK-NEXT:   return %[[LOW]], %[[HIGH]] : i1, i1
+func.func @mulsiExtendedOneRhsI1(%arg0: i1) -> (i1, i1) {
+  %one = arith.constant true
+  %low, %high = arith.mulsi_extended %arg0, %one: i1
+  return %low, %high : i1, i1
+}
+
+// CHECK-LABEL: @mulsiExtendedOneRhsSplatI1
+//  CHECK-SAME:   (%[[ARG:.+]]: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>)
+//  CHECK-NEXT:   %[[TS:.+]]  = arith.constant dense<true> : vector<3xi1>
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[ARG]], %[[TS]] : vector<3xi1>
+//  CHECK-NEXT:   return %[[LOW]], %[[HIGH]] : vector<3xi1>, vector<3xi1>
+func.func @mulsiExtendedOneRhsSplatI1(%arg0: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+  %one = arith.constant dense<true> : vector<3xi1>
+  %low, %high = arith.mulsi_extended %arg0, %one: vector<3xi1>
+  return %low, %high : vector<3xi1>, vector<3xi1>
+}
+
 // CHECK-LABEL: @mulsiExtendedUnusedHigh
 //  CHECK-SAME:   (%[[ARG:.+]]: i32) -> i32
 //  CHECK-NEXT:   %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/90150


More information about the Mlir-commits mailing list