[Mlir-commits] [mlir] [MLIR][LLVM] Add `llvm.experimental.constrained.fptrunc` operation (PR #86260)

Victor Perez llvmlistbot at llvm.org
Tue Mar 26 01:17:57 PDT 2024


https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/86260

>From 8a1312ce158266e7987bc7eff081fb39fdd666b2 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Fri, 22 Mar 2024 09:18:20 +0000
Subject: [PATCH 1/4] [MLIR][LLVM] Add `llvm.experimental.constrained.fptrunc`
 operation

Add operation mapping to the LLVM
`llvm.experimental.constrained.fptrunc.*` intrinsic.

The new operation implements the new
`LLVM::ExceptionBehaviorOpInterface` and
`LLVM::RoundingModeOpInterface` interfaces.

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 57 ++++++++++++
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 67 +++++++++++++++
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   | 41 +++++++++
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  5 ++
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  4 +
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 86 +++++++++++++++++++
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 23 +++++
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       | 19 ++++
 mlir/test/Target/LLVMIR/Import/intrinsic.ll   | 19 ++++
 .../test/Target/LLVMIR/llvmir-intrinsics.mlir | 31 +++++++
 10 files changed, 352 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index a7b269eb41ee2e..19fc69dda16696 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -705,4 +705,61 @@ def FramePointerKindEnum : LLVM_EnumAttr<
   let cppNamespace = "::mlir::LLVM::framePointerKind";
 }
 
+//===----------------------------------------------------------------------===//
+// RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::RoundingMode ones.
+// See llvm/include/llvm/ADT/FloatingPointMode.h.
+def RoundTowardZero
+    : LLVM_EnumAttrCase<"TowardZero", "towardzero", "TowardZero", 0>;
+def RoundNearestTiesToEven
+    : LLVM_EnumAttrCase<"NearestTiesToEven", "tonearest", "NearestTiesToEven", 1>;
+def RoundTowardPositive
+    : LLVM_EnumAttrCase<"TowardPositive", "upward", "TowardPositive", 2>;
+def RoundTowardNegative
+    : LLVM_EnumAttrCase<"TowardNegative", "downward", "TowardNegative", 3>;
+def RoundNearestTiesToAway
+    : LLVM_EnumAttrCase<"NearestTiesToAway", "tonearestaway", "NearestTiesToAway", 4>;
+def RoundDynamic
+    : LLVM_EnumAttrCase<"Dynamic", "dynamic", "Dynamic", 7>;
+// Needed as llvm::RoundingMode defines this.
+def RoundInvalid
+    : LLVM_EnumAttrCase<"Invalid", "invalid", "Invalid", -1>;
+
+// RoundingModeAttr should not be used in operations definitions.
+// Use ValidRoundingModeAttr instead.
+def RoundingModeAttr : LLVM_EnumAttr<
+    "RoundingMode",
+    "::llvm::RoundingMode",
+    "LLVM Rounding Mode",
+    [RoundTowardZero, RoundNearestTiesToEven, RoundTowardPositive,
+     RoundTowardNegative, RoundNearestTiesToAway, RoundDynamic, RoundInvalid]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
+def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
+
+//===----------------------------------------------------------------------===//
+// ExceptionBehavior
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::fp::ExceptionBehavior ones.
+// See llvm/include/llvm/IR/FPEnv.h.
+def ExceptionBehaviorIgnore
+    : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
+def ExceptionBehaviorMayTrap
+    : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
+def ExceptionBehaviorStrict
+    : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
+
+def ExceptionBehaviorAttr : LLVM_EnumAttr<
+    "ExceptionBehavior",
+    "::llvm::fp::ExceptionBehavior",
+    "LLVM Exception Behavior",
+    [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
+     ExceptionBehaviorStrict]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
 #endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index e7a1da8ee560ef..ce91fbe1e2b24a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,6 +290,73 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
   ];
 }
 
+def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+  let description = [{
+    An interface for operations receiving an exception behavior attribute
+    controlling FP exception behavior.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a ExceptionBehavior attribute for the operation",
+      /*returnType=*/  "ExceptionBehaviorAttr",
+      /*methodName=*/  "getExceptionBehaviorAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getExceptionbehaviorAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the ExceptionBehaviorAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getExceptionBehaviorAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "exceptionbehavior";
+      }]
+    >
+  ];
+}
+
+def RoundingModeOpInterface : OpInterface<"RoundingModeOpInterface"> {
+  let description = [{
+    An interface for operations receiving a rounding mode attribute
+    controlling FP rounding mode.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a RoundingMode attribute for the operation",
+      /*returnType=*/  "RoundingModeAttr",
+      /*methodName=*/  "getRoundingModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRoundingmodeAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the RoundingModeAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getRoundingModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "roundingmode";
+      }]
+    >,
+  ];
+}
 
 //===----------------------------------------------------------------------===//
 // LLVM dialect type interfaces.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index b88f1186a44b49..6a2b9b90350e1a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,6 +311,47 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
       "qualified(type($ptr))";
 }
 
