[Mlir-commits] [mlir] b56e65d - [mlir][arith] Initial support for fastmath flag attributes in the Arithmetic dialect (v2)

Slava Zakharin llvmlistbot at llvm.org
Wed Oct 26 11:56:45 PDT 2022


Author: Jeremy Furtek
Date: 2022-10-26T11:56:16-07:00
New Revision: b56e65d31825fe4a1ae02fdcbad58bb7993d63a7

URL: https://github.com/llvm/llvm-project/commit/b56e65d31825fe4a1ae02fdcbad58bb7993d63a7
DIFF: https://github.com/llvm/llvm-project/commit/b56e65d31825fe4a1ae02fdcbad58bb7993d63a7.diff

LOG: [mlir][arith] Initial support for fastmath flag attributes in the Arithmetic dialect (v2)

This diff adds initial (partial) support for "fastmath" attributes for floating
point operations in the arithmetic dialect. The "fastmath" attributes are
implemented using a default-valued bit enum. The defined flags currently mirror
the fastmath flags in the LLVM dialect (and in LLVM itself). Extending the
set of flags (if necessary) is left as a future task.

In this diff:
- Definition of FastMathAttr as a custom attribute in the Arithmetic dialect
  that inherits from the EnumAttr class.
- Definition of ArithFastMathInterface, which is an interface that is
  implemented by operations that have an arith::fastmath attribute.
- Declaration of a default-valued fastmath attribute for unary and (some) binary
  floating point operations in the Arithmetic dialect.
- Conversion code to lower arithmetic fastmath flags to LLVM fastmath flags

NOT in this diff (but planned or currently in progress):
- Documentation of flag meanings
- Addition of FastMathAttr attributes to other dialects that might lower to the
  Arithmetic dialect (e.g. Math and Complex)
- Folding/rewrite implementations that are enabled by fastmath flags
- Specification of fastmath values from Python bindings (pending other in-
  progress diffs)

Reviewed By: mehdi_amini, vzakhari

Differential Revision: https://reviews.llvm.org/D126305

Added: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
    mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
    mlir/include/mlir/Dialect/Arith/IR/Arith.h
    mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
    mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
    mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
    mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
    mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/lib/Dialect/Arith/IR/CMakeLists.txt
    mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
    mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
    mlir/test/CAPI/ir.c
    mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
    mlir/test/Dialect/Arith/ops.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 9de4334a9d70b..90e62aa11787c 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -22,8 +22,10 @@ namespace detail {
 /// and given operands.
 LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
                               ValueRange operands,
+                              ArrayRef<NamedAttribute> targetAttrs,
                               LLVMTypeConverter &typeConverter,
                               ConversionPatternRewriter &rewriter);
+
 } // namespace detail
 } // namespace LLVM
 
@@ -197,7 +199,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(),
-                                         adaptor.getOperands(),
+                                         adaptor.getOperands(), op->getAttrs(),
                                          *this->getTypeConverter(), rewriter);
   }
 };

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cae1b1cf3892d..d115c2d2f58fe 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -56,14 +56,34 @@ LogicalResult handleMultidimensionalVectors(
 
 LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
                                     ValueRange operands,
+                                    ArrayRef<NamedAttribute> targetAttrs,
                                     LLVMTypeConverter &typeConverter,
                                     ConversionPatternRewriter &rewriter);
 } // namespace detail
 } // namespace LLVM
 
+// Default attribute conversion class, which passes all source attributes
+// through to the target op, unmodified.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertPassThrough {
+public:
+  AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
+
+  ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
+
+private:
+  ArrayRef<NamedAttribute> srcAttrs;
+};
+
 /// Basic lowering implementation to rewrite Ops with just one result to the
 /// LLVM Dialect. This supports higher-dimensional vector types.
