[Mlir-commits] [mlir] [mlir][arith] Fix canon pattern for large ints in chained arith (PR #68900)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 12 08:13:41 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Rik Huijzer (rikhuijzer)

<details>
<summary>Changes</summary>

The logic for chained basic arithmetic operations in the `arith` dialect was using `getInt()` on `IntegerAttr`. This is a problem for very large integers. Specifically, in https://github.com/llvm/llvm-project/issues/64774 the following assertion failed:

```
Assertion failed: (getSignificantBits() <= 64 && "Too many bits for int64_t"), function getSExtValue, file APInt.h, line 1510.
```

According to a comment on `getInt()`, calls to `getInt()` should be replaced by `getValue()`:

https://github.com/llvm/llvm-project/blob/ab6a66dbec61654d0962f6abf6d6c5b776937584/mlir/include/mlir/IR/BuiltinAttributes.td#L707-L708

This patch fixes https://github.com/llvm/llvm-project/issues/64774 by doing such a replacement.


---
Full diff: https://github.com/llvm/llvm-project/pull/68900.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+17-8) 
- (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+10) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0ecc288f3b07701..25578b1c52f331b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,26 +39,35 @@ using namespace mlir::arith;
 static IntegerAttr
 applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
                     Attribute rhs,
-                    function_ref<int64_t(int64_t, int64_t)> binFn) {
-  return builder.getIntegerAttr(res.getType(),
-                                binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
-                                      llvm::cast<IntegerAttr>(rhs).getInt()));
+                    function_ref<APInt(APInt, APInt&)> binFn) {
+  auto lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+  auto rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+  auto value = binFn(lhsVal, rhsVal);
+  return IntegerAttr::get(res.getType(), value);
 }
 
 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
+  auto binFn = [](APInt a, APInt& b) -> APInt {
+    return std::move(a) + b;
+  };
+  return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
 }
 
 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+  auto binFn = [](APInt a, APInt& b) -> APInt {
+    return std::move(a) - b;
+  };
+  return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
 }
 
 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs,
-                             std::multiplies<int64_t>());
+  auto binFn = [](APInt a, APInt& b) -> APInt {
+    return std::move(a) * b;
+  };
+  return applyToIntegerAttrs(builder, res, lhs, rhs, binFn);
 }
 
 /// Invert an integer comparison predicate.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1b0547c9e8f804a..b18f5cfcb3f9a12 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -985,6 +985,16 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
   return %mul2 : i32
 }
 
+// CHECK-LABEL: @tripleMulLargeInt
+//       CHECK:   return
+func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
+  %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
+  %c5 = arith.constant 5 : i256
+  %mul1 = arith.muli %arg0, %0 : i256
+  %mul2 = arith.muli %mul1, %c5 : i256
+  return %mul2 : i256
+}
+
 // CHECK-LABEL: @addiMuliToSubiRhsI32
 //  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/68900


More information about the Mlir-commits mailing list