[clang] [mlir][arith] Fix canon pattern for large ints in chained arith (PR #68900)
Rik Huijzer via cfe-commits
cfe-commits at lists.llvm.org
Fri Oct 13 02:48:40 PDT 2023
https://github.com/rikhuijzer updated https://github.com/llvm/llvm-project/pull/68900
>From ddbde18e483d12485ba25c715e8a94480b9d6dcf Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Thu, 12 Oct 2023 16:55:22 +0200
Subject: [PATCH 1/4] [mlir][arith] Fix canon pattern for large ints in chained
arith
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 25 +++++++++++++++--------
mlir/test/Dialect/Arith/canonicalize.mlir | 10 +++++++++
2 files changed, 27 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0ecc288f3b07701..25578b1c52f331b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,26 +39,35 @@ using namespace mlir::arith;
static IntegerAttr
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
Attribute rhs,
- function_ref<int64_t(int64_t, int64_t)> binFn) {
- return builder.getIntegerAttr(res.getType(),
- binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
- llvm::cast<IntegerAttr>(rhs).getInt()));
+ function_ref<APInt(APInt, APInt&)> binFn) {
+ auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+ auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+ auto value = binFn(lhsVal, rhsVal);
+ return IntegerAttr::get(res.getType(), value);
}
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) + b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) - b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs,
- std::multiplies<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) * b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
/// Invert an integer comparison predicate.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b0547c9e8f804a..b18f5cfcb3f9a12 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -985,6 +985,16 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
return %mul2 : i32
}
+// CHECK-LABEL: @tripleMulLargeInt
+// CHECK: return
+func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
+ %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
+ %c5 = arith.constant 5 : i256
+ %mul1 = arith.muli %arg0, %0 : i256
+ %mul2 = arith.muli %mul1, %c5 : i256
+ return %mul2 : i256
+}
+
// CHECK-LABEL: @addiMuliToSubiRhsI32
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
>From c0f3efe78fa6e71d1acc4d38f526ca2ec194ddf8 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Fri, 13 Oct 2023 10:14:16 +0200
Subject: [PATCH 2/4] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 25578b1c52f331b..b749a4444f256e7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,7 +39,7 @@ using namespace mlir::arith;
static IntegerAttr
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
Attribute rhs,
- function_ref<APInt(APInt, APInt&)> binFn) {
+ function_ref<APInt(const APInt&, const APInt&)> binFn) {
auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
auto value = binFn(lhsVal, rhsVal);
@@ -49,7 +49,7 @@ applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
auto binFn = [](APInt a, APInt& b) -> APInt {
- return std::move(a) + b;
+ return a + b;
};
return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
>From 30e1ce11d567452dcd7481e999109d1f25164065 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Fri, 13 Oct 2023 10:49:20 +0200
Subject: [PATCH 3/4] Use `const`s and check result of fold
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 20 +++++++-------------
mlir/test/Dialect/Arith/canonicalize.mlir | 12 +++++++-----
2 files changed, 14 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b749a4444f256e7..5fe7a256cce07d1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,34 +39,28 @@ using namespace mlir::arith;
static IntegerAttr
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
Attribute rhs,
- function_ref<APInt(const APInt&, const APInt&)> binFn) {
- auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
- auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
- auto value = binFn(lhsVal, rhsVal);
+ function_ref<APInt(const APInt &, const APInt &)> binFn) {
+ APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+ APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+ APInt value = binFn(lhsVal, rhsVal);
return IntegerAttr::get(res.getType(), value);
}
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- auto binFn = [](APInt a, APInt& b) -> APInt {
- return a + b;
- };
+ auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a + b; };
return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- auto binFn = [](APInt a, APInt& b) -> APInt {
- return std::move(a) - b;
- };
+ auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a - b; };
return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- auto binFn = [](APInt a, APInt& b) -> APInt {
- return std::move(a) * b;
- };
+ auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a * b; };
return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index b18f5cfcb3f9a12..98788536980f939 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -986,13 +986,15 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
}
// CHECK-LABEL: @tripleMulLargeInt
-// CHECK: return
+// CHECK: %[[cres:.+]] = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020482 : i256
+// CHECK: %[[addi:.+]] = arith.addi %arg0, %[[cres]] : i256
+// CHECK: return %[[addi]]
func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
%0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
- %c5 = arith.constant 5 : i256
- %mul1 = arith.muli %arg0, %0 : i256
- %mul2 = arith.muli %mul1, %c5 : i256
- return %mul2 : i256
+ %1 = arith.constant 1 : i256
+ %2 = arith.addi %arg0, %0 : i256
+ %3 = arith.addi %2, %1 : i256
+ return %3 : i256
}
// CHECK-LABEL: @addiMuliToSubiRhsI32
>From 1ef723f17639ba473830a2e84f53eed76b2eb4e3 Mon Sep 17 00:00:00 2001
From: Rik Huijzer <github at huijzer.xyz>
Date: Fri, 13 Oct 2023 11:48:21 +0200
Subject: [PATCH 4/4] Use arith functions from `std` instead of lambda
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 933fbd6932b0e5a..3892e8fa0a32f2d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -48,20 +48,17 @@ applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a + b; };
- return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a - b; };
- return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
}
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- auto binFn = [](const APInt &a, const APInt &b) -> APInt { return a * b; };
- return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
/// Invert an integer comparison predicate.
More information about the cfe-commits
mailing list