[flang-commits] [flang] [mlir] [RFC][mlir] Conditional support for fast-math attributes. (PR #125620)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Fri Feb 7 12:42:51 PST 2025


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/125620

>From 56e9d9578299854a92fd6b17d0bc0d29879754c6 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 3 Feb 2025 17:21:10 -0800
Subject: [PATCH 1/3] [RFC][mlir] Conditional support for fast-math attributes.

This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to `none`
only if they operate on floating point values.

This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for `arith.select` operation
that may produce results of any type.

The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.

LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380
ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.
---
 .../include/flang/Optimizer/Dialect/FIROps.td | 26 +++++++
 .../flang/Optimizer/HLFIR/HLFIRDialect.h      |  5 ++
 .../include/flang/Optimizer/HLFIR/HLFIROps.td | 54 ++++++++++++++
 flang/lib/Optimizer/Builder/FIRBuilder.cpp    |  4 +-
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 12 +++-
 flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 17 +++++
 flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir |  2 +-
 flang/test/Fir/tbaa.fir                       |  6 +-
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 14 ++++
 .../Dialect/Arith/IR/ArithOpsInterfaces.td    | 69 ++++++++++++------
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 70 +++++++++++++------
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   | 11 +++
 mlir/lib/Dialect/Arith/IR/ArithDialect.cpp    | 47 +++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 40 +++++++++++
 mlir/test/Dialect/LLVMIR/inlining.mlir        | 12 ++--
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       | 26 ++++---
 mlir/test/Target/LLVMIR/omptarget-depend.mlir |  6 +-
 17 files changed, 353 insertions(+), 68 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553def..497d099fbe9366a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
                          llvm::cast<mlir::SymbolRefAttr>(callee));
       setOperand(0, llvm::cast<mlir::Value>(callee));
     }
+
+    /// Always allow FastMathFlags for fir.call's.
+    /// It is required to be able to propagate the call site's
+    /// FastMathFlags to the operations resulting from inlining
+    /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+    /// We could analyze the arguments' data types to see if there are
+    /// any floating point types, but this is unreliable. For example,
+    /// the runtime calls mostly take !fir.box<none> arguments,
+    /// and tracking them to the definitions may be not easy.
+    /// TODO: this should be restricted to fir.runtime calls,
+    /// because FastMathFlags for the user calls must come
+    /// from the function body, not the call site.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
     }
 
     static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+    /// Always allow FastMathFlags on fir.cmpc.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
   }];
 }
 
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
     static bool isPointerCompatible(mlir::Type ty);
     static bool canBeConverted(mlir::Type inType, mlir::Type outType);
     static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+    // FIXME: fir.convert should support ArithFastMathInterface.
   }];
   let hasCanonicalizer = 1;
 }
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c2..0e6d536d9bde5d0 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
 /// Scalar integer or a sequence of integers (via boxed array or expr).
 bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
 
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
 } // namespace hlfir
 
 #endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c28..f90ef8ed019ceb7 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
   let hasCanonicalizeMethod = 1;
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
   }];
 
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool isArithFastMathApplicable() {
+      return hlfir::isArithFastMathApplicable(getOperation());
+    }
+  }];
 }
 
 def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e71..d749fc9c633d7c5 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
 
 void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
   auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
-  if (fmi) {
-    // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
-    //       For now set the attribute by the name.
+  if (fmi && fmi.isArithFastMathApplicable()) {
     llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
     if (fastMathFlags != mlir::arith::FastMathFlags::none)
       op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959e..fca3fb077d0a3fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
-        call, resultTys, adaptor.getOperands(),
+    auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+        call.getLoc(), resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    auto fmi =
+        mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+    if (!fmi.isFastmathApplicable())
+      llvmCall->setAttr(
+          mlir::LLVM::CallOp::getFastmathAttrName(),
+          mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+                                             mlir::LLVM::FastmathFlags::none));
+    rewriter.replaceOp(call, llvmCall);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd560..53637f2090f2eff 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
   mlir::Type elementType = getFortranElementType(unwrappedType);
   return mlir::isa<mlir::IntegerType>(elementType);
 }
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+  if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+  if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+        mlir::Type elementType = getFortranElementType(v.getType());
+        return mlir::arith::ArithFastMathInterface::isCompatibleType(
+            elementType);
+      }))
+    return true;
+
+  return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e8..b04188d3ee1d9ca 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
     %45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
     gpu.launch_func  @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
     %46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
