[Mlir-commits] [mlir] 5c3b205 - [mlir] Update LLVMIR Fastmath flags use of MLIR BitEnum functionality

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 17 11:19:18 PDT 2022


Author: jfurtek
Date: 2022-05-17T18:19:14Z
New Revision: 5c3b20520b5716d61833c8ce45d19faa48ce8db7

URL: https://github.com/llvm/llvm-project/commit/5c3b20520b5716d61833c8ce45d19faa48ce8db7
DIFF: https://github.com/llvm/llvm-project/commit/5c3b20520b5716d61833c8ce45d19faa48ce8db7.diff

LOG: [mlir] Update LLVMIR Fastmath flags use of MLIR BitEnum functionality

This diff updates the LLVMIR dialect Fastmath flags attribute to use recently
added features of `BitEnum` attributes. Specifically, this diff uses the bit
enum "group" case to represent the `fast` value as an alias for a combination
of other values (`ninf`, `nnan`, ...), instead of using a separate integer
value. (This is in line with LLVM's fastmath flags representation.) This diff
also leverages the `printBitEnumPrimaryGroups` `tblgen` field for concise
enum printing.

The `BitEnum` features were developed for an upcoming diff that adds `fastmath`
support to the arithmetic dialect. This diff simply applies some of the relevant
new features to the LLVM dialect attribute.

Reviewed By: ftynse, Mogball

Differential Revision: https://reviews.llvm.org/D124720

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/IR/EnumAttr.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-tblgen/EnumsGen.cpp
    mlir/unittests/TableGen/EnumsGenTest.cpp
    mlir/unittests/TableGen/enums.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 9824a3710709e..039294a6e200b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -28,14 +28,17 @@ def FMFarcp     : I32BitEnumAttrCaseBit<"arcp", 3>;
 def FMFcontract : I32BitEnumAttrCaseBit<"contract", 4>;
 def FMFafn      : I32BitEnumAttrCaseBit<"afn", 5>;
 def FMFreassoc  : I32BitEnumAttrCaseBit<"reassoc", 6>;
-def FMFfast     : I32BitEnumAttrCaseBit<"fast", 7>;
+def FMFfast     : I32BitEnumAttrCaseGroup<"fast",
+  [ FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc]>;
 
 def FastmathFlags : I32BitEnumAttr<
     "FastmathFlags",
     "LLVM fastmath flags",
     [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
     ]> {
+  let separator = ", ";
   let cppNamespace = "::mlir::LLVM";
+  let printBitEnumPrimaryGroups = 1;
 }
 
 def LLVM_FMFAttr : DialectAttr<

diff  --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 66a557ce41ab5..a655d3456acd6 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -267,13 +267,20 @@ class BitEnumAttr<I intType, string name, string summary,
   // bits together.
   let symbolToStringFnRetType = "std::string";
 
-  // The delimiter used to separate bit enum cases in strings.
+  // The delimiter used to separate bit enum cases in strings. Only "|" and
+  // "," (along with optional spaces) are supported due to the use of the
+  // parseSeparatorFn in parameterParser below.
+  // Spaces in the separator string are used for printing, but will be optional
+  // for parsing.
   string separator = "|";
+  assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
+      "separator must contain '|' or ',' for parameter parsing";
 
   // Parsing function that corresponds to the enum separator. Only
   // "," and "|" are supported by this definition.
-  string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar",
-                                                   "parseOptionalComma");
+  string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0),
+                                "parseOptionalVerticalBar",
+                                "parseOptionalComma");
 
   // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
   // symbol is not valid.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 8a3a9b13130f6..ca28fdf64ad7f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2884,26 +2884,9 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
 }
 
-static constexpr const FastmathFlags fastmathFlagsList[] = {
-    // clang-format off
-    FastmathFlags::nnan,
-    FastmathFlags::ninf,
-    FastmathFlags::nsz,
-    FastmathFlags::arcp,
-    FastmathFlags::contract,
-    FastmathFlags::afn,
-    FastmathFlags::reassoc,
-    FastmathFlags::fast,
-    // clang-format on
-};
-
 void FMFAttr::print(AsmPrinter &printer) const {
   printer << "<";
-  auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
-    return bitEnumContains(this->getFlags(), flag);
-  });
-  llvm::interleaveComma(flags, printer,
-                        [&](auto flag) { printer << stringifyEnum(flag); });
+  printer << stringifyFastmathFlags(this->getFlags());
   printer << ">";
 }
 

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 10531e034d202..56e0834e2c873 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -157,7 +157,6 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
       {FastmathFlags::contract, &llvmFMF::setAllowContract},
       {FastmathFlags::afn,      &llvmFMF::setApproxFunc},
       {FastmathFlags::reassoc,  &llvmFMF::setAllowReassoc},
