[Mlir-commits] [mlir] eaa4bc6 - [mlir][arith] Add canon pattern for chained `arith.muli`

Jakub Kuderski llvmlistbot at llvm.org
Fri Jul 21 15:20:40 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-21T18:20:31-04:00
New Revision: eaa4bc655709520f752e81890e5154775e66d539

URL: https://github.com/llvm/llvm-project/commit/eaa4bc655709520f752e81890e5154775e66d539
DIFF: https://github.com/llvm/llvm-project/commit/eaa4bc655709520f752e81890e5154775e66d539.diff

LOG: [mlir][arith] Add canon pattern for chained `arith.muli`

@benvanik reported this as missing.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D155907

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4da00415abaa86..2ffd49c5034e69 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -285,6 +285,7 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
 def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
   let summary = "integer multiplication operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ba1f3f8bd1d86b..f3d84d0b261e8d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -18,9 +18,12 @@ def GetZeroAttr : NativeCodeCall<"$_builder.getZeroAttr($0.getType())">;
 // Add two integer attributes and create a new one with the result.
 def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
 
-// Subtract two integer attributes and createa a new one with the result.
+// Subtract two integer attributes and create a new one with the result.
 def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 
+// Multiply two integer attributes and create a new one with the result.
+def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
 //===----------------------------------------------------------------------===//
@@ -72,6 +75,13 @@ def AddIMulNegativeOneLhs :
         (Arith_SubIOp $y, $x),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
+// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
+def MulIMulIConstant :
+    Pat<(Arith_MulIOp:$res
+          (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+          (ConstantLikeMatcher APIntAttr:$c1)),
+        (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 219804b005027b..1c41818c318d24 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -8,6 +8,7 @@
 
 #include <cassert>
 #include <cstdint>
+#include <functional>
 #include <utility>
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -34,18 +35,29 @@ using namespace mlir::arith;
 // Pattern helpers
 //===----------------------------------------------------------------------===//
 
+static IntegerAttr
+applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
+                    Attribute rhs,
+                    function_ref<int64_t(int64_t, int64_t)> binFn) {
+  return builder.getIntegerAttr(res.getType(),
+                                binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
+                                      llvm::cast<IntegerAttr>(rhs).getInt()));
+}
+
 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return builder.getIntegerAttr(res.getType(),
-                                llvm::cast<IntegerAttr>(lhs).getInt() +
-                                    llvm::cast<IntegerAttr>(rhs).getInt());
+  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
 }
 
 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return builder.getIntegerAttr(res.getType(),
-                                llvm::cast<IntegerAttr>(lhs).getInt() -
-                                    llvm::cast<IntegerAttr>(rhs).getInt());
+  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+}
+
+static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
+                                   Attribute lhs, Attribute rhs) {
+  return applyToIntegerAttrs(builder, res, lhs, rhs,
+                             std::multiplies<int64_t>());
 }
 
 /// Invert an integer comparison predicate.
@@ -382,6 +394,11 @@ OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
       [](const APInt &a, const APInt &b) { return a * b; });
 }
 
+void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context) {
+  patterns.add<MulIMulIConstant>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index dd32e5b664f57c..5b392fe9cf58a0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -885,6 +885,30 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
   return %add : index
 }
 
+// CHECK-LABEL: @tripleMulIMulIIndex
+//       CHECK:   %[[cres:.+]] = arith.constant 15 : index
+//       CHECK:   %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index
+//       CHECK:   return %[[muli]]
+func.func @tripleMulIMulIIndex(%arg0: index) -> index {
+  %c3 = arith.constant 3 : index
+  %c5 = arith.constant 5 : index
+  %mul1 = arith.muli %arg0, %c3 : index
+  %mul2 = arith.muli %mul1, %c5 : index
+  return %mul2 : index
+}
+
+// CHECK-LABEL: @tripleMulIMulII32
+//       CHECK:   %[[cres:.+]] = arith.constant -21 : i32
+//       CHECK:   %[[muli:.+]] = arith.muli %arg0, %[[cres]] : i32
+//       CHECK:   return %[[muli]]
+func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
+  %c_n3 = arith.constant -3 : i32
+  %c7 = arith.constant 7 : i32
+  %mul1 = arith.muli %arg0, %c_n3 : i32
+  %mul2 = arith.muli %mul1, %c7 : i32
+  return %mul2 : i32
+}
+
 // CHECK-LABEL: @addiMuliToSubiRhsI32
 //  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32


        


More information about the Mlir-commits mailing list