-template <typename SourceOp, typename TargetOp>
+/// The AttrConvert template template parameter should be a template class
+/// with SourceOp and TargetOp type parameters, a constructor that takes
+/// a SourceOp instance, and a getAttrs() method that returns
+/// ArrayRef<NamedAttribute>.
+template <typename SourceOp, typename TargetOp,
+          template <typename, typename> typename AttrConvert =
+              AttrConvertPassThrough>
 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -75,9 +95,12 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
+    // Determine attributes for the target op
+    AttrConvert<SourceOp, TargetOp> attrConvert(op);
+
     return LLVM::detail::vectorOneToOneRewrite(
         op, TargetOp::getOperationName(), adaptor.getOperands(),
-        *this->getTypeConverter(), rewriter);
+        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
   }
 };
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 0ecd293d7778c..3e14e4d346753 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -17,6 +17,7 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
+#include "llvm/ADT/StringExtras.h"
 
 //===----------------------------------------------------------------------===//
 // ArithDialect
@@ -29,6 +30,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Arith Interfaces
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc"
 
 //===----------------------------------------------------------------------===//
 // Arith Dialect Operations

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index aaaa3b0d5b52f..13d252cf056e5 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -23,6 +23,7 @@ def Arith_Dialect : Dialect {
   }];
 
   let hasConstantMaterializer = 1;
+  let useDefaultAttributePrinterParser = 1;
 }
 
 // The predicate indicates the type of the comparison to perform:
@@ -92,4 +93,32 @@ def AtomicRMWKindAttr : I64EnumAttr<
   let cppNamespace = "::mlir::arith";
 }
 
+def FASTMATH_NONE            : I32BitEnumAttrCaseNone<"none"      >;
+def FASTMATH_REASSOC         : I32BitEnumAttrCaseBit<"reassoc",  0>;
+def FASTMATH_NO_NANS         : I32BitEnumAttrCaseBit<"nnan",     1>;
+def FASTMATH_NO_INFS         : I32BitEnumAttrCaseBit<"ninf",     2>;
+def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz",      3>;
+def FASTMATH_ALLOW_RECIP     : I32BitEnumAttrCaseBit<"arcp",     4>;
+def FASTMATH_ALLOW_CONTRACT  : I32BitEnumAttrCaseBit<"contract", 5>;
+def FASTMATH_APPROX_FUNC     : I32BitEnumAttrCaseBit<"afn",      6>;
+def FASTMATH_FAST            : I32BitEnumAttrCaseGroup<
+    "fast",
+    [
+      FASTMATH_REASSOC,         FASTMATH_NO_NANS,     FASTMATH_NO_INFS,
+      FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT,
+      FASTMATH_APPROX_FUNC]>;
+
+def FastMathFlags : I32BitEnumAttr<
+    "FastMathFlags",
+    "Floating point fast math flags",
+    [
+      FASTMATH_NONE,           FASTMATH_REASSOC,         FASTMATH_NO_NANS,
+      FASTMATH_NO_INFS,        FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP,
+      FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC,     FASTMATH_FAST]> {
+  let separator = ",";
+  let cppNamespace = "::mlir::arith";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
 #endif // ARITH_BASE

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 692338eb8370e..f12a1a33f6912 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -10,6 +10,7 @@
 #define ARITH_OPS
 
 include "mlir/Dialect/Arith/IR/ArithBase.td"
+include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -17,6 +18,12 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/EnumAttr.td"
+
+def Arith_FastMathAttr :
+    EnumAttr<Arith_Dialect, FastMathFlags, "fastmath"> {
+  let assemblyFormat = "`<` $value `>`";
+}
 
 // Base class for Arith dialect ops. Ops in this dialect have no side
 // effects and can be applied element-wise to vectors and tensors.
@@ -58,15 +65,27 @@ class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
 
 // Base class for floating point unary operations.
 class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
-    Arith_UnaryOp<mnemonic, traits>,
-    Arguments<(ins FloatLike:$operand)>,
-    Results<(outs FloatLike:$result)>;
+    Arith_UnaryOp<mnemonic,
+      !listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
+                  traits)>,
+    Arguments<(ins FloatLike:$operand,
+      DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
+    Results<(outs FloatLike:$result)> {
+  let assemblyFormat = [{ $operand custom<ArithFastMathAttr>($fastmath)
+                          attr-dict `:` type($result) }];
+}
 
 // Base class for floating point binary operations.
 class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
-    Arith_BinaryOp<mnemonic, traits>,
-    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>,
-    Results<(outs FloatLike:$result)>;
+    Arith_BinaryOp<mnemonic,
+      !listconcat([DeclareOpInterfaceMethods<ArithFastMathInterface>],
+                  traits)>,
+    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+      DefaultValuedAttr<Arith_FastMathAttr, "FastMathFlags::none">:$fastmath)>,
+    Results<(outs FloatLike:$result)> {
+  let assemblyFormat = [{ $lhs `,` $rhs `` custom<ArithFastMathAttr>($fastmath)
+                          attr-dict `:` type($result) }];
+}
 
 // Base class for arithmetic cast operations. Requires a single operand and
 // result. If either is a shaped type, then the other must be of the same shape.

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
new file mode 100644
index 0000000000000..acaecf6f409dc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -0,0 +1,52 @@
+//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the Arith interfaces definition file.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARITH_OPS_INTERFACES
+#define ARITH_OPS_INTERFACES
+
+include "mlir/IR/OpBase.td"
+
+def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
+  let description = [{
+    Access to operation fastmath flags.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a FastMathFlagsAttr attribute for the operation",
+      /*returnType=*/  "FastMathFlagsAttr",
+      /*methodName=*/  "getFastMathFlagsAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getFastmathAttr();
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the FastMathFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getFastMathAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "fastmath";
+      }]
+      >
+
+  ];
+}
+
+#endif // ARITH_OPS_INTERFACES

