[Mlir-commits] [mlir] [mlir][arith] Fix canon pattern for large ints in chained arith (PR #68900)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 12 08:13:41 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Rik Huijzer (rikhuijzer)
<details>
<summary>Changes</summary>
The logic for chained basic arithmetic operations in the `arith` dialect was using `getInt()` on `IntegerAttr`. This is a problem for very large integers. Specifically, in https://github.com/llvm/llvm-project/issues/64774 the following assertion failed:
```
Assertion failed: (getSignificantBits() <= 64 && "Too many bits for int64_t"), function getSExtValue, file APInt.h, line 1510.
```
According to a comment on `getInt()`, calls to `getInt()` should be replaced by `getValue()`:
https://github.com/llvm/llvm-project/blob/ab6a66dbec61654d0962f6abf6d6c5b776937584/mlir/include/mlir/IR/BuiltinAttributes.td#L707-L708
This patch fixes https://github.com/llvm/llvm-project/issues/64774 by doing such a replacement.
---
Full diff: https://github.com/llvm/llvm-project/pull/68900.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+17-8)
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+10)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0ecc288f3b07701..25578b1c52f331b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,26 +39,35 @@ using namespace mlir::arith;
static IntegerAttr
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
Attribute rhs,
- function_ref<int64_t(int64_t, int64_t)> binFn) {
- return builder.getIntegerAttr(res.getType(),
- binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
- llvm::cast<IntegerAttr>(rhs).getInt()));
+ function_ref<APInt(APInt, APInt&)> binFn) {
+ auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+ auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+ auto value = binFn(lhsVal, rhsVal);
+ return IntegerAttr::get(res.getType(), value);
}
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) + b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) - b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
Attribute lhs, Attribute rhs) {
- return applyToIntegerAttrs(builder, res, lhs, rhs,
- std::multiplies<int64_t>());
+ auto binFn = [](APInt a, APInt& b) -> APInt {
+ return std::move(a) * b;
+ };
+ return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
}
/// Invert an integer comparison predicate.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b0547c9e8f804a..b18f5cfcb3f9a12 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -985,6 +985,16 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
return %mul2 : i32
}
+// CHECK-LABEL: @tripleMulLargeInt
+// CHECK: return
+func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
+ %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
+ %c5 = arith.constant 5 : i256
+ %mul1 = arith.muli %arg0, %0 : i256
+ %mul2 = arith.muli %mul1, %c5 : i256
+ return %mul2 : i256
+}
+
// CHECK-LABEL: @addiMuliToSubiRhsI32
// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32
``````````
</details>
https://github.com/llvm/llvm-project/pull/68900
More information about the Mlir-commits
mailing list