[Mlir-commits] [mlir] [MLIR][Arith] Add rounding mode attribute to `truncf` (PR #86152)

Victor Perez llvmlistbot at llvm.org
Thu Mar 28 10:12:00 PDT 2024


https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/86152

>From eeb49dea6ff0470cef0a44e341283b16cd6a5d89 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:38:19 +0000
Subject: [PATCH 1/4] [MLIR][Arith] Add rounding mode attribute to `truncf`

Add rounding mode attribute to `arith`. This attribute can be used in
different FP `arith` operations to control rounding mode. Rounding
modes correspond to IEEE 754-specified rounding modes.

As this is not supported in other dialects, conversion should fail for
now in case this attribute is present.

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Dialect/Arith/IR/ArithBase.td        | 25 ++++++++++++++
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 21 ++++++++++--
 .../Dialect/Arith/IR/ArithOpsInterfaces.td    | 33 +++++++++++++++++++
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  3 ++
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 28 +++++++++++++++-
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  |  9 +++++
 .../Dialect/Arith/Transforms/ExpandOps.cpp    |  6 ++++
 mlir/test/Dialect/Arith/ops.mlir              | 10 ++++++
 8 files changed, 131 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c43c880b0..a9d976d9e4e28c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven       // Round to nearest, ties to even
+    : I32EnumAttrCase<"tonearesteven", 0>;
+def Arith_RDownward                  // Round toward -inf
+    : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward                    // Round toward +inf
+    : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero                // Round toward 0
+    : I32EnumAttrCase<"towardzero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+    : I32EnumAttrCase<"tonearestaway", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+    "RoundingMode", "Floating point rounding mode",
+    [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+     Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+  let cppNamespace = "::mlir::arith";
+}
+
 #endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d0395d1f..ead19c69a0831c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
 // TruncFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+    Arith_Op<"truncf",
+      [Pure, SameOperandsAndResultShape,
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+    Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to narrower floating-point";
   let description = [{
     Truncate a floating-point value to a smaller floating-point-typed value.
     The destination type must be strictly narrower than the source type.
-    If the value cannot be exactly represented, it is rounded using the default
-    rounding mode. When operating on vectors, casts elementwise.
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+    When operating on vectors, casts elementwise.
   }];
+  let builders = [
+    OpBuilder<(ins "Type":$out, "Value":$in), [{
+      $_state.addOperands(in);
+      $_state.addTypes(out);
+    }]>
+  ];
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c32ef205..82d6c9ad6b03da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
   ];
 }
 
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+  let description = [{
+    Access to op rounding mode.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a RoundingModeAttr attribute for the operation",
+      /*returnType=*/  "RoundingModeAttr",
+      /*methodName=*/  "getRoundingModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRoundingmodeAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the RoundingModeAttr attribute for
+                         the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getRoundingModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "roundingmode";
+      }]
+    >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13ae362e92..0113a3df0b8e3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
 }
 
 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+  // Only supporting default rounding mode as of now.
+  if (op.getRoundingmodeAttr())
+    return failure();
   Type outType = op.getOut().getType();
   if (auto outVecType = outType.dyn_cast<VectorType>()) {
     if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a75c5b3e..9d1961486303a2 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
 
 namespace {
 
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
+/// will lower to; \tparam AttrConvert is the attribute conversion to convert
+/// the rounding mode attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+          template <typename, typename> typename AttrConvert =
+              AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                   AttrConvert>::VectorConvertToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+      return failure();
+    return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                      AttrConvert>::matchAndRewrite(op, adaptor,
+                                                                    rewriter);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
-    VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+    ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+                                          false>;
 using TruncIOpLowering =
     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
 using UIToFPOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..843d9c0afaadc1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
     } else {
       rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
                                                     adaptor.getOperands());
+      if (auto roundingModeOp =
+              dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+        if (arith::RoundingModeAttr roundingMode =
+                roundingModeOp.getRoundingModeAttr()) {
+          // TODO: Perform rounding mode attribute conversion and attach to new
+          // operation when defined in the dialect.
+          return failure();
+        }
+      }
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 71e14a153cfda9..0a57ad5ec7493d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -253,6 +253,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
     }
 
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573e324b5f..e4b23f073117e6 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
   return %0 : vector<[8]xbf16>
 }
 
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+  %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+  %3 = arith.truncf %arg0 towardzero : f64 to f32
+  %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+  return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
 // CHECK-LABEL: test_uitofp
 func.func @test_uitofp(%arg0 : i32) -> f32 {
   %0 = arith.uitofp %arg0 : i32 to f32

>From a7e82a7dd07a2e3b8acb4265a30ad4880a80ec37 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:49:59 +0000
Subject: [PATCH 2/4] Drop unused

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 0a57ad5ec7493d..dd04a599655894 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -258,7 +258,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
           op, "only applicable to default rounding mode.");
     }
 
