[Mlir-commits] [mlir] dd38f89 - [mlir][LLVMIR] Update LLVMIR fastmath to use EnumAttr tblgen classes
Slava Zakharin
llvmlistbot at llvm.org
Mon Oct 17 15:06:01 PDT 2022
Author: Jeremy Furtek
Date: 2022-10-17T15:03:47-07:00
New Revision: dd38f899803465dd2765d1601b3989df3bd53863
URL: https://github.com/llvm/llvm-project/commit/dd38f899803465dd2765d1601b3989df3bd53863
DIFF: https://github.com/llvm/llvm-project/commit/dd38f899803465dd2765d1601b3989df3bd53863.diff
LOG: [mlir][LLVMIR] Update LLVMIR fastmath to use EnumAttr tblgen classes
This diff updates the `fastmath` attribute in the LLVMIR dialect to use `tblgen`
classes that were developed after the initial LLVMIR `fastmath` implementation.
Using the `EnumAttr` `tblgen` classes brings the LLVMIR `fastmath` attribute in
line with other dialects, and eliminates some of the custom printing and parsing
code in the LLVMIR dialect.
Subsequent commits will further reduce the custom processing code for the LLVMIR
`fastmath` attribute by unifying printing/parsing functionality between the
LLVMIR and `arith` `fastmath` attributes. (The actual attributes will remain
separate, but the printing and parsing will be made generic, and will be usable
by other dialects/attributes.)
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D135289
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/LLVMIR/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index abdb0b7e01004..56b8e2d4be880 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -1,10 +1,5 @@
add_subdirectory(Transforms)
-set(LLVM_TARGET_DEFINITIONS LLVMAttrDefs.td)
-mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls)
-mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs)
-add_public_tablegen_target(MLIRLLVMAttrsIncGen)
-
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
@@ -12,6 +7,10 @@ mlir_tablegen(LLVMOpsDialect.h.inc -gen-dialect-decls)
mlir_tablegen(LLVMOpsDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(LLVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(LLVMOpsAttrDefs.h.inc -gen-attrdef-decls
+ -attrdefs-dialect=llvm)
+mlir_tablegen(LLVMOpsAttrDefs.cpp.inc -gen-attrdef-defs
+ -attrdefs-dialect=llvm)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 2de2f7a7cb6a4..95d4a90710f93 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -15,17 +15,6 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
// All of the attributes will extend this class.
class LLVM_Attr<string name> : AttrDef<LLVM_Dialect, name>;
-// The "FastMath" flags associated with floating point LLVM instructions.
-def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
- let mnemonic = "fastmath";
-
- // List of type parameters.
- let parameters = (ins
- "FastmathFlags":$flags
- );
- let hasCustomAssemblyFormat = 1;
-}
-
// Attribute definition for the LLVM Linkage enum.
def LinkageAttr : LLVM_Attr<"Linkage"> {
let mnemonic = "linkage";
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b25d151fa2351..b4429fc6d13fe 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -13,7 +13,9 @@
#ifndef LLVMIR_OPS
#define LLVMIR_OPS
+include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
@@ -21,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+def FMFnone : I32BitEnumAttrCaseNone<"none">;
def FMFnnan : I32BitEnumAttrCaseBit<"nnan", 0>;
def FMFninf : I32BitEnumAttrCaseBit<"ninf", 1>;
def FMFnsz : I32BitEnumAttrCaseBit<"nsz", 2>;
@@ -34,22 +37,18 @@ def FMFfast : I32BitEnumAttrCaseGroup<"fast",
def FastmathFlags : I32BitEnumAttr<
"FastmathFlags",
"LLVM fastmath flags",
- [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
- ]> {
+ [FMFnone, FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn,
+ FMFreassoc, FMFfast]> {
let separator = ", ";
let cppNamespace = "::mlir::LLVM";
+ let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}
-def LLVM_FMFAttr : DialectAttr<
- LLVM_Dialect,
- CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">,
- "LLVM fastmath flags"> {
- let storageType = "::mlir::LLVM::FMFAttr";
- let returnType = "::mlir::LLVM::FastmathFlags";
- let convertFromStorage = "$_self.getFlags()";
- let constBuilderCall =
- "::mlir::LLVM::FMFAttr::get($_builder.getContext(), $0)";
+// The "FastMath" flags associated with floating point LLVM instructions.
+def LLVM_FastmathFlagsAttr :
+ EnumAttr<LLVM_Dialect, FastmathFlags, "fastmath"> {
+ let assemblyFormat = "`<` $value `>`";
}
def LOptDisableUnroll : I32EnumAttrCase<"disable_unroll", 1>;
@@ -229,7 +228,8 @@ class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> {
- dag fmfArg = (ins DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ dag fmfArg = (
+ ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let arguments = !con(commonArgs, fmfArg);
}
@@ -239,7 +239,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
LLVM_Op<mnemonic,
!listconcat([Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
LLVM_Builder<"$res = builder.Create" # instName # "($operand);"> {
- let arguments = (ins type:$operand, DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ let arguments = (
+ ins type:$operand,
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let results = (outs type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
@@ -354,7 +356,8 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
let arguments = (ins FCmpPredicate:$predicate,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
- DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+ "{}">:$fastmathFlags);
let builders = [
OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)>
];
@@ -747,7 +750,8 @@ def LLVM_CallOp : LLVM_Op<"call",
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>,
- DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr,
+ "{}">:$fastmathFlags);
let results = (outs Optional<LLVM_Type>:$result);
let builders = [
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index a07c434ac4a39..f37d47d744cfe 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -72,7 +72,7 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
Value real = complexStruct.real(rewriter, op.getLoc());
Value imag = complexStruct.imaginary(rewriter, op.getLoc());
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value sqNorm = rewriter.create<LLVM::FAddOp>(
loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
@@ -180,7 +180,7 @@ struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
@@ -208,7 +208,7 @@ struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
@@ -253,7 +253,7 @@ struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value rhsRe = arg.rhs.real();
Value rhsIm = arg.rhs.imag();
Value lhsRe = arg.lhs.real();
@@ -290,7 +290,7 @@ struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to substract complex numbers.
- auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
+ auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
Value real =
rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3e9c235a53633..0ed3294b3d845 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -54,7 +54,8 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute, 8> filteredAttrs(
llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
if (attr.getName() == "fastmathFlags") {
- auto defAttr = FMFAttr::get(attr.getValue().getContext(), {});
+ auto defAttr =
+ FastmathFlagsAttr::get(attr.getValue().getContext(), {});
return defAttr != attr.getValue();
}
return true;
@@ -2563,7 +2564,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//
void LLVMDialect::initialize() {
- addAttributes<FMFAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
+ addAttributes<FastmathFlagsAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
// clang-format off
addTypes<LLVMVoidType,
@@ -2809,39 +2810,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
-void FMFAttr::print(AsmPrinter &printer) const {
- printer << "<";
- printer << stringifyFastmathFlags(this->getFlags());
- printer << ">";
-}
-
-Attribute FMFAttr::parse(AsmParser &parser, Type type) {
- if (failed(parser.parseLess()))
- return {};
-
- FastmathFlags flags = {};
- if (failed(parser.parseOptionalGreater())) {
- auto parseFlags = [&]() -> ParseResult {
- StringRef elemName;
- if (failed(parser.parseKeyword(&elemName)))
- return failure();
-
- auto elem = symbolizeFastmathFlags(elemName);
- if (!elem)
- return parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
- << elemName;
-
- flags = flags | *elem;
- return success();
- };
- if (failed(parser.parseCommaSeparatedList(parseFlags)) ||
- parser.parseGreater())
- return {};
- }
-
- return FMFAttr::get(parser.getContext(), flags);
-}
-
void LinkageAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ad26fd7436fa9..5f565bb7ce1b0 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -477,12 +477,12 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
%7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
- %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : f32
+ %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<none>} : f32
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
%9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan,ninf>} : f32
// CHECK: {{.*}} = llvm.fneg %arg0 : f32
- %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : f32
+ %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<none>} : f32
return
}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index f7b021133465d..51538743027b0 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1666,7 +1666,7 @@ llvm.func @fastmathFlags(%arg0: f32) {
// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}})
// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}})
// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}})
- %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (f32) -> (f32)
+ %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<none>} : (f32) -> (f32)
%9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nnan>} : (f32) -> (f32)
%10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<ninf>} : (f32) -> (f32)
%11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nsz>} : (f32) -> (f32)
More information about the Mlir-commits
mailing list