[Mlir-commits] [mlir] [mlir][arith] fix canonicalization of mulsi_extended for i1 (PR #90150)

Jakub Kuderski llvmlistbot at llvm.org
Thu Apr 25 20:35:00 PDT 2024


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/90150

>From 76da2118cac676e6b211eeb49a3ed221e33e80f6 Mon Sep 17 00:00:00 2001
From: Semyon Khechnev <hechnev at gmail.com>
Date: Fri, 26 Apr 2024 01:50:19 +0300
Subject: [PATCH] [mlir][arith] fix canonicalization of mulsi_extended

There is the `MulSIExtendedRHSOne` canonicalization for arith.mulsi_extended that is defined as follows: `mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]`. In the implementation of this, there is a `IsScalarOrSplatOne` constraint for the second argument. However, this constraint does not correctly handle situation when multiplying i1 values. Therefore, an additional constraint has been added which checks the second argument for strict positivity.

fix #88732
---
 .../Dialect/Arith/IR/ArithCanonicalization.td |  1 +
 mlir/test/Dialect/Arith/canonicalize.mlir     | 22 +++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index caca2ff81964f7..02d05780a7ac1d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -175,6 +175,7 @@ def MulSIExtendedToMulI :
 def IsScalarOrSplatOne :
     Constraint<And<[
       CPred<"succeeded(getIntOrSplatIntValue($0))">,
+      CPred<"getIntOrSplatIntValue($0)->isStrictlyPositive()">,
       CPred<"*getIntOrSplatIntValue($0) == 1">]>>;
 
 // mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 79a318565e98f9..6c4193bc06ca2d 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1223,6 +1223,28 @@ func.func @mulsiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vec
   return %low, %high : vector<3xi32>, vector<3xi32>
 }
 
+// CHECK-LABEL: @mulsiExtendedOneRhsI1
+//  CHECK-SAME:   (%[[ARG:.+]]: i1) -> (i1, i1)
+//  CHECK-NEXT:   %[[T:.+]]  = arith.constant true
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[ARG]], %[[T]] : i1
+//  CHECK-NEXT:   return %[[LOW]], %[[HIGH]] : i1, i1
+func.func @mulsiExtendedOneRhsI1(%arg0: i1) -> (i1, i1) {
+  %one = arith.constant true
+  %low, %high = arith.mulsi_extended %arg0, %one: i1
+  return %low, %high : i1, i1
+}
+
+// CHECK-LABEL: @mulsiExtendedOneRhsSplatI1
+//  CHECK-SAME:   (%[[ARG:.+]]: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>)
+//  CHECK-NEXT:   %[[TS:.+]]  = arith.constant dense<true> : vector<3xi1>
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[ARG]], %[[TS]] : vector<3xi1>
+//  CHECK-NEXT:   return %[[LOW]], %[[HIGH]] : vector<3xi1>, vector<3xi1>
+func.func @mulsiExtendedOneRhsSplatI1(%arg0: vector<3xi1>) -> (vector<3xi1>, vector<3xi1>) {
+  %one = arith.constant dense<true> : vector<3xi1>
+  %low, %high = arith.mulsi_extended %arg0, %one: vector<3xi1>
+  return %low, %high : vector<3xi1>, vector<3xi1>
+}
+
 // CHECK-LABEL: @mulsiExtendedUnusedHigh
 //  CHECK-SAME:   (%[[ARG:.+]]: i32) -> i32
 //  CHECK-NEXT:   %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32



More information about the Mlir-commits mailing list