+// Constrained Floating-Point Intrinsics
+
+class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+    : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
+                           /*overloadedResults=*/[0],
+                           /*overloadedOperands=*/[0],
+                           /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+                           # !cond(
+                               !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
+                               true : []),
+                           /*requiresFastmath=*/0,
+                           /*immArgPositions=*/[],
+                           /*immArgAttrNames=*/[]> {
+  dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
+  dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
+                            true : (ins)),
+                      (ins ExceptionBehaviorAttr:$exceptionbehavior));
+  let arguments = !con(regularArgs, attrArgs);
+  let llvmBuilder = [{
+    $res = LLVM::detail::createConstrainedIntrinsicCall(
+      builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
+       # mnem
+       # [{);
+  }];
+  let mlirBuilder = [{
+    auto op = moduleImport.translateConstrainedIntrinsic(
+      $_location, $_resultType, llvmOperands,
+      $_qualCppClassName::getOperationName());
+    if (!op)
+      return failure();
+    $res = op;
+  }];
+}
+
+def LLVM_ConstrainedFPTruncIntr
+    : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+  let assemblyFormat = [{
+    $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+  }];
+}
+
 // Intrinsics with multiple returns.
 
 class LLVM_ArithWithOverflowOp<string mnem>
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b49d2f539453e6..16f9994c126e06 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -232,6 +232,11 @@ class ModuleImport {
                             SmallVectorImpl<Value> &valuesOut,
                             SmallVectorImpl<NamedAttribute> &attrsOut);
 
+  /// Import constrained intrinsic.
+  Value translateConstrainedIntrinsic(Location loc, Type type,
+                                      ArrayRef<llvm::Value *> llvmOperands,
+                                      StringRef opName);
+
 private:
   /// Clears the accumulated state before processing a new region.
   void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index fb4392eb223c7f..458ed585167bc0 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -423,6 +423,10 @@ llvm::CallInst *createIntrinsicCall(
     ArrayRef<unsigned> immArgPositions,
     ArrayRef<StringLiteral> immArgAttrNames);
 
+llvm::CallInst *createConstrainedIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic);
+
 } // namespace detail
 
 } // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index d63ea12ecd49b1..85a543c174f51d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,6 +1258,92 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
   return success();
 }
 
+static RoundingModeAttr metadataToRoundingMode(Builder &builder,
+                                               llvm::Metadata *metadata) {
+  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+  if (!mdstr)
+    return {};
+  std::optional<llvm::RoundingMode> optLLVM =
+      llvm::convertStrToRoundingMode(mdstr->getString());
+  if (!optLLVM)
+    return {};
+  return builder.getAttr<RoundingModeAttr>(
+      convertRoundingModeFromLLVM(*optLLVM));
+}
+
+static ExceptionBehaviorAttr
+metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
+  if (!mdstr)
+    return {};
+  std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+      llvm::convertStrToExceptionBehavior(mdstr->getString());
+  if (!optLLVM)
+    return {};
+  return builder.getAttr<ExceptionBehaviorAttr>(
+      convertExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+static void
+splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
+                       SmallVectorImpl<llvm::Value *> &values,
+                       SmallVectorImpl<llvm::Metadata *> &metadata) {
+  for (llvm::Value *in : inputs) {
+    if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
+      metadata.push_back(mdval->getMetadata());
+    } else {
+      values.push_back(in);
+    }
+  }
+}
+
+Value ModuleImport::translateConstrainedIntrinsic(
+    Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
+    StringRef opName) {
+  // Split metadata values from regular ones.
+  SmallVector<llvm::Value *> values;
+  SmallVector<llvm::Metadata *> metadata;
+  splitMetadataAndValues(llvmOperands, values, metadata);
+
+  // Expect 1 or 2 metadata values.
+  assert((metadata.size() == 1 || metadata.size() == 2) &&
+         "Unexpected number of arguments");
+
+  SmallVector<Value> mlirOperands;
+  SmallVector<NamedAttribute> mlirAttrs;
+  if (failed(
+          convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
+    return {};
+  }
+
+  // Create operation as usual.
+  StringAttr opNameAttr = builder.getStringAttr(opName);
+  Operation *op =
+      builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
+
+  // Set exception behavior attribute.
+  auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
+  ExceptionBehaviorAttr attr =
+      metadataToExceptionBehavior(builder, metadata.back());
+  if (!attr)
+    return {};
+  op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
+
+  // If avaialbe, set rounding mode attribute.
+  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
+    assert(metadata.size() > 1 && "Unexpected number of arguments");
+    // rounding_mode present
+    RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
+    if (!attr)
+      return {};
+    roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
+  } else {
+    assert(metadata.size() == 1 && "Unexpected number of arguments");
+  }
+
+  return op->getResult(0);
+}
+
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
   IntegerAttr integerAttr;
   FailureOr<Value> converted = convertValue(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f90495d407fdfe..9100b1a18c71c6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,6 +862,29 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(llvmIntr, args);
 }
 
+llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
+    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
+    Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  SmallVector<llvm::Type *> overloadedTypes{
+      moduleTranslation.convertType(intrOp->getResult(0).getType()),
+      moduleTranslation.convertType(intrOp->getOperand(0).getType())};
+  llvm::Function *callee =
+      llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
+  SmallVector<llvm::Value *> args =
+      moduleTranslation.lookupValues(intrOp->getOperands());
+  std::optional<llvm::RoundingMode> rounding;
+  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
+    rounding = convertRoundingModeToLLVM(
+        roundingModeOp.getRoundingModeAttr().getValue());
+  }
+  llvm::fp::ExceptionBehavior except =
+      convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
+                                         .getExceptionBehaviorAttr()
+                                         .getValue());
+  return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
+}
+
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult ModuleTranslation::convertOperation(Operation &op,
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index b157cf00141842..cc415e7b3662be 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -647,3 +647,22 @@ llvm.func @experimental_noalias_scope_decl() {
   llvm.intr.experimental.noalias.scope.decl #alias_scope
   llvm.return
 }
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+  %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+  %1 = llvm.intr.experimental.constrained.fptrunc %in tonearest maytrap : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+  %2 = llvm.intr.experimental.constrained.fptrunc %in upward strict : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+  %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
+  // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+  %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
+  %tmp0 = llvm.fadd %0, %1 : f32
+  %tmp1 = llvm.fadd %2, %3 : f32
+  %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
+  %res = llvm.fadd %tmp2, %4 : f32
+  llvm.return %res : f32
+}
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 1ec9005458c50b..85561839f31a70 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -894,6 +894,23 @@ define float @ssa_copy(float %0) {
   ret float %2
 }
 
