[Mlir-commits] [mlir] 589764a - [mlir][math] Initial support for fastmath flag attributes for Math dialect.

Slava Zakharin llvmlistbot at llvm.org
Fri Nov 4 10:44:27 PDT 2022


Author: Slava Zakharin
Date: 2022-11-04T10:41:56-07:00
New Revision: 589764a382642ae8374cfe21a6b10f839c8596da

URL: https://github.com/llvm/llvm-project/commit/589764a382642ae8374cfe21a6b10f839c8596da
DIFF: https://github.com/llvm/llvm-project/commit/589764a382642ae8374cfe21a6b10f839c8596da.diff

LOG: [mlir][math] Initial support for fastmath flag attributes for Math dialect.

Added arith::FastMathAttr and ArithFastMathInterface support for Math dialect
floating point operations.

This change-set creates ArithCommon conversion utils that currently
provide classes and methods to aid with arith::FastMathAttr conversion
into LLVM::FastmathFlags. These utils are used in ArithToLLVM and
MathToLLVM convertors, but may eventually be used by other converters
that need to convert fast math attributes.

Since Math dialect operations use arith::FastMathAttr, MathOps.td now
has to include enum and attributes definitions from Arith dialect.
To minimize the amount of TD code included from Arith dialect,
I moved FastMathAttr definition into ArithBase.td.

Differential Revision: https://reviews.llvm.org/D136312

Added: 
    mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
    mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
    mlir/lib/Conversion/ArithCommon/CMakeLists.txt

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/include/mlir/Dialect/Math/IR/Math.h
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
    mlir/test/Dialect/Math/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
new file mode 100644
index 0000000000000..f27f7bb5975ec
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -0,0 +1,81 @@
+//===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
+#define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+
+//===----------------------------------------------------------------------===//
+// Support for converting Arith FastMathFlags to LLVM FastmathFlags
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace arith {
+// Map arithmetic fastmath enum values to LLVMIR enum values.
+LLVM::FastmathFlags
+convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
+
+// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+LLVM::FastmathFlagsAttr
+convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
+
+// Attribute converter that populates a NamedAttrList by removing the fastmath
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertFastMathToLLVM {
+public:
+  AttrConvertFastMathToLLVM(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith fastmath attribute.
+    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+    // Remove the source fastmath attribute.
+    auto arithFMFAttr =
+        convertedAttr.erase(arithFMFAttrName)
+            .template dyn_cast_or_null<arith::FastMathFlagsAttr>();
+    if (arithFMFAttr) {
+      llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
+      convertedAttr.set(targetAttrName,
+                        convertArithFastMathAttrToLLVM(arithFMFAttr));
+    }
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+  NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the fastmath
+// attribute from the source operation attributes. This may be useful for
+// target operations that do not require the fastmath attribute, or for targets
+// that do not yet support the LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrDropFastMath {
+public:
+  AttrDropFastMath(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith fastmath attribute.
+    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+    // Remove the source fastmath attribute.
+    convertedAttr.erase(arithFMFAttrName);
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+  NamedAttrList convertedAttr;
+};
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index 13d252cf056e5..78fd7bdf012f8 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -121,4 +121,9 @@ def FastMathFlags : I32BitEnumAttr<
   let printBitEnumPrimaryGroups = 1;
 }
 
+def Arith_FastMathAttr :
+    EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // ARITH_BASE

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6ca74392f0565..3d6cef9705ebe 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -20,11 +20,6 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/EnumAttr.td"
 
-def Arith_FastMathAttr :
-    EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 // Base class for Arith dialect ops. Ops in this dialect have no side
 // effects and can be applied element-wise to vectors and tensors.
 class Arith_Op<string mnemonic, list<Trait> traits = []> :

diff  --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h
index 6af358bf57b37..98416d1c9abdf 100644
--- a/mlir/include/mlir/Dialect/Math/IR/Math.h
+++ b/mlir/include/mlir/Dialect/Math/IR/Math.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_MATH_IR_MATH_H_
 #define MLIR_DIALECT_MATH_IR_MATH_H_
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"

diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 99e209000c0f5..a5b28bd0891c5 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -9,6 +9,8 @@
 #ifndef MATH_OPS
 #define MATH_OPS
 
+include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Dialect/Math/IR/MathBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/VectorInterfaces.td"
@@ -36,11 +38,16 @@ class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
 // operand and result of the same type. This type can be a floating point type,
 // vector or tensor thereof.
 class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
-    Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
-  let arguments = (ins FloatLike:$operand);
+    Math_Op<mnemonic,
+        traits # [SameOperandsAndResultType,
+                  DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let arguments = (ins FloatLike:$operand,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath);
   let results = (outs FloatLike:$result);
 
-  let assemblyFormat = "$operand attr-dict `:` type($result)";
+  let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
 }
 
 // Base class for binary math operations on integer types. Require two
@@ -58,22 +65,32 @@ class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
 // operands and one result of the same type. This type can be a floating point
 // type, vector or tensor thereof.
 class Math_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
-    Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
-  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs);
+    Math_Op<mnemonic,
+        traits # [SameOperandsAndResultType,
+                  DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath);
   let results = (outs FloatLike:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
+  let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
 }
 
 // Base class for floating point ternary operations. Require three operands and
 // one result of the same type. This type can be a floating point type, vector
 // or tensor thereof.
 class Math_FloatTernaryOp<string mnemonic, list<Trait> traits = []> :
