[flang-commits] [flang] [mlir] [RFC][mlir] Conditional support for fast-math attributes. (PR #125620)
via flang-commits
flang-commits at lists.llvm.org
Mon Feb 3 18:23:43 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Slava Zakharin (vzakhari)
<details>
<summary>Changes</summary>
This patch suggests changes for operations that support
arith::ArithFastMathInterface/LLVM::FastmathFlagsInterface.
Some of the operations may have fast-math flags not equal to `none`
only if they operate on floating point values.
This is inspired by https://llvm.org/docs/LangRef.html#fastmath-return-types
and my goal to add fast-math support for `arith.select` operation
that may produce results of any type.
The changes add new isArithFastMathApplicable/isFastmathApplicable
methods to the above interfaces that tell whether an operation
supporting the interface may have non-none fast-math flags.
LLVM dialect isFastmathApplicable implementation is based on https://github.com/llvm/llvm-project/blob/bac62ee5b473e70981a6bd9759ec316315fca07d/llvm/include/llvm/IR/Operator.h#L380
ARITH dialect isArithFastMathApplicable is more relaxed, because
it has to support custom MLIR types. This is the area where
improvements are needed (see TODO comments). I will appreciate
feedback here.
HLFIR dialect is a another example where conditional fast-math
support may be applied currently.
---
Patch is 32.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125620.diff
17 Files Affected:
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+26)
- (modified) flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h (+5)
- (modified) flang/include/flang/Optimizer/HLFIR/HLFIROps.td (+54)
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+1-3)
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+10-2)
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp (+17)
- (modified) flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir (+1-1)
- (modified) flang/test/Fir/tbaa.fir (+3-3)
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+14)
- (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td (+49-20)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+50-20)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+11)
- (modified) mlir/lib/Dialect/Arith/IR/ArithDialect.cpp (+47)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+40)
- (modified) mlir/test/Dialect/LLVMIR/inlining.mlir (+6-6)
- (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+16-10)
- (modified) mlir/test/Target/LLVMIR/omptarget-depend.mlir (+3-3)
``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8dbc9df9f553de..497d099fbe9366 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}
+
+ /// Always allow FastMathFlags for fir.call's.
+ /// It is required to be able to propagate the call site's
+ /// FastMathFlags to the operations resulting from inlining
+ /// (if any) of a fir.call (see SimplifyIntrinsics pass).
+ /// We could analyze the arguments' data types to see if there are
+ /// any floating point types, but this is unreliable. For example,
+ /// the runtime calls mostly take !fir.box<none> arguments,
+ /// and tracking them to the definitions may be not easy.
+ /// TODO: this should be restricted to fir.runtime calls,
+ /// because FastMathFlags for the user calls must come
+ /// from the function body, not the call site.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}
static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);
+
+ /// Always allow FastMathFlags on fir.cmpc.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
}];
}
@@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);
+
+ // FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
index 15296aa7e8c75c..0e6d536d9bde5d 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
@@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);
+/// Return true iff FastMathFlagsAttr is applicable
+/// to the given HLFIR dialect operation that supports
+/// ArithFastMathInterface.
+bool isArithFastMathApplicable(mlir::Operation *op);
+
} // namespace hlfir
#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index f4102538efc3c2..f90ef8ed019ceb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
@@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
@@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
@@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
@@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_SetLengthOp : hlfir_Op<"set_length",
@@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_DotProductOp : hlfir_Op<"dot_product",
@@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_MatmulOp : hlfir_Op<"matmul",
@@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_TransposeOp : hlfir_Op<"transpose",
@@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool isArithFastMathApplicable() {
+ return hlfir::isArithFastMathApplicable(getOperation());
+ }
+ }];
}
def hlfir_CShiftOp
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index d9779c46ae79e7..d749fc9c633d7c 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,
void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
- if (fmi) {
- // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
- // For now set the attribute by the name.
+ if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959..fca3fb077d0a3f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,18 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
- call, resultTys, adaptor.getOperands(),
+ auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
+ call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ auto fmi =
+ mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
+ if (!fmi.isFastmathApplicable())
+ llvmCall->setAttr(
+ mlir::LLVM::CallOp::getFastmathAttrName(),
+ mlir::LLVM::FastmathFlagsAttr::get(call.getContext(),
+ mlir::LLVM::FastmathFlags::none));
+ rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
index cb77aef74acd56..53637f2090f2ef 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
@@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}
+
+bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
+ if (llvm::any_of(op->getResults(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+ if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
+ mlir::Type elementType = getFortranElementType(v.getType());
+ return mlir::arith::ArithFastMathInterface::isCompatibleType(
+ elementType);
+ }))
+ return true;
+
+ return true;
+}
diff --git a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
index 0827e378c7c07e..b04188d3ee1d9c 100644
--- a/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
+++ b/flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
@@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
- %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+ %47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index 401ebbc8c49fe6..c2c9ad362370f6 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
-// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
@@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
-// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
-// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
+// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
+// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index ea9b0f6509b80b..bd23890556ffdd 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1211,6 +1211,9 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
The destination type must to be strictly wider than the source type.
When operating on vectors, casts elementwise.
}];
+ let extraClassDeclaration = [{
+ bool isApplicable() { return true; }
+ }];
let hasVerifier = 1;
let hasFolder = 1;
@@ -1545,6 +1548,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
let hasCanonicalizer = 1;
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs)}];
+
+ let extraClassDeclaration = [{
+ /// Always allow FastMathFlags on arith.cmpf.
+ /// It does not produce a floating point result, but
+ /// LLVM is currently relying on fast-math flags attached
+ /// to floating point comparison.
+ /// This can be removed whenever LLVM stops doing it.
+ bool isArithFastMathApplicable() {
+ return true;
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 82d6c9ad6b03da..860c096ef2e8b9 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -22,31 +22,60 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
let cppNamespace = "::mlir::arith";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
- /*returnType=*/ "FastMathFlagsAttr",
- /*methodName=*/ "getFastMathFlagsAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
+ /*returnType=*/"FastMathFlagsAttr",
+ /*methodName=*/"getFastMathFlagsAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "StringRef",
- /*methodName=*/ "getFastMathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"StringRef",
+ /*methodName=*/"getFastMathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmath";
- }]
- >
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastMathFlagsAttr attribute
+ is applicable to the operation that supports
+ ArithFastMathInterface. If it returns false,
+ then the FastMathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isArithFastMathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
- ];
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point type
+ /// or contains one.
+ static bool isCompatibleType(::mlir::Type);
+
+ /// Default implementation of isArithFastMathApplicable().
+ /// It returns true iff any of the results of the operations
+ /// has a type that is compatible with fast-math.
+ bool isApplicableImpl();
+ }];
+
+ let verify = [{
+ auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
+ auto attr = fmi.getFastMathFlagsAttr();
+ if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
+ !fmi.isArithFastMathApplicable())
+ return $_op->emitOpError() << "FastMathFlagsAttr is not applicable";
+ return ::mlir::success();
+ }];
}
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 5ccddef158d9c2..ca55f933e4efad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -22,30 +22,60 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
let cppNamespace = "::mlir::LLVM";
- let methods = [
- InterfaceMethod<
- /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation",
- /*returnType=*/ "::mlir::LLVM::FastmathFlagsAttr",
- /*methodName=*/ "getFastmathAttr",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/"Returns a FastmathFlagsAttr attribute for the operation",
+ /*returnType=*/"::mlir::LLVM::FastmathFlagsAttr",
+ /*methodName=*/"getFastmathAttr",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
- }]
- >,
- StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute
+ }]>,
+ StaticInterfaceMethod<
+ /*desc=*/[{Returns the name of the FastmathFlagsAttr attribute
for the operation}],
- /*returnType=*/ "::llvm::StringRef",
- /*methodName=*/ "getFastmathAttrName",
- /*args=*/ (ins),
- /*methodBody=*/ [{}],
- /*defaultImpl=*/ [{
+ /*returnType=*/"::llvm::StringRef",
+ /*methodName=*/"getFastmathAttrName",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
return "fastmathFlags";
- }]
- >
- ];
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{Returns true iff FastmathFlagsAttr attribute
+ is applicable to the operation that supports
+ FastmathInterface. If it returns false,
+ then the FastmathFlagsAttr of the operation
+ must be nullptr or have 'none' value}],
+ /*returnType=*/"bool",
+ /*methodName=*/"isFastmathApplicable",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImpl=*/[{
+ return ::mlir::cast<::mlir::LLVM::FastmathFlagsInterface>(this->getOperation()).isApplicableImpl();
+ }]>];
+
+ let extraClassDeclaration = [{
+ /// Returns true iff the given type is a floating point typ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/125620
More information about the flang-commits
mailing list