[flang-commits] [flang] [mlir] [RFC][mlir] Conditional support for fast-math attributes. (PR #125620)

via flang-commits flang-commits at lists.llvm.org
Mon Feb 3 18:23:43 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Slava Zakharin (vzakhari)

<details>
<summary>Changes</summary>

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to `none`
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for `arith.select` operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380
ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.


---

Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff


17 Files Affected:

- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+26) 
- (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+5) 
- (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+54) 
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+1-3) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+10-2) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+17) 
- (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1) 
- (modified) flang/test/Fir/tbaa.fir (+3-3) 
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+14) 
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+49-20) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+50-20) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11) 
- (modified) mlir/lib/Dialect/Arith/IR/ArithDialect.cpp (+47) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+40) 
- (modified) mlir/test/Dialect/LLVMIR/inlining.mlir (+6-6) 
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+16-10) 
- (modified) mlir/test/Target/LLVMIR/omptarget-depend.mlir (+3-3) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
                          llvm::cast<mlir::SymbolRefAttr>(callee));
       setOperand(0, llvm::cast<mlir::Value>(callee));
     }
+
+    /// Always allow FastMathFlags for fir.call's.
+    /// It is required to be able to propagate the call site's
+    /// FastMathFlags to the operations resulting from inlining
+    /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+    /// We could analyze the arguments' data types to see if there are
+    /// any floating point types, but this is unreliable. For example,
+    /// the runtime calls mostly take !fir.box<none> arguments,
+    /// and tracking them to the definitions may be not easy.
+    /// TODO: this should be restricted to fir.runtime calls,
+    /// because FastMathFlags for the user calls must come
+    /// from the function body, not the call site.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
     }
 
     static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+    /// Always allow FastMathFlags on fir.cmpc.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
     static bool isPointerCompatible(mlir::Type ty);
     static bool canBeConverted(mlir::Type inType, mlir::Type outType);
     static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+    // FIXME: fir.convert should support ArithFastMathInterface.
   }];
   let hasCanonicalizer = 1;
 }
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
 /// Scalar integer or a sequence of integers (via boxed array or expr).
 bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
 
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
   let hasCanonicalizeMethod = 1;
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
 
 void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
-  if (fmi) {
-    // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
-    //       For now set the attribute by the name.
+  if (fmi && fmi.isArithFastMathApplicable()) {
     llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
     if (fastMathFlags != mlir::arith::FastMathFlags::none)
       op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        call, resultTys, adaptor.getOperands(),
+    auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+        call.getLoc(), resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    auto fmi =
+        mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+    if (!fmi.isFastmathApplicable())
+      llvmCall->setAttr(
+          mlir::LLVM::CallOp::getFastmathAttrName(),
+          mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+                                             mlir::LLVM::FastmathFlags::none));
+    rewriter.replaceOp(call, llvmCall);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
   mlir::Type elementType = getFortranElementType(unwrappedType);
   return mlir::isa<mlir::IntegerType>(elementType);
 }
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+  if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+  if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+
+  return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
     %45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
     gpu.launch_func  @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
     %46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
-    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
     %48 = llvm.mlir.constant(9 : i32) : i32
     %49 = llvm.mlir.zero : !llvm.ptr
     %50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
 // CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
 // CHECK:           %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
 // CHECK:           llvm.return
 // CHECK:         }
 // CHECK:         llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
     The destination type must to be strictly wider than the source type.
     When operating on vectors, casts elementwise.
   }];
+  let extraClassDeclaration = [{
+    bool isApplicable() { return true; }
+  }];
   let hasVerifier = 1;
   let hasFolder = 1;
 
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
   let hasCanonicalizer = 1;
   let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs)}];
+
+  let extraClassDeclaration = [{
+    /// Always allow FastMathFlags on arith.cmpf.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
 
   let cppNamespace = "::mlir::arith";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastMathFlagsAttr attribute for the operation",
-      /*returnType=*/  "FastMathFlagsAttr",
-      /*methodName=*/  "getFastMathFlagsAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+           /*returnType=*/"FastMathFlagsAttr",
+           /*methodName=*/"getFastMathFlagsAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastMathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getFastMathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"StringRef",
+           /*methodName=*/"getFastMathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmath";
-      }]
-      >
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         ArithFastMathInterface. If it returns false,
+                         then the FastMathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isArithFastMathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
 
-  ];
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point type
+    /// or contains one.
+    static bool isCompatibleType(::mlir::Type);
+
+    /// Default implementation of isArithFastMathApplicable().
+    /// It returns true iff any of the results of the operations
+    /// has a type that is compatible with fast-math.
+    bool isApplicableImpl();
+  }];
+
+  let verify = [{
+    auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+    auto attr = fmi.getFastMathFlagsAttr();
+    if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+        !fmi.isArithFastMathApplicable())
+      return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+    return ::mlir::success();
+  }];
 }
 
 def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
   let cppNamespace = "::mlir::LLVM";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastmathFlagsAttr attribute for the operation",
-      /*returnType=*/  "::mlir::LLVM::FastmathFlagsAttr",
-      /*methodName=*/  "getFastmathAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+           /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+           /*methodName=*/"getFastmathAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastmathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "::llvm::StringRef",
-      /*methodName=*/  "getFastmathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"::llvm::StringRef",
+           /*methodName=*/"getFastmathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmathFlags";
-      }]
-      >
-  ];
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         FastmathInterface. If it returns false,
+                         then the FastmathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isFastmathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
+
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point typ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125620


More information about the flang-commits mailing list