-    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();

>From f144e154d7bee6e8ce649f72fc9e5d07ef1ea8ff Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 28 Mar 2024 17:03:18 +0000
Subject: [PATCH 3/4] Extend

---
 .../ArithCommon/AttrToLLVMConverter.h         | 14 ++++++
 .../ArithCommon/AttrToLLVMConverter.cpp       | 31 +++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 45 +++++++++++++++----
 mlir/test/Dialect/Arith/canonicalize.mlir     | 45 +++++++++++++++++++
 4 files changed, 127 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 32d7979c32dfb2..a5a95c3834e8a1 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -36,6 +36,20 @@ convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
 LLVM::IntegerOverflowFlagsAttr
 convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
 
+/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
+/// mode enum value.
+LLVM::RoundingMode
+convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
+
+/// Creates an LLVM rounding mode attribute from a given arithmetic rounding
+/// mode attribute.
+LLVM::RoundingModeAttr
+convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
+
+/// Returns an attribute for the default LLVM FP exception behavior.
+LLVM::FPExceptionBehaviorAttr
+getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
+
 // Attribute converter that populates a NamedAttrList by removing the fastmath
 // attribute from the source operation attributes, and replacing it with an
 // equivalent LLVM fastmath attribute.
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index dab064a3a954ec..595e4b9cd232cc 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -55,3 +55,34 @@ LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
   return LLVM::IntegerOverflowFlagsAttr::get(
       flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
 }
+
+LLVM::RoundingMode
+mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
+  switch (roundingMode) {
+  case arith::RoundingMode::downward:
+    return LLVM::RoundingMode::TowardNegative;
+  case arith::RoundingMode::tonearestaway:
+    return LLVM::RoundingMode::NearestTiesToAway;
+  case arith::RoundingMode::tonearesteven:
+    return LLVM::RoundingMode::NearestTiesToEven;
+  case arith::RoundingMode::towardzero:
+    return LLVM::RoundingMode::TowardZero;
+  case arith::RoundingMode::upward:
+    return LLVM::RoundingMode::TowardPositive;
+  }
+  llvm_unreachable("Unhandled rounding mode");
+}
+
+LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
+    arith::RoundingModeAttr roundingModeAttr) {
+  assert(roundingModeAttr && "Expecting valid attribute");
+  return LLVM::RoundingModeAttr::get(
+      roundingModeAttr.getContext(),
+      convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
+}
+
+LLVM::FPExceptionBehaviorAttr
+mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
+  return LLVM::FPExceptionBehaviorAttr::get(&context,
+                                            LLVM::FPExceptionBehavior::Ignore);
+}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2f32d9a26e7752..95c3d4795e8f10 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -91,6 +91,29 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
   llvm_unreachable("unknown cmpi predicate kind");
 }
 
+/// Equivalent to
+/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
+///
+/// Not possible to implement as chain of calls as this would introduce a
+/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
+/// on the LLVM dialect and on translation to LLVM.
+static llvm::RoundingMode
+convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
+  switch (roundingMode) {
+  case RoundingMode::downward:
+    return llvm::RoundingMode::TowardNegative;
+  case RoundingMode::tonearestaway:
+    return llvm::RoundingMode::NearestTiesToAway;
+  case RoundingMode::tonearesteven:
+    return llvm::RoundingMode::NearestTiesToEven;
+  case RoundingMode::towardzero:
+    return llvm::RoundingMode::TowardZero;
+  case RoundingMode::upward:
+    return llvm::RoundingMode::TowardPositive;
+  }
+  llvm_unreachable("Unhandled rounding mode");
+}
+
 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
   return arith::CmpIPredicateAttr::get(pred.getContext(),
                                        invertPredicate(pred.getValue()));
@@ -1233,13 +1256,12 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
 }
 
 /// Attempts to convert `sourceValue` to an APFloat value with