-    Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
-  let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c);
+    Math_Op<mnemonic,
+        traits # [SameOperandsAndResultType,
+                  DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath);
   let results = (outs FloatLike:$result);
 
-  let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
+  let assemblyFormat = [{ $a `,` $b `,` $c (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
new file mode 100644
index 0000000000000..8c5d76f9f2d72
--- /dev/null
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -0,0 +1,38 @@
+//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+
+using namespace mlir;
+
+// Map arithmetic fastmath enum values to LLVMIR enum values.
+LLVM::FastmathFlags
+mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
+  LLVM::FastmathFlags llvmFMF{};
+  const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
+      {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
+      {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
+      {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
+      {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
+      {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
+      {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
+      {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
+  for (auto fmfMap : flags) {
+    if (bitEnumContainsAny(arithFMF, fmfMap.first))
+      llvmFMF = llvmFMF | fmfMap.second;
+  }
+  return llvmFMF;
+}
+
+// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+LLVM::FastmathFlagsAttr
+mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
+  arith::FastMathFlags arithFMF = fmfAttr.getValue();
+  return LLVM::FastmathFlagsAttr::get(
+      fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
+}

diff  --git a/mlir/lib/Conversion/ArithCommon/CMakeLists.txt b/mlir/lib/Conversion/ArithCommon/CMakeLists.txt
new file mode 100644
index 0000000000000..888c45f2e52fe
--- /dev/null
+++ b/mlir/lib/Conversion/ArithCommon/CMakeLists.txt
@@ -0,0 +1,10 @@
+add_mlir_conversion_library(MLIRArithAttrToLLVMConversion
+  AttrToLLVMConverter.cpp
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRLLVMDialect
+  )

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f2814b56d4d34..1409b7fe1bca8 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -24,93 +25,20 @@ using namespace mlir;
 
 namespace {
 
-// Map arithmetic fastmath enum values to LLVMIR enum values.
-static LLVM::FastmathFlags
-convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
-  LLVM::FastmathFlags llvmFMF{};
-  const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
-      {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
-      {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
-      {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
-      {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
-      {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
-      {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
-      {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
-  for (auto fmfMap : flags) {
-    if (bitEnumContainsAny(arithFMF, fmfMap.first))
-      llvmFMF = llvmFMF | fmfMap.second;
-  }
-  return llvmFMF;
-}
-
-// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
-static LLVM::FastmathFlagsAttr
-convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) {
-  arith::FastMathFlags arithFMF = fmfAttr.getValue();
-  return LLVM::FastmathFlagsAttr::get(
-      fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
-}
-
-// Attribute converter that populates a NamedAttrList by removing the fastmath
-// attribute from the source operation attributes, and replacing it with an
-// equivalent LLVM fastmath attribute.
-template <typename SourceOp, typename TargetOp>
-class AttrConvertFastMath {
-public:
-  AttrConvertFastMath(SourceOp srcOp) {
-    // Copy the source attributes.
-    convertedAttr = NamedAttrList{srcOp->getAttrs()};
-    // Get the name of the arith fastmath attribute.
-    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
-    // Remove the source fastmath attribute.
-    auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName)
-                            .template dyn_cast_or_null<arith::FastMathFlagsAttr>();
-    if (arithFMFAttr) {
-      llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
-      convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr));
-    }
-  }
-
-  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
-
-private:
-  NamedAttrList convertedAttr;
-};
-
-// Attribute converter that populates a NamedAttrList by removing the fastmath
-// attribute from the source operation attributes. This may be useful for
-// target operations that do not require the fastmath attribute, or for targets
-// that do not yet support the LLVM fastmath attribute.
-template <typename SourceOp, typename TargetOp>
-class AttrDropFastMath {
-public:
-  AttrDropFastMath(SourceOp srcOp) {
-    // Copy the source attributes.
-    convertedAttr = NamedAttrList{srcOp->getAttrs()};
-    // Get the name of the arith fastmath attribute.
-    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
-    // Remove the source fastmath attribute.
-    convertedAttr.erase(arithFMFAttrName);
-  }
-
-  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
-
-private:
-  NamedAttrList convertedAttr;
-};
-
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
 
-using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                                                  AttrConvertFastMath>;
+using AddFOpLowering =
+    VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
-using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                                                  AttrConvertFastMath>;
+using DivFOpLowering =
+    VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
@@ -125,28 +53,30 @@ using FPToSIOpLowering =
 using FPToUIOpLowering =
     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
 // TODO: Add LLVM intrinsic support for fastmath
-using MaxFOpLowering =
-    VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
+using MaxFOpLowering = VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
+                                                  arith::AttrDropFastMath>;
 using MaxSIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
 using MaxUIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
 // TODO: Add LLVM intrinsic support for fastmath
-using MinFOpLowering =
-    VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
+using MinFOpLowering = VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
+                                                  arith::AttrDropFastMath>;
 using MinSIOpLowering =
     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
-using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                                                  AttrConvertFastMath>;
+using MulFOpLowering =
+    VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
-using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
-                                                  AttrConvertFastMath>;
+using NegFOpLowering =
+    VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 // TODO: Add LLVM intrinsic support for fastmath
-using RemFOpLowering =
-    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
+using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                                                  arith::AttrDropFastMath>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -160,8 +90,9 @@ using ShRUIOpLowering =
     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
 using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
-using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                                                  AttrConvertFastMath>;
+using SubFOpLowering =
+    VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+                               arith::AttrConvertFastMathToLLVM>;
 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
 using TruncFOpLowering =
     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;

diff  --git a/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
index 45ee8708aa155..bb1fa2fbb6577 100644
--- a/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRArithAttrToLLVMConversion
   MLIRArithDialect
   MLIRLLVMCommonConversion
   MLIRLLVMDialect

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index a65814d36b5b4..62dae19a31344 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_subdirectory(AffineToStandard)
 add_subdirectory(AMDGPUToROCDL)
+add_subdirectory(ArithCommon)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)

diff  --git a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
index a6e6b4f56d37e..97393fc849691 100644
--- a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToLLVM
   Core
 
   LINK_LIBS PUBLIC
+  MLIRArithAttrToLLVMConversion
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
   MLIRMathDialect

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index b67a86f443b5c..b5ce019b20832 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@@ -24,31 +25,39 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
-using AbsFOpLowering = VectorConvertToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
-using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
+
+template <typename SourceOp, typename TargetOp>
+using ConvertFastMath = arith::AttrConvertFastMathToLLVM<SourceOp, TargetOp>;
+
+template <typename SourceOp, typename TargetOp>
+using ConvertFMFMathToLLVMPattern =
+    VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath>;
+
+using AbsFOpLowering = ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
+using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
 using CopySignOpLowering =
-    VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
-using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
+    ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
+using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
 using CtPopFOpLowering =
     VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
-using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
-using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
+using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
+using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using FloorOpLowering =
-    VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
+    ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
+using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
 using Log10OpLowering =
-    VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
-using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
-using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
-using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
+    ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
+using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
+using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
+using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
 using RoundEvenOpLowering =
-    VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
+    ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
 using RoundOpLowering =
-    VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
-using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
-using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+    ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
+using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
+using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
 using FTruncOpLowering =
-    VectorConvertToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
+    ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
 
 // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
 template <typename MathOp, typename LLVMOp>
@@ -113,6 +122,8 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
     auto resultType = op.getResult().getType();
     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+    ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
+    ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
 
     if (!operandType.isa<LLVM::LLVMArrayType>()) {
       LLVM::ConstantOp one;
@@ -123,8 +134,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
       } else {
         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
       }
-      auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
-      rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
+      auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand(),
+                                              expAttrs.getAttrs());
+      rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
+          op, operandType, ValueRange{exp, one}, subAttrs.getAttrs());
       return success();
     }
 
@@ -142,9 +155,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
-          auto exp =
-              rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
-          return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
+          auto exp = rewriter.create<LLVM::ExpOp>(
+              loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs());
+          return rewriter.create<LLVM::FSubOp>(
+              loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs());
         },
         rewriter);
   }
@@ -166,6 +180,8 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
     auto resultType = op.getResult().getType();
     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+    ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
+    ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
 
     if (!operandType.isa<LLVM::LLVMArrayType>()) {
       LLVM::ConstantOp one =
@@ -176,9 +192,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
                                            floatOne))
               : rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
 
-      auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
-                                               adaptor.getOperand());
-      rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
+      auto add = rewriter.create<LLVM::FAddOp>(
+          loc, operandType, ValueRange{one, adaptor.getOperand()},
+          addAttrs.getAttrs());
+      rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, ValueRange{add},
+                                               logAttrs.getAttrs());
       return success();
     }
 
@@ -196,9 +214,11 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
-          auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
-                                                   operands[0]);
-          return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
+          auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy,
+                                                   ValueRange{one, operands[0]},
+                                                   addAttrs.getAttrs());
+          return rewriter.create<LLVM::LogOp>(
+              loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs());
         },
         rewriter);
   }
