[Mlir-commits] [mlir] dbf42f0 - Reland "[mlir][arith] Add canonicalization patterns for 'mul*i_extended'"

Jakub Kuderski llvmlistbot at llvm.org
Tue Dec 13 11:35:26 PST 2022


Author: Jakub Kuderski
Date: 2022-12-13T14:33:31-05:00
New Revision: dbf42f0b1269eacb928efc079fd6ee54975c91da

URL: https://github.com/llvm/llvm-project/commit/dbf42f0b1269eacb928efc079fd6ee54975c91da
DIFF: https://github.com/llvm/llvm-project/commit/dbf42f0b1269eacb928efc079fd6ee54975c91da.diff

LOG: Reland "[mlir][arith] Add canonicalization patterns for 'mul*i_extended'"

- Add a fold for `mulsi_extended(x, 1)`
- Add folds to demote wide integer multiplication to `mul*i_extended` when the result is shifted
   and truncated: `trunci(shrui(mul(*ext(x), *ext(y)), c)) -> mul*i_extended(x, y)`

Reviewed By: Mogball, jpienaar

Differential Revision: https://reviews.llvm.org/D139778

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 594ba46d62acb..ce61890f7784d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1071,6 +1071,7 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index a2f45c25146f8..6f88a77be6798 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -122,6 +122,22 @@ def MulSIExtendedToMulI :
         [(Arith_MulIOp $x, $y), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
+
+def IsScalarOrSplatOne :
+    Constraint<And<[
+      CPred<"succeeded(getIntOrSplatIntValue($0))">,
+      CPred<"getIntOrSplatIntValue($0).value() == 1">]>>;
+
+// mulsi_extended(x, 1) -> [x, extsi(cmpi slt, x, 0)]
+def MulSIExtendedRHSOne :
+    Pattern<(Arith_MulSIExtendedOp $x, (ConstantLikeMatcher AnyAttr:$c1)),
+            [(replaceWithValue $x),
+             (Arith_ExtSIOp(Arith_CmpIOp
+                              (NativeCodeCall<"arith::CmpIPredicate::slt">),
+                              $x,
+                              (Arith_ConstantOp (GetZeroAttr $x))))],
+            [(IsScalarOrSplatOne $c1)]>;
+
 //===----------------------------------------------------------------------===//
 // MulUIExtendedOp
 //===----------------------------------------------------------------------===//
@@ -251,6 +267,54 @@ def OrOfExtSI :
         (Arith_ExtSIOp (Arith_OrIOp $x, $y)),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
+//===----------------------------------------------------------------------===//
+// TruncIOp
+//===----------------------------------------------------------------------===//
+
+def ValuesWithSameType :
+    Constraint<
+      CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>;
+
+def ValueWiderThan :
+    Constraint<And<[
+      CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">,
+      CPred<"getScalarOrElementWidth($1) > 0">]>>;
+
+def TruncationMatchesShiftAmount :
+    Constraint<And<[
+      CPred<"succeeded(getIntOrSplatIntValue($2))">,
+      CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
+              "getIntOrSplatIntValue($2).value()">]>>;
+
+// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
+def TruncIShrSIToTrunciShrUI :
+    Pat<(Arith_TruncIOp:$tr
+          (Arith_ShRSIOp $x, (ConstantLikeMatcher AnyAttr:$c0))),
+        (Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))),
+        [(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
+
+// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
+def TruncIShrUIMulIToMulSIExtended :
+    Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
+                              (Arith_MulIOp:$mul
+                                (Arith_ExtSIOp $x), (Arith_ExtSIOp $y)),
+                              (ConstantLikeMatcher AnyAttr:$c0))),
+        (Arith_MulSIExtendedOp:$res__1 $x, $y),
+      [(ValuesWithSameType $tr, $x, $y),
+       (ValueWiderThan $mul, $x),
+       (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+
+// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
+def TruncIShrUIMulIToMulUIExtended :
+    Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
+                              (Arith_MulIOp:$mul
+                                (Arith_ExtUIOp $x), (Arith_ExtUIOp $y)),
+                              (ConstantLikeMatcher AnyAttr:$c0))),
+        (Arith_MulUIExtendedOp:$res__1 $x, $y),
+      [(ValuesWithSameType $tr, $x, $y),
+       (ValueWiderThan $mul, $x),
+       (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 25a3dd425cf22..590ea4d0716e8 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include <cassert>
+#include <cstdint>
 #include <utility>
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -74,6 +75,29 @@ static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
                                        invertPredicate(pred.getValue()));
 }
 