-      {FastmathFlags::fast,     &llvmFMF::setFast},
       // clang-format on
   };
   llvm::FastMathFlags ret;

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index cc4da8e095a79..0158a05b521db 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -445,7 +445,7 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) {
 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
   %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32
 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
-  %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
+  %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan,ninf>} : f32
 
 // CHECK: {{.*}} = llvm.fneg %arg0 : f32
   %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 6b14296ca6636..feb2a93e9af42 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -413,9 +413,9 @@ func.func @disallowed_case7_fail() {
 
 // CHECK-LABEL: func @allowed_cases_pass
 func.func @allowed_cases_pass() {
-  // CHECK: test.op_with_bit_enum <read,write>
+  // CHECK: test.op_with_bit_enum <read, write>
   "test.op_with_bit_enum"() {value = #test.bit_enum<read, write>} : () -> ()
-  // CHECK: test.op_with_bit_enum <read,execute>
+  // CHECK: test.op_with_bit_enum <read, execute>
   test.op_with_bit_enum <read,execute>
   return
 }
@@ -424,11 +424,11 @@ func.func @allowed_cases_pass() {
 
 // CHECK-LABEL: func @allowed_cases_pass
 func.func @allowed_cases_pass() {
-  // CHECK: test.op_with_bit_enum_vbar <user|group>
+  // CHECK: test.op_with_bit_enum_vbar <user | group>
   "test.op_with_bit_enum_vbar"() {
     value = #test.bit_enum_vbar<user|group>
   } : () -> ()
-  // CHECK: test.op_with_bit_enum_vbar <user|group|other>
+  // CHECK: test.op_with_bit_enum_vbar <user | group | other>
   test.op_with_bit_enum_vbar <user | group | other>
   return
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fbe4104ee0fde..7826f842bcba8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -324,7 +324,7 @@ def TestBitEnum
       ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "test";
-  let separator = ",";
+  let separator = ", ";
 }
 
 // Define the enum attribute.
@@ -347,7 +347,7 @@ def TestBitEnumVerticalBar
       ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "test";
-  let separator = "|";
+  let separator = " | ";
 }
 
 def TestBitEnumVerticalBarAttr

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 1e71cdb16fe03..d473ee4173406 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -277,6 +277,7 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
   StringRef separator = enumDef.getValueAsString("separator");
+  StringRef separatorTrimmed = separator.trim();
   auto enumerants = enumAttr.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
@@ -292,15 +293,16 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 
   // Split the string to get symbols for all the bits.
   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
-  os << formatv("  str.split(symbols, \"{0}\");\n\n", separator);
+  // Remove whitespace from the separator string when parsing.
+  os << formatv("  str.split(symbols, \"{0}\");\n\n", separatorTrimmed);
 
   os << formatv("  {0} val = 0;\n", underlyingType);
   os << "  for (auto symbol : symbols) {\n";
 
   // Convert each symbol to the bit ordinal and set the corresponding bit.
-  os << formatv(
-      "    auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n",
-      underlyingType);
+  os << formatv("    auto bit = "
+                "llvm::StringSwitch<::llvm::Optional<{0}>>(symbol.trim())\n",
+                underlyingType);
   for (const auto &enumerant : enumerants) {
     // Skip the special enumerant for None.
     if (auto val = enumerant.getValue())

diff  --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index 29d274c915515..d971a9b9e0abe 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -80,7 +80,7 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
   EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), "Bit3");
   EXPECT_EQ(
       stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
-      "Bit0|Bit3");
+      "Bit0 | Bit3");
 
   EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1");
   EXPECT_EQ(
@@ -96,7 +96,7 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
             BitEnumWithNone::Bit3 | BitEnumWithNone::Bit0);
 
   EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None);
-  EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None);
+  EXPECT_EQ(symbolizeBitEnumWithNone("Bit3 | Bit4"), llvm::None);
 
   EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None);
 }
@@ -129,11 +129,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
   EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
                                          BitEnumPrimaryGroup::Bit2 |
                                          BitEnumPrimaryGroup::Bit3),
-            "Bit0,Bit2,Bit3");
+            "Bit0, Bit2, Bit3");
   EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
                                          BitEnumPrimaryGroup::Bit4 |
                                          BitEnumPrimaryGroup::Bit5),
-            "Bits4And5,Bit0");
+            "Bits4And5, Bit0");
   EXPECT_EQ(stringifyBitEnumPrimaryGroup(
                 BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
                 BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 |

diff  --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index 5c48b2c770907..8500d7337a0e1 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -33,7 +33,9 @@ def Bit4 : I32BitEnumAttrCaseBit<"Bit4", 4>;
 def Bit5 : I32BitEnumAttrCaseBit<"Bit5", 5>;
 
 def BitEnumWithNone : I32BitEnumAttr<"BitEnumWithNone", "A test enum",
-                                     [NoBits, Bit0, Bit3]>;
+                                     [NoBits, Bit0, Bit3]> {
+  let separator = " | ";
+}
 
 def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",
                                         [Bit0, Bit3]>;
@@ -46,12 +48,14 @@ def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5",
                                        [Bits0To3, Bits4And5]>;
 
 def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
-                                      [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
+                                      [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]> {
+  let separator = "|";
+}
 
 def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum",
                                         [Bit0, Bit1, Bit2, Bit3, Bit4, Bit5,
                                          Bits0To3, Bits4And5, Bits0To5]> {
-  let separator = ",";
+  let separator = ", ";
   let printBitEnumPrimaryGroups = 1;
 }
 


        


More information about the Mlir-commits mailing list