[Mlir-commits] [mlir] 547113f - [mlir][ODS] Add `ConstantEnumCase` (#78992)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 01:05:08 PST 2024


Author: Markus Böck
Date: 2024-01-30T10:05:04+01:00
New Revision: 547113fd1f52e2a3d08e3b71ddcd47505ca4a21a

URL: https://github.com/llvm/llvm-project/commit/547113fd1f52e2a3d08e3b71ddcd47505ca4a21a
DIFF: https://github.com/llvm/llvm-project/commit/547113fd1f52e2a3d08e3b71ddcd47505ca4a21a.diff

LOG: [mlir][ODS] Add `ConstantEnumCase` (#78992)

Specifying an enum case of an enum attr currently requires the use of
either `NativeCodeCall` or a `ConstantAttr` specifying the full C++ name
of the enum case. The disadvantages of both are less readable code due
to including C++ expressions and very few checks of any kind, creating
C++ code that does not compile instead.

This PR adds `ConstantEnumCase`, a kind of `ConstantAttr` which
automatically derives the correct value representation from a given enum
and the string representation of an enum case. It supports both
`EnumAttrInfo`s (enums wrapping `IntegerAttr`) and `EnumAttr` (proper
dialect attributes). It even supports bit-enums, allowing one to list
multiple enum cases and have them be combined. If an enum case is not
found, an assertion is triggered with a proper error message.

Besides the tests, it was also used to simplify DRR patterns in the
arith dialect.

Added: 
    

Modified: 
    mlir/include/mlir/IR/EnumAttr.td
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/IR/enum-attr-roundtrip.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index cb918b5eceb1a..f4dc480647783 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -417,4 +417,56 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
   let assemblyFormat = "$value";
 }
 
+class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
+  defvar cases =
+    !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
+
+  assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
+
+  // `!empty` check to not cause an error if the cases are empty.
+  // The assertion catches the issue later and emits a proper error message.
+  string value = enumAttrInfo.cppType # "::"
+    # !if(!empty(cases), "", !head(cases).symbol);
+}
+
+class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
+  defvar pos = !find(case, "|");
+
+  // Recursive instantiation looking up the symbol before the `|` in
+  // enum cases.
+  string value = !if(
+    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
+    /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
+    # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
+  );
+}
+
+class ConstantEnumCaseBase<Attr attribute,
+    EnumAttrInfo enumAttrInfo, string case>
+  : ConstantAttr<attribute,
+  !if(!isa<BitEnumAttr>(enumAttrInfo),
+    _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
+    _symbolToValue<enumAttrInfo, case>.value
+  )
+>;
+
+/// Attribute constraint matching a constant enum case. `attribute` should be
+/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
+/// of an enum case. Multiple enum values of a bit-enum can be combined using
+/// `|` as a separator. Note that there mustn't be any whitespace around the
+/// separator.
+/// This attribute constraint is additionally buildable, making it possible to
+/// use it in result patterns.
+///
+/// Examples:
+/// * ConstantEnumCase<Arith_IntegerOverflowAttr, "nsw|nuw">
+/// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
+class ConstantEnumCase<Attr attribute, string case>
+  : ConstantEnumCaseBase<attribute,
+    !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
+          !cast<EnumAttr>(attribute).enum), case> {
+  assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
+    "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
+}
+
 #endif // ENUMATTR_TD

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 18ceeb0054045..11c4a29718e1d 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -28,7 +28,7 @@ def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
 // flags and always reset them to default (wraparound) which is safe but can
 // inhibit later optimizations. Individual patterns must be reviewed for
 // better handling of overflow flags.
-def DefOverflow : NativeCodeCall<"getDefOverflowFlags($_builder)">;
+defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
 
 class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;
 
@@ -45,7 +45,7 @@ def AddIAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // addi(subi(x, c0), c1) -> addi(x, c1 - c0)
 def AddISubConstantRHS :
@@ -53,7 +53,7 @@ def AddISubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // addi(subi(c0, x), c1) -> subi(c0 + c1, x)
 def AddISubConstantLHS :
@@ -61,7 +61,7 @@ def AddISubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 def IsScalarOrSplatNegativeOne :
     Constraint<And<[
@@ -73,7 +73,7 @@ def AddIMulNegativeOneRhs :
     Pat<(Arith_AddIOp
            $x,
            (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0), $ovf1), $ovf2),
-        (Arith_SubIOp $x, $y, (DefOverflow)),
+        (Arith_SubIOp $x, $y, DefOverflow),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // addi(muli(x, -1), y) -> subi(y, x)
@@ -81,7 +81,7 @@ def AddIMulNegativeOneLhs :
     Pat<(Arith_AddIOp
            (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0), $ovf1),
            $y, $ovf2),
-        (Arith_SubIOp $y, $x, (DefOverflow)),
+        (Arith_SubIOp $y, $x, DefOverflow),
         [(IsScalarOrSplatNegativeOne $c0)]>;
 
 // muli(muli(x, c0), c1) -> muli(x, c0 * c1)