+; CHECK-LABEL: experimental_constrained_fptrunc
+define void @experimental_constrained_fptrunc(double %s, <4 x double> %v) {
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
+  %1 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.towardzero", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
+  %2 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearest", metadata !"fpexcept.maytrap")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} upward strict : f64 to f32
+  %3 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.upward", metadata !"fpexcept.strict")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} downward ignore : f64 to f32
+  %4 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.downward", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
+  %5 = call float @llvm.experimental.constrained.fptrunc.f32.f64(double %s, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+  ; CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : vector<4xf64> to vector<4xf16>
+  %6 = call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double> %v, metadata !"round.tonearestaway", metadata !"fpexcept.ignore")
+  ret void
+}
+
 declare float @llvm.fmuladd.f32(float, float, float)
 declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>)
 declare float @llvm.fma.f32(float, float, float)
@@ -1120,3 +1137,5 @@ declare void @llvm.assume(i1)
 declare float @llvm.ssa.copy.f32(float returned)
 declare <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float>, <4 x float>, i64)
 declare <4 x float> @llvm.vector.extract.v4f32.nxv4f32(<vscale x 4 x float>, i64)
+declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f64(<4 x double>, metadata, metadata)
+declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index fc2e0fd201a738..0013522582a727 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -964,6 +964,35 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
   llvm.return %0 : f32
 }
 
+// CHECK-LABEL: @experimental_constrained_fptrunc
+llvm.func @experimental_constrained_fptrunc(%s: f64, %v: vector<4xf32>) {
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.towardzero"
+  // CHECK: metadata !"fpexcept.ignore"
+  %0 = llvm.intr.experimental.constrained.fptrunc %s towardzero ignore : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.tonearest"
+  // CHECK: metadata !"fpexcept.maytrap"
+  %1 = llvm.intr.experimental.constrained.fptrunc %s tonearest maytrap : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.upward"
+  // CHECK: metadata !"fpexcept.strict"
+  %2 = llvm.intr.experimental.constrained.fptrunc %s upward strict : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.downward"
+  // CHECK: metadata !"fpexcept.ignore"
+  %3 = llvm.intr.experimental.constrained.fptrunc %s downward ignore : f64 to f32
+  // CHECK: call float @llvm.experimental.constrained.fptrunc.f32.f64(
+  // CHECK: metadata !"round.tonearestaway"
+  // CHECK: metadata !"fpexcept.ignore"
+  %4 = llvm.intr.experimental.constrained.fptrunc %s tonearestaway ignore : f64 to f32
+  // CHECK: call <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(
+  // CHECK: metadata !"round.upward"
+  // CHECK: metadata !"fpexcept.strict"
+  %5 = llvm.intr.experimental.constrained.fptrunc %v upward strict : vector<4xf32> to vector<4xf16>
+  llvm.return
+}
+
 // Check that intrinsics are declared with appropriate types.
 // CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
 // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@@ -1126,3 +1155,5 @@ llvm.func @ssa_copy(%arg: f32) -> f32 {
 // CHECK-DAG: declare ptr addrspace(1) @llvm.stacksave.p1()
 // CHECK-DAG: declare void @llvm.stackrestore.p0(ptr)
 // CHECK-DAG: declare void @llvm.stackrestore.p1(ptr addrspace(1))
+// CHECK-DAG: declare float @llvm.experimental.constrained.fptrunc.f32.f64(double, metadata, metadata)
+// CHECK-DAG: declare <4 x half> @llvm.experimental.constrained.fptrunc.v4f16.v4f32(<4 x float>, metadata, metadata)

>From c882a29ff2d26c4850e74044b5c78ef99f7d1fbc Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 25 Mar 2024 13:23:21 +0000
Subject: [PATCH 2/4] Apply suggestions

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td |  16 +--
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     |  18 ++--
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td   |  81 +++++++++++---
 .../include/mlir/Dialect/LLVMIR/LLVMOpBase.td |   4 +
 .../include/mlir/Target/LLVMIR/ModuleImport.h |  13 ++-
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  11 +-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 100 ++++--------------
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  35 +++---
 mlir/test/Dialect/LLVMIR/roundtrip.mlir       |   8 +-
 .../tools/mlir-tblgen/LLVMIRConversionGen.cpp |   4 +
 10 files changed, 142 insertions(+), 148 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index 19fc69dda16696..04d797031245e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -741,24 +741,24 @@ def RoundingModeAttr : LLVM_EnumAttr<
 def ValidRoundingModeAttr : ConfinedAttr<RoundingModeAttr, [IntMinValue<0>]>;
 
 //===----------------------------------------------------------------------===//