+static int64_t getScalarOrElementWidth(Type type) {
+  Type elemTy = getElementTypeOrSelf(type);
+  if (elemTy.isIntOrFloat())
+    return elemTy.getIntOrFloatBitWidth();
+
+  return -1;
+}
+
+static int64_t getScalarOrElementWidth(Value value) {
+  return getScalarOrElementWidth(value.getType());
+}
+
+static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
+  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
+    return intAttr.getValue();
+
+  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>())
+    if (splatAttr.getElementType().isa<IntegerType>())
+      return splatAttr.getSplatValue<APInt>();
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd canonicalization patterns
 //===----------------------------------------------------------------------===//
@@ -393,7 +417,7 @@ arith::MulSIExtendedOp::fold(ArrayRef<Attribute> operands,
 
 void arith::MulSIExtendedOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.add<MulSIExtendedToMulI>(context);
+  patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1249,6 +1273,12 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
 }
 
+void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<TruncIShrSIToTrunciShrUI, TruncIShrUIMulIToMulSIExtended,
+               TruncIShrUIMulIToMulUIExtended>(context);
+}
+
 LogicalResult arith::TruncIOp::verify() {
   return verifyTruncateOp<IntegerType>(*this);
 }

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 644af88e298c7..a4c800cc826ac 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -761,6 +761,30 @@ func.func @mulsiExtendedZeroLhs(%arg0: i32) -> (i32, i32) {
   return %low, %high : i32, i32
 }
 
+// CHECK-LABEL: @mulsiExtendedOneRhs
+//  CHECK-SAME:   (%[[ARG:.+]]: i32) -> (i32, i32)
+//  CHECK-NEXT:   %[[C0:.+]]  = arith.constant 0 : i32
+//  CHECK-NEXT:   %[[CMP:.+]] = arith.cmpi slt, %[[ARG]], %[[C0]] : i32
+//  CHECK-NEXT:   %[[EXT:.+]] = arith.extsi %[[CMP]] : i1 to i32
+//  CHECK-NEXT:   return %[[ARG]], %[[EXT]] : i32, i32
+func.func @mulsiExtendedOneRhs(%arg0: i32) -> (i32, i32) {
+  %one = arith.constant 1 : i32
+  %low, %high = arith.mulsi_extended %arg0, %one: i32
+  return %low, %high : i32, i32
+}
+
+// CHECK-LABEL: @mulsiExtendedOneRhsSplat
+//  CHECK-SAME:   (%[[ARG:.+]]: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>)
+//  CHECK-NEXT:   %[[C0:.+]]  = arith.constant dense<0> : vector<3xi32>
+//  CHECK-NEXT:   %[[CMP:.+]] = arith.cmpi slt, %[[ARG]], %[[C0]] : vector<3xi32>
+//  CHECK-NEXT:   %[[EXT:.+]] = arith.extsi %[[CMP]] : vector<3xi1> to vector<3xi32>
+//  CHECK-NEXT:   return %[[ARG]], %[[EXT]] : vector<3xi32>, vector<3xi32>
+func.func @mulsiExtendedOneRhsSplat(%arg0: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
+  %one = arith.constant dense<1> : vector<3xi32>
+  %low, %high = arith.mulsi_extended %arg0, %one: vector<3xi32>
+  return %low, %high : vector<3xi32>, vector<3xi32>
+}
+
 // CHECK-LABEL: @mulsiExtendedUnusedHigh
 //  CHECK-SAME:   (%[[ARG:.+]]: i32) -> i32
 //  CHECK-NEXT:   %[[RES:.+]] = arith.muli %[[ARG]], %[[ARG]] : i32
@@ -1916,3 +1940,148 @@ func.func @andand3(%a : i32, %b : i32) -> i32 {
   %res = arith.andi %c, %b : i32
   return %res : i32
 }
