[Mlir-commits] [mlir] 7ef1754 - [mlir][arith] Fix canon pattern for large ints in chained arith (#68900)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 13 03:35:10 PDT 2023


Author: Rik Huijzer
Date: 2023-10-13T12:35:04+02:00
New Revision: 7ef1754301a88ea0cbcffae53c2027abad3cc357

URL: https://github.com/llvm/llvm-project/commit/7ef1754301a88ea0cbcffae53c2027abad3cc357
DIFF: https://github.com/llvm/llvm-project/commit/7ef1754301a88ea0cbcffae53c2027abad3cc357.diff

LOG: [mlir][arith] Fix canon pattern for large ints in chained arith (#68900)

The logic for chained basic arithmetic operations in the `arith` dialect
was using `getInt()` on `IntegerAttr`. This is a problem for very large
integers. Specifically, in
https://github.com/llvm/llvm-project/issues/64774 the following
assertion failed:

```
Assertion failed: (getSignificantBits() <= 64 && "Too many bits for int64_t"), function getSExtValue, file APInt.h, line 1510.
```

According to a comment on `getInt()`, calls to `getInt()` should be
replaced by `getValue()`:

https://github.com/llvm/llvm-project/blob/ab6a66dbec61654d0962f6abf6d6c5b776937584/mlir/include/mlir/IR/BuiltinAttributes.td#L707-L708

This patch fixes https://github.com/llvm/llvm-project/issues/64774 by
doing such a replacement.

---------

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index ae8a6ef350ce191..3892e8fa0a32f2d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -39,26 +39,26 @@ using namespace mlir::arith;
 static IntegerAttr
 applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
                     Attribute rhs,
-                    function_ref<int64_t(int64_t, int64_t)> binFn) {
-  return builder.getIntegerAttr(res.getType(),
-                                binFn(llvm::cast<IntegerAttr>(lhs).getInt(),
-                                      llvm::cast<IntegerAttr>(rhs).getInt()));
+                    function_ref<APInt(const APInt &, const APInt &)> binFn) {
+  APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
+  APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
+  APInt value = binFn(lhsVal, rhsVal);
+  return IntegerAttr::get(res.getType(), value);
 }
 
 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<int64_t>());
+  return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
 }
 
 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<int64_t>());
+  return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
 }
 
 static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
                                    Attribute lhs, Attribute rhs) {
-  return applyToIntegerAttrs(builder, res, lhs, rhs,
-                             std::multiplies<int64_t>());
+  return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
 /// Invert an integer comparison predicate.

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f697f3d01458eee..5e4476a21df04ea 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -909,6 +909,18 @@ func.func @tripleMulIMulII32(%arg0: i32) -> i32 {
   return %mul2 : i32
 }
 
+// CHECK-LABEL: @tripleMulLargeInt
+//       CHECK:   %[[cres:.+]] = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020482 : i256
+//       CHECK:   %[[addi:.+]] = arith.addi %arg0, %[[cres]] : i256
+//       CHECK:   return %[[addi]]
+func.func @tripleMulLargeInt(%arg0: i256) -> i256 {
+  %0 = arith.constant 3618502788666131213697322783095070105623107215331596699973092056135872020481 : i256
+  %1 = arith.constant 1 : i256
+  %2 = arith.addi %arg0, %0 : i256
+  %3 = arith.addi %2, %1 : i256
+  return %3 : i256
+}
+
 // CHECK-LABEL: @addiMuliToSubiRhsI32
 //  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32


        


More information about the Mlir-commits mailing list