-// ExceptionBehavior
+// FPExceptionBehavior
 //===----------------------------------------------------------------------===//
 
 // These values must match llvm::fp::ExceptionBehavior ones.
 // See llvm/include/llvm/IR/FPEnv.h.
-def ExceptionBehaviorIgnore
+def FPExceptionBehaviorIgnore
     : LLVM_EnumAttrCase<"Ignore", "ignore", "ebIgnore", 0>;
-def ExceptionBehaviorMayTrap
+def FPExceptionBehaviorMayTrap
     : LLVM_EnumAttrCase<"MayTrap", "maytrap", "ebMayTrap", 1>;
-def ExceptionBehaviorStrict
+def FPExceptionBehaviorStrict
     : LLVM_EnumAttrCase<"Strict", "strict", "ebStrict", 2>;
 
-def ExceptionBehaviorAttr : LLVM_EnumAttr<
-    "ExceptionBehavior",
+def FPExceptionBehaviorAttr : LLVM_EnumAttr<
+    "FPExceptionBehavior",
     "::llvm::fp::ExceptionBehavior",
     "LLVM Exception Behavior",
-    [ExceptionBehaviorIgnore, ExceptionBehaviorMayTrap,
-     ExceptionBehaviorStrict]> {
+    [FPExceptionBehaviorIgnore, FPExceptionBehaviorMayTrap,
+     FPExceptionBehaviorStrict]> {
   let cppNamespace = "::mlir::LLVM";
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index ce91fbe1e2b24a..cee752aeb269b7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -290,7 +290,7 @@ def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> {
   ];
 }
 
-def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
+def FPExceptionBehaviorOpInterface : OpInterface<"FPExceptionBehaviorOpInterface"> {
   let description = [{
     An interface for operations receiving an exception behavior attribute
     controlling FP exception behavior.
@@ -300,25 +300,25 @@ def ExceptionBehaviorOpInterface : OpInterface<"ExceptionBehaviorOpInterface"> {
 
   let methods = [
     InterfaceMethod<
-      /*desc=*/        "Returns a ExceptionBehavior attribute for the operation",
-      /*returnType=*/  "ExceptionBehaviorAttr",
-      /*methodName=*/  "getExceptionBehaviorAttr",
+      /*desc=*/        "Returns a FPExceptionBehavior attribute for the operation",
+      /*returnType=*/  "FPExceptionBehaviorAttr",
+      /*methodName=*/  "getFPExceptionBehaviorAttr",
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
         auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getExceptionbehaviorAttr();
+        return op.getFpExceptionBehaviorAttr();
       }]
     >,
     StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the ExceptionBehaviorAttr attribute
-                         for the operation}],
+      /*desc=*/        [{Returns the name of the FPExceptionBehaviorAttr
+                        attribute for the operation}],
       /*returnType=*/  "StringRef",
-      /*methodName=*/  "getExceptionBehaviorAttrName",
+      /*methodName=*/  "getFPExceptionBehaviorAttrName",
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        return "exceptionbehavior";
+        return "fpExceptionBehavior";
       }]
     >
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 6a2b9b90350e1a..e86369b47b49a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -313,11 +313,15 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
 
 // Constrained Floating-Point Intrinsics
 
-class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
+class LLVM_ConstrainedIntr<string mnem, int numArgs,
+                           bit overloadedResult, list<int> overloadedOperands,
+                           bit hasRoundingMode>
     : LLVM_OneResultIntrOp<"experimental.constrained." # mnem,
-                           /*overloadedResults=*/[0],
-                           /*overloadedOperands=*/[0],
-                           /*traits=*/[Pure, DeclareOpInterfaceMethods<ExceptionBehaviorOpInterface>]
+                           /*overloadedResults=*/
+                           !cond(!gt(overloadedResult, 0) : [0],
+                                 true : []),
+                           overloadedOperands,
+                           /*traits=*/[Pure, DeclareOpInterfaceMethods<FPExceptionBehaviorOpInterface>]
                            # !cond(
                                !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
                                true : []),