@@ -220,6 +240,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
     auto resultType = op.getResult().getType();
     auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
     auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
+    ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
+    ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
 
     if (!operandType.isa<LLVM::LLVMArrayType>()) {
       LLVM::ConstantOp one;
@@ -230,8 +252,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
       } else {
         one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
       }
-      auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
-      rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
+      auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand(),
+                                                sqrtAttrs.getAttrs());
+      rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
+          op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
       return success();
     }
 
@@ -249,9 +273,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
-          auto sqrt =
-              rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
-          return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
+          auto sqrt = rewriter.create<LLVM::SqrtOp>(
+              loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs());
+          return rewriter.create<LLVM::FDivOp>(
+              loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs());
         },
         rewriter);
   }

diff  --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index bcdbad1709e93..8c7f031cb97d9 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -36,6 +36,18 @@ func.func @log1p(%arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @log1p_fmf(
+// CHECK-SAME: f32
+func.func @log1p_fmf(%arg0 : f32) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+  // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
+  // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+  %0 = math.log1p %arg0 fastmath<fast> : f32
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @log1p_2dvector(
 func.func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
   // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -49,6 +61,19 @@ func.func @log1p_2dvector(%arg0 : vector<4x3xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @log1p_2dvector_fmf(
+func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
+  // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32>
+  // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<3xf32>
+  // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<3xf32>) -> vector<3xf32>
+  // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+  %0 = math.log1p %arg0 fastmath<fast> : vector<4x3xf32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @expm1(
 // CHECK-SAME: f32
 func.func @expm1(%arg0 : f32) {
@@ -61,6 +86,42 @@ func.func @expm1(%arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @expm1_fmf(
+// CHECK-SAME: f32
+func.func @expm1_fmf(%arg0 : f32) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+  // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath<fast>} : f32
+  %0 = math.expm1 %arg0 fastmath<fast> : f32
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @expm1_vector(
+// CHECK-SAME: vector<4xf32>
+func.func @expm1_vector(%arg0 : vector<4xf32>) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) : (vector<4xf32>) -> vector<4xf32>
+  // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<4xf32>
+  %0 = math.expm1 %arg0 : vector<4xf32>
+  func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @expm1_vector_fmf(
+// CHECK-SAME: vector<4xf32>
+func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+  // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf32>) -> vector<4xf32>
+  // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath<fast>} : vector<4xf32>
+  %0 = math.expm1 %arg0 fastmath<fast> : vector<4xf32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt(
 // CHECK-SAME: f32
 func.func @rsqrt(%arg0 : f32) {
@@ -148,6 +209,18 @@ func.func @rsqrt_double(%arg0 : f64) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt_double_fmf(
+// CHECK-SAME: f64
+func.func @rsqrt_double_fmf(%arg0 : f64) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64
+  // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f64) -> f64
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : f64
+  %0 = math.rsqrt %arg0 fastmath<fast> : f64
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt_vector(
 // CHECK-SAME: vector<4xf32>
 func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
@@ -160,6 +233,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @rsqrt_vector_fmf(
+// CHECK-SAME: vector<4xf32>
+func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
+  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32>
+  // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (vector<4xf32>) -> vector<4xf32>
+  // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<4xf32>
+  %0 = math.rsqrt %arg0 fastmath<fast> : vector<4xf32>
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @rsqrt_multidim_vector(
 func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
   // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -210,3 +295,19 @@ func.func @trunc(%arg0 : f32) {
   %0 = math.trunc %arg0 : f32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @fastmath(
+// CHECK-SAME: f32
+func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) {
+  // CHECK: llvm.intr.trunc(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
+  %0 = math.trunc %arg0 fastmath<fast> : f32
+  // CHECK: llvm.intr.pow(%arg0, %arg0) {fastmathFlags = #llvm.fastmath<afn>} : (f32, f32) -> f32
+  %1 = math.powf %arg0, %arg0 fastmath<afn> : f32
+  // CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32
+  %2 = math.sqrt %arg0 fastmath<none> : f32
+  // CHECK: llvm.intr.fma(%arg0, %arg0, %arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32, f32) -> f32
+  %3 = math.fma %arg0, %arg0, %arg0 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
+  func.return
+}

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index d984cbb66f8c2..7e121f80dd79e 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -269,3 +269,17 @@ func.func @trunc(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   %2 = math.trunc %t : tensor<4x4x?xf32>
   return
 }
+
+// CHECK-LABEL: func @fastmath(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @fastmath(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.trunc %[[F]] fastmath<fast> : f32
+  %0 = math.trunc %f fastmath<fast> : f32
+  // CHECK: %{{.*}} = math.powf %[[V]], %[[V]] fastmath<fast> : vector<4xf32>
+  %1 = math.powf %v, %v fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : vector<4xf32>
+  // CHECK: %{{.*}} = math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32>
+  %2 = math.fma %t, %t, %t fastmath<none> : tensor<4x4x?xf32>
+  // CHECK: %{{.*}} = math.absf %[[F]] fastmath<ninf> : f32
+  %3 = math.absf %f fastmath<ninf> : f32
+  return
+}


        


More information about the Mlir-commits mailing list