[llvm-branch-commits] [mlir] c1d58c2 - [mlir] Add fastmath flags support to some LLVM dialect ops
Alex Zinenko via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 7 05:05:01 PST 2021
Author: Ivan Butygin
Date: 2021-01-07T14:00:09+01:00
New Revision: c1d58c2b0023cd41f0da128f5190fa887d8f6c69
URL: https://github.com/llvm/llvm-project/commit/c1d58c2b0023cd41f0da128f5190fa887d8f6c69
DIFF: https://github.com/llvm/llvm-project/commit/c1d58c2b0023cd41f0da128f5190fa887d8f6c69.diff
LOG: [mlir] Add fastmath flags support to some LLVM dialect ops
Add fastmath enum, attributes to some llvm dialect ops, `FastmathFlagsInterface` op interface, and `translateModuleToLLVMIR` support.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D92485
Added:
mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
Modified:
mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 6166f3632607..29cef3f0032d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -10,6 +10,8 @@ add_public_tablegen_target(MLIRLLVMOpsIncGen)
add_mlir_doc(LLVMOps -gen-op-doc LLVMOps Dialects/)
+add_mlir_interface(LLVMOpsInterfaces)
+
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 630bad4914b1..22ff1517f77b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -29,6 +29,7 @@
#include "llvm/IR/Type.h"
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
+#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc"
namespace llvm {
class Type;
@@ -46,8 +47,23 @@ class LLVMDialect;
namespace detail {
struct LLVMTypeStorage;
struct LLVMDialectImpl;
+struct BitmaskEnumStorage;
} // namespace detail
+/// An attribute that specifies LLVM instruction fastmath flags.
+class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
+ detail::BitmaskEnumStorage> {
+public:
+ using Base::Base;
+
+ static FMFAttr get(FastmathFlags flags, MLIRContext *context);
+
+ FastmathFlags getFlags() const;
+
+ void print(DialectAsmPrinter &p) const;
+ static Attribute parse(DialectAsmParser &parser);
+};
+
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 428ca6783afd..53c42540aa48 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -14,10 +14,39 @@
#define LLVMIR_OPS
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+def FMFnnan : BitEnumAttrCase<"nnan", 0x1>;
+def FMFninf : BitEnumAttrCase<"ninf", 0x2>;
+def FMFnsz : BitEnumAttrCase<"nsz", 0x4>;
+def FMFarcp : BitEnumAttrCase<"arcp", 0x8>;
+def FMFcontract : BitEnumAttrCase<"contract", 0x10>;
+def FMFafn : BitEnumAttrCase<"afn", 0x20>;
+def FMFreassoc : BitEnumAttrCase<"reassoc", 0x40>;
+def FMFfast : BitEnumAttrCase<"fast", 0x80>;
+
+def FastmathFlags : BitEnumAttr<
+ "FastmathFlags",
+ "LLVM fastmath flags",
+ [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
+ ]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
+def LLVM_FMFAttr : DialectAttr<
+ LLVM_Dialect,
+ CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">,
+ "LLVM fastmath flags"> {
+ let storageType = "::mlir::LLVM::FMFAttr";
+ let returnType = "::mlir::LLVM::FastmathFlags";
+ let convertFromStorage = "$_self.getFlags()";
+ let constBuilderCall =
+ "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
+}
+
class LLVM_Builder<string builder> {
string llvmBuilder = builder;
}
@@ -77,29 +106,35 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
LLVM_Op<mnemonic,
!listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> {
- let arguments = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
- LLVM_ScalarOrVectorOf<type>:$rhs);
+ dag commonArgs = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
+ LLVM_ScalarOrVectorOf<type>:$rhs);
let results = (outs LLVM_ScalarOrVectorOf<type>:$res);
let builders = [LLVM_OneResultOpBuilder];
- let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)";
+ let assemblyFormat = "$lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
}
class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
list<OpTrait> traits = []> :
- LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits>;
+ LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits> {
+ let arguments = commonArgs;
+}
class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
list<OpTrait> traits = []> :
- LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc, traits>;
+ LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc,
+ !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> {
+ dag fmfArg = (ins DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+ let arguments = !con(commonArgs, fmfArg);
+}
// Class for arithmetic unary operations.
-class LLVM_UnaryArithmeticOp<Type type, string mnemonic,
+class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
string builderFunc, list<OpTrait> traits = []> :
LLVM_Op<mnemonic,
- !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
+ !listconcat([NoSideEffect, SameOperandsAndResultType, DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> {
- let arguments = (ins type:$operand);
+ let arguments = (ins type:$operand, DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
let results = (outs type:$res);
let builders = [LLVM_OneResultOpBuilder];
- let assemblyFormat = "$operand attr-dict `:` type($res)";
+ let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
}
// Integer binary operations.
@@ -185,20 +220,24 @@ def FCmpPredicate : I64EnumAttr<
let cppNamespace = "::mlir::LLVM";
}
-// Other integer operations.
-def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> {
+// Other floating-point operations.
+def LLVM_FCmpOp : LLVM_Op<"fcmp", [
+ NoSideEffect, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let arguments = (ins FCmpPredicate:$predicate,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
- LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs);
+ LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
+ DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_i1>:$res);
let llvmBuilder = [{
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
let builders = [
- OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
+ OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs,
+ CArg<"FastmathFlags", "{}">:$fmf),
[{
build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1),
- $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
+ $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
+ ::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext()));
}]>];
let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
let printer = [{ printFCmpOp(p, *this); }];
@@ -210,8 +249,8 @@ def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">;
def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">;
def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">;
def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">;
-def LLVM_FNegOp : LLVM_UnaryArithmeticOp<LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
- "fneg", "CreateFNeg">;
+def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp<
+ LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, "fneg", "CreateFNeg">;
// Common code definition that is used to verify and set the alignment attribute
// of LLVM ops that accept such an attribute.
@@ -405,7 +444,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let printer = [{ printLandingpadOp(p, *this); }];
}
-def LLVM_CallOp : LLVM_Op<"call"> {
+def LLVM_CallOp : LLVM_Op<"call",
+ [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
@@ -436,7 +476,8 @@ def LLVM_CallOp : LLVM_Op<"call"> {
```
}];
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
- Variadic<LLVM_Type>);
+ Variadic<LLVM_Type>,
+ DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
let results = (outs Variadic<LLVM_Type>);
let builders = [
OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
new file mode 100644
index 000000000000..d31ae81ab2dd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
@@ -0,0 +1,30 @@
+//===-- LLVMOpsInterfaces.td - LLVM op interfaces ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the LLVM IR interfaces definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_OPS_INTERFACES
+#define LLVM_OPS_INTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
+ let description = [{
+ Access to op fastmath flags.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">,
+ ];
+}
+
+#endif // LLVM_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 2ebb24b5aaeb..78927fbcd457 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -828,7 +828,8 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
operation, dstType,
rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
- operation.operand1(), operation.operand2());
+ operation.operand1(), operation.operand2(),
+ LLVM::FMFAttr::get({}, operation.getContext()));
return success();
}
};
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 39680a28a33e..5e270881656c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1836,10 +1836,11 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to add complex numbers.
+ auto fmf = LLVM::FMFAttr::get({}, op.getContext());
Value real =
- rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real());
+ rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
- rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag());
+ rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -1863,10 +1864,11 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
// Emit IR to substract complex numbers.
+ auto fmf = LLVM::FMFAttr::get({}, op.getContext());
Value real =
- rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real());
+ rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
Value imag =
- rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag());
+ rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
result.setReal(rewriter, loc, real);
result.setImaginary(rewriter, loc, imag);
@@ -3155,11 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
ConversionPatternRewriter &rewriter) const override {
CmpFOpAdaptor transformed(operands);
+ auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext());
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
rewriter.getI64IntegerAttr(static_cast<int64_t>(
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
- transformed.lhs(), transformed.rhs());
+ transformed.lhs(), transformed.rhs(), fmf);
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 91fb02db9601..c2f88d06062c 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLLVMIR
DEPENDS
MLIRLLVMOpsIncGen
+ MLIRLLVMOpsInterfacesIncGen
MLIROpenMPOpsIncGen
intrinsics_gen
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0a9b61628384..b7f7789ee44b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -36,6 +36,51 @@ static constexpr const char kVolatileAttrName[] = "volatile_";
static constexpr const char kNonTemporalAttrName[] = "nontemporal";
#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
+#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct BitmaskEnumStorage : public AttributeStorage {
+ using KeyTy = uint64_t;
+
+ BitmaskEnumStorage(KeyTy val) : value(val) {}
+
+ bool operator==(const KeyTy &key) const { return value == key; }
+
+ static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<BitmaskEnumStorage>())
+ BitmaskEnumStorage(key);
+ }
+
+ KeyTy value = 0;
+};
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute, 8> filteredAttrs(
+ llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+ if (attr.first == "fastmathFlags") {
+ auto defAttr = FMFAttr::get({}, attr.second.getContext());
+ return defAttr != attr.second;
+ }
+ return true;
+ }));
+ return filteredAttrs;
+}
+
+static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
+ NamedAttrList &result) {
+ return parser.parseOptionalAttrDict(result);
+}
+
+static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
+ DictionaryAttr attrs) {
+ printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::CmpOp.
@@ -50,7 +95,7 @@ static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
<< "\" " << op.getOperand(0) << ", " << op.getOperand(1);
- p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
+ p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"predicate"});
p << " : " << op.lhs().getType();
}
@@ -771,7 +816,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
p << '(' << args << ')';
- p.printOptionalAttrDict(op.getAttrs(), {"callee"});
+ p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"callee"});
// Reconstruct the function MLIR function type from operand and result types.
p << " : "
@@ -2041,6 +2086,8 @@ static LogicalResult verify(FenceOp &op) {
//===----------------------------------------------------------------------===//
void LLVMDialect::initialize() {
+ addAttributes<FMFAttr>();
+
// clang-format off
addTypes<LLVMVoidType,
LLVMHalfType,
@@ -2172,3 +2219,87 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
+
+FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) {
+ return Base::get(context, static_cast<uint64_t>(flags));
+}
+
+FastmathFlags FMFAttr::getFlags() const {
+ return static_cast<FastmathFlags>(getImpl()->value);
+}
+
+static constexpr const FastmathFlags FastmathFlagsList[] = {
+ // clang-format off
+ FastmathFlags::nnan,
+ FastmathFlags::ninf,
+ FastmathFlags::nsz,
+ FastmathFlags::arcp,
+ FastmathFlags::contract,
+ FastmathFlags::afn,
+ FastmathFlags::reassoc,
+ FastmathFlags::fast,
+ // clang-format on
+};
+
+void FMFAttr::print(DialectAsmPrinter &printer) const {
+ printer << "fastmath<";
+ auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
+ return bitEnumContains(getFlags(), flag);
+ });
+ llvm::interleaveComma(flags, printer,
+ [&](auto flag) { printer << stringifyEnum(flag); });
+ printer << ">";
+}
+
+Attribute FMFAttr::parse(DialectAsmParser &parser) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ FastmathFlags flags = {};
+ if (failed(parser.parseOptionalGreater())) {
+ do {
+ StringRef elemName;
+ if (failed(parser.parseKeyword(&elemName)))
+ return {};
+
+ auto elem = symbolizeFastmathFlags(elemName);
+ if (!elem) {
+ parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
+ << elemName;
+ return {};
+ }
+
+ flags = flags | *elem;
+ } while (succeeded(parser.parseOptionalComma()));
+
+ if (failed(parser.parseGreater()))
+ return {};
+ }
+
+ return FMFAttr::get(flags, parser.getBuilder().getContext());
+}
+
+Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
+ Type type) const {
+ if (type) {
+ parser.emitError(parser.getNameLoc(), "unexpected type");
+ return {};
+ }
+ StringRef attrKind;
+ if (parser.parseKeyword(&attrKind))
+ return {};
+
+ if (attrKind == "fastmath")
+ return FMFAttr::parse(parser);
+
+ parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ")
+ << attrKind;
+ return {};
+}
+
+void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
+ if (auto fmf = attr.dyn_cast<FMFAttr>())
+ fmf.print(os);
+ else
+ llvm_unreachable("Unknown attribute type");
+}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 5ffb11e76a93..7700867bb461 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -666,6 +666,29 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
});
}
+static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
+ using llvmFMF = llvm::FastMathFlags;
+ using FuncT = void (llvmFMF::*)(bool);
+ const std::pair<FastmathFlags, FuncT> handlers[] = {
+ // clang-format off
+ {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
+ {FastmathFlags::ninf, &llvmFMF::setNoInfs},
+ {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
+ {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
+ {FastmathFlags::contract, &llvmFMF::setAllowContract},
+ {FastmathFlags::afn, &llvmFMF::setApproxFunc},
+ {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
+ {FastmathFlags::fast, &llvmFMF::setFast},
+ // clang-format on
+ };
+ llvm::FastMathFlags ret;
+ auto fmf = op.fastmathFlags();
+ for (auto it : handlers)
+ if (bitEnumContains(fmf, it.first))
+ (ret.*(it.second))(true);
+ return ret;
+}
+
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`. LLVM IR Builder does not have a generic interface so
/// this has to be a long chain of `if`s calling
diff erent functions with a
@@ -680,6 +703,10 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
return position;
};
+ llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
+ if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
+ builder.setFastMathFlags(getFastmathFlags(fmf));
+
#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
// Emit function calls. If the "callee" attribute is present, this is a
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index fc9ff686d78f..05d83810e179 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -387,3 +387,35 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
llvm.return
}
+
+// CHECK-LABEL: @fastmathFlags
+func @fastmathFlags(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32) {
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %1 = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %2 = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+ %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)>
+ %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)>
+
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : !llvm.float
+ %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : !llvm.float
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+ %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.fneg %arg0 : !llvm.float
+ %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : !llvm.float
+ return
+}
diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 099b8c96cb16..921c3e87fdae 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1360,6 +1360,50 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
// -----
+llvm.func @fastmathFlagsFunc(!llvm.float) -> !llvm.float
+
+// CHECK-LABEL: @fastmathFlags
+llvm.func @fastmathFlags(%arg0: !llvm.float) {
+// CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = fdiv nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = frem nnan ninf float {{.*}}, {{.*}}
+ %0 = llvm.fadd %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+ %1 = llvm.fsub %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+ %2 = llvm.fmul %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+ %3 = llvm.fdiv %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+ %4 = llvm.frem %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = fcmp nnan ninf oeq {{.*}}, {{.*}}
+ %5 = llvm.fcmp "oeq" %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = fneg nnan ninf float {{.*}}
+ %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = call float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call nnan float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call ninf float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call nsz float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call arcp float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call contract float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}})
+ %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (!llvm.float) -> (!llvm.float)
+ %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nnan>} : (!llvm.float) -> (!llvm.float)
+ %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<ninf>} : (!llvm.float) -> (!llvm.float)
+ %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nsz>} : (!llvm.float) -> (!llvm.float)
+ %12 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<arcp>} : (!llvm.float) -> (!llvm.float)
+ %13 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.float) -> (!llvm.float)
+ %14 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<afn>} : (!llvm.float) -> (!llvm.float)
+ %15 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<reassoc>} : (!llvm.float) -> (!llvm.float)
+ %16 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.float) -> (!llvm.float)
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @switch_args
llvm.func @switch_args(%arg0: !llvm.i32) {
%0 = llvm.mlir.constant(5 : i32) : !llvm.i32
More information about the llvm-branch-commits
mailing list