@@ -327,28 +331,73 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs, bit hasRoundingMode>
   dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
   dag attrArgs = !con(!cond(!gt(hasRoundingMode, 0) : (ins ValidRoundingModeAttr:$roundingmode),
                             true : (ins)),
-                      (ins ExceptionBehaviorAttr:$exceptionbehavior));
+                      (ins FPExceptionBehaviorAttr:$fpExceptionBehavior));
   let arguments = !con(regularArgs, attrArgs);
   let llvmBuilder = [{
-    $res = LLVM::detail::createConstrainedIntrinsicCall(
-      builder, moduleTranslation, &opInst, llvm::Intrinsic::experimental_constrained_}]
-       # mnem
-       # [{);
+    SmallVector<llvm::Value *> args =
+      moduleTranslation.lookupValues(opInst.getOperands());
+    SmallVector<llvm::Type *> overloadedTypes; }] #
+    !cond(!gt(overloadedResult, 0) : [{
+    // Take into account overloaded result type
+    overloadedTypes.push_back($_resultType); }],
+    // No overloaded result type
+          true : "") # [{
+    llvm::transform(ArrayRef<unsigned>}] # overloadedOperandsCpp # [{,
+                    std::back_inserter(overloadedTypes),
+                    [&args](unsigned index) { return args[index]->getType(); });
+    llvm::Module *module = builder.GetInsertBlock()->getModule();
+    llvm::Function *callee =
+      llvm::Intrinsic::getDeclaration(module,
+        llvm::Intrinsic::experimental_constrained_}] #
+    mnem # [{, overloadedTypes); }] #
+    !cond(!gt(hasRoundingMode, 0) : [{
+    // Get rounding mode using interface
+    llvm::RoundingMode rounding =
+        moduleTranslation.translateRoundingMode($roundingmode); }],
+          true : [{
+    // No rounding mode
+    std::optional<llvm::RoundingMode> rounding; }]) # [{
+    llvm::fp::ExceptionBehavior except =
+      moduleTranslation.translateFPExceptionBehavior($fpExceptionBehavior);
+    $res = builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
   }];
   let mlirBuilder = [{
-    auto op = moduleImport.translateConstrainedIntrinsic(
-      $_location, $_resultType, llvmOperands,
-      $_qualCppClassName::getOperationName());
-    if (!op)
+    SmallVector<Value> mlirOperands;
+    SmallVector<NamedAttribute> mlirAttrs;
+    if (failed(moduleImport.convertIntrinsicArguments(
+        llvmOperands.take_front( }] # numArgs # [{),
+        {}, {}, mlirOperands, mlirAttrs))) {
+      return failure();
+    }
+
+    FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
+        $_fpExceptionBehavior_attr($fpExceptionBehavior);
+    if (!fpExceptionBehaviorAttr)
+      return failure();
+    mlirAttrs.push_back(
+        $_builder.getNamedAttr(
+            $_qualCppClassName::getFPExceptionBehaviorAttrName(),
+            fpExceptionBehaviorAttr)); }] #
+    !cond(!gt(hasRoundingMode, 0) : [{
+    RoundingModeAttr roundingModeAttr =
+        $_roundingMode_attr($roundingmode);
+    if (!roundingModeAttr)
       return failure();
-    $res = op;
+    mlirAttrs.push_back(
+        $_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
+                               roundingModeAttr));
+    }], true : "") # [{
+    $res = $_builder.create<$_qualCppClassName>($_location,
+      $_resultType, mlirOperands, mlirAttrs);
   }];
 }
 
 def LLVM_ConstrainedFPTruncIntr