@@ -90,7 +90,7 @@ def MulIMulIConstant :
           (Arith_MulIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_MulIOp $x, (Arith_ConstantOp (MulIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 //===----------------------------------------------------------------------===//
 // AddUIExtendedOp
@@ -100,7 +100,7 @@ def MulIMulIConstant :
 // uses. Since the 'overflow' result is unused, any replacement value will do.
 def AddUIExtendedToAddI:
     Pattern<(Arith_AddUIExtendedOp:$res $x, $y),
-             [(Arith_AddIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+             [(Arith_AddIOp $x, $y, DefOverflow), (replaceWithValue $x)],
              [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -113,7 +113,7 @@ def SubIRHSAddConstant :
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, addi(x, c0)) -> subi(c1 - c0, x)
 def SubILHSAddConstant :
@@ -121,7 +121,7 @@ def SubILHSAddConstant :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_AddIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(x, c0), c1) -> subi(x, c0 + c1)
 def SubIRHSSubConstantRHS :
@@ -129,7 +129,7 @@ def SubIRHSSubConstantRHS :
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp $x, (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(c0, x), c1) -> subi(c0 - c1, x)
 def SubIRHSSubConstantLHS :
@@ -137,7 +137,7 @@ def SubIRHSSubConstantLHS :
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1),
           (ConstantLikeMatcher APIntAttr:$c1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (SubIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, subi(x, c0)) -> subi(c0 + c1, x)
 def SubILHSSubConstantRHS :
@@ -145,7 +145,7 @@ def SubILHSSubConstantRHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp $x, (ConstantLikeMatcher APIntAttr:$c0), $ovf1), $ovf2),
         (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x,
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(c1, subi(c0, x)) -> addi(x, c1 - c0)
 def SubILHSSubConstantLHS :
@@ -153,12 +153,12 @@ def SubILHSSubConstantLHS :
           (ConstantLikeMatcher APIntAttr:$c1),
           (Arith_SubIOp (ConstantLikeMatcher APIntAttr:$c0), $x, $ovf1), $ovf2),
         (Arith_AddIOp $x, (Arith_ConstantOp (SubIntAttrs $res, $c1, $c0)),
-            (DefOverflow))>;
+            DefOverflow)>;
 
 // subi(subi(a, b), a) -> subi(0, b)
 def SubISubILHSRHSLHS :
     Pat<(Arith_SubIOp:$res (Arith_SubIOp $x, $y, $ovf1), $x, $ovf2),
-        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, (DefOverflow))>;
+        (Arith_SubIOp (Arith_ConstantOp (GetZeroAttr $y)), $y, DefOverflow)>;
 
 //===----------------------------------------------------------------------===//
 // MulSIExtendedOp
@@ -168,7 +168,7 @@ def SubISubILHSRHSLHS :
 // Since the `high` result it not used, any replacement value will do.
 def MulSIExtendedToMulI :
     Pattern<(Arith_MulSIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 
@@ -182,9 +182,9 @@ def MulSIExtendedRHSOne :
     Pattern<(Arith_MulSIExtendedOp $x, (ConstantLikeMatcher AnyAttr:$c1)),
             [(replaceWithValue $x),
              (Arith_ExtSIOp(Arith_CmpIOp
-                              (NativeCodeCall<"arith::CmpIPredicate::slt">),
-                              $x,
-                              (Arith_ConstantOp (GetZeroAttr $x))))],
+                ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">,
+                $x,
+                (Arith_ConstantOp (GetZeroAttr $x))))],
             [(IsScalarOrSplatOne $c1)]>;
 
 //===----------------------------------------------------------------------===//
@@ -195,7 +195,7 @@ def MulSIExtendedRHSOne :
 // Since the `high` result it not used, any replacement value will do.
 def MulUIExtendedToMulI :
     Pattern<(Arith_MulUIExtendedOp:$res $x, $y),
-        [(Arith_MulIOp $x, $y, (DefOverflow)), (replaceWithValue $x)],
+        [(Arith_MulIOp $x, $y, DefOverflow), (replaceWithValue $x)],
         [(Constraint<CPred<"$0.getUses().empty()">> $res__1)]>;
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2d124ce4980fa..ff72becc8dfa7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -61,11 +61,6 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
   return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
 }
 
-static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
-  return IntegerOverflowFlagsAttr::get(builder.getContext(),
-                                       IntegerOverflowFlags::none);
-}
-
 /// Invert an integer comparison predicate.
 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   switch (pred) {

diff  --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir
index 8ef4495f0bf03..0b4d379cfb7d5 100644
--- a/mlir/test/IR/enum-attr-roundtrip.mlir
+++ b/mlir/test/IR/enum-attr-roundtrip.mlir
@@ -26,3 +26,12 @@ func.func @test_match_op_with_enum() -> () {
   test.op_with_enum first tag 0 : i32
   return
 }
+
+// CHECK-LABEL: @test_match_op_with_bit_enum
+func.func @test_match_op_with_bit_enum() -> () {
+  // CHECK: test.op_with_bit_enum <write> tag 0 : i32
+  test.op_with_bit_enum <write> tag 0 : i32
+  // CHECK: test.op_with_bit_enum <read, execute> tag 1 : i32
+  test.op_with_bit_enum <execute, write> tag 0 : i32
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 11e409c6072f7..91ce0af9cd7fd 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -396,11 +396,9 @@ def OpWithEnum : TEST_Op<"op_with_enum"> {
 }
 
 // Define a pattern that matches and creates an enum attribute.
-def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
-                                   "::test::TestEnum::First">:$value,
+def : Pat<(OpWithEnum ConstantEnumCase<TestEnumAttr, "first">:$value,
                       ConstantAttr<I32Attr, "0">:$tag),
-          (OpWithEnum ConstantAttr<TestEnumAttr,
-                                   "::test::TestEnum::Second">,
+          (OpWithEnum ConstantEnumCase<TestEnumAttr, "second">,
                       ConstantAttr<I32Attr, "1">)>;
 
 //===----------------------------------------------------------------------===//
@@ -430,6 +428,12 @@ def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
   let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
 }
 
+// Define a pattern that matches and creates a bit enum attribute.
+def : Pat<(OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "write|execute">,
+                         ConstantAttr<I32Attr, "0">),
+          (OpWithBitEnum ConstantEnumCase<TestBitEnumAttr, "execute|read">,
+                         ConstantAttr<I32Attr, "1">)>;
+
 //===----------------------------------------------------------------------===//
 // Test Regions
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list