-    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+    %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
     %48 = llvm.mlir.constant(9 : i32) : i32
     %49 = llvm.mlir.zero : !llvm.ptr
     %50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6b..c2c9ad362370f6f 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
 // CHECK:           %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
 // CHECK:           %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
 // CHECK:           %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK:           %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
 // CHECK:           %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
 // CHECK:           "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
 // CHECK:           %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
 // CHECK:           %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
 // CHECK:           llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK:           %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK:           %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
 // CHECK:           llvm.return
 // CHECK:         }
 // CHECK:         llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b6..bd23890556ffddd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
     The destination type must to be strictly wider than the source type.
     When operating on vectors, casts elementwise.
   }];
+  let extraClassDeclaration = [{
+    bool isApplicable() { return true; }
+  }];
   let hasVerifier = 1;
   let hasFolder = 1;
 
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
   let hasCanonicalizer = 1;
   let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($lhs)}];
+
+  let extraClassDeclaration = [{
+    /// Always allow FastMathFlags on arith.cmpf.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isArithFastMathApplicable() {
+      return true;
+    }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da7..860c096ef2e8b9c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
 
   let cppNamespace = "::mlir::arith";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastMathFlagsAttr attribute for the operation",
-      /*returnType=*/  "FastMathFlagsAttr",
-      /*methodName=*/  "getFastMathFlagsAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+           /*returnType=*/"FastMathFlagsAttr",
+           /*methodName=*/"getFastMathFlagsAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         ConcreteOp op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastMathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getFastMathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"StringRef",
+           /*methodName=*/"getFastMathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmath";
-      }]
-      >
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         ArithFastMathInterface. If it returns false,
+                         then the FastMathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isArithFastMathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
 
-  ];
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point type
+    /// or contains one.
+    static bool isCompatibleType(::mlir::Type);
+
+    /// Default implementation of isArithFastMathApplicable().
+    /// It returns true iff any of the results of the operations
+    /// has a type that is compatible with fast-math.
+    bool isApplicableImpl();
+  }];
+
+  let verify = [{
+    auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+    auto attr = fmi.getFastMathFlagsAttr();
+    if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+        !fmi.isArithFastMathApplicable())
+      return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+    return ::mlir::success();
+  }];
 }
 
 def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2b..ca55f933e4efad3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
   let cppNamespace = "::mlir::LLVM";
 
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns a FastmathFlagsAttr attribute for the operation",
-      /*returnType=*/  "::mlir::LLVM::FastmathFlagsAttr",
-      /*methodName=*/  "getFastmathAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+           /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+           /*methodName=*/"getFastmathAttr",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the FastmathFlagsAttr attribute
+      }]>,
+       StaticInterfaceMethod<
+           /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
                          for the operation}],
-      /*returnType=*/  "::llvm::StringRef",
-      /*methodName=*/  "getFastmathAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
+           /*returnType=*/"::llvm::StringRef",
+           /*methodName=*/"getFastmathAttrName",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
         return "fastmathFlags";
