[Mlir-commits] [mlir] [mlir][arith] Overflow flags propagation in arith canonicalizations. (PR #91646)

Ivan Butygin llvmlistbot at llvm.org
Thu May 9 12:19:08 PDT 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/91646

None

>From 14f003295c1ae5cb7d3a55ebc4c6233f6979db03 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 9 May 2024 21:15:47 +0200
Subject: [PATCH] [mlir][arith] Overflow flags propagation in arith
 canonicalizations.

---
 .../Dialect/Arith/IR/ArithCanonicalization.td |  35 ++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |   7 +
 mlir/test/Dialect/Arith/canonicalize.mlir     | 153 ++++++++++++++++++
 3 files changed, 178 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 02d05780a7ac1..6f8e0b868a179 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
 // Multiply two integer attributes and create a new one with the result.
 def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 
-// TODO: Canonicalizations currently doesn't take into account integer overflow
-// flags and always reset them to default (wraparound) which is safe but can
-// inhibit later optimizations. Individual patterns must be reviewed for
-// better handling of overflow flags.
+// Merge overflow flags from 2 ops, selecting most conservative combination.
+def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
+
+// Default oveflow flag (wraparound).
 defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
@@ -45,7 +45,7 @@ def AddIAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
            (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
-        (Arith_SubIOp $x, $y, DefOverflow),
+        (Arith_SubIOp $x, $y, (MergeOverflow $ovf1, $ovf2)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
            (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
            $y, $ovf2),
-        (Arith_SubIOp $y, $x, DefOverflow),
+        (Arith_SubIOp $y, $x, (MergeOverflow $ovf1, $ovf2)),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -153,12 +153,13 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            DefOverflow)>;
+            (MergeOverflow $ovf1, $ovf2))>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y,
+            (MergeOverflow $ovf1, $ovf2))>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a1568d0ebba3a..d86dfe91771b5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -64,6 +64,13 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
+static IntegerOverflowFlagsAttr
+mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
+                   IntegerOverflowFlagsAttr val2) {
+  return IntegerOverflowFlagsAttr::get(val1.getContext(),
+                                       val1.getValue() & val2.getValue());
+}
+
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f7ce2123a93c6..1b0999fa8a798 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleAddAddOvf1
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddAddOvf1(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddAddOvf2
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddAddOvf2(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleAddSub0
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleAddSub0Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddSub0Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleAddSub1
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleAddSub1Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 25 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleAddSub1Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+  %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubAdd0
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubAdd0Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 25 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubAdd0Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubAdd1
 //       CHECK:   %[[cres:.+]] = arith.constant -25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -881,6 +941,18 @@ func.func @tripleSubAdd1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubAdd1Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant -25 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubAdd1Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @subSub0
 //       CHECK:   %[[c0:.+]] = arith.constant 0 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[c0]], %arg1 : index
@@ -891,6 +963,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index {
   return %sub2 : index
 }
 
+// CHECK-LABEL: @subSub0Ovf
+//       CHECK:   %[[c0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index {
+  %sub1 = arith.subi %arg0, %arg1 overflow<nsw, nuw> : index
+  %sub2 = arith.subi %sub1, %arg0 overflow<nsw, nuw> : index
+  return %sub2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub0
 //       CHECK:   %[[cres:.+]] = arith.constant 25 : index
 //       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -903,6 +985,19 @@ func.func @tripleSubSub0(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub0Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 25 : index
+//       CHECK:   %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub0Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
+
 // CHECK-LABEL: @tripleSubSub1
 //       CHECK:   %[[cres:.+]] = arith.constant -25 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -915,6 +1010,18 @@ func.func @tripleSubSub1(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub1Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant -25 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub1Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+  %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub2
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -927,6 +1034,18 @@ func.func @tripleSubSub2(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub2Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub2Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+  %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @tripleSubSub3
 //       CHECK:   %[[cres:.+]] = arith.constant 59 : index
 //       CHECK:   %[[add:.+]] = arith.subi %arg0, %[[cres]] : index
@@ -939,6 +1058,18 @@ func.func @tripleSubSub3(%arg0: index) -> index {
   return %add2 : index
 }
 
+// CHECK-LABEL: @tripleSubSub3Ovf
+//       CHECK:   %[[cres:.+]] = arith.constant 59 : index
+//       CHECK:   %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow<nsw, nuw> : index
+//       CHECK:   return %[[add]]
+func.func @tripleSubSub3Ovf(%arg0: index) -> index {
+  %c17 = arith.constant 17 : index
+  %c42 = arith.constant 42 : index
+  %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+  %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+  return %add2 : index
+}
+
 // CHECK-LABEL: @subAdd1
 //  CHECK-NEXT:   return %arg0
 func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
@@ -1018,6 +1149,17 @@ func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
   return %add : i32
 }
 
+// CHECK-LABEL: @addiMuliToSubiRhsI32Ovf
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiRhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 overflow<nsw, nuw> : i32
+  %add = arith.addi %arg0, %neg overflow<nsw, nuw> : i32
+  return %add : i32
+}
+
 // CHECK-LABEL: @addiMuliToSubiRhsIndex
 //  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
@@ -1051,6 +1193,17 @@ func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
   return %add : i32
 }
 
+// CHECK-LABEL: @addiMuliToSubiLhsI32Ovf
+//  CHECK-SAME:   (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+//       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
+//       CHECK:   return %[[SUB]]
+func.func @addiMuliToSubiLhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
+  %c-1 = arith.constant -1 : i32
+  %neg = arith.muli %arg1, %c-1 overflow<nsw, nuw> : i32
+  %add = arith.addi %neg, %arg0 overflow<nsw, nuw> : i32
+  return %add : i32
+}
+
 // CHECK-LABEL: @addiMuliToSubiLhsIndex
 //  CHECK-SAME:   (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
 //       CHECK:   %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index



More information about the Mlir-commits mailing list