[Mlir-commits] [mlir] 547113f - [mlir][ODS] Add `ConstantEnumCase` (#78992)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 01:05:08 PST 2024
Author: Markus Böck
Date: 2024-01-30T10:05:04+01:00
New Revision: 547113fd1f52e2a3d08e3b71ddcd47505ca4a21a
URL: https://github.com/llvm/llvm-project/commit/547113fd1f52e2a3d08e3b71ddcd47505ca4a21a
DIFF: https://github.com/llvm/llvm-project/commit/547113fd1f52e2a3d08e3b71ddcd47505ca4a21a.diff
LOG: [mlir][ODS] Add `ConstantEnumCase` (#78992)
Specifying an enum case of an enum attr currently requires the use of
either `NativeCodeCall` or a `ConstantAttr` specifying the full C++ name
of the enum case. The disadvantages of both are less readable code due
to including C++ expressions and very few checks of any kind, creating
C++ code that does not compile instead.
This PR adds `ConstantEnumCase`, a kind of `ConstantAttr` which
automatically derives the correct value representation from a given enum
and the string representation of an enum case. It supports both
`EnumAttrInfo`s (enums wrapping `IntegerAttr`) and `EnumAttr` (proper
dialect attributes). It even supports bit-enums, allowing one to list
multiple enum cases and have them be combined. If an enum case is not
found, an assertion is triggered with a proper error message.
Besides the tests, it was also used to simplify DRR patterns in the
arith dialect.
Added:
Modified:
mlir/include/mlir/IR/EnumAttr.td
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/IR/enum-attr-roundtrip.mlir
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index cb918b5eceb1a..f4dc480647783 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -417,4 +417,56 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
let assemblyFormat = "$value";
}
+class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
+ defvar cases =
+ !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
+
+ assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
+
+ // `!empty` check to not cause an error if the cases are empty.
+ // The assertion catches the issue later and emits a proper error message.
+ string value = enumAttrInfo.cppType # "::"
+ # !if(!empty(cases), "", !head(cases).symbol);
+}
+
+class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
+ defvar pos = !find(case, "|");
+
+ // Recursive instantiation looking up the symbol before the `|` in
+ // enum cases.
+ string value = !if(
+ !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
+ /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
+ # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
+ );
+}
+
+class ConstantEnumCaseBase<Attr attribute,
+ EnumAttrInfo enumAttrInfo, string case>
+ : ConstantAttr<attribute,
+ !if(!isa<BitEnumAttr>(enumAttrInfo),
+ _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
+ _symbolToValue<enumAttrInfo, case>.value
+ )
+>;
+
+/// Attribute constraint matching a constant enum case. `attribute` should be
+/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
+/// of an enum case. Multiple enum values of a bit-enum can be combined using
+/// `|` as a separator. Note that there mustn't be any whitespace around the
+/// separator.
+/// This attribute constraint is additionally buildable, making it possible to
+/// use it in result patterns.
+///
+/// Examples:
+/// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">
+/// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
+class ConstantEnumCase<Attr attribute, string case>
+ : ConstantEnumCaseBase<attribute,
+ !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
+ !cast<EnumAttr>(attribute).enum), case> {
+ assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
+ "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
+}
+
#endif // ENUMATTR_TD
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 18ceeb0054045..11c4a29718e1d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -28,7 +28,7 @@ def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
// flags and always reset them to default (wraparound) which is safe but can
// inhibit later optimizations. Individual patterns must be reviewed for
// better handling of overflow flags.
-def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
@@ -45,7 +45,7 @@ def AddIAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- (DefOverflow))>;
+ DefOverflow)>;
// addi(subi(x, c0), c1) -> addi(x, c1 - c0)
def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- (DefOverflow))>;
+ DefOverflow)>;
// addi(subi(c0, x), c1) -> subi(c0 + c1, x)
def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- (DefOverflow))>;
+ DefOverflow)>;
def IsScalarOrSplatNegativeOne :
Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
Pat<(Arith_AddIOp
$x,
(Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
- (Arith_SubIOp $x, $y, (DefOverflow)),
+ (Arith_SubIOp $x, $y, DefOverflow),
[(IsScalarOrSplatNegativeOne $c0)]>;
// addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
Pat<(Arith_AddIOp
(Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
$y, $ovf2),
- (Arith_SubIOp $y, $x, (DefOverflow)),
+ (Arith_SubIOp $y, $x, DefOverflow),
[(IsScalarOrSplatNegativeOne $c0)]>;
// muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
(Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
- (DefOverflow))>;
+ DefOverflow)>;
//===----------------------------------------------------------------------===//
// AddUIExtendedOp
@@ -100,7 +100,7 @@ def MulIMulIConstant :
// uses. Since the 'overflow' result is unused, any replacement value will do.
def AddUIExtendedToAddI:
Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
- [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+ [(Arith_AddIOp $x, $y, DefOverflow), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
//===----------------------------------------------------------------------===//
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
- (DefOverflow))>;
+ DefOverflow)>;
// subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
- (DefOverflow))>;
+ DefOverflow)>;
// subi(subi(x, c0), c1) -> subi(x, c0 + c1)
def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
- (DefOverflow))>;
+ DefOverflow)>;
// subi(subi(c0, x), c1) -> subi(c0 - c1, x)
def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
(ConstantLikeMatcher APIntAttr:$c1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
- (DefOverflow))>;
+ DefOverflow)>;
// subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
(Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
- (DefOverflow))>;
+ DefOverflow)>;
// subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
def SubILHSSubConstantLHS :
@@ -153,12 +153,12 @@ def SubILHSSubConstantLHS :
(ConstantLikeMatcher APIntAttr:$c1),
(Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
(Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
- (DefOverflow))>;
+ DefOverflow)>;
// subi(subi(a, b), a) -> subi(0, b)
def SubISubILHSRHSLHS :
Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
- (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
+ (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
//===----------------------------------------------------------------------===//
// MulSIExtendedOp
@@ -168,7 +168,7 @@ def SubISubILHSRHSLHS :
// Since the `high` result it not used, any replacement value will do.
def MulSIExtendedToMulI :
Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
- [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+ [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
@@ -182,9 +182,9 @@ def MulSIExtendedRHSOne :
Pattern<(Arith_MulSIExtendedOp $x, (ConstantLikeMatcher AnyAttr:$c1)),
[(replaceWithValue $x),
(Arith_ExtSIOp(Arith_CmpIOp
- (NativeCodeCall<"arith::CmpIPredicate::slt">),
- $x,
- (Arith_ConstantOp (GetZeroAttr $x))))],
+ ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">,
+ $x,
+ (Arith_ConstantOp (GetZeroAttr $x))))],
[(IsScalarOrSplatOne $c1)]>;
//===----------------------------------------------------------------------===//
@@ -195,7 +195,7 @@ def MulSIExtendedRHSOne :
// Since the `high` result it not used, any replacement value will do.
def MulUIExtendedToMulI :
Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
- [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+ [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
[(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2d124ce4980fa..ff72becc8dfa7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,11 +61,6 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
-static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
- return IntegerOverflowFlagsAttr::get(builder.getContext(),
- IntegerOverflowFlags::none);
-}
-
/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
switch (pred) {
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 8ef4495f0bf03..0b4d379cfb7d5 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -26,3 +26,12 @@ func.func @test_match_op_with_enum() -> () {
test.op_with_enum first tag 0 : i32
return
}
+
+// CHECK-LABEL: @test_match_op_with_bit_enum
+func.func @test_match_op_with_bit_enum() -> () {
+ // CHECK: test.op_with_bit_enum <write> tag 0 : i32
+ test.op_with_bit_enum <write> tag 0 : i32
+ // CHECK: test.op_with_bit_enum <read, execute> tag 1 : i32
+ test.op_with_bit_enum <execute, write> tag 0 : i32
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 11e409c6072f7..91ce0af9cd7fd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -396,11 +396,9 @@ def OpWithEnum : TEST_Op<"op_with_enum"> {
}
// Define a pattern that matches and creates an enum attribute.
-def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
- "::test::TestEnum::First">:$value,
+def : Pat<(OpWithEnum ConstantEnumCase<TestEnumAttr, "first">:$value,
ConstantAttr<I32Attr, "0">:$tag),
- (OpWithEnum ConstantAttr<TestEnumAttr,
- "::test::TestEnum::Second">,
+ (OpWithEnum ConstantEnumCase<TestEnumAttr, "second">,
ConstantAttr<I32Attr, "1">)>;
//===----------------------------------------------------------------------===//
@@ -430,6 +428,12 @@ def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
}
+// Define a pattern that matches and creates a bit enum attribute.
+def : Pat<(OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "write|execute">,
+ ConstantAttr<I32Attr, "0">),
+ (OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "execute|read">,
+ ConstantAttr<I32Attr, "1">)>;
+
//===----------------------------------------------------------------------===//
// Test Regions
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list