[Mlir-commits] [mlir] 5c3b205 - [mlir] Update LLVMIR Fastmath flags use of MLIR BitEnum functionality
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 17 11:19:18 PDT 2022
Author: jfurtek
Date: 2022-05-17T18:19:14Z
New Revision: 5c3b20520b5716d61833c8ce45d19faa48ce8db7
URL: https://github.com/llvm/llvm-project/commit/5c3b20520b5716d61833c8ce45d19faa48ce8db7
DIFF: https://github.com/llvm/llvm-project/commit/5c3b20520b5716d61833c8ce45d19faa48ce8db7.diff
LOG: [mlir] Update LLVMIR Fastmath flags use of MLIR BitEnum functionality
This diff updates the LLVMIR dialect Fastmath flags attribute to use recently
added features of `BitEnum` attributes. Specifically, this diff uses the bit
enum "group" case to represent the `fast` value as an alias for a combination
of other values (`ninf`, `nnan`, ...), instead of using a separate integer
value. (This is in line with LLVM's fastmath flags representation.) This diff
also leverages the `printBitEnumPrimaryGroups` `tblgen` field for concise
enum printing.
The `BitEnum` features were developed for an upcoming diff that adds `fastmath`
support to the arithmetic dialect. This diff simply applies some of the relevant
new features to the LLVM dialect attribute.
Reviewed By: ftynse, Mogball
Differential Revision: https://reviews.llvm.org/D124720
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/IR/EnumAttr.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/IR/attribute.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
mlir/unittests/TableGen/enums.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 9824a3710709e..039294a6e200b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -28,14 +28,17 @@ def FMFarcp : I32BitEnumAttrCaseBit<"arcp", 3>;
def FMFcontract : I32BitEnumAttrCaseBit<"contract", 4>;
def FMFafn : I32BitEnumAttrCaseBit<"afn", 5>;
def FMFreassoc : I32BitEnumAttrCaseBit<"reassoc", 6>;
-def FMFfast : I32BitEnumAttrCaseBit<"fast", 7>;
+def FMFfast : I32BitEnumAttrCaseGroup<"fast",
+ [ FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc]>;
def FastmathFlags : I32BitEnumAttr<
"FastmathFlags",
"LLVM fastmath flags",
[FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
]> {
+ let separator = ", ";
let cppNamespace = "::mlir::LLVM";
+ let printBitEnumPrimaryGroups = 1;
}
def LLVM_FMFAttr : DialectAttr<
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 66a557ce41ab5..a655d3456acd6 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -267,13 +267,20 @@ class BitEnumAttr<I intType, string name, string summary,
// bits together.
let symbolToStringFnRetType = "std::string";
- // The delimiter used to separate bit enum cases in strings.
+ // The delimiter used to separate bit enum cases in strings. Only "|" and
+ // "," (along with optional spaces) are supported due to the use of the
+ // parseSeparatorFn in parameterParser below.
+ // Spaces in the separator string are used for printing, but will be optional
+ // for parsing.
string separator = "|";
+ assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
+ "separator must contain '|' or ',' for parameter parsing";
// Parsing function that corresponds to the enum separator. Only
// "," and "|" are supported by this definition.
- string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar",
- "parseOptionalComma");
+ string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0),
+ "parseOptionalVerticalBar",
+ "parseOptionalComma");
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 8a3a9b13130f6..ca28fdf64ad7f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2884,26 +2884,9 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
-static constexpr const FastmathFlags fastmathFlagsList[] = {
- // clang-format off
- FastmathFlags::nnan,
- FastmathFlags::ninf,
- FastmathFlags::nsz,
- FastmathFlags::arcp,
- FastmathFlags::contract,
- FastmathFlags::afn,
- FastmathFlags::reassoc,
- FastmathFlags::fast,
- // clang-format on
-};
-
void FMFAttr::print(AsmPrinter &printer) const {
printer << "<";
- auto flags = llvm::make_filter_range(fastmathFlagsList, [&](auto flag) {
- return bitEnumContains(this->getFlags(), flag);
- });
- llvm::interleaveComma(flags, printer,
- [&](auto flag) { printer << stringifyEnum(flag); });
+ printer << stringifyFastmathFlags(this->getFlags());
printer << ">";
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 10531e034d202..56e0834e2c873 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -157,7 +157,6 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
{FastmathFlags::contract, &llvmFMF::setAllowContract},
{FastmathFlags::afn, &llvmFMF::setApproxFunc},
{FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
- {FastmathFlags::fast, &llvmFMF::setFast},
// clang-format on
};
llvm::FastMathFlags ret;
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index cc4da8e095a79..0158a05b521db 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -445,7 +445,7 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) {
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
%8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
- %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
+ %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan,ninf>} : f32
// CHECK: {{.*}} = llvm.fneg %arg0 : f32
%10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 6b14296ca6636..feb2a93e9af42 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -413,9 +413,9 @@ func.func @disallowed_case7_fail() {
// CHECK-LABEL: func @allowed_cases_pass
func.func @allowed_cases_pass() {
- // CHECK: test.op_with_bit_enum <read,write>
+ // CHECK: test.op_with_bit_enum <read, write>
"test.op_with_bit_enum"() {value = #test.bit_enum<read, write>} : () -> ()
- // CHECK: test.op_with_bit_enum <read,execute>
+ // CHECK: test.op_with_bit_enum <read, execute>
test.op_with_bit_enum <read,execute>
return
}
@@ -424,11 +424,11 @@ func.func @allowed_cases_pass() {
// CHECK-LABEL: func @allowed_cases_pass
func.func @allowed_cases_pass() {
- // CHECK: test.op_with_bit_enum_vbar <user|group>
+ // CHECK: test.op_with_bit_enum_vbar <user | group>
"test.op_with_bit_enum_vbar"() {
value = #test.bit_enum_vbar<user|group>
} : () -> ()
- // CHECK: test.op_with_bit_enum_vbar <user|group|other>
+ // CHECK: test.op_with_bit_enum_vbar <user | group | other>
test.op_with_bit_enum_vbar <user | group | other>
return
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fbe4104ee0fde..7826f842bcba8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -324,7 +324,7 @@ def TestBitEnum
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
- let separator = ",";
+ let separator = ", ";
}
// Define the enum attribute.
@@ -347,7 +347,7 @@ def TestBitEnumVerticalBar
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
- let separator = "|";
+ let separator = " | ";
}
def TestBitEnumVerticalBarAttr
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 1e71cdb16fe03..d473ee4173406 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -277,6 +277,7 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef separator = enumDef.getValueAsString("separator");
+ StringRef separatorTrimmed = separator.trim();
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
@@ -292,15 +293,16 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
// Split the string to get symbols for all the bits.
os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
- os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
+ // Remove whitespace from the separator string when parsing.
+ os << formatv(" str.split(symbols, \"{0}\");\n\n", separatorTrimmed);
os << formatv(" {0} val = 0;\n", underlyingType);
os << " for (auto symbol : symbols) {\n";
// Convert each symbol to the bit ordinal and set the corresponding bit.
- os << formatv(
- " auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n",
- underlyingType);
+ os << formatv(" auto bit = "
+ "llvm::StringSwitch<::llvm::Optional<{0}>>(symbol.trim())\n",
+ underlyingType);
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index 29d274c915515..d971a9b9e0abe 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -80,7 +80,7 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), "Bit3");
EXPECT_EQ(
stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
- "Bit0|Bit3");
+ "Bit0 | Bit3");
EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1");
EXPECT_EQ(
@@ -96,7 +96,7 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
BitEnumWithNone::Bit3 | BitEnumWithNone::Bit0);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None);
- EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None);
+ EXPECT_EQ(symbolizeBitEnumWithNone("Bit3 | Bit4"), llvm::None);
EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None);
}
@@ -129,11 +129,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
BitEnumPrimaryGroup::Bit2 |
BitEnumPrimaryGroup::Bit3),
- "Bit0,Bit2,Bit3");
+ "Bit0, Bit2, Bit3");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
BitEnumPrimaryGroup::Bit4 |
BitEnumPrimaryGroup::Bit5),
- "Bits4And5,Bit0");
+ "Bits4And5, Bit0");
EXPECT_EQ(stringifyBitEnumPrimaryGroup(
BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 |
diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index 5c48b2c770907..8500d7337a0e1 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -33,7 +33,9 @@ def Bit4 : I32BitEnumAttrCaseBit<"Bit4", 4>;
def Bit5 : I32BitEnumAttrCaseBit<"Bit5", 5>;
def BitEnumWithNone : I32BitEnumAttr<"BitEnumWithNone", "A test enum",
- [NoBits, Bit0, Bit3]>;
+ [NoBits, Bit0, Bit3]> {
+ let separator = " | ";
+}
def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",
[Bit0, Bit3]>;
@@ -46,12 +48,14 @@ def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5",
[Bits0To3, Bits4And5]>;
def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
- [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
+ [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]> {
+ let separator = "|";
+}
def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum",
[Bit0, Bit1, Bit2, Bit3, Bit4, Bit5,
Bits0To3, Bits4And5, Bits0To5]> {
- let separator = ",";
+ let separator = ", ";
let printBitEnumPrimaryGroups = 1;
}
More information about the Mlir-commits
mailing list