[Mlir-commits] [mlir] eaa4bc6 - [mlir][arith] Add canon pattern for chained `arith.muli`
Jakub Kuderski
llvmlistbot at llvm.org
Fri Jul 21 15:20:40 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-21T18:20:31-04:00
New Revision: eaa4bc655709520f752e81890e5154775e66d539
URL: https://github.com/llvm/llvm-project/commit/eaa4bc655709520f752e81890e5154775e66d539
DIFF: https://github.com/llvm/llvm-project/commit/eaa4bc655709520f752e81890e5154775e66d539.diff
LOG: [mlir][arith] Add canon pattern for chained `arith.muli`
@benvanik reported this as missing.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D155907
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4da00415abaa86..2ffd49c5034e69 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -285,6 +285,7 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index ba1f3f8bd1d86b..f3d84d0b261e8d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -18,9 +18,12 @@ def GetZeroAttr : NativeCodeCall<"$_builder.getZeroAttr($0.getType())">;
// Add two integer attributes and create a new one with the result.
def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
-// Subtract two integer attributes and createa a new one with the result.
+// Subtract two integer attributes and create a new one with the result.
def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
+// Multiply two integer attributes and create a new one with the result.
+def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
//===----------------------------------------------------------------------===//
@@ -72,6 +75,13 @@ def AddIMulNegativeOneLhs :
(Arith_SubIOp $y, $x),
[(IsScalarOrSplatNegativeOne $c0)]>;
+// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
+def MulIMulIConstant :
+ Pat<(Arith_MulIOp:$res
+ (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)))>;
+
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 219804b005027b..1c41818c318d24 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -8,6 +8,7 @@
#include <cassert>
#include <cstdint>
+#include <functional>
#include <utility>
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -34,18 +35,29 @@ using namespace mlir::arith;
// Pattern helpers
//===----------------------------------------------------------------------===//
+static IntegerAttr
+applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
+ Attribute rhs,
+ function_ref<int64_t(int64_t, int64_t)> binFn) {
+ return builder.getIntegerAttr(res.getType(),
+ binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
+ llvm::cast<IntegerAttr>(rhs).getInt()));
+}
+
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return builder.getIntegerAttr(res.getType(),
- llvm::cast<IntegerAttr>(lhs).getInt() +
- llvm::cast<IntegerAttr>(rhs).getInt());
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return builder.getIntegerAttr(res.getType(),
- llvm::cast<IntegerAttr>(lhs).getInt() -
- llvm::cast<IntegerAttr>(rhs).getInt());
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+}
+
+static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs,
+ std::multiplies<int64_t>());
}
/// Invert an integer comparison predicate.
@@ -382,6 +394,11 @@ OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
[](const APInt &a, const APInt &b) { return a * b; });
}
+void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context) {
+ patterns.add<MulIMulIConstant>(context);
+}
+
//===----------------------------------------------------------------------===//
// MulSIExtendedOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index dd32e5b664f57c..5b392fe9cf58a0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -885,6 +885,30 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
return %add : index
}
+// CHECK-LABEL: @tripleMulIMulIIndex
+// CHECK: %[[cres:.+]] = arith.constant 15 : index
+// CHECK: %[[muli:.+]] = arith.muli %arg0, %[[cres]] : index
+// CHECK: return %[[muli]]
+func.func @tripleMulIMulIIndex(%arg0: index) -> index {
+ %c3 = arith.constant 3 : index
+ %c5 = arith.constant 5 : index
+ %mul1 = arith.muli %arg0, %c3 : index
+ %mul2 = arith.muli %mul1, %c5 : index
+ return %mul2 : index
+}
+
+// CHECK-LABEL: @tripleMulIMulII32
+// CHECK: %[[cres:.+]] = arith.constant -21 : i32
+// CHECK: %[[muli:.+]] = arith.muli %arg0, %[[cres]] : i32
+// CHECK: return %[[muli]]
+func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
+ %c_n3 = arith.constant -3 : i32
+ %c7 = arith.constant 7 : i32
+ %mul1 = arith.muli %arg0, %c_n3 : i32
+ %mul2 = arith.muli %mul1, %c7 : i32
+ return %mul2 : i32
+}
+
// CHECK-LABEL: @addiMuliToSubiRhsI32
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
More information about the Mlir-commits
mailing list