diff  --git a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt
index 93ff719e677be..5cdde2edd50f1 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt
@@ -1,5 +1,14 @@
 set(LLVM_TARGET_DEFINITIONS ArithOps.td)
 mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls)
 mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls
+              -attrdefs-dialect=arith)
+mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs
+              -attrdefs-dialect=arith)
 add_mlir_dialect(ArithOps arith)
 add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc)
+
+set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td)
+mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRArithOpsInterfacesIncGen)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index ca38f072d9df6..64e5e0abfd763 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -48,7 +48,6 @@ class SmartMutex;
 namespace mlir {
 namespace LLVM {
 class LLVMDialect;
-class LoopOptionsAttrBuilder;
 
 namespace detail {
 struct LLVMTypeStorage;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
index 671e19a36d5e9..d9c1a41bd2b68 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
@@ -23,8 +23,28 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   let cppNamespace = "::mlir::LLVM";
 
   let methods = [
-    InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags",
-                    "getFastmathFlags">,
+    InterfaceMethod<
+      /*desc=*/        "Returns a FastmathFlagsAttr attribute for the operation",
+      /*returnType=*/  "FastmathFlagsAttr",
+      /*methodName=*/  "getFastmathAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        return op.getFastmathFlagsAttr();
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the FastmathFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getFastmathAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "fastmathFlags";
+      }]
+      >
   ];
 }
 