-/// `targetSemantics`, without any information loss or rounding.
-static FailureOr<APFloat>
-convertFloatValue(APFloat sourceValue,
-                  const llvm::fltSemantics &targetSemantics) {
+/// `targetSemantics` and `roundingMode`, without any information loss.
+static FailureOr<APFloat> convertFloatValue(
+    APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
+    llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
   bool losesInfo = false;
-  auto status = sourceValue.convert(
-      targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+  auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
   if (losesInfo || status != APFloat::opOK)
     return failure();
 
@@ -1398,8 +1420,15 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
-      [&targetSemantics](const APFloat &a, bool &castStatus) {
-        FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+      [this, &targetSemantics](const APFloat &a, bool &castStatus) {
+        FailureOr<APFloat> result;
+        if (std::optional<RoundingMode> roundingMode = getRoundingmode()) {
+          llvm::RoundingMode llvmRoundingMode =
+            convertArithRoundingModeToLLVMIR(*roundingMode);
+          result = convertFloatValue(a, targetSemantics, llvmRoundingMode);
+        } else {
+          result = convertFloatValue(a, targetSemantics);
+        }
         if (failed(result)) {
           castStatus = false;
           return a;
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index bdc6c91d926775..e70d858cbaf33f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -757,6 +757,51 @@ func.func @truncFPConstant() -> bf16 {
   return %0 : bf16
 }
 
+// CHECK-LABEL: @truncFPToNearestEvenConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+//       CHECK:   return %[[cres]]
+func.func @truncFPToNearestEvenConstant() -> bf16 {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = arith.truncf %cst tonearesteven : f32 to bf16
+  return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPDownwardConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+//       CHECK:   return %[[cres]]
+func.func @truncFPDownwardConstant() -> bf16 {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = arith.truncf %cst downward : f32 to bf16
+  return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPUpwardConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+//       CHECK:   return %[[cres]]
+func.func @truncFPUpwardConstant() -> bf16 {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = arith.truncf %cst upward : f32 to bf16
+  return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPTowardZeroConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+//       CHECK:   return %[[cres]]
+func.func @truncFPTowardZeroConstant() -> bf16 {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = arith.truncf %cst towardzero : f32 to bf16
+  return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPToNearestAwayConstant
+//       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+//       CHECK:   return %[[cres]]
+func.func @truncFPToNearestAwayConstant() -> bf16 {
+  %cst = arith.constant 1.000000e+00 : f32
+  %0 = arith.truncf %cst tonearestaway : f32 to bf16
+  return %0 : bf16
+}
+
 // CHECK-LABEL: @truncFPVectorConstant
 //       CHECK:   %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
 //       CHECK:   return %[[cres]]

>From 3b756e4555397cfcfd11b262613875dc0a82d9cd Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 28 Mar 2024 17:11:39 +0000
Subject: [PATCH 4/4] Extend

---
 .../ArithCommon/AttrToLLVMConverter.h         | 35 +++++++++++++++++++
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    |  4 +++
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 15 ++++++++
 3 files changed, 54 insertions(+)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index a5a95c3834e8a1..87643be7415e8d 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -103,6 +103,41 @@ class AttrConvertOverflowToLLVM {
 private:
   NamedAttrList convertedAttr;
 };
+
+template <typename SourceOp, typename TargetOp>
+class AttrConverterConstrainedFPToLLVM {
+  static_assert(TargetOp::template hasTrait<
+                    LLVM::FPExceptionBehaviorOpInterface::Trait>(),
+                "Target constrained FP operations must implement "
+                "LLVM::FPExceptionBehaviorOpInterface");
+
+public:
+  AttrConverterConstrainedFPToLLVM(
+      SourceOp srcOp) {
+    // Copy the source attributes.
+    convertedAttr = NamedAttrList{srcOp->getAttrs()};
+
+    if constexpr (TargetOp::template hasTrait<
+                      LLVM::RoundingModeOpInterface::Trait>()) {
+      // Get the name of the rounding mode attribute.
+      StringRef arithAttrName = srcOp.getRoundingModeAttrName();
+      // Remove the source attribute.
+      auto arithAttr =
+          cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName));
+      // Set the target attribute.
+      convertedAttr.set(TargetOp::getRoundingModeAttrName(),
+                        convertArithRoundingModeAttrToLLVM(arithAttr));
+    }
+    convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
+                      getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
+  }
+
+  ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+  NamedAttrList convertedAttr;
+};
+
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 9d1961486303a2..6d00815cf9e0f0 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -139,6 +139,9 @@ using SubIOpLowering =
 using TruncFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
                                           false>;
+using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
+    arith::AttrConverterConstrainedFPToLLVM>;
 using TruncIOpLowering =
     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
 using UIToFPOpLowering =
@@ -563,6 +566,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     SubFOpLowering,
     SubIOpLowering,
     TruncFOpLowering,
+    ConstrainedTruncFOpLowering,
     TruncIOpLowering,
     UIToFPOpLowering,
     XOrIOpLowering
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 29268eef47e853..8eacc7a67181d0 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -289,6 +289,21 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
   return
 }
 
+// CHECK-LABEL: experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
+  %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
+  %3 = arith.truncf %arg0 towardzero : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
+  %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+  return
+}
+
 // Check sign and zero extension and truncation of integers.
 // CHECK-LABEL: @integer_extension_and_truncation
 func.func @integer_extension_and_truncation(%arg0 : i3) {



More information about the Mlir-commits mailing list