-    : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1, /*hasRoundingMode=*/1> {
+    : LLVM_ConstrainedIntr<"fptrunc", /*numArgs=*/1,
+        /*overloadedResult=*/1, /*overloadedOperands=*/[0],
+        /*hasRoundingMode=*/1> {
   let assemblyFormat = [{
-    $arg_0 $roundingmode $exceptionbehavior attr-dict `:` type($arg_0) `to` type(results)
+    $arg_0 $roundingmode $fpExceptionBehavior attr-dict `:` type($arg_0) `to` type(results)
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index b6aa73dad22970..7b9a9cf017c537 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -170,6 +170,10 @@ class LLVM_OpBase<Dialect dialect, string mnemonic, list<Trait> traits = []> :
   //   - $_float_attr - substituted by a call to a float attribute matcher;
   //   - $_var_attr - substituted by a call to a variable attribute matcher;
   //   - $_label_attr - substituted by a call to a label attribute matcher;
+  //   - $_roundingMode_attr - substituted by a call to a rounding mode
+  //     attribute matcher;
+  //   - $_fpExceptionBehavior_attr - substituted by a call to a FP exception
+  //     behavior attribute matcher;
   //   - $_resultType - substituted with the MLIR result type;
   //   - $_location - substituted with the MLIR location;
   //   - $_builder - substituted with the MLIR builder;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 16f9994c126e06..b551eb937cfe8d 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -152,6 +152,14 @@ class ModuleImport {
   /// Converts `value` to a label attribute. Asserts if the matching fails.
   DILabelAttr matchLabelAttr(llvm::Value *value);
 
+  /// Converts `value` to a FP exception behavior attribute. Asserts if the
+  /// matching fails.
+  FPExceptionBehaviorAttr matchFPExceptionBehaviorAttr(llvm::Value *value);
+
+  /// Converts `value` to a rounding mode attribute. Asserts if the matching
+  /// fails.
+  RoundingModeAttr matchRoundingModeAttr(llvm::Value *value);
+
   /// Converts `value` to an array of alias scopes or returns failure if the
   /// conversion fails.
   FailureOr<SmallVector<AliasScopeAttr>>
@@ -232,11 +240,6 @@ class ModuleImport {
                             SmallVectorImpl<Value> &valuesOut,
                             SmallVectorImpl<NamedAttribute> &attrsOut);
 
-  /// Import constrained intrinsic.
-  Value translateConstrainedIntrinsic(Location loc, Type type,
-                                      ArrayRef<llvm::Value *> llvmOperands,
-                                      StringRef opName);
-
 private:
   /// Clears the accumulated state before processing a new region.
   void clearRegionState() {
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 458ed585167bc0..310a43e0de96b3 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -201,6 +201,13 @@ class ModuleTranslation {
   /// Translates the given LLVM debug info metadata.
   llvm::Metadata *translateDebugInfo(LLVM::DINodeAttr attr);
 
+  /// Translates the given LLVM rounding mode metadata.
+  llvm::RoundingMode translateRoundingMode(LLVM::RoundingMode rounding);
+
+  /// Translates the given LLVM FP exception behavior metadata.
+  llvm::fp::ExceptionBehavior
+  translateFPExceptionBehavior(LLVM::FPExceptionBehavior exceptionBehavior);
+
   /// Translates the contents of the given block to LLVM IR using this
   /// translator. The LLVM IR basic block corresponding to the given block is
   /// expected to exist in the mapping of this translator. Uses `builder` to
@@ -423,10 +430,6 @@ llvm::CallInst *createIntrinsicCall(
     ArrayRef<unsigned> immArgPositions,
     ArrayRef<StringLiteral> immArgAttrNames);
 
-llvm::CallInst *createConstrainedIntrinsicCall(
-    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
-    Operation *intrOp, llvm::Intrinsic::ID intrinsic);
-
 } // namespace detail
 
 } // namespace LLVM
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 85a543c174f51d..374127a19ebb9f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,21 +1258,8 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
   return success();
 }
 
-static RoundingModeAttr metadataToRoundingMode(Builder &builder,
-                                               llvm::Metadata *metadata) {
-  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
-  if (!mdstr)
-    return {};
-  std::optional<llvm::RoundingMode> optLLVM =
-      llvm::convertStrToRoundingMode(mdstr->getString());
-  if (!optLLVM)
-    return {};
-  return builder.getAttr<RoundingModeAttr>(
-      convertRoundingModeFromLLVM(*optLLVM));
-}
-
-static ExceptionBehaviorAttr
-metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
+static FPExceptionBehaviorAttr
+metadataToFPExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
   auto *mdstr = dyn_cast<llvm::MDString>(metadata);
   if (!mdstr)
     return {};
@@ -1280,68 +1267,8 @@ metadataToExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
       llvm::convertStrToExceptionBehavior(mdstr->getString());
   if (!optLLVM)
     return {};
-  return builder.getAttr<ExceptionBehaviorAttr>(
-      convertExceptionBehaviorFromLLVM(*optLLVM));
-}
-
-static void
-splitMetadataAndValues(ArrayRef<llvm::Value *> inputs,
-                       SmallVectorImpl<llvm::Value *> &values,
-                       SmallVectorImpl<llvm::Metadata *> &metadata) {
-  for (llvm::Value *in : inputs) {
-    if (auto *mdval = dyn_cast<llvm::MetadataAsValue>(in)) {
-      metadata.push_back(mdval->getMetadata());
-    } else {
-      values.push_back(in);
-    }
-  }
-}
-
-Value ModuleImport::translateConstrainedIntrinsic(
-    Location loc, Type type, ArrayRef<llvm::Value *> llvmOperands,
-    StringRef opName) {
-  // Split metadata values from regular ones.
-  SmallVector<llvm::Value *> values;
-  SmallVector<llvm::Metadata *> metadata;
-  splitMetadataAndValues(llvmOperands, values, metadata);
-
-  // Expect 1 or 2 metadata values.
-  assert((metadata.size() == 1 || metadata.size() == 2) &&
-         "Unexpected number of arguments");
-
-  SmallVector<Value> mlirOperands;
-  SmallVector<NamedAttribute> mlirAttrs;
-  if (failed(
-          convertIntrinsicArguments(values, {}, {}, mlirOperands, mlirAttrs))) {
-    return {};
-  }
-
-  // Create operation as usual.
-  StringAttr opNameAttr = builder.getStringAttr(opName);
-  Operation *op =
-      builder.create(loc, opNameAttr, mlirOperands, type, mlirAttrs);
-
-  // Set exception behavior attribute.
-  auto exceptionBehaviorOp = cast<ExceptionBehaviorOpInterface>(op);
-  ExceptionBehaviorAttr attr =
-      metadataToExceptionBehavior(builder, metadata.back());
-  if (!attr)
-    return {};
-  op->setAttr(exceptionBehaviorOp.getExceptionBehaviorAttrName(), attr);
-
-  // If avaialbe, set rounding mode attribute.
-  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(op)) {
-    assert(metadata.size() > 1 && "Unexpected number of arguments");
-    // rounding_mode present
-    RoundingModeAttr attr = metadataToRoundingMode(builder, metadata[0]);
-    if (!attr)
-      return {};
-    roundingModeOp->setAttr(roundingModeOp.getRoundingModeAttrName(), attr);
-  } else {
-    assert(metadata.size() == 1 && "Unexpected number of arguments");
-  }
-
-  return op->getResult(0);
+  return builder.getAttr<FPExceptionBehaviorAttr>(
+      convertFPExceptionBehaviorFromLLVM(*optLLVM));
 }
 
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
@@ -1376,6 +1303,25 @@ DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
   return debugImporter->translate(node);
 }
 
+FPExceptionBehaviorAttr
+ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
+  auto *metadata = cast<llvm::MetadataAsValue>(value);
+  auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
+  std::optional<llvm::fp::ExceptionBehavior> optLLVM =
+      llvm::convertStrToExceptionBehavior(mdstr->getString());
+  return builder.getAttr<FPExceptionBehaviorAttr>(
+      convertFPExceptionBehaviorFromLLVM(*optLLVM));
+}
+
+RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
+  auto *metadata = cast<llvm::MetadataAsValue>(value);
+  auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
+  std::optional<llvm::RoundingMode> optLLVM =
+      llvm::convertStrToRoundingMode(mdstr->getString());
+  return builder.getAttr<RoundingModeAttr>(
+      convertRoundingModeFromLLVM(*optLLVM));
+}
+
 FailureOr<SmallVector<AliasScopeAttr>>
 ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
   auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 9100b1a18c71c6..1523f4195b38c8 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -862,29 +862,6 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
   return builder.CreateCall(llvmIntr, args);
 }
 
