[Mlir-commits] [mlir] [mlir][arith] Overflow flags propagation in arith canonicalizations. (PR #91646)
Ivan Butygin
llvmlistbot at llvm.org
Thu May 9 12:19:08 PDT 2024
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/91646
None
>From 14f003295c1ae5cb7d3a55ebc4c6233f6979db03 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 9 May 2024 21:15:47 +0200
Subject: [PATCH] [mlir][arith] Overflow flags propagation in arith
canonicalizations.
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 35 ++--
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 7 +
mlir/test/Dialect/Arith/canonicalize.mlir | 153 ++++++++++++++++++
3 files changed, 178 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 02d05780a7ac1..6f8e0b868a179 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,10 +24,10 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
-// TODO: Canonicalizations currently doesn't take into account integer overflow
-// flags and always reset them to default (wraparound) which is safe but can
-// inhibit later optimizations. Individual patterns must be reviewed for
-// better handling of overflow flags.
+// Merge overflow flags from 2 ops, selecting most conservative combination.
+def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
+
+// Default oveflow flag (wraparound).
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
@@ -45,7 +45,7 @@ def AddIAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
def IsScalarOrSplatNegativeOne :
Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
Pat<(Arith_AddIOp
$x,
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
- (Arith_SubIOp $x, $y, DefOverflow),
+ (Arith_SubIOp $x, $y, (MergeOverflow $ovf1, $ovf2)),
[(IsScalarOrSplatNegativeOne $c0)]>;
// addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
Pat<(Arith_AddIOp
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
$y, $ovf2),
- (Arith_SubIOp $y, $x, DefOverflow),
+ (Arith_SubIOp $y, $x, (MergeOverflow $ovf1, $ovf2)),
[(IsScalarOrSplatNegativeOne $c0)]>;
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
@@ -153,12 +153,13 @@ def SubILHSSubConstantLHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- DefOverflow)>;
+ (MergeOverflow $ovf1, $ovf2))>;
// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
- (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
+ (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y,
+ (MergeOverflow $ovf1, $ovf2))>;
//===----------------------------------------------------------------------===//
// MulSIExtendedOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index a1568d0ebba3a..d86dfe91771b5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -64,6 +64,13 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+static IntegerOverflowFlagsAttr
+mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
+ IntegerOverflowFlagsAttr val2) {
+ return IntegerOverflowFlagsAttr::get(val1.getContext(),
+ val1.getValue() & val2.getValue());
+}
+
/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f7ce2123a93c6..1b0999fa8a798 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -833,6 +833,30 @@ func.func @tripleAddAdd(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleAddAddOvf1
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleAddAddOvf1(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddAddOvf2
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func.func @tripleAddAddOvf2(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -845,6 +869,18 @@ func.func @tripleAddSub0(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleAddSub0Ovf
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleAddSub0Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleAddSub1
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -857,6 +893,18 @@ func.func @tripleAddSub1(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleAddSub1Ovf
+// CHECK: %[[cres:.+]] = arith.constant 25 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleAddSub1Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+ %add2 = arith.addi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubAdd0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -869,6 +917,18 @@ func.func @tripleSubAdd0(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubAdd0Ovf
+// CHECK: %[[cres:.+]] = arith.constant 25 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubAdd0Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubAdd1
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -881,6 +941,18 @@ func.func @tripleSubAdd1(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubAdd1Ovf
+// CHECK: %[[cres:.+]] = arith.constant -25 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubAdd1Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.addi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @subSub0
// CHECK: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 : index
@@ -891,6 +963,16 @@ func.func @subSub0(%arg0: index, %arg1: index) -> index {
return %sub2 : index
}
+// CHECK-LABEL: @subSub0Ovf
+// CHECK: %[[c0:.+]] = arith.constant 0 : index
+// CHECK: %[[add:.+]] = arith.subi %[[c0]], %arg1 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @subSub0Ovf(%arg0: index, %arg1: index) -> index {
+ %sub1 = arith.subi %arg0, %arg1 overflow<nsw, nuw> : index
+ %sub2 = arith.subi %sub1, %arg0 overflow<nsw, nuw> : index
+ return %sub2 : index
+}
+
// CHECK-LABEL: @tripleSubSub0
// CHECK: %[[cres:.+]] = arith.constant 25 : index
// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] : index
@@ -903,6 +985,19 @@ func.func @tripleSubSub0(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub0Ovf
+// CHECK: %[[cres:.+]] = arith.constant 25 : index
+// CHECK: %[[add:.+]] = arith.addi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub0Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
+
// CHECK-LABEL: @tripleSubSub1
// CHECK: %[[cres:.+]] = arith.constant -25 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -915,6 +1010,18 @@ func.func @tripleSubSub1(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub1Ovf
+// CHECK: %[[cres:.+]] = arith.constant -25 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub1Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %c17, %arg0 overflow<nsw, nuw> : index
+ %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubSub2
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
@@ -927,6 +1034,18 @@ func.func @tripleSubSub2(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub2Ovf
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub2Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+ %add2 = arith.subi %c42, %add1 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @tripleSubSub3
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] : index
@@ -939,6 +1058,18 @@ func.func @tripleSubSub3(%arg0: index) -> index {
return %add2 : index
}
+// CHECK-LABEL: @tripleSubSub3Ovf
+// CHECK: %[[cres:.+]] = arith.constant 59 : index
+// CHECK: %[[add:.+]] = arith.subi %arg0, %[[cres]] overflow<nsw, nuw> : index
+// CHECK: return %[[add]]
+func.func @tripleSubSub3Ovf(%arg0: index) -> index {
+ %c17 = arith.constant 17 : index
+ %c42 = arith.constant 42 : index
+ %add1 = arith.subi %arg0, %c17 overflow<nsw, nuw> : index
+ %add2 = arith.subi %add1, %c42 overflow<nsw, nuw> : index
+ return %add2 : index
+}
+
// CHECK-LABEL: @subAdd1
// CHECK-NEXT: return %arg0
func.func @subAdd1(%arg0: index, %arg1 : index) -> index {
@@ -1018,6 +1149,17 @@ func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 {
return %add : i32
}
+// CHECK-LABEL: @addiMuliToSubiRhsI32Ovf
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiRhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %neg = arith.muli %arg1, %c-1 overflow<nsw, nuw> : i32
+ %add = arith.addi %arg0, %neg overflow<nsw, nuw> : i32
+ return %add : i32
+}
+
// CHECK-LABEL: @addiMuliToSubiRhsIndex
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
@@ -1051,6 +1193,17 @@ func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 {
return %add : i32
}
+// CHECK-LABEL: @addiMuliToSubiLhsI32Ovf
+// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32)
+// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] overflow<nsw, nuw> : i32
+// CHECK: return %[[SUB]]
+func.func @addiMuliToSubiLhsI32Ovf(%arg0: i32, %arg1: i32) -> i32 {
+ %c-1 = arith.constant -1 : i32
+ %neg = arith.muli %arg1, %c-1 overflow<nsw, nuw> : i32
+ %add = arith.addi %neg, %arg0 overflow<nsw, nuw> : i32
+ return %add : i32
+}
+
// CHECK-LABEL: @addiMuliToSubiLhsIndex
// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index
More information about the Mlir-commits
mailing list