diff  --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1610e5cee8b7d..cbaa67c21532c 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -24,16 +24,93 @@ using namespace mlir;
 
 namespace {
 
+// Map arithmetic fastmath enum values to LLVMIR enum values.
+static LLVM::FastmathFlags
+convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
+  LLVM::FastmathFlags llvmFMF{};
+  const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
+      {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
+      {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
+      {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
+      {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
+      {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
+      {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
+      {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
+  for (auto fmfMap : flags) {
+    if (bitEnumContainsAny(arithFMF, fmfMap.first))
+      llvmFMF = llvmFMF | fmfMap.second;
+  }
+  return llvmFMF;
+}
+
+// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
+static LLVM::FastmathFlagsAttr
+convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) {
+  arith::FastMathFlags arithFMF = fmfAttr.getValue();
+  return LLVM::FastmathFlagsAttr::get(
+      fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
+}
+
+// Attribute converter that populates a NamedAttrList by removing the fastmath
+// attribute from the source operation attributes, and replacing it with an
+// equivalent LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrConvertFastMath {
+public:
+  AttrConvertFastMath(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith fastmath attribute.
+    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+    // Remove the source fastmath attribute.
+    auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName)
+                            .dyn_cast_or_null<arith::FastMathFlagsAttr>();
+    if (arithFMFAttr) {
+      llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
+      convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr));
+    }
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+  NamedAttrList convertedAttr;
+};
+
+// Attribute converter that populates a NamedAttrList by removing the fastmath
+// attribute from the source operation attributes. This may be useful for
+// target operations that do not require the fastmath attribute, or for targets
+// that do not yet support the LLVM fastmath attribute.
+template <typename SourceOp, typename TargetOp>
+class AttrDropFastMath {
+public:
+  AttrDropFastMath(SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+    // Get the name of the arith fastmath attribute.
+    llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
+    // Remove the source fastmath attribute.
+    convertedAttr.erase(arithFMFAttrName);
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+  NamedAttrList convertedAttr;
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
 
-using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
+using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+                                                  AttrConvertFastMath>;
 using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
-using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
+using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+                                                  AttrConvertFastMath>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
@@ -47,23 +124,29 @@ using FPToSIOpLowering =
     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
 using FPToUIOpLowering =
     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
+// TODO: Add LLVM intrinsic support for fastmath
 using MaxFOpLowering =
-    VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp>;
+    VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp, AttrDropFastMath>;
 using MaxSIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
 using MaxUIOpLowering =
     VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
+// TODO: Add LLVM intrinsic support for fastmath
 using MinFOpLowering =
-    VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp>;
+    VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp, AttrDropFastMath>;
 using MinSIOpLowering =
     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
-using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
+using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+                                                  AttrConvertFastMath>;
 using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
-using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
+using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
+                                                  AttrConvertFastMath>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
-using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
+// TODO: Add LLVM intrinsic support for fastmath
+using RemFOpLowering =
+    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, AttrDropFastMath>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -77,7 +160,8 @@ using ShRUIOpLowering =
     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
 using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
-using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
+using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+                                                  AttrConvertFastMath>;
 using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
 using TruncFOpLowering =
     VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
@@ -153,7 +237,7 @@ LogicalResult
 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
                                     ConversionPatternRewriter &rewriter) const {
   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
-                                       adaptor.getOperands(),
+                                       adaptor.getOperands(), op->getAttrs(),
                                        *getTypeConverter(), rewriter);
 }
 

diff  --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index f37d47d744cfe..14f32d66d2447 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -90,7 +90,7 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     return LLVM::detail::oneToOneRewrite(
         op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
-        *getTypeConverter(), rewriter);
+        op->getAttrs(), *getTypeConverter(), rewriter);
   }
 };
 

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 96d83eec18056..8413dcfc83958 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -308,7 +308,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
 /// and given operands.
 LogicalResult LLVM::detail::oneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
-    LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+    ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
+    ConversionPatternRewriter &rewriter) {
   unsigned numResults = op->getNumResults();
 
   SmallVector<Type> resultTypes;
@@ -322,7 +323,7 @@ LogicalResult LLVM::detail::oneToOneRewrite(
   // Create the operation through state since we don't know its C++ type.
   Operation *newOp =
       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                      resultTypes, op->getAttrs());
+                      resultTypes, targetAttrs);
 
   // If the operation produced 0 or 1 result, return them immediately.
   if (numResults == 0)

diff  --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 2f0091f99dd3b..e95c702d79f38 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -105,7 +105,8 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
 
 LogicalResult LLVM::detail::vectorOneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
