[flang-commits] [flang] [mlir] [RFC][mlir] Conditional support for fast-math attributes. (PR #125620)
Slava Zakharin via flang-commits
flang-commits at lists.llvm.org
Fri Feb 7 12:42:51 PST 2025
https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/125620
>From 56e9d9578299854a92fd6b17d0bc0d29879754c6 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 3 Feb 2025 17:21:10 -0800
Subject: [PATCH 1/3] [RFC][mlir] Conditional support for fast-math attributes.
This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to `none`
only if they operate on floating point values.
This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for `arith.select` operation
that may produce results of any type.
The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.
LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380
ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 26 +++++++
.../flang/Optimizer/HLFIR/HLFIRDialect.h | 5 ++
.../include/flang/Optimizer/HLFIR/HLFIROps.td | 54 ++++++++++++++
flang/lib/Optimizer/Builder/FIRBuilder.cpp | 4 +-
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 12 +++-
flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp | 17 +++++
flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir | 2 +-
flang/test/Fir/tbaa.fir | 6 +-
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 14 ++++
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 69 ++++++++++++------
.../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 70 +++++++++++++------
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 11 +++
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp | 47 +++++++++++++
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 40 +++++++++++
mlir/test/Dialect/LLVMIR/inlining.mlir | 12 ++--
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 26 ++++---
mlir/test/Target/LLVMIR/omptarget-depend.mlir | 6 +-
17 files changed, 353 insertions(+), 68 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553def..497d099fbe9366a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}
+
+ /// Always allow FastMathFlags for fir.call's.
+ /// It is required to be able to propagate the call site's
+ /// FastMathFlags to the operations resulting from inlining
+ /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+ /// We could analyze the arguments' data types to see if there are
+ /// any floating point types, but this is unreliable. For example,
+ /// the runtime calls mostly take !fir.box<none> arguments,
+ /// and tracking them to the definitions may be not easy.
+ /// TODO: this should be restricted to fir.runtime calls,
+ /// because FastMathFlags for the user calls must come
+ /// from the function body, not the call site.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}
static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+ /// Always allow FastMathFlags on fir.cmpc.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+ // FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c2..0e6d536d9bde5d0 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c28..f90ef8ed019ceb7 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e71..d749fc9c633d7c5 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
- if (fmi) {
- // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
- // For now set the attribute by the name.
+ if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959e..fca3fb077d0a3fb 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(),
+ auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+ call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ auto fmi =
+ mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+ if (!fmi.isFastmathApplicable())
+ llvmCall->setAttr(
+ mlir::LLVM::CallOp::getFastmathAttrName(),
+ mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+ mlir::LLVM::FastmathFlags::none));
+ rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd560..53637f2090f2eff 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+ if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+ if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+
+ return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e8..b04188d3ee1d9ca 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
- %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+ %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6b..c2c9ad362370f6f 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b6..bd23890556ffddd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
+ let extraClassDeclaration = [{
+ bool isApplicable() { return true; }
+ }];
let hasVerifier = 1;
let hasFolder = 1;
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
let hasCanonicalizer = 1;
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs)}];
+
+ let extraClassDeclaration = [{
+ /// Always allow FastMathFlags on arith.cmpf.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da7..860c096ef2e8b9c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
let cppNamespace = "::mlir::arith";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
- /*returnType=*/ "FastMathFlagsAttr",
- /*methodName=*/ "getFastMathFlagsAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+ /*returnType=*/"FastMathFlagsAttr",
+ /*methodName=*/"getFastMathFlagsAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "StringRef",
- /*methodName=*/ "getFastMathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"StringRef",
+ /*methodName=*/"getFastMathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmath";
- }]
- >
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+ is applicable to the operation that supports
+ ArithFastMathInterface. If it returns false,
+ then the FastMathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isArithFastMathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
- ];
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point type
+ /// or contains one.
+ static bool isCompatibleType(::mlir::Type);
+
+ /// Default implementation of isArithFastMathApplicable().
+ /// It returns true iff any of the results of the operations
+ /// has a type that is compatible with fast-math.
+ bool isApplicableImpl();
+ }];
+
+ let verify = [{
+ auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+ auto attr = fmi.getFastMathFlagsAttr();
+ if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+ !fmi.isArithFastMathApplicable())
+ return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+ return ::mlir::success();
+ }];
}
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2b..ca55f933e4efad3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
- /*returnType=*/ "::mlir::LLVM::FastmathFlagsAttr",
- /*methodName=*/ "getFastmathAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+ /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+ /*methodName=*/"getFastmathAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "::llvm::StringRef",
- /*methodName=*/ "getFastmathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"::llvm::StringRef",
+ /*methodName=*/"getFastmathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmathFlags";
- }]
- >
- ];
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+ is applicable to the operation that supports
+ FastmathInterface. If it returns false,
+ then the FastmathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isFastmathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
+
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point type
+ /// or contains one.
+ static bool isCompatibleType(::mlir::Type);
+
+ /// Default implementation of isFastmathApplicable().
+ /// It returns true iff any of the results of the operations
+ /// has a type that is compatible with fast-math.
+ bool isApplicableImpl();
+ }];
+
+ let verify = [{
+ auto fmi = ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>($_op);
+ auto attr = fmi.getFastmathAttr();
+ if (attr && attr.getValue() != ::mlir::LLVM::FastmathFlags::none &&
+ !fmi.isFastmathApplicable())
+ return $_op->emitOpError() << "FastmathFlagsAttr is not applicable";
+ return ::mlir::success();
+ }];
}
def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index ee6e10efed4f16b..17267efc17a3a93 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -225,6 +225,17 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
// Set the $predicate index to -1 to indicate there is no matching operand
// and decrement the following indices.
list<int> llvmArgIndices = [-1, 0, 1, 2];
+
+ let extraClassDeclaration = [{
+ /// Always allow FastmathFlags on llvm.fcmp.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isFastmathApplicable() {
+ return true;
+ }
+ }];
}
// Floating point binary operations.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 042acf610090001..32f7cc2bd9b12b4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -66,3 +67,49 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
return ConstantOp::materialize(builder, value, type, loc);
}
+
+/// Return true if the type is compatible with fast math, i.e.
+/// it is a float type or contains a float type.
+bool arith::ArithFastMathInterface::isCompatibleType(Type type) {
+ if (isa<FloatType>(type))
+ return true;
+
+ // ShapeType's with ValueSemantics represent containers
+ // passed around as values (not references), so look inside
+ // them to see if the element type is compatible with FastMath.
+ if (type.hasTrait<ValueSemantics>())
+ if (auto shapedType = dyn_cast<ShapedType>(type))
+ return isCompatibleType(shapedType.getElementType());
+
+ // ComplexType's element type is always a FloatType.
+ if (auto complexType = dyn_cast<ComplexType>(type))
+ return true;
+
+ // TODO: what about TupleType and custom dialect struct-like types?
+ // It seems that they worth an interface to get to the list of element types.
+ //
+ // NOTE: LLVM only allows fast-math flags for instructions producing
+ // structures with homogeneous floating point members. I think
+ // this restriction must not be asserted here, because custom
+ // MLIR operations may be converted such that the original operation's
+ // FastMathFlags still need to be propagated to the target
+ // operations.
+
+ return false;
+}
+
+/// Return true if any of the results of the operation
+/// has a type compatible with fast math, i.e. it is a float type
+/// or contains a float type.
+///
+/// TODO: the results often have the same type, and traversing
+/// the same type again and again is not very efficient.
+/// We can cache it here for the duration of the processing.
+/// Other ideas?
+bool arith::ArithFastMathInterface::isApplicableImpl() {
+ Operation *op = getOperation();
+ if (llvm::any_of(op->getResults(),
+ [](Value v) { return isCompatibleType(v.getType()); }))
+ return true;
+ return false;
+}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a6e996f3fb810db..2593cdad1e65a81 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3783,3 +3783,43 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
+
+/// Return true if the type is compatible with fast math, i.e.
+/// it is a float type or contains a float type.
+bool mlir::LLVM::FastmathFlagsInterface::isCompatibleType(Type type) {
+ if (auto structType = dyn_cast<LLVMStructType>(type)) {
+ if (structType.isIdentified())
+ return false;
+ ArrayRef<Type> elementTypes = structType.getBody();
+ if (elementTypes.empty() || !llvm::all_equal(elementTypes))
+ return false;
+
+ type = elementTypes[0];
+ } else if (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
+ do {
+ type = arrayType.getElementType();
+ } while (arrayType = dyn_cast<LLVMArrayType>(type));
+ }
+
+ if (isa<FloatType>(type))
+ return true;
+
+ type =
+ TypeSwitch<Type, Type>(type)
+ .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>(
+ [](auto containerType) { return containerType.getElementType(); })
+ .Default(type);
+
+ return isa<FloatType>(type);
+}
+
+/// Return true if any of the results of the operation
+/// has a type compatible with fast math, i.e. it is a float type
+/// or contains a float type.
+bool mlir::LLVM::FastmathFlagsInterface::isApplicableImpl() {
+ Operation *op = getOperation();
+ if (llvm::any_of(op->getResults(),
+ [](Value v) { return isCompatibleType(v.getType()); }))
+ return true;
+ return false;
+}
diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir
index eb249a477175349..cf88c5d1d78ec3d 100644
--- a/mlir/test/Dialect/LLVMIR/inlining.mlir
+++ b/mlir/test/Dialect/LLVMIR/inlining.mlir
@@ -74,18 +74,18 @@ func.func @llvm_ret(%arg0 : i32) -> i32 {
// -----
// Include all function attributes that don't prevent inlining
-llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count = 42 : i64, dso_local } {
- %0 = llvm.mlir.constant(42 : i32) : i32
- llvm.return %0 : i32
+llvm.func internal fastcc @callee() -> (f32) attributes { function_entry_count = 42 : i64, dso_local } {
+ %0 = llvm.mlir.constant(42.0 : f32) : f32
+ llvm.return %0 : f32
}
// CHECK-LABEL: llvm.func @caller
// CHECK-NEXT: %[[CST:.+]] = llvm.mlir.constant
// CHECK-NEXT: llvm.return %[[CST]]
-llvm.func @caller() -> (i32) {
+llvm.func @caller() -> (f32) {
// Include all call attributes that don't prevent inlining.
- %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
- llvm.return %0 : i32
+ %0 = llvm.call fastcc @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (f32)
+ llvm.return %0 : f32
}
// -----
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 88660ce598f3c22..0f49539267c6df0 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -102,15 +102,15 @@ func.func @ops(%arg0: i32, %arg1: f32,
// Variadic calls
// CHECK: llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
-// CHECK: llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> ()
-// CHECK: %[[VARIADIC_FUNC:.*]] = llvm.mlir.addressof @vararg_func : !llvm.ptr
-// CHECK: llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
-// CHECK: llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> ()
+// CHECK: llvm.call @vararg_func_f32(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> f32
+// CHECK: %[[VARIADIC_FUNC:.*]] = llvm.mlir.addressof @vararg_func_f32 : !llvm.ptr
+// CHECK: llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<f32 (i32, ...)>) : !llvm.ptr, (i32, i32) -> f32
+// CHECK: llvm.call %[[VARIADIC_FUNC]](%[[I32]], %[[I32]]) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> f32
llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : (i32, i32) -> ()
- llvm.call @vararg_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> ()
- %variadic_func = llvm.mlir.addressof @vararg_func : !llvm.ptr
- llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
- llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> ()
+ %tmp1 = llvm.call @vararg_func_f32(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : (i32, i32) -> f32
+ %variadic_func = llvm.mlir.addressof @vararg_func_f32 : !llvm.ptr
+ llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) : !llvm.ptr, (i32, i32) -> f32
+ %tmp2 = llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<f32 (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> f32
// Function call attributes
// CHECK: llvm.call @baz() {convergent} : () -> ()
@@ -618,6 +618,9 @@ llvm.func @useInlineAsm(%arg0: i32) {
llvm.return
}
+// CHECK-LABEL: @fastmathStructReturn
+llvm.func @fastmathStructReturn(%arg0: i32) -> !llvm.struct<(f32, f32)>
+
// CHECK-LABEL: @fastmathFlags
func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f32>, %arg4: vector<2 x f32>) {
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<fast>} : f32
@@ -643,8 +646,8 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
%6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath<fast>} : f32
-// CHECK: {{.*}} = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
- %7 = llvm.call @foo(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(i32, f64, i32)>
+// CHECK: {{.*}} = llvm.call @fastmathStructReturn(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(f32, f32)>
+ %7 = llvm.call @fastmathStructReturn(%arg2) {fastmathFlags = #llvm.fastmath<fast>} : (i32) -> !llvm.struct<(f32, f32)>
// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32
%8 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<none>} : f32
@@ -700,6 +703,9 @@ llvm.func @invariant_group_intrinsics(%p: !llvm.ptr) {
llvm.return
}
+// CHECK-LABEL: @vararg_func_f32
+llvm.func @vararg_func_f32(%arg0: i32, ...) -> f32
+
// CHECK-LABEL: @vararg_func
llvm.func @vararg_func(%arg0: i32, ...) {
// CHECK: %[[C:.*]] = llvm.mlir.constant(1 : i32)
diff --git a/mlir/test/Target/LLVMIR/omptarget-depend.mlir b/mlir/test/Target/LLVMIR/omptarget-depend.mlir
index 71fecd0fa5fd0a7..723246c323f2caa 100644
--- a/mlir/test/Target/LLVMIR/omptarget-depend.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-depend.mlir
@@ -113,9 +113,9 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
llvm.func @main(%arg0: i32, %arg1: !llvm.ptr, %arg2: !llvm.ptr) -> i32 {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.zero : !llvm.ptr
- llvm.call @_FortranAProgramStart(%arg0, %arg1, %arg2, %1) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
- llvm.call @_QQmain() {fastmathFlags = #llvm.fastmath<contract>} : () -> ()
- llvm.call @_FortranAProgramEndStatement() {fastmathFlags = #llvm.fastmath<contract>} : () -> ()
+ llvm.call @_FortranAProgramStart(%arg0, %arg1, %arg2, %1) : (i32, !llvm.ptr, !llvm.ptr, !llvm.ptr) -> ()
+ llvm.call @_QQmain() : () -> ()
+ llvm.call @_FortranAProgramEndStatement() : () -> ()
llvm.return %0 : i32
}
}
>From a517b834582c7ab6b927cb372fab4ab73249fe40 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Tue, 4 Feb 2025 10:24:42 -0800
Subject: [PATCH 2/3] Addressed review comments.
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 5 +----
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 3 ---
mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td | 5 ++++-
mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td | 5 ++++-
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 2 +-
mlir/test/Dialect/LLVMIR/invalid.mlir | 8 ++++++++
6 files changed, 18 insertions(+), 10 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index fca3fb077d0a3fb..109524635b3b6a4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -596,10 +596,7 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
auto fmi =
mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
if (!fmi.isFastmathApplicable())
- llvmCall->setAttr(
- mlir::LLVM::CallOp::getFastmathAttrName(),
- mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
- mlir::LLVM::FastmathFlags::none));
+ llvmCall.setFastmathFlags(mlir::LLVM::FastmathFlags::none);
rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index bd23890556ffddd..bd5e05977b90991 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,9 +1211,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
- let extraClassDeclaration = [{
- bool isApplicable() { return true; }
- }];
let hasVerifier = 1;
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 860c096ef2e8b9c..446563a23d104d4 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -73,7 +73,10 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
auto attr = fmi.getFastMathFlagsAttr();
if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
!fmi.isArithFastMathApplicable())
- return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+ return $_op->emitOpError()
+ << "has flag(s) `" << stringifyEnum(attr.getValue())
+ << "`, but fast-math flags are not applicable "
+ "(`isArithFastMathApplicable()` returns false)";
return ::mlir::success();
}];
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index ca55f933e4efad3..b63cbb5c362cf5e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -73,7 +73,10 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
auto attr = fmi.getFastmathAttr();
if (attr && attr.getValue() != ::mlir::LLVM::FastmathFlags::none &&
!fmi.isFastmathApplicable())
- return $_op->emitOpError() << "FastmathFlagsAttr is not applicable";
+ return $_op->emitOpError()
+ << "has flag(s) `" << stringifyEnum(attr.getValue())
+ << "`, but fast-math flags are not applicable "
+ "(`isFastmathApplicable()` returns false)";
return ::mlir::success();
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 2593cdad1e65a81..07d26a089317b3a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3798,7 +3798,7 @@ bool mlir::LLVM::FastmathFlagsInterface::isCompatibleType(Type type) {
} else if (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
do {
type = arrayType.getElementType();
- } while (arrayType = dyn_cast<LLVMArrayType>(type));
+ } while ((arrayType = dyn_cast<LLVMArrayType>(type)));
}
if (isa<FloatType>(type))
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 5c939318fe3ed67..fab48a89dbb1d99 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1703,3 +1703,11 @@ llvm.func @wrong_number_of_bundle_tags() {
} : (i32, i32) -> ()
llvm.return
}
+
+// -----
+
+func.func @call_invalid_fastmath(%callee : !llvm.ptr) {
+ // expected-error at +1 {{has flag(s) `nsz, afn`, but fast-math flags are not applicable (`isFastmathApplicable()` returns false)}}
+ llvm.call %callee() {fastmathFlags = #llvm.fastmath<nsz,afn>} : !llvm.ptr, () -> i32
+ llvm.return
+}
>From 226bb38dd172529beb13bc5e504b7b993330c0b0 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Fri, 7 Feb 2025 12:42:20 -0800
Subject: [PATCH 3/3] Addressed review comments.
---
mlir/lib/Dialect/Arith/IR/ArithDialect.cpp | 21 +--------------------
1 file changed, 1 insertion(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
index 32f7cc2bd9b12b4..4123d88d57551e4 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
@@ -71,7 +71,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
/// Return true if the type is compatible with fast math, i.e.
/// it is a float type or contains a float type.
bool arith::ArithFastMathInterface::isCompatibleType(Type type) {
- if (isa<FloatType>(type))
+ if (isa<FloatType>(type) || isa<ComplexType>(type))
return true;
// ShapeType's with ValueSemantics represent containers
@@ -81,31 +81,12 @@ bool arith::ArithFastMathInterface::isCompatibleType(Type type) {
if (auto shapedType = dyn_cast<ShapedType>(type))
return isCompatibleType(shapedType.getElementType());
- // ComplexType's element type is always a FloatType.
- if (auto complexType = dyn_cast<ComplexType>(type))
- return true;
-
- // TODO: what about TupleType and custom dialect struct-like types?
- // It seems that they worth an interface to get to the list of element types.
- //
- // NOTE: LLVM only allows fast-math flags for instructions producing
- // structures with homogeneous floating point members. I think
- // this restriction must not be asserted here, because custom
- // MLIR operations may be converted such that the original operation's
- // FastMathFlags still need to be propagated to the target
- // operations.
-
return false;
}
/// Return true if any of the results of the operation
/// has a type compatible with fast math, i.e. it is a float type
/// or contains a float type.
-///
-/// TODO: the results often have the same type, and traversing
-/// the same type again and again is not very efficient.
-/// We can cache it here for the duration of the processing.
-/// Other ideas?
bool arith::ArithFastMathInterface::isApplicableImpl() {
Operation *op = getOperation();
if (llvm::any_of(op->getResults(),
More information about the flang-commits
mailing list