[Mlir-commits] [mlir] [mlir][ODS] Add `ConstantEnumCase` (PR #78992)

Markus Böck llvmlistbot at llvm.org
Tue Jan 23 00:44:21 PST 2024


https://github.com/zero9178 updated https://github.com/llvm/llvm-project/pull/78992

>From 6019e58ef7bb20e9fc91653e3d63dbee4f64c1fe Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Mon, 22 Jan 2024 16:05:04 +0100
Subject: [PATCH 1/2] [mlir][ODS] Add `ConstantEnumCase`

Specifying an enum case of an enum attr currently requires the use of `NativeCodeCall`. The disadvantages of doing so are less readable code due to including C++ expressions and very few checks of any kind, creating C++ code that does not compile instead.

This PR adds `ConstantEnumCase`, a kind of `ConstantAttr` which automatically derives the correct value string representation given an enum and the string representation of an enum. It supports both `EnumAttrInfo`s (enums wrapping `IntegerAttr`) and `EnumAttr` (proper dialect attributes).
It even supports bit-enums, allowing one to list multiple enum cases and have them be combined.
If an enum case is not found, an assertion is triggered with a proper error message.

Besides the tests, it was also used to simplify DDR patterns in the arith dialect.
---
 mlir/include/mlir/IR/EnumAttr.td              | 50 +++++++++++++++++++
 .../Dialect/Arith/IR/ArithCanonicalization.td | 40 +++++++--------
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  5 --
 mlir/test/IR/enum-attr-roundtrip.mlir         |  9 ++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 12 +++--
 5 files changed, 87 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index cb918b5eceb1a19..3726a2f6ebbbf8c 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -417,4 +417,54 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
   let assemblyFormat = "$value";
 }
 
+class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
+  defvar cases =
+    !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
+
+  assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
+
+  // `!empty` check to not cause an error if the cases are empty.
+  // The assertion catches the issue later and emits a proper error message.
+  string value = enumAttrInfo.cppType # "::"
+    # !if(!empty(cases), "", !head(cases).symbol);
+}
+
+class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
+  defvar pos = !find(case, "|");
+
+  // Recursive instantiation looking up the symbol before the `|` in
+  // enum cases.
+  string value = !if(
+    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
+    /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
+    # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
+  );
+}
+
+class ConstantEnumCaseBase<Attr attribute,
+    EnumAttrInfo enumAttrInfo, string case>
+  : ConstantAttr<attribute,
+  !if(!isa<BitEnumAttr>(enumAttrInfo),
+    _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
+    _symbolToValue<enumAttrInfo, case>.value
+  )
+>;
+
+/// Constant attribute for defining enum values. `attribute` should be one of
+/// `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation of an
+/// enum case. Multiple enum values of a bit-enum can be combined using `|` as
+/// a separator. Note that there mustn't be any whitespace around the
+/// separator.
+///
+/// Examples:
+/// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">
+/// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
+class ConstantEnumCase<Attr attribute, string case>
+  : ConstantEnumCaseBase<attribute,
+    !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
+          !cast<EnumAttr>(attribute).enum), case> {
+  assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
+    "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
+}
+
 #endif // ENUMATTR_TD
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 18ceeb0054045e8..11c4a29718e1d96 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -28,7 +28,7 @@ def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 // flags and always reset them to default (wraparound) which is safe but can
 // inhibit later optimizations. Individual patterns must be reviewed for
 // better handling of overflow flags.
-def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
@@ -45,7 +45,7 @@ def AddIAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
            (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
-        (Arith_SubIOp $x, $y, (DefOverflow)),
+        (Arith_SubIOp $x, $y, DefOverflow),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
            (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
            $y, $ovf2),
-        (Arith_SubIOp $y, $x, (DefOverflow)),
+        (Arith_SubIOp $y, $x, DefOverflow),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -100,7 +100,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, DefOverflow), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -153,12 +153,12 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -168,7 +168,7 @@ def SubISubILHSRHSLHS :
 // Since the `high` result it not used, any replacement value will do.
 def MulSIExtendedToMulI :
     Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 
