[Mlir-commits] [mlir] [MLIR][LLVM] Add `llvm.experimental.constrained.fptrunc` operation (PR #86260)
Victor Perez
llvmlistbot at llvm.org
Mon Mar 25 08:46:32 PDT 2024
https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/86260
>From 8a1312ce158266e7987bc7eff081fb39fdd666b2 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Fri, 22 Mar 2024 09:18:20 +0000
Subject: [PATCH 1/3] [MLIR][LLVM] Add `llvm.experimental.constrained.fptrunc`
operation
Add operation mapping to the LLVM
`llvm.experimental.constrained.fptrunc.*` intrinsic.
The new operation implements the new
`LLVM::ExceptionBehaviorOpInterface` and
`LLVM::RoundingModeOpInterface` interfaces.
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 57 ++++++++++++
.../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 67 +++++++++++++++
.../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 41 +++++++++
.../include/mlir/Target/LLVMIR/ModuleImport.h | 5 ++
.../mlir/Target/LLVMIR/ModuleTranslation.h | 4 +
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 86 +++++++++++++++++++
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 23 +++++
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 19 ++++
mlir/test/Target/LLVMIR/Import/intrinsic.ll | 19 ++++
.../test/Target/LLVMIR/llvmir-intrinsics.mlir | 31 +++++++
10 files changed, 352 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index a7b269eb41ee2e..19fc69dda16696 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
let cppNamespace = "::mlir::LLVM::framePointerKind";
}
+//===----------------------------------------------------------------------===//
+// RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::RoundingMode ones.
+// See llvm/include/llvm/ADT/FloatingPointMode.h.
+def RoundTowardZero
+ : LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
+def RoundNearestTiesToEven
+ : LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
+def RoundTowardPositive
+ : LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
+def RoundTowardNegative
+ : LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
+def RoundNearestTiesToAway
+ : LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
+def RoundDynamic
+ : LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
+// Needed as llvm::RoundingMode defines this.
+def RoundInvalid
+ : LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
+
+// RoundingModeAttr should not be used in operations definitions.
+// Use ValidRoundingModeAttr instead.
+def RoundingModeAttr : LLVM_EnumAttr<
+ "RoundingMode",
+ "::llvm::RoundingMode",
+ "LLVM Rounding Mode",
+ [RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
+ RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
+def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
+
+//===----------------------------------------------------------------------===//
+// ExceptionBehavior
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::fp::ExceptionBehavior ones.
+// See llvm/include/llvm/IR/FPEnv.h.
+def ExceptionBehaviorIgnore
+ : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
+def ExceptionBehaviorMayTrap
+ : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
+def ExceptionBehaviorStrict
+ : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
+
+def ExceptionBehaviorAttr : LLVM_EnumAttr<
+ "ExceptionBehavior",
+ "::llvm::fp::ExceptionBehavior",
+ "LLVM Exception Behavior",
+ [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
+ ExceptionBehaviorStrict]> {
+ let cppNamespace = "::mlir::LLVM";
+}
+
#endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index e7a1da8ee560ef..ce91fbe1e2b24a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
];
}
+def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+ let description = [{
+ An interface for operations receiving an exception behavior attribute
+ controlling FP exception behavior.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a ExceptionBehavior attribute for the operation",
+ /*returnType=*/ "ExceptionBehaviorAttr",
+ /*methodName=*/ "getExceptionBehaviorAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getExceptionbehaviorAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the ExceptionBehaviorAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getExceptionBehaviorAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "exceptionbehavior";
+ }]
+ >
+ ];
+}
+
+def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
+ let description = [{
+ An interface for operations receiving a rounding mode attribute
+ controlling FP rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingMode attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >,
+ ];
+}
//===----------------------------------------------------------------------===//
// LLVM dialect type interfaces.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b88f1186a44b49..6a2b9b90350e1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,6 +311,47 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
"qualified(type($ptr))";
}
+// Constrained Floating-Point Intrinsics
+
+class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+ : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
+ /*overloadedResults=*/[0],
+ /*overloadedOperands=*/[0],
+ /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+ # !cond(
+ !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
+ true : []),
+ /*requiresFastmath=*/0,
+ /*immArgPositions=*/[],
+ /*immArgAttrNames=*/[]> {
+ dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
+ dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
+ true : (ins)),
+ (ins ExceptionBehaviorAttr:$exceptionbehavior));
+ let arguments = !con(regularArgs, attrArgs);
+ let llvmBuilder = [{
+ $res = LLVM::detail::createConstrainedIntrinsicCall(
+ builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
+ # mnem
+ # [{);
+ }];
+ let mlirBuilder = [{
+ auto op = moduleImport.translateConstrainedIntrinsic(
+ $_location, $_resultType, llvmOperands,
+ $_qualCppClassName::getOperationName());
+ if (!op)
+ return failure();
+ $res = op;
+ }];
+}
+
+def LLVM_ConstrainedFPTruncIntr
+ : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+ let assemblyFormat = [{
+ $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+ }];
+}
+
// Intrinsics with multiple returns.
class LLVM_ArithWithOverflowOp<string mnem>
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b49d2f539453e6..16f9994c126e06 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -232,6 +232,11 @@ class ModuleImport {
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);
+ /// Import constrained intrinsic.
+ Value translateConstrainedIntrinsic(Location loc, Type type,
+ ArrayRef<llvm::Value *> llvmOperands,
+ StringRef opName);
+
private:
/// Clears the accumulated state before processing a new region.
void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index fb4392eb223c7f..458ed585167bc0 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -423,6 +423,10 @@ llvm::CallInst *createIntrinsicCall(
ArrayRef<unsigned> immArgPositions,
ArrayRef<StringLiteral> immArgAttrNames);
+llvm::CallInst *createConstrainedIntrinsicCall(
+ llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+ Operation *intrOp, llvm::Intrinsic::ID intrinsic);
+
} // namespace detail
} // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d63ea12ecd49b1..85a543c174f51d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,6 +1258,92 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
return success();
}
+static RoundingModeAttr metadataToRoundingMode(Builder &builder,
+ llvm::Metadata *metadata) {
+ auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+ if (!mdstr)
+ return {};
+ std::optional<llvm::RoundingMode> optLLVM =
+ llvm::convertStrToRoundingMode(mdstr->getString());
+ if (!optLLVM)
+ return {};
+ return builder.getAttr<RoundingModeAttr>(
+ convertRoundingModeFromLLVM(*optLLVM));
+}
+
+static ExceptionBehaviorAttr
+metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+ auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+ if (!mdstr)
+ return {};
+ std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+ llvm::convertStrToExceptionBehavior(mdstr->getString());
+ if (!optLLVM)
+ return {};
+ return builder.getAttr<ExceptionBehaviorAttr>(
+ convertExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+static void
+splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
+ SmallVectorImpl<llvm::Value *> &values,
+ SmallVectorImpl<llvm::Metadata *> &metadata) {
+ for (llvm::Value *in : inputs) {
+ if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
+ metadata.push_back(mdval->getMetadata());
+ } else {
+ values.push_back(in);
+ }
+ }
+}
+
+Value ModuleImport::translateConstrainedIntrinsic(
+ Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
+ StringRef opName) {
+ // Split metadata values from regular ones.
+ SmallVector<llvm::Value *> values;
+ SmallVector<llvm::Metadata *> metadata;
+ splitMetadataAndValues(llvmOperands, values, metadata);
+
+ // Expect 1 or 2 metadata values.
+ assert((metadata.size() == 1 || metadata.size() == 2) &&
+ "Unexpected number of arguments");
+
+ SmallVector<Value> mlirOperands;
+ SmallVector<NamedAttribute> mlirAttrs;
+ if (failed(
+ convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
+ return {};
+ }
+
+ // Create operation as usual.
+ StringAttr opNameAttr = builder.getStringAttr(opName);
+ Operation *op =
+ builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
+
+ // Set exception behavior attribute.
+ auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
+ ExceptionBehaviorAttr attr =
+ metadataToExceptionBehavior(builder, metadata.back());
+ if (!attr)
+ return {};
+ op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
+
+ // If avaialbe, set rounding mode attribute.
+ if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
+ assert(metadata.size() > 1 && "Unexpected number of arguments");
+ // rounding_mode present
+ RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
+ if (!attr)
+ return {};
+ roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
+ } else {
+ assert(metadata.size() == 1 && "Unexpected number of arguments");
+ }
+
+ return op->getResult(0);
+}
+
IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
IntegerAttr integerAttr;
FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f90495d407fdfe..9100b1a18c71c6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,6 +862,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(llvmIntr, args);
}
+llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
+ llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+ Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ SmallVector<llvm::Type *> overloadedTypes{
+ moduleTranslation.convertType(intrOp->getResult(0).getType()),
+ moduleTranslation.convertType(intrOp->getOperand(0).getType())};
+ llvm::Function *callee =
+ llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+ SmallVector<llvm::Value *> args =
+ moduleTranslation.lookupValues(intrOp->getOperands());
+ std::optional<llvm::RoundingMode> rounding;
+ if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
+ rounding = convertRoundingModeToLLVM(
+ roundingModeOp.getRoundingModeAttr().getValue());
+ }
+ llvm::fp::ExceptionBehavior except =
+ convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
+ .getExceptionBehaviorAttr()
+ .getValue());
+ return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
+}
+
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult ModuleTranslation::convertOperation(Operation &op,
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b157cf00141842..cc415e7b3662be 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -647,3 +647,22 @@ llvm.func @experimental_noalias_scope_decl() {
llvm.intr.experimental.noalias.scope.decl #alias_scope
llvm.return
}
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+ %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+ %1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+ %2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+ %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
+ // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+ %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
+ %tmp0 = llvm.fadd %0, %1 : f32
+ %tmp1 = llvm.fadd %2, %3 : f32
+ %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
+ %res = llvm.fadd %tmp2, %4 : f32
+ llvm.return %res : f32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 1ec9005458c50b..85561839f31a70 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
ret float %2
}
+; CHECK-LABEL: experimental_constrained_fptrunc
+define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+ %1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+ %2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+ %3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+ %4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+ %5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+ ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
+ %6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+ ret void
+}
+
declare float @llvm.fmuladd.f32(float, float, float)
declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
declare float @llvm.ssa.copy.f32(float returned)
declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
+declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
+declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index fc2e0fd201a738..0013522582a727 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
llvm.return %0 : f32
}
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.towardzero"
+ // CHECK: metadata !"fpexcept.ignore"
+ %0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.tonearest"
+ // CHECK: metadata !"fpexcept.maytrap"
+ %1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.upward"
+ // CHECK: metadata !"fpexcept.strict"
+ %2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.downward"
+ // CHECK: metadata !"fpexcept.ignore"
+ %3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
+ // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+ // CHECK: metadata !"round.tonearestaway"
+ // CHECK: metadata !"fpexcept.ignore"
+ %4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
+ // CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
+ // CHECK: metadata !"round.upward"
+ // CHECK: metadata !"fpexcept.strict"
+ %5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
+ llvm.return
+}
+
// Check that intrinsics are declared with appropriate types.
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -1126,3 +1155,5 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
// CHECK-DAG: declare ptr addrspace(1) @llvm.stacksave.p1()
// CHECK-DAG: declare void @llvm.stackrestore.p0(ptr)
// CHECK-DAG: declare void @llvm.stackrestore.p1(ptr addrspace(1))
+// CHECK-DAG: declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
+// CHECK-DAG: declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(<4 x float>, metadata, metadata)
>From c882a29ff2d26c4850e74044b5c78ef99f7d1fbc Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 25 Mar 2024 13:23:21 +0000
Subject: [PATCH 2/3] Apply suggestions
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 16 +--
.../mlir/Dialect/LLVMIR/LLVMInterfaces.td | 18 ++--
.../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 81 +++++++++++---
.../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 4 +
.../include/mlir/Target/LLVMIR/ModuleImport.h | 13 ++-
.../mlir/Target/LLVMIR/ModuleTranslation.h | 11 +-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 100 ++++--------------
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 35 +++---
mlir/test/Dialect/LLVMIR/roundtrip.mlir | 8 +-
.../tools/mlir-tblgen/LLVMIRConversionGen.cpp | 4 +
10 files changed, 142 insertions(+), 148 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index 19fc69dda16696..04d797031245e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -741,24 +741,24 @@ def RoundingModeAttr : LLVM_EnumAttr<
def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
//===----------------------------------------------------------------------===//
-// ExceptionBehavior
+// FPExceptionBehavior
//===----------------------------------------------------------------------===//
// These values must match llvm::fp::ExceptionBehavior ones.
// See llvm/include/llvm/IR/FPEnv.h.
-def ExceptionBehaviorIgnore
+def FPExceptionBehaviorIgnore
: LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
-def ExceptionBehaviorMayTrap
+def FPExceptionBehaviorMayTrap
: LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
-def ExceptionBehaviorStrict
+def FPExceptionBehaviorStrict
: LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
-def ExceptionBehaviorAttr : LLVM_EnumAttr<
- "ExceptionBehavior",
+def FPExceptionBehaviorAttr : LLVM_EnumAttr<
+ "FPExceptionBehavior",
"::llvm::fp::ExceptionBehavior",
"LLVM Exception Behavior",
- [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
- ExceptionBehaviorStrict]> {
+ [FPExceptionBehaviorIgnore, FPExceptionBehaviorMayTrap,
+ FPExceptionBehaviorStrict]> {
let cppNamespace = "::mlir::LLVM";
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index ce91fbe1e2b24a..cee752aeb269b7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,7 +290,7 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
];
}
-def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
let description = [{
An interface for operations receiving an exception behavior attribute
controlling FP exception behavior.
@@ -300,25 +300,25 @@ def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
let methods = [
InterfaceMethod<
- /*desc=*/ "Returns a ExceptionBehavior attribute for the operation",
- /*returnType=*/ "ExceptionBehaviorAttr",
- /*methodName=*/ "getExceptionBehaviorAttr",
+ /*desc=*/ "Returns a FPExceptionBehavior attribute for the operation",
+ /*returnType=*/ "FPExceptionBehaviorAttr",
+ /*methodName=*/ "getFPExceptionBehaviorAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
- return op.getExceptionbehaviorAttr();
+ return op.getFpExceptionBehaviorAttr();
}]
>,
StaticInterfaceMethod<
- /*desc=*/ [{Returns the name of the ExceptionBehaviorAttr attribute
- for the operation}],
+ /*desc=*/ [{Returns the name of the FPExceptionBehaviorAttr
+ attribute for the operation}],
/*returnType=*/ "StringRef",
- /*methodName=*/ "getExceptionBehaviorAttrName",
+ /*methodName=*/ "getFPExceptionBehaviorAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- return "exceptionbehavior";
+ return "fpExceptionBehavior";
}]
>
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 6a2b9b90350e1a..e86369b47b49a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -313,11 +313,15 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
// Constrained Floating-Point Intrinsics
-class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+class LLVM_ConstrainedIntr<string mnem, int numArgs,
+ bit overloadedResult, list<int> overloadedOperands,
+ bit hasRoundingMode>
: LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
- /*overloadedResults=*/[0],
- /*overloadedOperands=*/[0],
- /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+ /*overloadedResults=*/
+ !cond(!gt(overloadedResult, 0) : [0],
+ true : []),
+ overloadedOperands,
+ /*traits=*/[Pure, DeclareOpInterfaceMethods<FPExceptionBehaviorOpInterface>]
# !cond(
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
true : []),
@@ -327,28 +331,73 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
true : (ins)),
- (ins ExceptionBehaviorAttr:$exceptionbehavior));
+ (ins FPExceptionBehaviorAttr:$fpExceptionBehavior));
let arguments = !con(regularArgs, attrArgs);
let llvmBuilder = [{
- $res = LLVM::detail::createConstrainedIntrinsicCall(
- builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
- # mnem
- # [{);
+ SmallVector<llvm::Value *> args =
+ moduleTranslation.lookupValues(opInst.getOperands());
+ SmallVector<llvm::Type *> overloadedTypes; }] #
+ !cond(!gt(overloadedResult, 0) : [{
+ // Take into account overloaded result type
+ overloadedTypes.push_back($_resultType); }],
+ // No overloaded result type
+ true : "") # [{
+ llvm::transform(ArrayRef<unsigned>}] # overloadedOperandsCpp # [{,
+ std::back_inserter(overloadedTypes),
+ [&args](unsigned index) { return args[index]->getType(); });
+ llvm::Module *module = builder.GetInsertBlock()->getModule();
+ llvm::Function *callee =
+ llvm::Intrinsic::getDeclaration(module,
+ llvm::Intrinsic::experimental_constrained_}] #
+ mnem # [{, overloadedTypes); }] #
+ !cond(!gt(hasRoundingMode, 0) : [{
+ // Get rounding mode using interface
+ llvm::RoundingMode rounding =
+ moduleTranslation.translateRoundingMode($roundingmode); }],
+ true : [{
+ // No rounding mode
+ std::optional<llvm::RoundingMode> rounding; }]) # [{
+ llvm::fp::ExceptionBehavior except =
+ moduleTranslation.translateFPExceptionBehavior($fpExceptionBehavior);
+ $res = builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
}];
let mlirBuilder = [{
- auto op = moduleImport.translateConstrainedIntrinsic(
- $_location, $_resultType, llvmOperands,
- $_qualCppClassName::getOperationName());
- if (!op)
+ SmallVector<Value> mlirOperands;
+ SmallVector<NamedAttribute> mlirAttrs;
+ if (failed(moduleImport.convertIntrinsicArguments(
+ llvmOperands.take_front( }] # numArgs # [{),
+ {}, {}, mlirOperands, mlirAttrs))) {
+ return failure();
+ }
+
+ FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
+ $_fpExceptionBehavior_attr($fpExceptionBehavior);
+ if (!fpExceptionBehaviorAttr)
+ return failure();
+ mlirAttrs.push_back(
+ $_builder.getNamedAttr(
+ $_qualCppClassName::getFPExceptionBehaviorAttrName(),
+ fpExceptionBehaviorAttr)); }] #
+ !cond(!gt(hasRoundingMode, 0) : [{
+ RoundingModeAttr roundingModeAttr =
+ $_roundingMode_attr($roundingmode);
+ if (!roundingModeAttr)
return failure();
- $res = op;
+ mlirAttrs.push_back(
+ $_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
+ roundingModeAttr));
+ }], true : "") # [{
+ $res = $_builder.create<$_qualCppClassName>($_location,
+ $_resultType, mlirOperands, mlirAttrs);
}];
}
def LLVM_ConstrainedFPTruncIntr
- : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+ : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1,
+ /*overloadedResult=*/1, /*overloadedOperands=*/[0],
+ /*hasRoundingMode=*/1> {
let assemblyFormat = [{
- $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+ $arg_0 $roundingmode $fpExceptionBehavior attr-dict `:` type($arg_0) `to` type(results)
}];
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index b6aa73dad22970..7b9a9cf017c537 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -170,6 +170,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
// - $_float_attr - substituted by a call to a float attribute matcher;
// - $_var_attr - substituted by a call to a variable attribute matcher;
// - $_label_attr - substituted by a call to a label attribute matcher;
+ // - $_roundingMode_attr - substituted by a call to a rounding mode
+ // attribute matcher;
+ // - $_fpExceptionBehavior_attr - substituted by a call to a FP exception
+ // behavior attribute matcher;
// - $_resultType - substituted with the MLIR result type;
// - $_location - substituted with the MLIR location;
// - $_builder - substituted with the MLIR builder;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 16f9994c126e06..b551eb937cfe8d 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -152,6 +152,14 @@ class ModuleImport {
/// Converts `value` to a label attribute. Asserts if the matching fails.
DILabelAttr matchLabelAttr(llvm::Value *value);
+ /// Converts `value` to a FP exception behavior attribute. Asserts if the
+ /// matching fails.
+ FPExceptionBehaviorAttr matchFPExceptionBehaviorAttr(llvm::Value *value);
+
+ /// Converts `value` to a rounding mode attribute. Asserts if the matching
+ /// fails.
+ RoundingModeAttr matchRoundingModeAttr(llvm::Value *value);
+
/// Converts `value` to an array of alias scopes or returns failure if the
/// conversion fails.
FailureOr<SmallVector<AliasScopeAttr>>
@@ -232,11 +240,6 @@ class ModuleImport {
SmallVectorImpl<Value> &valuesOut,
SmallVectorImpl<NamedAttribute> &attrsOut);
- /// Import constrained intrinsic.
- Value translateConstrainedIntrinsic(Location loc, Type type,
- ArrayRef<llvm::Value *> llvmOperands,
- StringRef opName);
-
private:
/// Clears the accumulated state before processing a new region.
void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 458ed585167bc0..310a43e0de96b3 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -201,6 +201,13 @@ class ModuleTranslation {
/// Translates the given LLVM debug info metadata.
llvm::Metadata *translateDebugInfo(LLVM::DINodeAttr attr);
+ /// Translates the given LLVM rounding mode metadata.
+ llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding);
+
+ /// Translates the given LLVM FP exception behavior metadata.
+ llvm::fp::ExceptionBehavior
+ translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior);
+
/// Translates the contents of the given block to LLVM IR using this
/// translator. The LLVM IR basic block corresponding to the given block is
/// expected to exist in the mapping of this translator. Uses `builder` to
@@ -423,10 +430,6 @@ llvm::CallInst *createIntrinsicCall(
ArrayRef<unsigned> immArgPositions,
ArrayRef<StringLiteral> immArgAttrNames);
-llvm::CallInst *createConstrainedIntrinsicCall(
- llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
- Operation *intrOp, llvm::Intrinsic::ID intrinsic);
-
} // namespace detail
} // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 85a543c174f51d..374127a19ebb9f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,21 +1258,8 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
return success();
}
-static RoundingModeAttr metadataToRoundingMode(Builder &builder,
- llvm::Metadata *metadata) {
- auto *mdstr = dyn_cast<llvm::MDString>(metadata);
- if (!mdstr)
- return {};
- std::optional<llvm::RoundingMode> optLLVM =
- llvm::convertStrToRoundingMode(mdstr->getString());
- if (!optLLVM)
- return {};
- return builder.getAttr<RoundingModeAttr>(
- convertRoundingModeFromLLVM(*optLLVM));
-}
-
-static ExceptionBehaviorAttr
-metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+static FPExceptionBehaviorAttr
+metadataToFPExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
auto *mdstr = dyn_cast<llvm::MDString>(metadata);
if (!mdstr)
return {};
@@ -1280,68 +1267,8 @@ metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
llvm::convertStrToExceptionBehavior(mdstr->getString());
if (!optLLVM)
return {};
- return builder.getAttr<ExceptionBehaviorAttr>(
- convertExceptionBehaviorFromLLVM(*optLLVM));
-}
-
-static void
-splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
- SmallVectorImpl<llvm::Value *> &values,
- SmallVectorImpl<llvm::Metadata *> &metadata) {
- for (llvm::Value *in : inputs) {
- if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
- metadata.push_back(mdval->getMetadata());
- } else {
- values.push_back(in);
- }
- }
-}
-
-Value ModuleImport::translateConstrainedIntrinsic(
- Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
- StringRef opName) {
- // Split metadata values from regular ones.
- SmallVector<llvm::Value *> values;
- SmallVector<llvm::Metadata *> metadata;
- splitMetadataAndValues(llvmOperands, values, metadata);
-
- // Expect 1 or 2 metadata values.
- assert((metadata.size() == 1 || metadata.size() == 2) &&
- "Unexpected number of arguments");
-
- SmallVector<Value> mlirOperands;
- SmallVector<NamedAttribute> mlirAttrs;
- if (failed(
- convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
- return {};
- }
-
- // Create operation as usual.
- StringAttr opNameAttr = builder.getStringAttr(opName);
- Operation *op =
- builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
-
- // Set exception behavior attribute.
- auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
- ExceptionBehaviorAttr attr =
- metadataToExceptionBehavior(builder, metadata.back());
- if (!attr)
- return {};
- op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
-
- // If avaialbe, set rounding mode attribute.
- if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
- assert(metadata.size() > 1 && "Unexpected number of arguments");
- // rounding_mode present
- RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
- if (!attr)
- return {};
- roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
- } else {
- assert(metadata.size() == 1 && "Unexpected number of arguments");
- }
-
- return op->getResult(0);
+ return builder.getAttr<FPExceptionBehaviorAttr>(
+ convertFPExceptionBehaviorFromLLVM(*optLLVM));
}
IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
@@ -1376,6 +1303,25 @@ DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
return debugImporter->translate(node);
}
+FPExceptionBehaviorAttr
+ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
+ auto *metadata = cast<llvm::MetadataAsValue>(value);
+ auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
+ std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+ llvm::convertStrToExceptionBehavior(mdstr->getString());
+ return builder.getAttr<FPExceptionBehaviorAttr>(
+ convertFPExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
+ auto *metadata = cast<llvm::MetadataAsValue>(value);
+ auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
+ std::optional<llvm::RoundingMode> optLLVM =
+ llvm::convertStrToRoundingMode(mdstr->getString());
+ return builder.getAttr<RoundingModeAttr>(
+ convertRoundingModeFromLLVM(*optLLVM));
+}
+
FailureOr<SmallVector<AliasScopeAttr>>
ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 9100b1a18c71c6..1523f4195b38c8 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,29 +862,6 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
return builder.CreateCall(llvmIntr, args);
}
-llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
- llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
- Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
- llvm::Module *module = builder.GetInsertBlock()->getModule();
- SmallVector<llvm::Type *> overloadedTypes{
- moduleTranslation.convertType(intrOp->getResult(0).getType()),
- moduleTranslation.convertType(intrOp->getOperand(0).getType())};
- llvm::Function *callee =
- llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
- SmallVector<llvm::Value *> args =
- moduleTranslation.lookupValues(intrOp->getOperands());
- std::optional<llvm::RoundingMode> rounding;
- if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
- rounding = convertRoundingModeToLLVM(
- roundingModeOp.getRoundingModeAttr().getValue());
- }
- llvm::fp::ExceptionBehavior except =
- convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
- .getExceptionBehaviorAttr()
- .getValue());
- return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
-}
-
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
LogicalResult ModuleTranslation::convertOperation(Operation &op,
@@ -1744,6 +1721,18 @@ llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
return debugTranslation->translate(attr);
}
+/// Translates the given LLVM rounding mode metadata.
+llvm::RoundingMode
+ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
+ return convertRoundingModeToLLVM(rounding);
+}
+
+/// Translates the given LLVM FP exception behavior metadata.
+llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
+ LLVM::FPExceptionBehavior exceptionBehavior) {
+ return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
+}
+
llvm::NamedMDNode *
ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
return llvmModule->getOrInsertNamedMetadata(name);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index cc415e7b3662be..31acf2b95e4638 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -649,7 +649,7 @@ llvm.func @experimental_noalias_scope_decl() {
}
// CHECK-LABEL: @experimental_constrained_fptrunc
-llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+llvm.func @experimental_constrained_fptrunc(%in: f64) {
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
%0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
@@ -660,9 +660,5 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
%3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
// CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
%4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
- %tmp0 = llvm.fadd %0, %1 : f32
- %tmp1 = llvm.fadd %2, %3 : f32
- %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
- %res = llvm.fadd %tmp2, %4 : f32
- llvm.return %res : f32
+ llvm.return
}
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 23bc9b00dc902d..2c7acec3b1b853 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -272,6 +272,10 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
bs << "moduleImport.matchLocalVariableAttr";
} else if (name == "_label_attr") {
bs << "moduleImport.matchLabelAttr";
+ } else if (name == "_fpExceptionBehavior_attr") {
+ bs << "moduleImport.matchFPExceptionBehaviorAttr";
+ } else if (name == "_roundingMode_attr") {
+ bs << "moduleImport.matchRoundingModeAttr";
} else if (name == "_resultType") {
bs << "moduleImport.convertType(inst->getType())";
} else if (name == "_location") {
>From f5b9e6f711757a7c2628e0761b589cee228b4e4c Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 25 Mar 2024 15:46:11 +0000
Subject: [PATCH 3/3] Assert on always expected attribute
---
mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index e86369b47b49a4..246c1027f8ff30 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -372,17 +372,14 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
$_fpExceptionBehavior_attr($fpExceptionBehavior);
- if (!fpExceptionBehaviorAttr)
- return failure();
+ assert(fpExceptionBehaviorAttr && "Expecting FP exception behavior");
mlirAttrs.push_back(
$_builder.getNamedAttr(
$_qualCppClassName::getFPExceptionBehaviorAttrName(),
fpExceptionBehaviorAttr)); }] #
!cond(!gt(hasRoundingMode, 0) : [{
- RoundingModeAttr roundingModeAttr =
- $_roundingMode_attr($roundingmode);
- if (!roundingModeAttr)
- return failure();
+ RoundingModeAttr roundingModeAttr = $_roundingMode_attr($roundingmode);
+ assert(roundingModeAttr && "Expecting rounding mode");
mlirAttrs.push_back(
$_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
roundingModeAttr));
More information about the Mlir-commits
mailing list