-llvm::CallInst *mlir::LLVM::detail::createConstrainedIntrinsicCall(
-    llvm::IRBuilderBase &builder, ModuleTranslation &moduleTranslation,
-    Operation *intrOp, llvm::Intrinsic::ID intrinsic) {
-  llvm::Module *module = builder.GetInsertBlock()->getModule();
-  SmallVector<llvm::Type *> overloadedTypes{
-      moduleTranslation.convertType(intrOp->getResult(0).getType()),
-      moduleTranslation.convertType(intrOp->getOperand(0).getType())};
-  llvm::Function *callee =
-      llvm::Intrinsic::getDeclaration(module, intrinsic, overloadedTypes);
-  SmallVector<llvm::Value *> args =
-      moduleTranslation.lookupValues(intrOp->getOperands());
-  std::optional<llvm::RoundingMode> rounding;
-  if (auto roundingModeOp = dyn_cast<RoundingModeOpInterface>(intrOp)) {
-    rounding = convertRoundingModeToLLVM(
-        roundingModeOp.getRoundingModeAttr().getValue());
-  }
-  llvm::fp::ExceptionBehavior except =
-      convertExceptionBehaviorToLLVM(cast<ExceptionBehaviorOpInterface>(intrOp)
-                                         .getExceptionBehaviorAttr()
-                                         .getValue());
-  return builder.CreateConstrainedFPCall(callee, args, "", rounding, except);
-}
-
 /// Given a single MLIR operation, create the corresponding LLVM IR operation
 /// using the `builder`.
 LogicalResult ModuleTranslation::convertOperation(Operation &op,
@@ -1744,6 +1721,18 @@ llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
   return debugTranslation->translate(attr);
 }
 
+/// Translates the given LLVM rounding mode metadata.
+llvm::RoundingMode
+ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
+  return convertRoundingModeToLLVM(rounding);
+}
+
+/// Translates the given LLVM FP exception behavior metadata.
+llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
+    LLVM::FPExceptionBehavior exceptionBehavior) {
+  return convertFPExceptionBehaviorToLLVM(exceptionBehavior);
+}
+
 llvm::NamedMDNode *
 ModuleTranslation::getOrInsertNamedModuleMetadata(StringRef name) {
   return llvmModule->getOrInsertNamedMetadata(name);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index cc415e7b3662be..31acf2b95e4638 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -649,7 +649,7 @@ llvm.func @experimental_noalias_scope_decl() {
 }
 
 // CHECK-LABEL: @experimental_constrained_fptrunc
-llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
+llvm.func @experimental_constrained_fptrunc(%in: f64) {
   // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} towardzero ignore : f64 to f32
   %0 = llvm.intr.experimental.constrained.fptrunc %in towardzero ignore : f64 to f32
   // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearest maytrap : f64 to f32
@@ -660,9 +660,5 @@ llvm.func @experimental_constrained_fptrunc(%in: f64) -> f32 {
   %3 = llvm.intr.experimental.constrained.fptrunc %in downward ignore : f64 to f32
   // CHECK: llvm.intr.experimental.constrained.fptrunc %{{.*}} tonearestaway ignore : f64 to f32
   %4 = llvm.intr.experimental.constrained.fptrunc %in tonearestaway ignore : f64 to f32
-  %tmp0 = llvm.fadd %0, %1 : f32
-  %tmp1 = llvm.fadd %2, %3 : f32
-  %tmp2 = llvm.fadd %tmp0, %tmp1 : f32
-  %res = llvm.fadd %tmp2, %4 : f32
-  llvm.return %res : f32
+  llvm.return
 }
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 23bc9b00dc902d..2c7acec3b1b853 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -272,6 +272,10 @@ static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
       bs << "moduleImport.matchLocalVariableAttr";
     } else if (name == "_label_attr") {
       bs << "moduleImport.matchLabelAttr";
+    } else if (name == "_fpExceptionBehavior_attr") {
+      bs << "moduleImport.matchFPExceptionBehaviorAttr";
+    } else if (name == "_roundingMode_attr") {
+      bs << "moduleImport.matchRoundingModeAttr";
     } else if (name == "_resultType") {
       bs << "moduleImport.convertType(inst->getType())";
     } else if (name == "_location") {

>From f5b9e6f711757a7c2628e0761b589cee228b4e4c Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Mon, 25 Mar 2024 15:46:11 +0000
Subject: [PATCH 3/4] Assert on always expected attribute

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index e86369b47b49a4..246c1027f8ff30 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -372,17 +372,14 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
 
     FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
         $_fpExceptionBehavior_attr($fpExceptionBehavior);
-    if (!fpExceptionBehaviorAttr)
-      return failure();
+    assert(fpExceptionBehaviorAttr && "Expecting FP exception behavior");
     mlirAttrs.push_back(
         $_builder.getNamedAttr(
             $_qualCppClassName::getFPExceptionBehaviorAttrName(),
             fpExceptionBehaviorAttr)); }] #
     !cond(!gt(hasRoundingMode, 0) : [{
-    RoundingModeAttr roundingModeAttr =
-        $_roundingMode_attr($roundingmode);
-    if (!roundingModeAttr)
-      return failure();
+    RoundingModeAttr roundingModeAttr = $_roundingMode_attr($roundingmode);
+    assert(roundingModeAttr && "Expecting rounding mode");
     mlirAttrs.push_back(
         $_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
                                roundingModeAttr));

>From f60b9970a1b90b67d7702f15d0f6d40f5efb601e Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Tue, 26 Mar 2024 08:16:07 +0000
Subject: [PATCH 4/4] Apply NITs

---
 .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td       | 12 +++++-------
 mlir/lib/Target/LLVMIR/ModuleImport.cpp           | 15 ++-------------
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp      |  2 --
 3 files changed, 7 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 246c1027f8ff30..f4bac9376f2ea0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -311,7 +311,7 @@ def LLVM_InvariantEndOp : LLVM_ZeroResultIntrOp<"invariant.end", [2],
       "qualified(type($ptr))";
 }
 
-// Constrained Floating-Point Intrinsics
+// Constrained Floating-Point Intrinsics.
 
 class LLVM_ConstrainedIntr<string mnem, int numArgs,
                            bit overloadedResult, list<int> overloadedOperands,
@@ -338,9 +338,9 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
       moduleTranslation.lookupValues(opInst.getOperands());
     SmallVector<llvm::Type *> overloadedTypes; }] #
     !cond(!gt(overloadedResult, 0) : [{
-    // Take into account overloaded result type
+    // Take into account overloaded result type.
     overloadedTypes.push_back($_resultType); }],
-    // No overloaded result type
+    // No overloaded result type.
           true : "") # [{
     llvm::transform(ArrayRef<unsigned>}] # overloadedOperandsCpp # [{,
                     std::back_inserter(overloadedTypes),