-    LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
+    ArrayRef<NamedAttribute> targetAttrs, LLVMTypeConverter &typeConverter,
+    ConversionPatternRewriter &rewriter) {
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
@@ -114,13 +115,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
 
   auto llvmNDVectorTy = operands[0].getType();
   if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
-    return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
+    return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
+                           rewriter);
 
-  auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
-                                            ValueRange operands) {
+  auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
+                                                         ValueRange operands) {
     return rewriter
         .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                llvm1DVectorTy, op->getAttrs())
+                llvm1DVectorTy, targetAttrs)
         ->getResult(0);
   };
 

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 2cb5a553634bb..a30ba2eff6412 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -215,17 +215,21 @@ def OrOfExtSI :
 //===----------------------------------------------------------------------===//
 
 // mulf(negf(x), negf(y)) -> mulf(x,y)
+// (retain fastmath flags of original mulf)
 def MulFOfNegF :
-    Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y),
-      [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
+    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
+        (Arith_MulFOp $x, $y, $fmf),
+        [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 //===----------------------------------------------------------------------===//
 // DivFOp
 //===----------------------------------------------------------------------===//
 
 // divf(negf(x), negf(y)) -> divf(x,y)
+// (retain fastmath flags of original divf)
 def DivFOfNegF :
-    Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y),
-      [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
+    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
+        (Arith_DivFOp $x, $y, $fmf),
+        [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 #endif // ARITH_PATTERNS

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 35cb7da2d48a4..b15f7f05b8531 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -8,12 +8,17 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::arith;
 
 #include "mlir/Dialect/Arith/IR/ArithOpsDialect.cpp.inc"
+#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc"
 
 namespace {
 /// This class defines the interface for handling inlining for arithmetic
@@ -34,6 +39,10 @@ void arith::ArithDialect::initialize() {
 #define GET_OP_LIST
 #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
       >();
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc"
+      >();
   addInterfaces<ArithInlinerInterface>();
 }
 

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d1d03a549092d..5693ad1c0e8d1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -23,6 +23,31 @@
 using namespace mlir;
 using namespace mlir::arith;
 
+//===----------------------------------------------------------------------===//
+// Floating point op parse/print helpers
+//===----------------------------------------------------------------------===//
+static ParseResult parseArithFastMathAttr(OpAsmParser &parser,
+                                          Attribute &attr) {
+  if (succeeded(
+          parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) {
+    attr = FastMathFlagsAttr::parse(parser, Type{});
+    return success(static_cast<bool>(attr));
+  } else {
+    // No fastmath attribute mnemonic present - defer attribute creation and use
+    // the default value.
+    return success();
+  }
+}
+
+static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op,
+                                   FastMathFlagsAttr fmAttr) {
+  // Elide printing the fastmath attribute when fastmath=none
+  if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) {
+    printer << " " << FastMathFlagsAttr::getMnemonic();
+    fmAttr.print(printer);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern helpers
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
index ffd7ee3279555..0de17bbfbd12a 100644
--- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArithDialect
 
   DEPENDS
   MLIRArithOpsIncGen
+  MLIRArithOpsInterfacesIncGen
 
   LINK_LIBS PUBLIC
   MLIRDialect

diff  --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index ea86819faf1c0..1aee27560ea35 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -51,13 +51,13 @@ struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
     Type elementType = getSrcVectorElementType<OpTy>(op);
     unsigned bitwidth = elementType.getIntOrFloatBitWidth();
     if (bitwidth == 32)
-      return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
-                                           adaptor.getOperands(),
-                                           getTypeConverter(), rewriter);
+      return LLVM::detail::oneToOneRewrite(
+          op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
+          op->getAttrs(), getTypeConverter(), rewriter);
     if (bitwidth == 64)
-      return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
-                                           adaptor.getOperands(),
-                                           getTypeConverter(), rewriter);
+      return LLVM::detail::oneToOneRewrite(
+          op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
+          op->getAttrs(), getTypeConverter(), rewriter);
     return rewriter.notifyMatchFailure(
         op, "expected 'src' to be either f32 or f64");
   }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index e81c052e6de0e..1f89a55ee363e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -160,9 +160,9 @@ static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
       // clang-format on
   };
   llvm::FastMathFlags ret;
-  auto fmf = op.getFastmathFlags();
+  ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
   for (auto it : handlers)
-    if (bitEnumContainsAll(fmf, it.first))
+    if (bitEnumContainsAll(fmfMlir, it.first))
       (ret.*(it.second))(true);
   return ret;
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index e1d6133bd0652..308a3d87a8d1d 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -309,7 +309,7 @@ int collectStats(MlirOperation operation) {
   // clang-format off
   // CHECK-LABEL: @stats
   // CHECK: Number of operations: 12
-  // CHECK: Number of attributes: 4
+  // CHECK: Number of attributes: 5
   // CHECK: Number of blocks: 3
   // CHECK: Number of regions: 3
   // CHECK: Number of values: 9

diff  --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 05706d89de742..81f402195fb4f 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -448,3 +448,20 @@ func.func @minmaxf(%arg0 : f32, %arg1 : f32) -> f32 {
   %1 = arith.maxf %arg0, %arg1 : f32
   return %0 : f32
 }
+
+// -----
+
+// CHECK-LABEL: @fastmath
+func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+// CHECK: {{.*}} = llvm.fmul %arg0, %arg1  {fastmathFlags = #llvm.fastmath<fast>} : f32
+// CHECK: {{.*}} = llvm.fneg %arg0  {fastmathFlags = #llvm.fastmath<fast>} : f32
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1  : f32
+// CHECK: {{.*}} = llvm.fadd %arg0, %arg1  {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
+  %0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
+  %1 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
+  %2 = arith.negf %arg0 fastmath<fast> : f32
+  %3 = arith.addf %arg0, %arg1 fastmath<none> : f32
+  %4 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
+  return
+}

diff  --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index c34850ff6e305..9d5c686d73b50 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1031,3 +1031,27 @@ func.func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
   %min_unsigned = arith.minui %i1, %i2 : i32
   return
 }
