[Mlir-commits] [mlir] 345f57d - [mlir][arith] Overflow flags propagation in arith canonicalizations. (#91646)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 13 04:47:42 PDT 2024
Author: Ivan Butygin
Date: 2024-05-13T14:47:37+03:00
New Revision: 345f57df16af7e4fac3a321035e504b5d49206f4
URL: https://github.com/llvm/llvm-project/commit/345f57df16af7e4fac3a321035e504b5d49206f4
DIFF: https://github.com/llvm/llvm-project/commit/345f57df16af7e4fac3a321035e504b5d49206f4.diff
LOG: [mlir][arith] Overflow flags propagation in arith canonicalizations. (#91646)
Added:
Modified:
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 02d05780a7ac1..6d7ac2be951dd 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
-// TODO: Canonicalizations currently doesn't take into account integer overflow
-// flags and always reset them to default (wraparound) which is safe but can
-// inhibit later optimizations. Individual patterns must be reviewed for
-// better handling of overflow flags.
+// Merge overflow flags from 2 ops, selecting the most conservative combination.
+def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
+
+// Default overflow flag (all wraparounds allowed).
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
@@ -45,7 +45,7 @@ def AddIAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
def IsScalarOrSplatNegativeOne :
Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
Pat<(Arith_AddIOp
$x,
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
- (Arith_SubIOp $x, $y, DefOverflow),
+ (Arith_SubIOp $x, $y, DefOverflow), // TODO: overflow flags
[(IsScalarOrSplatNegativeOne $c0)]>;
// addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
Pat<(Arith_AddIOp
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
$y, $ovf2),
- (Arith_SubIOp $y, $x, DefOverflow),
+ (Arith_SubIOp $y, $x, DefOverflow), // TODO: overflow flags
[(IsScalarOrSplatNegativeOne $c0)]>;
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ DefOverflow)>; // TODO: overflow flags
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
@@ -153,12 +153,13 @@ def SubILHSSubConstantLHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
- (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
+ (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y,
+ (MergeOverflow $ovf1, $ovf2))>;
//===----------------------------------------------------------------------===//
// MulSIExtendedOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a1568d0ebba3a..a0b50251c6b67 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -64,6 +64,14 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+// Merge overflow flags from 2 ops, selecting the most conservative combination.
+static IntegerOverflowFlagsAttr
+mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
+ IntegerOverflowFlagsAttr val2) {
+ return IntegerOverflowFlagsAttr::get(val1.getContext(),
+ val1.getValue() & val2.getValue());
+}
+
/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f7ce2123a93c6..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleAddAddOvf1
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleAddAddOvf1(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddAddOvf2
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func.func @tripleAddAddOvf2(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleAddSub0Ovf
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleAddSub0Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleAddSub1
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleAddSub1Ovf
+// CHECK: %[[cres:.+]] = arith.constant 25 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleAddSub1Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubAdd0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubAdd0Ovf
+// CHECK: %[[cres:.+]] = arith.constant 25 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubAdd0Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubAdd1
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -891,6 +951,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index {
return %sub2 : index
}
+// CHECK-LABEL: @subSub0Ovf
+// CHECK: %[[c0:.+]] = arith.constant 0 : index
+// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index {
+ %sub1 = arith.subi %arg0, %arg1 overflow<nsw, nuw> : index
+ %sub2 = arith.subi %sub1, %arg0 overflow<nsw, nuw> : index
+ return %sub2 : index
+}
+
// CHECK-LABEL: @tripleSubSub0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -903,6 +973,19 @@ func.func @tripleSubSub0(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub0Ovf
+// CHECK: %[[cres:.+]] = arith.constant 25 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub0Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
+
// CHECK-LABEL: @tripleSubSub1
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -915,6 +998,18 @@ func.func @tripleSubSub1(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub1Ovf
+// CHECK: %[[cres:.+]] = arith.constant -25 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub1Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubSub2
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -927,6 +1022,18 @@ func.func @tripleSubSub2(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub2Ovf
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub2Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+ %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubSub3
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] : index
@@ -939,6 +1046,18 @@ func.func @tripleSubSub3(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub3Ovf
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub3Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+ %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @subAdd1
// CHECK-NEXT: return %arg0
func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
More information about the Mlir-commits
mailing list