@@ -351,11 +351,11 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
         llvm::Intrinsic::experimental_constrained_}] #
     mnem # [{, overloadedTypes); }] #
     !cond(!gt(hasRoundingMode, 0) : [{
-    // Get rounding mode using interface
+    // Get rounding mode using interface.
     llvm::RoundingMode rounding =
         moduleTranslation.translateRoundingMode($roundingmode); }],
           true : [{
-    // No rounding mode
+    // No rounding mode.
     std::optional<llvm::RoundingMode> rounding; }]) # [{
     llvm::fp::ExceptionBehavior except =
       moduleTranslation.translateFPExceptionBehavior($fpExceptionBehavior);
@@ -372,14 +372,12 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
 
     FPExceptionBehaviorAttr fpExceptionBehaviorAttr =
         $_fpExceptionBehavior_attr($fpExceptionBehavior);
-    assert(fpExceptionBehaviorAttr && "Expecting FP exception behavior");
     mlirAttrs.push_back(
         $_builder.getNamedAttr(
             $_qualCppClassName::getFPExceptionBehaviorAttrName(),
             fpExceptionBehaviorAttr)); }] #
     !cond(!gt(hasRoundingMode, 0) : [{
     RoundingModeAttr roundingModeAttr = $_roundingMode_attr($roundingmode);
-    assert(roundingModeAttr && "Expecting rounding mode");
     mlirAttrs.push_back(
         $_builder.getNamedAttr($_qualCppClassName::getRoundingModeAttrName(),
                                roundingModeAttr));
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 374127a19ebb9f..6e70d52fa760b6 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1258,19 +1258,6 @@ LogicalResult ModuleImport::convertIntrinsicArguments(
   return success();
 }
 
-static FPExceptionBehaviorAttr
-metadataToFPExceptionBehavior(Builder &builder, llvm::Metadata *metadata) {
-  auto *mdstr = dyn_cast<llvm::MDString>(metadata);
-  if (!mdstr)
-    return {};
-  std::optional<llvm::fp::ExceptionBehavior> optLLVM =
-      llvm::convertStrToExceptionBehavior(mdstr->getString());
-  if (!optLLVM)
-    return {};
-  return builder.getAttr<FPExceptionBehaviorAttr>(
-      convertFPExceptionBehaviorFromLLVM(*optLLVM));
-}
-
 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
   IntegerAttr integerAttr;
   FailureOr<Value> converted = convertValue(value);
@@ -1309,6 +1296,7 @@ ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
   auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
   std::optional<llvm::fp::ExceptionBehavior> optLLVM =
       llvm::convertStrToExceptionBehavior(mdstr->getString());
+  assert(optLLVM && "Expecting FP exception behavior");
   return builder.getAttr<FPExceptionBehaviorAttr>(
       convertFPExceptionBehaviorFromLLVM(*optLLVM));
 }
@@ -1318,6 +1306,7 @@ RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
   auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
   std::optional<llvm::RoundingMode> optLLVM =
       llvm::convertStrToRoundingMode(mdstr->getString());
+  assert(optLLVM && "Expecting rounding mode");
   return builder.getAttr<RoundingModeAttr>(
       convertRoundingModeFromLLVM(*optLLVM));
 }
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 1523f4195b38c8..669b95a9c6a5be 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1721,13 +1721,11 @@ llvm::Metadata *ModuleTranslation::translateDebugInfo(LLVM::DINodeAttr attr) {
   return debugTranslation->translate(attr);
 }
 
-/// Translates the given LLVM rounding mode metadata.
 llvm::RoundingMode
 ModuleTranslation::translateRoundingMode(LLVM::RoundingMode rounding) {
   return convertRoundingModeToLLVM(rounding);
 }
 
-/// Translates the given LLVM FP exception behavior metadata.
 llvm::fp::ExceptionBehavior ModuleTranslation::translateFPExceptionBehavior(
     LLVM::FPExceptionBehavior exceptionBehavior) {
   return convertFPExceptionBehaviorToLLVM(exceptionBehavior);



More information about the Mlir-commits mailing list