-      }]
-      >
-  ];
+      }]>,
+       InterfaceMethod<
+           /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+                         is applicable to the operation that supports
+                         FastmathInterface. If it returns false,
+                         then the FastmathFlagsAttr of the operation
+                         must be nullptr or have 'none' value}],
+           /*returnType=*/"bool",
+           /*methodName=*/"isFastmathApplicable",
+           /*args=*/(ins),
+           /*methodBody=*/[{}],
+           /*defaultImpl=*/[{
+        return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+      }]>];
+
+  let extraClassDeclaration = [{
+    /// Returns true iff the given type is a floating point type
+    /// or contains one.
+    static bool isCompatibleType(::mlir::Type);
+
+    /// Default implementation of isFastmathApplicable().
+    /// It returns true iff any of the results of the operations
+    /// has a type that is compatible with fast-math.
+    bool isApplicableImpl();
+  }];
+
+  let verify = [{
+    auto fmi = ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>($_op);
+    auto attr = fmi.getFastmathAttr();
+    if (attr && attr.getValue() != ::mlir::LLVM::FastmathFlags::none &&
+        !fmi.isFastmathApplicable())
+      return $_op->emitOpError() << "FastmathFlagsAttr is not applicable";
+    return ::mlir::success();
+  }];
 }
 
 def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ee6e10efed4f16b..17267efc17a3a93 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -225,6 +225,17 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
   // Set the $predicate index to -1 to indicate there is no matching operand
   // and decrement the following indices.
   list<int> llvmArgIndices = [-1, 0, 1, 2];
+
+  let extraClassDeclaration = [{
+    /// Always allow FastmathFlags on llvm.fcmp.
+    /// It does not produce a floating point result, but
+    /// LLVM is currently relying on fast-math flags attached
+    /// to floating point comparison.
+    /// This can be removed whenever LLVM stops doing it.
+    bool isFastmathApplicable() {
+      return true;
+    }
+  }];
 }
 
 // Floating point binary operations.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 042acf610090001..32f7cc2bd9b12b4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -66,3 +67,49 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
 
   return ConstantOp::materialize(builder, value, type, loc);
 }
+
+/// Return true if the type is compatible with fast math, i.e.
+/// it is a float type or contains a float type.
+bool arith::ArithFastMathInterface::isCompatibleType(Type type) {
+  if (isa<FloatType>(type))
+    return true;
+
+  // ShapeType's with ValueSemantics represent containers
+  // passed around as values (not references), so look inside
+  // them to see if the element type is compatible with FastMath.
+  if (type.hasTrait<ValueSemantics>())
+    if (auto shapedType = dyn_cast<ShapedType>(type))
+      return isCompatibleType(shapedType.getElementType());
+
+  // ComplexType's element type is always a FloatType.
+  if (auto complexType = dyn_cast<ComplexType>(type))
+    return true;
+
+  // TODO: what about TupleType and custom dialect struct-like types?
+  // It seems that they worth an interface to get to the list of element types.
+  //
+  // NOTE: LLVM only allows fast-math flags for instructions producing
+  // structures with homogeneous floating point members. I think
+  // this restriction must not be asserted here, because custom
+  // MLIR operations may be converted such that the original operation's
+  // FastMathFlags still need to be propagated to the target
+  // operations.
+
+  return false;
+}
+
+/// Return true if any of the results of the operation
+/// has a type compatible with fast math, i.e. it is a float type
+/// or contains a float type.
+///
+/// TODO: the results often have the same type, and traversing
+/// the same type again and again is not very efficient.
+/// We can cache it here for the duration of the processing.
+/// Other ideas?
+bool arith::ArithFastMathInterface::isApplicableImpl() {
+  Operation *op = getOperation();
+  if (llvm::any_of(op->getResults(),
+                   [](Value v) { return isCompatibleType(v.getType()); }))
+    return true;
+  return false;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a6e996f3fb810db..2593cdad1e65a81 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3783,3 +3783,43 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
   return op->hasTrait<OpTrait::SymbolTable>() &&
          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
 }
