[Mlir-commits] [mlir] [mlir][ODS] Add `ConstantEnumCase` (PR #78992)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 22 07:08:44 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir-core

Author: Markus Böck (zero9178)

<details>
<summary>Changes</summary>

Specifying an enum case of an enum attr currently requires the use of `NativeCodeCall`. The disadvantages of doing so are less readable code due to including C++ expressions and very few checks of any kind, creating C++ code that does not compile instead.

This PR adds `ConstantEnumCase`, a kind of `ConstantAttr` which automatically derives the correct value string representation given an enum and the string representation of an enum. It supports both `EnumAttrInfo`s (enums wrapping `IntegerAttr`) and `EnumAttr` (proper dialect attributes). It even supports bit-enums, allowing one to list multiple enum cases and have them be combined. If an enum case is not found, an assertion is triggered with a proper error message.

Besides the tests, it was also used to simplify DDR patterns in the arith dialect.

---
Full diff: https://github.com/llvm/llvm-project/pull/78992.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/EnumAttr.td (+50) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+20-20) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (-5) 
- (modified) mlir/test/IR/enum-attr-roundtrip.mlir (+9) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+8-4) 


``````````diff
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index cb918b5eceb1a1..3726a2f6ebbbf8 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -417,4 +417,54 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
   let assemblyFormat = "$value";
 }
 
+class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
+  defvar cases =
+    !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
+
+  assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
+
+  // `!empty` check to not cause an error if the cases are empty.
+  // The assertion catches the issue later and emits a proper error message.
+  string value = enumAttrInfo.cppType # "::"
+    # !if(!empty(cases), "", !head(cases).symbol);
+}
+
+class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
+  defvar pos = !find(case, "|");
+
+  // Recursive instantiation looking up the symbol before the `|` in
+  // enum cases.
+  string value = !if(
+    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
+    /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
+    # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
+  );
+}
+
+class ConstantEnumCaseBase<Attr attribute,
+    EnumAttrInfo enumAttrInfo, string case>
+  : ConstantAttr<attribute,
+  !if(!isa<BitEnumAttr>(enumAttrInfo),
+    _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
+    _symbolToValue<enumAttrInfo, case>.value
+  )
+>;
+
+/// Constant attribute for defining enum values. `attribute` should be one of
+/// `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation of an
+/// enum case. Multiple enum values of a bit-enum can be combined using `|` as
+/// a separator. Note that there mustn't be any whitespace around the
+/// separator.
+///
+/// Examples:
+/// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">
+/// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
+class ConstantEnumCase<Attr attribute, string case>
+  : ConstantEnumCaseBase<attribute,
+    !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
+          !cast<EnumAttr>(attribute).enum), case> {
+  assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
+    "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
+}
+
 #endif // ENUMATTR_TD
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 18ceeb0054045e..11c4a29718e1d9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -28,7 +28,7 @@ def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 // flags and always reset them to default (wraparound) which is safe but can
 // inhibit later optimizations. Individual patterns must be reviewed for
 // better handling of overflow flags.
-def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
@@ -45,7 +45,7 @@ def AddIAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
            (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
-        (Arith_SubIOp $x, $y, (DefOverflow)),
+        (Arith_SubIOp $x, $y, DefOverflow),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
            (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
            $y, $ovf2),
-        (Arith_SubIOp $y, $x, (DefOverflow)),
+        (Arith_SubIOp $y, $x, DefOverflow),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -100,7 +100,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, DefOverflow), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -153,12 +153,12 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -168,7 +168,7 @@ def SubISubILHSRHSLHS :
 // Since the `high` result it not used, any replacement value will do.
 def MulSIExtendedToMulI :
     Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 
@@ -182,9 +182,9 @@ def MulSIExtendedRHSOne :
     Pattern<(Arith_MulSIExtendedOp $x, (ConstantLikeMatcher AnyAttr:$c1)),
             [(replaceWithValue $x),
              (Arith_ExtSIOp(Arith_CmpIOp
-                              (NativeCodeCall<"arith::CmpIPredicate::slt">),
-                              $x,
-                              (Arith_ConstantOp (GetZeroAttr $x))))],
+                ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">,
+                $x,
+                (Arith_ConstantOp (GetZeroAttr $x))))],
             [(IsScalarOrSplatOne $c1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -195,7 +195,7 @@ def MulSIExtendedRHSOne :
 // Since the `high` result it not used, any replacement value will do.
 def MulUIExtendedToMulI :
     Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2d124ce4980fa4..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,11 +61,6 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
-static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
-  return IntegerOverflowFlagsAttr::get(builder.getContext(),
-                                       IntegerOverflowFlags::none);
-}
-
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {
diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 8ef4495f0bf033..0b4d379cfb7d5f 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -26,3 +26,12 @@ func.func @test_match_op_with_enum() -> () {
   test.op_with_enum first tag 0 : i32
   return
 }
+
+// CHECK-LABEL: @test_match_op_with_bit_enum
+func.func @test_match_op_with_bit_enum() -> () {
+  // CHECK: test.op_with_bit_enum <write> tag 0 : i32
+  test.op_with_bit_enum <write> tag 0 : i32
+  // CHECK: test.op_with_bit_enum <read, execute> tag 1 : i32
+  test.op_with_bit_enum <execute, write> tag 0 : i32
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 11e409c6072f7c..91ce0af9cd7fd5 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -396,11 +396,9 @@ def OpWithEnum : TEST_Op<"op_with_enum"> {
 }
 
 // Define a pattern that matches and creates an enum attribute.
-def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
-                                   "::test::TestEnum::First">:$value,
+def : Pat<(OpWithEnum ConstantEnumCase<TestEnumAttr, "first">:$value,
                       ConstantAttr<I32Attr, "0">:$tag),
-          (OpWithEnum ConstantAttr<TestEnumAttr,
-                                   "::test::TestEnum::Second">,
+          (OpWithEnum ConstantEnumCase<TestEnumAttr, "second">,
                       ConstantAttr<I32Attr, "1">)>;
 
 //===----------------------------------------------------------------------===//
@@ -430,6 +428,12 @@ def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
   let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
 }
 
+// Define a pattern that matches and creates a bit enum attribute.
+def : Pat<(OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "write|execute">,
+                         ConstantAttr<I32Attr, "0">),
+          (OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "execute|read">,
+                         ConstantAttr<I32Attr, "1">)>;
+
 //===----------------------------------------------------------------------===//
 // Test Regions
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/78992


More information about the Mlir-commits mailing list