+
+// CHECK-LABEL: @fastmath
+func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.negf %arg0 fastmath<fast> : f32
+  %0 = arith.addf %arg0, %arg1 fastmath<fast> : f32
+  %1 = arith.subf %arg0, %arg1 fastmath<fast> : f32
+  %2 = arith.mulf %arg0, %arg1 fastmath<fast> : f32
+  %3 = arith.divf %arg0, %arg1 fastmath<fast> : f32
+  %4 = arith.remf %arg0, %arg1 fastmath<fast> : f32
+  %5 = arith.negf %arg0 fastmath<fast> : f32
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32
+  %6 = arith.addf %arg0, %arg1 fastmath<none> : f32
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
+  %7 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
+// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath<fast> : f32
+  %8 = arith.mulf %arg0, %arg1 fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : f32
+
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index e6ab837141f1f..9200c6117a493 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -270,7 +270,7 @@ func.func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->
 // -----
 
 func.func @generic(%arg0: memref<?x?xf32>) {
-  // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}}
+  // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) {fastmath = #arith.fastmath<none>} : (f32, f32) -> f32}}
   linalg.generic  {
     indexing_maps = [ affine_map<(i, j) -> (i, j)> ],
     iterator_types = ["parallel", "parallel"]}

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index fcfaf86f1229e..1d0f50fc6552f 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1001,7 +1001,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
   for (FormatElement *param : dir->getArguments()) {
     if (auto *attr = dyn_cast<AttributeVariable>(param)) {
       const NamedAttribute *var = attr->getVar();
-      if (var->attr.isOptional())
+      if (var->attr.isOptional() || var->attr.hasDefaultValue())
         body << llvm::formatv("    if ({0}Attr)\n  ", var->name);
 
       body << llvm::formatv("    result.addAttribute(\"{0}\", {0}Attr);\n",


        


More information about the Mlir-commits mailing list