+
+/// Return true if the type is compatible with fast math, i.e.
+/// it is a float type or contains a float type.
+bool mlir::LLVM::FastmathFlagsInterface::isCompatibleType(Type type) {
+  if (auto structType = dyn_cast<LLVMStructType>(type)) {
+    if (structType.isIdentified())
+      return false;
+    ArrayRef<Type> elementTypes = structType.getBody();
+    if (elementTypes.empty() || !llvm::all_equal(elementTypes))
+      return false;
+
+    type = elementTypes[0];
+  } else if (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
+    do {
+      type = arrayType.getElementType();
+    } while (arrayType = dyn_cast<LLVMArrayType>(type));
+  }
+
+  if (isa<FloatType>(type))
+    return true;
+
+  type =
+      TypeSwitch<Type, Type>(type)
+          .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
+              [](auto containerType) { return containerType.getElementType(); })
+          .Default(type);
+
+  return isa<FloatType>(type);
+}
+
+/// Return true if any of the results of the operation
+/// has a type compatible with fast math, i.e. it is a float type
+/// or contains a float type.
+bool mlir::LLVM::FastmathFlagsInterface::isApplicableImpl() {
+  Operation *op = getOperation();
+  if (llvm::any_of(op->getResults(),
+                   [](Value v) { return isCompatibleType(v.getType()); }))
+    return true;
+  return false;
+}
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index eb249a477175349..cf88c5d1d78ec3d 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -74,18 +74,18 @@ func.func @llvm_ret(%arg0 : i32) -> i32 {
 // -----
 
 // Include all function attributes that don't prevent inlining
-llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count = 42 : i64, dso_local } {
-  %0 = llvm.mlir.constant(42 : i32) : i32
-  llvm.return %0 : i32
+llvm.func internal fastcc @callee() -> (f32) attributes { function_entry_count = 42 : i64, dso_local } {
+  %0 = llvm.mlir.constant(42.0 : f32) : f32
+  llvm.return %0 : f32
 }
 
 // CHECK-LABEL: llvm.func @caller
 // CHECK-NEXT: %[[CST:.+]] = llvm.mlir.constant
 // CHECK-NEXT: llvm.return %[[CST]]
-llvm.func @caller() -> (i32) {
+llvm.func @caller() -> (f32) {
   // Include all call attributes that don't prevent inlining.
-  %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
-  llvm.return %0 : i32
+  %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (f32)
+  llvm.return %0 : f32
 }
 
 // -----
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 88660ce598f3c22..0f49539267c6df0 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -102,15 +102,15 @@ func.func @ops(%arg0: i32, %arg1: f32,
 
 // Variadic calls
 // CHECK:  llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
-// CHECK:  llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> ()
-// CHECK:  %[[VARIADIC_FUNC:.*]] = llvm.mlir.addressof @vararg_func : !llvm.ptr
-// CHECK:  llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
-// CHECK:  llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> ()
+// CHECK:  llvm.call @vararg_func_f32(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> f32
+// CHECK:  %[[VARIADIC_FUNC:.*]] = llvm.mlir.addressof @vararg_func_f32 : !llvm.ptr
+// CHECK:  llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<f32 (i32, ...)>) : !llvm.ptr, (i32, i32) -> f32
+// CHECK:  llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> f32
   llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
-  llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> ()
-  %variadic_func = llvm.mlir.addressof @vararg_func : !llvm.ptr
-  llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
-  llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> ()
+  %tmp1 = llvm.call @vararg_func_f32(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> f32
+  %variadic_func = llvm.mlir.addressof @vararg_func_f32 : !llvm.ptr
+  llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) : !llvm.ptr, (i32, i32) -> f32
+  %tmp2 = llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> f32
 
 // Function call attributes
 // CHECK: llvm.call @baz() {convergent} : () -> ()
@@ -618,6 +618,9 @@ llvm.func @useInlineAsm(%arg0: i32) {
   llvm.return
 }
 
+// CHECK-LABEL: @fastmathStructReturn
+llvm.func @fastmathStructReturn(%arg0: i32) -> !llvm.struct<(f32, f32)>
+
 // CHECK-LABEL: @fastmathFlags
 func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f32>, %arg4: vector<2 x f32>) {
 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : f32
@@ -643,8 +646,8 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
 // CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
   %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
 
-// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
-  %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
+// CHECK: {{.*}} = llvm.call @fastmathStructReturn(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(f32, f32)>
+  %7 = llvm.call @fastmathStructReturn(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(f32, f32)>
 
 // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
   %8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<none>} : f32
@@ -700,6 +703,9 @@ llvm.func @invariant_group_intrinsics(%p: !llvm.ptr) {
   llvm.return
 }
 
+// CHECK-LABEL: @vararg_func_f32
+llvm.func @vararg_func_f32(%arg0: i32, ...) -> f32
+
 // CHECK-LABEL: @vararg_func
 llvm.func @vararg_func(%arg0: i32, ...) {
   // CHECK: %[[C:.*]] = llvm.mlir.constant(1 : i32)
diff --git a/mlir/test/Target/LLVMIR/omptarget-depend.mlir b/mlir/test/Target/LLVMIR/omptarget-depend.mlir
index 71fecd0fa5fd0a7..723246c323f2caa 100644
--- a/mlir/test/Target/LLVMIR/omptarget-depend.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-depend.mlir
@@ -113,9 +113,9 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
   llvm.func @main(%arg0: i32, %arg1: !llvm.ptr, %arg2: !llvm.ptr) -> i32 {
     %0 = llvm.mlir.constant(0 : i32) : i32
     %1 = llvm.mlir.zero : !llvm.ptr
-    llvm.call @_FortranAProgramStart(%arg0, %arg1, %arg2, %1) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
-    llvm.call @_QQmain() {fastmathFlags = #llvm.fastmath<contract>} : () -> ()
-    llvm.call @_FortranAProgramEndStatement() {fastmathFlags = #llvm.fastmath<contract>} : () -> ()
+    llvm.call @_FortranAProgramStart(%arg0, %arg1, %arg2, %1) : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+    llvm.call @_QQmain() : () -> ()
+    llvm.call @_FortranAProgramEndStatement() : () -> ()
     llvm.return %0 : i32
   }
 }

>From a517b834582c7ab6b927cb372fab4ab73249fe40 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 4 Feb 2025 10:24:42 -0800
Subject: [PATCH 2/3] Addressed review comments.

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp                  | 5 +----
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td           | 3 ---
 mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td | 5 ++++-
 mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td       | 5 ++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp               | 2 +-
 mlir/test/Dialect/LLVMIR/invalid.mlir                    | 8 ++++++++
 6 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index fca3fb077d0a3fb..109524635b3b6a4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -596,10 +596,7 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     auto fmi =
         mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
     if (!fmi.isFastmathApplicable())
-      llvmCall->setAttr(
-          mlir::LLVM::CallOp::getFastmathAttrName(),
-          mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
-                                             mlir::LLVM::FastmathFlags::none));
+      llvmCall.setFastmathFlags(mlir::LLVM::FastmathFlags::none);
     rewriter.replaceOp(call, llvmCall);
     return mlir::success();
   }
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index bd23890556ffddd..bd5e05977b90991 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,9 +1211,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
     The destination type must to be strictly wider than the source type.
     When operating on vectors, casts elementwise.
   }];
-  let extraClassDeclaration = [{
-    bool isApplicable() { return true; }
-  }];
   let hasVerifier = 1;
   let hasFolder = 1;
 
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 860c096ef2e8b9c..446563a23d104d4 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -73,7 +73,10 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
     auto attr = fmi.getFastMathFlagsAttr();
     if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
         !fmi.isArithFastMathApplicable())
