[llvm-branch-commits] [mlir] c1d58c2 - [mlir] Add fastmath flags support to some LLVM dialect ops

Alex Zinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 7 05:05:01 PST 2021


Author: Ivan Butygin
Date: 2021-01-07T14:00:09+01:00
New Revision: c1d58c2b0023cd41f0da128f5190fa887d8f6c69

URL: https://github.com/llvm/llvm-project/commit/c1d58c2b0023cd41f0da128f5190fa887d8f6c69
DIFF: https://github.com/llvm/llvm-project/commit/c1d58c2b0023cd41f0da128f5190fa887d8f6c69.diff

LOG: [mlir] Add fastmath flags support to some LLVM dialect ops

Add fastmath enum, attributes to some llvm dialect ops, `FastmathFlagsInterface` op interface, and `translateModuleToLLVMIR` support.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D92485

Added: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/CMakeLists.txt
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 6166f3632607..29cef3f0032d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -10,6 +10,8 @@ add_public_tablegen_target(MLIRLLVMOpsIncGen)
 
 add_mlir_doc(LLVMOps -gen-op-doc LLVMOps Dialects/)
 
+add_mlir_interface(LLVMOpsInterfaces)
+
 set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
 mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
 mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 630bad4914b1..22ff1517f77b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -29,6 +29,7 @@
 #include "llvm/IR/Type.h"
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
+#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc"
 
 namespace llvm {
 class Type;
@@ -46,8 +47,23 @@ class LLVMDialect;
 namespace detail {
 struct LLVMTypeStorage;
 struct LLVMDialectImpl;
+struct BitmaskEnumStorage;
 } // namespace detail
 
+/// An attribute that specifies LLVM instruction fastmath flags.
+class FMFAttr : public Attribute::AttrBase<FMFAttr, Attribute,
+                                           detail::BitmaskEnumStorage> {
+public:
+  using Base::Base;
+
+  static FMFAttr get(FastmathFlags flags, MLIRContext *context);
+
+  FastmathFlags getFlags() const;
+
+  void print(DialectAsmPrinter &p) const;
+  static Attribute parse(DialectAsmParser &parser);
+};
+
 } // namespace LLVM
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 428ca6783afd..53c42540aa48 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -14,10 +14,39 @@
 #define LLVMIR_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+def FMFnnan     : BitEnumAttrCase<"nnan", 0x1>;
+def FMFninf     : BitEnumAttrCase<"ninf", 0x2>;
+def FMFnsz      : BitEnumAttrCase<"nsz", 0x4>;
+def FMFarcp     : BitEnumAttrCase<"arcp", 0x8>;
+def FMFcontract : BitEnumAttrCase<"contract", 0x10>;
+def FMFafn      : BitEnumAttrCase<"afn", 0x20>;
+def FMFreassoc  : BitEnumAttrCase<"reassoc", 0x40>;
+def FMFfast     : BitEnumAttrCase<"fast", 0x80>;
+
+def FastmathFlags : BitEnumAttr<
+    "FastmathFlags",
+    "LLVM fastmath flags",
+    [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
+    ]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
+def LLVM_FMFAttr : DialectAttr<
+    LLVM_Dialect,
+    CPred<"$_self.isa<::mlir::LLVM::FMFAttr>()">,
+    "LLVM fastmath flags"> {
+  let storageType = "::mlir::LLVM::FMFAttr";
+  let returnType = "::mlir::LLVM::FastmathFlags";
+  let convertFromStorage = "$_self.getFlags()";
+  let constBuilderCall =
+          "::mlir::LLVM::FMFAttr::get($0, $_builder.getContext())";
+}
+
 class LLVM_Builder<string builder> {
   string llvmBuilder = builder;
 }
@@ -77,29 +106,35 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
     LLVM_Op<mnemonic,
            !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
     LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> {
-  let arguments = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
-                   LLVM_ScalarOrVectorOf<type>:$rhs);
+  dag commonArgs = (ins LLVM_ScalarOrVectorOf<type>:$lhs,
+                    LLVM_ScalarOrVectorOf<type>:$rhs);
   let results = (outs LLVM_ScalarOrVectorOf<type>:$res);
   let builders = [LLVM_OneResultOpBuilder];
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($res)";
+  let assemblyFormat = "$lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
 }
 class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
                            list<OpTrait> traits = []> :