+
+// -----
+
+// CHECK-LABEL: @truncIShrSIToTrunciShrUI
+//  CHECK-SAME:   (%[[A:.+]]: i64)
+//  CHECK-NEXT:   %[[C32:.+]] = arith.constant 32 : i64
+//  CHECK-NEXT:   %[[SHR:.+]] = arith.shrui %[[A]], %[[C32]] : i64
+//  CHECK-NEXT:   %[[TRU:.+]] = arith.trunci %[[SHR]] : i64 to i32
+//  CHECK-NEXT:   return %[[TRU]] : i32
+func.func @truncIShrSIToTrunciShrUI(%a: i64) -> i32 {
+  %c32 = arith.constant 32: i64
+  %sh = arith.shrsi %a, %c32 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt1
+//       CHECK:   arith.shrsi
+func.func @truncIShrSIToTrunciShrUIBadShiftAmt1(%a: i64) -> i32 {
+  %c33 = arith.constant 33: i64
+  %sh = arith.shrsi %a, %c33 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @truncIShrSIToTrunciShrUIBadShiftAmt2
+//  CHECK:        arith.shrsi
+func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 {
+  %c31 = arith.constant 31: i64
+  %sh = arith.shrsi %a, %c31 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @wideMulToMulSIExtended
+//  CHECK-SAME:   (%[[A:.+]]: i32, %[[B:.+]]: i32)
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32
+//  CHECK-NEXT:   return %[[HIGH]] : i32
+func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
+  %x = arith.extsi %a: i32 to i64
+  %y = arith.extsi %b: i32 to i64
+  %m = arith.muli %x, %y: i64
+  %c32 = arith.constant 32: i64
+  %sh = arith.shrui %m, %c32 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @wideMulToMulSIExtendedVector
+//  CHECK-SAME:   (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
+//  CHECK-NEXT:   return %[[HIGH]] : vector<3xi32>
+func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
+  %x = arith.extsi %a: vector<3xi32> to vector<3xi64>
+  %y = arith.extsi %b: vector<3xi32> to vector<3xi64>
+  %m = arith.muli %x, %y: vector<3xi64>
+  %c32 = arith.constant dense<32>: vector<3xi64>
+  %sh = arith.shrui %m, %c32 : vector<3xi64>
+  %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
+  return %hi : vector<3xi32>
+}
+
+// CHECK-LABEL: @wideMulToMulUIExtended
+//  CHECK-SAME:   (%[[A:.+]]: i32, %[[B:.+]]: i32)
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32
+//  CHECK-NEXT:   return %[[HIGH]] : i32
+func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
+  %x = arith.extui %a: i32 to i64
+  %y = arith.extui %b: i32 to i64
+  %m = arith.muli %x, %y: i64
+  %c32 = arith.constant 32: i64
+  %sh = arith.shrui %m, %c32 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @wideMulToMulUIExtendedVector
+//  CHECK-SAME:   (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
+//  CHECK-NEXT:   %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
+//  CHECK-NEXT:   return %[[HIGH]] : vector<3xi32>
+func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
+  %x = arith.extui %a: vector<3xi32> to vector<3xi64>
+  %y = arith.extui %b: vector<3xi32> to vector<3xi64>
+  %m = arith.muli %x, %y: vector<3xi64>
+  %c32 = arith.constant dense<32>: vector<3xi64>
+  %sh = arith.shrui %m, %c32 : vector<3xi64>
+  %hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
+  return %hi : vector<3xi32>
+}
+
+// CHECK-LABEL: @wideMulToMulIExtendedMixedExt
+//       CHECK:   arith.muli
+//       CHECK:   arith.shrui
+//       CHECK:   arith.trunci
+func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 {
+  %x = arith.extsi %a: i32 to i64
+  %y = arith.extui %b: i32 to i64
+  %m = arith.muli %x, %y: i64
+  %c32 = arith.constant 32: i64
+  %sh = arith.shrui %m, %c32 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @wideMulToMulSIExtendedBadExt
+//       CHECK:   arith.muli
+//       CHECK:   arith.shrui
+//       CHECK:   arith.trunci
+func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 {
+  %x = arith.extsi %a: i16 to i64
+  %y = arith.extsi %b: i16 to i64
+  %m = arith.muli %x, %y: i64
+  %c32 = arith.constant 32: i64
+  %sh = arith.shrui %m, %c32 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1
+//       CHECK:   arith.muli
+//       CHECK:   arith.shrui
+//       CHECK:   arith.trunci
+func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 {
+  %x = arith.extsi %a: i32 to i64
+  %y = arith.extsi %b: i32 to i64
+  %m = arith.muli %x, %y: i64
+  %c33 = arith.constant 33: i64
+  %sh = arith.shrui %m, %c33 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}
+
+// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2
+//       CHECK:   arith.muli
+//       CHECK:   arith.shrui
+//       CHECK:   arith.trunci
+func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
+  %x = arith.extsi %a: i32 to i64
+  %y = arith.extsi %b: i32 to i64
+  %m = arith.muli %x, %y: i64
+  %c31 = arith.constant 31: i64
+  %sh = arith.shrui %m, %c31 : i64
+  %hi = arith.trunci %sh: i64 to i32
+  return %hi : i32
+}


        


More information about the Mlir-commits mailing list