-      return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+      return $_op->emitOpError()
+          << "has flag(s) `" << stringifyEnum(attr.getValue())
+          << "`, but fast-math flags are not applicable "
+             "(`isArithFastMathApplicable()` returns false)";
     return ::mlir::success();
   }];
 }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index ca55f933e4efad3..b63cbb5c362cf5e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -73,7 +73,10 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
     auto attr = fmi.getFastmathAttr();
     if (attr && attr.getValue() != ::mlir::LLVM::FastmathFlags::none &&
         !fmi.isFastmathApplicable())
-      return $_op->emitOpError() << "FastmathFlagsAttr is not applicable";
+      return $_op->emitOpError()
+          << "has flag(s) `" << stringifyEnum(attr.getValue())
+          << "`, but fast-math flags are not applicable "
+             "(`isFastmathApplicable()` returns false)";
     return ::mlir::success();
   }];
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2593cdad1e65a81..07d26a089317b3a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3798,7 +3798,7 @@ bool mlir::LLVM::FastmathFlagsInterface::isCompatibleType(Type type) {
   } else if (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
     do {
       type = arrayType.getElementType();
-    } while (arrayType = dyn_cast<LLVMArrayType>(type));
+    } while ((arrayType = dyn_cast<LLVMArrayType>(type)));
   }
 
   if (isa<FloatType>(type))
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 5c939318fe3ed67..fab48a89dbb1d99 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1703,3 +1703,11 @@ llvm.func @wrong_number_of_bundle_tags() {
   } : (i32, i32) -> ()
   llvm.return
 }