-    LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits>;
+    LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits> {
+  let arguments = commonArgs;
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
                              list<OpTrait> traits = []> :
-    LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc, traits>;
+    LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, builderFunc,
+    !listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)> {
+  dag fmfArg = (ins DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
+  let arguments = !con(commonArgs, fmfArg);
+}
 
 // Class for arithmetic unary operations.
-class LLVM_UnaryArithmeticOp<Type type, string mnemonic,
+class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
                              string builderFunc, list<OpTrait> traits = []> :
     LLVM_Op<mnemonic,
-           !listconcat([NoSideEffect, SameOperandsAndResultType], traits)>,
+           !listconcat([NoSideEffect, SameOperandsAndResultType, DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
     LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> {
-  let arguments = (ins type:$operand);
+  let arguments = (ins type:$operand, DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
   let results = (outs type:$res);
   let builders = [LLVM_OneResultOpBuilder];
-  let assemblyFormat = "$operand attr-dict `:` type($res)";
+  let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
 }
 
 // Integer binary operations.
@@ -185,20 +220,24 @@ def FCmpPredicate : I64EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
-// Other integer operations.
-def LLVM_FCmpOp : LLVM_Op<"fcmp", [NoSideEffect]> {
+// Other floating-point operations.
+def LLVM_FCmpOp : LLVM_Op<"fcmp", [
+    NoSideEffect, DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
   let arguments = (ins FCmpPredicate:$predicate,
                    LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$lhs,
-                   LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs);
+                   LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$rhs,
+                   DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
   let results = (outs LLVM_ScalarOrVectorOf<LLVM_i1>:$res);
   let llvmBuilder = [{
     $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
   }];
   let builders = [
-    OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs),
+    OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs,
+                  CArg<"FastmathFlags", "{}">:$fmf),
     [{
       build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1),
-            $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
+            $_builder.getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs,
+            ::mlir::LLVM::FMFAttr::get(fmf, $_builder.getContext()));
     }]>];
   let parser = [{ return parseCmpOp<FCmpPredicate>(parser, result); }];
   let printer = [{ printFCmpOp(p, *this); }];
@@ -210,8 +249,8 @@ def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">;
 def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">;
 def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">;
 def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">;
-def LLVM_FNegOp : LLVM_UnaryArithmeticOp<LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
-                                         "fneg", "CreateFNeg">;
+def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp<
+  LLVM_ScalarOrVectorOf<LLVM_AnyFloat>, "fneg", "CreateFNeg">;
 
 // Common code definition that is used to verify and set the alignment attribute
 // of LLVM ops that accept such an attribute.
@@ -405,7 +444,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
   let printer = [{ printLandingpadOp(p, *this); }];
 }
 
-def LLVM_CallOp : LLVM_Op<"call"> {
+def LLVM_CallOp : LLVM_Op<"call",
+                          [DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
   let summary = "Call to an LLVM function.";
   let description = [{
 
@@ -436,7 +476,8 @@ def LLVM_CallOp : LLVM_Op<"call"> {
     ```
   }];
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
-                   Variadic<LLVM_Type>);
+                   Variadic<LLVM_Type>,
+                   DefaultValuedAttr<LLVM_FMFAttr, "{}">:$fastmathFlags);
   let results = (outs Variadic<LLVM_Type>);
   let builders = [
     OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
new file mode 100644
index 000000000000..d31ae81ab2dd
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
@@ -0,0 +1,30 @@
+//===-- LLVMOpsInterfaces.td - LLVM op interfaces ----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the LLVM IR interfaces definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_OPS_INTERFACES
+#define LLVM_OPS_INTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
+  let description = [{
+    Access to op fastmath flags.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">,
+  ];
+}
+
+#endif // LLVM_OPS_INTERFACES

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 2ebb24b5aaeb..78927fbcd457 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -828,7 +828,8 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
     rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
         operation, dstType,
         rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
-        operation.operand1(), operation.operand2());
+        operation.operand1(), operation.operand2(),
+        LLVM::FMFAttr::get({}, operation.getContext()));
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 39680a28a33e..5e270881656c 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1836,10 +1836,11 @@ struct AddCFOpLowering : public ConvertOpToLLVMPattern<AddCFOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to add complex numbers.
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
     Value real =
-        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real());
+        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
     Value imag =
-        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag());
+        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
     result.setReal(rewriter, loc, real);
     result.setImaginary(rewriter, loc, imag);
 
@@ -1863,10 +1864,11 @@ struct SubCFOpLowering : public ConvertOpToLLVMPattern<SubCFOp> {
     auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
 
     // Emit IR to substract complex numbers.
+    auto fmf = LLVM::FMFAttr::get({}, op.getContext());
     Value real =
-        rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real());
+        rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
     Value imag =
-        rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag());
+        rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
     result.setReal(rewriter, loc, real);
     result.setImaginary(rewriter, loc, imag);
 
@@ -3155,11 +3157,12 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<CmpFOp> {
                   ConversionPatternRewriter &rewriter) const override {
     CmpFOpAdaptor transformed(operands);
 
+    auto fmf = LLVM::FMFAttr::get({}, cmpfOp.getContext());
     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
         cmpfOp, typeConverter->convertType(cmpfOp.getResult().getType()),
         rewriter.getI64IntegerAttr(static_cast<int64_t>(
             convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
-        transformed.lhs(), transformed.rhs());
+        transformed.lhs(), transformed.rhs(), fmf);
 
     return success();
   }

diff  --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 91fb02db9601..c2f88d06062c 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLLVMIR
 
   DEPENDS
   MLIRLLVMOpsIncGen
+  MLIRLLVMOpsInterfacesIncGen
   MLIROpenMPOpsIncGen
   intrinsics_gen
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 0a9b61628384..b7f7789ee44b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -36,6 +36,51 @@ static constexpr const char kVolatileAttrName[] = "volatile_";
 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
+#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct BitmaskEnumStorage : public AttributeStorage {
+  using KeyTy = uint64_t;
+
+  BitmaskEnumStorage(KeyTy val) : value(val) {}
+
+  bool operator==(const KeyTy &key) const { return value == key; }
+
+  static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
+                                       const KeyTy &key) {
+    return new (allocator.allocate<BitmaskEnumStorage>())
+        BitmaskEnumStorage(key);
+  }
+
+  KeyTy value = 0;
+};
+} // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute, 8> filteredAttrs(
+      llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+        if (attr.first == "fastmathFlags") {
+          auto defAttr = FMFAttr::get({}, attr.second.getContext());
+          return defAttr != attr.second;
+        }
+        return true;
+      }));
+  return filteredAttrs;
+}
+
+static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
+                                    NamedAttrList &result) {
+  return parser.parseOptionalAttrDict(result);
+}
+
+static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
+                             DictionaryAttr attrs) {
+  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+}
 
 //===----------------------------------------------------------------------===//
 // Printing/parsing for LLVM::CmpOp.
@@ -50,7 +95,7 @@ static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
   p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
     << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
-  p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
+  p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"predicate"});
   p << " : " << op.lhs().getType();
 }
 
@@ -771,7 +816,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
 
   auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
   p << '(' << args << ')';
-  p.printOptionalAttrDict(op.getAttrs(), {"callee"});
+  p.printOptionalAttrDict(processFMFAttr(op.getAttrs()), {"callee"});
 
   // Reconstruct the function MLIR function type from operand and result types.
   p << " : "
@@ -2041,6 +2086,8 @@ static LogicalResult verify(FenceOp &op) {
 //===----------------------------------------------------------------------===//
 
 void LLVMDialect::initialize() {
+  addAttributes<FMFAttr>();
+
   // clang-format off
   addTypes<LLVMVoidType,
            LLVMHalfType,
@@ -2172,3 +2219,87 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
   return op->hasTrait<OpTrait::SymbolTable>() &&
          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
 }
+
+FMFAttr FMFAttr::get(FastmathFlags flags, MLIRContext *context) {
+  return Base::get(context, static_cast<uint64_t>(flags));
+}
+
+FastmathFlags FMFAttr::getFlags() const {
+  return static_cast<FastmathFlags>(getImpl()->value);
+}
+
+static constexpr const FastmathFlags FastmathFlagsList[] = {
+    // clang-format off
+    FastmathFlags::nnan,
+    FastmathFlags::ninf,
+    FastmathFlags::nsz,
+    FastmathFlags::arcp,
+    FastmathFlags::contract,
+    FastmathFlags::afn,
+    FastmathFlags::reassoc,
+    FastmathFlags::fast,
+    // clang-format on
+};
+
+void FMFAttr::print(DialectAsmPrinter &printer) const {
+  printer << "fastmath<";
+  auto flags = llvm::make_filter_range(FastmathFlagsList, [&](auto flag) {
+    return bitEnumContains(getFlags(), flag);
+  });
+  llvm::interleaveComma(flags, printer,
+                        [&](auto flag) { printer << stringifyEnum(flag); });
+  printer << ">";
+}
+
+Attribute FMFAttr::parse(DialectAsmParser &parser) {
+  if (failed(parser.parseLess()))
+    return {};
+
+  FastmathFlags flags = {};
+  if (failed(parser.parseOptionalGreater())) {
+    do {
+      StringRef elemName;
+      if (failed(parser.parseKeyword(&elemName)))
+        return {};
+
+      auto elem = symbolizeFastmathFlags(elemName);
+      if (!elem) {
+        parser.emitError(parser.getNameLoc(), "Unknown fastmath flag: ")
+            << elemName;
+        return {};
+      }
+
+      flags = flags | *elem;
+    } while (succeeded(parser.parseOptionalComma()));
+
+    if (failed(parser.parseGreater()))
+      return {};
+  }
+
+  return FMFAttr::get(flags, parser.getBuilder().getContext());
+}
+
+Attribute LLVMDialect::parseAttribute(DialectAsmParser &parser,
+                                      Type type) const {
+  if (type) {
+    parser.emitError(parser.getNameLoc(), "unexpected type");
+    return {};
+  }
+  StringRef attrKind;
+  if (parser.parseKeyword(&attrKind))
+    return {};
+
+  if (attrKind == "fastmath")
+    return FMFAttr::parse(parser);
+
+  parser.emitError(parser.getNameLoc(), "Unknown attrribute type: ")
+      << attrKind;
+  return {};
+}
+
+void LLVMDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
+  if (auto fmf = attr.dyn_cast<FMFAttr>())
+    fmf.print(os);
+  else
+    llvm_unreachable("Unknown attribute type");
+}

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 5ffb11e76a93..7700867bb461 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -666,6 +666,29 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
       });
 }
 
+static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
+  using llvmFMF = llvm::FastMathFlags;
+  using FuncT = void (llvmFMF::*)(bool);
+  const std::pair<FastmathFlags, FuncT> handlers[] = {
+      // clang-format off
+      {FastmathFlags::nnan,     &llvmFMF::setNoNaNs},
+      {FastmathFlags::ninf,     &llvmFMF::setNoInfs},
+      {FastmathFlags::nsz,      &llvmFMF::setNoSignedZeros},
+      {FastmathFlags::arcp,     &llvmFMF::setAllowReciprocal},
+      {FastmathFlags::contract, &llvmFMF::setAllowContract},
+      {FastmathFlags::afn,      &llvmFMF::setApproxFunc},
+      {FastmathFlags::reassoc,  &llvmFMF::setAllowReassoc},
+      {FastmathFlags::fast,     &llvmFMF::setFast},
+      // clang-format on
+  };
+  llvm::FastMathFlags ret;
+  auto fmf = op.fastmathFlags();
+  for (auto it : handlers)
+    if (bitEnumContains(fmf, it.first))
+      (ret.*(it.second))(true);
+  return ret;
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.  LLVM IR Builder does not have a generic interface so
 /// this has to be a long chain of `if`s calling 
diff erent functions with a
@@ -680,6 +703,10 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     return position;
   };
 
+  llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
+  if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
+    builder.setFastMathFlags(getFastmathFlags(fmf));
+
 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
 
   // Emit function calls.  If the "callee" attribute is present, this is a

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index fc9ff686d78f..05d83810e179 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -387,3 +387,35 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
 
   llvm.return
 }
+
+// CHECK-LABEL: @fastmathFlags
+func @fastmathFlags(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32) {
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+// CHECK: {{.*}} = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %1 = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %2 = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+  %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)>
+  %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.i32) -> !llvm.struct<(i32, double, i32)>
+
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : !llvm.float
+  %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<>} : !llvm.float
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+  %9 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = llvm.fneg %arg0 : !llvm.float
+  %10 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<>} : !llvm.float
+  return
+}

diff  --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 099b8c96cb16..921c3e87fdae 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1360,6 +1360,50 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
 
 // -----
 
+llvm.func @fastmathFlagsFunc(!llvm.float) -> !llvm.float
+
+// CHECK-LABEL: @fastmathFlags
+llvm.func @fastmathFlags(%arg0: !llvm.float) {
+// CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = fdiv nnan ninf float {{.*}}, {{.*}}
+// CHECK: {{.*}} = frem nnan ninf float {{.*}}, {{.*}}
+  %0 = llvm.fadd %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+  %1 = llvm.fsub %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+  %2 = llvm.fmul %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+  %3 = llvm.fdiv %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+  %4 = llvm.frem %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = fcmp nnan ninf oeq {{.*}}, {{.*}}
+  %5 = llvm.fcmp "oeq" %arg0, %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = fneg nnan ninf float {{.*}}
+  %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<nnan, ninf>} : !llvm.float
+
+// CHECK: {{.*}} = call float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call nnan float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call ninf float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call nsz float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call arcp float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call contract float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call afn float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call reassoc float @fastmathFlagsFunc({{.*}})
+// CHECK: {{.*}} = call fast float @fastmathFlagsFunc({{.*}})
+  %8 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<>} : (!llvm.float) -> (!llvm.float)
+  %9 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nnan>} : (!llvm.float) -> (!llvm.float)
+  %10 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<ninf>} : (!llvm.float) -> (!llvm.float)
+  %11 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<nsz>} : (!llvm.float) -> (!llvm.float)
+  %12 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<arcp>} : (!llvm.float) -> (!llvm.float)
+  %13 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.float) -> (!llvm.float)
+  %14 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<afn>} : (!llvm.float) -> (!llvm.float)
+  %15 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<reassoc>} : (!llvm.float) -> (!llvm.float)
+  %16 = llvm.call @fastmathFlagsFunc(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (!llvm.float) -> (!llvm.float)
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: @switch_args
 llvm.func @switch_args(%arg0: !llvm.i32) {
   %0 = llvm.mlir.constant(5 : i32) : !llvm.i32


        


More information about the llvm-branch-commits mailing list