@@ -182,9 +182,9 @@ def MulSIExtendedRHSOne :
     Pattern<(Arith_MulSIExtendedOp $x, (ConstantLikeMatcher AnyAttr:$c1)),
             [(replaceWithValue $x),
              (Arith_ExtSIOp(Arith_CmpIOp
-                              (NativeCodeCall<"arith::CmpIPredicate::slt">),
-                              $x,
-                              (Arith_ConstantOp (GetZeroAttr $x))))],
+                ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">,
+                $x,
+                (Arith_ConstantOp (GetZeroAttr $x))))],
             [(IsScalarOrSplatOne $c1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -195,7 +195,7 @@ def MulSIExtendedRHSOne :
 // Since the `high` result it not used, any replacement value will do.
 def MulUIExtendedToMulI :
     Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2d124ce4980fa46..ff72becc8dfa776 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,11 +61,6 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
-static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
-  return IntegerOverflowFlagsAttr::get(builder.getContext(),
-                                       IntegerOverflowFlags::none);
-}
-
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 8ef4495f0bf033f..0b4d379cfb7d5ff 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -26,3 +26,12 @@ func.func @test_match_op_with_enum() -> () {
   test.op_with_enum first tag 0 : i32
   return
 }
+
+// CHECK-LABEL: @test_match_op_with_bit_enum
+func.func @test_match_op_with_bit_enum() -> () {
+  // CHECK: test.op_with_bit_enum <write> tag 0 : i32
+  test.op_with_bit_enum <write> tag 0 : i32
+  // CHECK: test.op_with_bit_enum <read, execute> tag 1 : i32
+  test.op_with_bit_enum <execute, write> tag 0 : i32
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 11e409c6072f7c4..91ce0af9cd7fd5f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -396,11 +396,9 @@ def OpWithEnum : TEST_Op<"op_with_enum"> {
 }
 
 // Define a pattern that matches and creates an enum attribute.
-def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
-                                   "::test::TestEnum::First">:$value,
+def : Pat<(OpWithEnum ConstantEnumCase<TestEnumAttr, "first">:$value,
                       ConstantAttr<I32Attr, "0">:$tag),
-          (OpWithEnum ConstantAttr<TestEnumAttr,
-                                   "::test::TestEnum::Second">,
+          (OpWithEnum ConstantEnumCase<TestEnumAttr, "second">,
                       ConstantAttr<I32Attr, "1">)>;
 
 //===----------------------------------------------------------------------===//
@@ -430,6 +428,12 @@ def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
   let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
 }
 
+// Define a pattern that matches and creates a bit enum attribute.
+def : Pat<(OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "write|execute">,
+                         ConstantAttr<I32Attr, "0">),
+          (OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "execute|read">,
+                         ConstantAttr<I32Attr, "1">)>;
+
 //===----------------------------------------------------------------------===//
 // Test Regions
 //===----------------------------------------------------------------------===//

>From d8323f8e5aa3c5e56724e26f3c61f91c28f37103 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Tue, 23 Jan 2024 09:44:02 +0100
Subject: [PATCH 2/2] address review comments

---
 mlir/include/mlir/IR/EnumAttr.td | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 3726a2f6ebbbf8c..f4dc48064778349 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -450,11 +450,13 @@ class ConstantEnumCaseBase<Attr attribute,
   )
 >;
 
-/// Constant attribute for defining enum values. `attribute` should be one of
-/// `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation of an
-/// enum case. Multiple enum values of a bit-enum can be combined using `|` as
-/// a separator. Note that there mustn't be any whitespace around the
+/// Attribute constraint matching a constant enum case. `attribute` should be
+/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
+/// of an enum case. Multiple enum values of a bit-enum can be combined using
+/// `|` as a separator. Note that there mustn't be any whitespace around the
 /// separator.
+/// This attribute constraint is additionally buildable, making it possible to
+/// use it in result patterns.
 ///
 /// Examples:
 /// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">



More information about the Mlir-commits mailing list