+
+// -----
+
+func.func @call_invalid_fastmath(%callee : !llvm.ptr) {
+  // expected-error at +1 {{has flag(s) `nsz, afn`, but fast-math flags are not applicable (`isFastmathApplicable()` returns false)}}
+  llvm.call %callee() {fastmathFlags = #llvm.fastmath<nsz,afn>} : !llvm.ptr, () -> i32
+  llvm.return
+}

>From 226bb38dd172529beb13bc5e504b7b993330c0b0 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Fri, 7 Feb 2025 12:42:20 -0800
Subject: [PATCH 3/3] Addressed review comments.

---
 mlir/lib/Dialect/Arith/IR/ArithDialect.cpp | 21 +--------------------
 1 file changed, 1 insertion(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 32f7cc2bd9b12b4..4123d88d57551e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -71,7 +71,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
 /// Return true if the type is compatible with fast math, i.e.
 /// it is a float type or contains a float type.
 bool arith::ArithFastMathInterface::isCompatibleType(Type type) {
-  if (isa<FloatType>(type))
+  if (isa<FloatType>(type) || isa<ComplexType>(type))
     return true;
 
   // ShapeType's with ValueSemantics represent containers
@@ -81,31 +81,12 @@ bool arith::ArithFastMathInterface::isCompatibleType(Type type) {
     if (auto shapedType = dyn_cast<ShapedType>(type))
       return isCompatibleType(shapedType.getElementType());
 
-  // ComplexType's element type is always a FloatType.
-  if (auto complexType = dyn_cast<ComplexType>(type))
-    return true;
-
-  // TODO: what about TupleType and custom dialect struct-like types?
-  // It seems that they worth an interface to get to the list of element types.
-  //
-  // NOTE: LLVM only allows fast-math flags for instructions producing
-  // structures with homogeneous floating point members. I think
-  // this restriction must not be asserted here, because custom
-  // MLIR operations may be converted such that the original operation's
-  // FastMathFlags still need to be propagated to the target
-  // operations.
-
   return false;
 }
 
 /// Return true if any of the results of the operation
 /// has a type compatible with fast math, i.e. it is a float type
 /// or contains a float type.
-///
-/// TODO: the results often have the same type, and traversing
-/// the same type again and again is not very efficient.
-/// We can cache it here for the duration of the processing.
-/// Other ideas?
 bool arith::ArithFastMathInterface::isApplicableImpl() {
   Operation *op = getOperation();
   if (llvm::any_of(op->getResults(),



More information about the flang-commits mailing list