[Mlir-commits] [mlir] [MLIR][Arith] Add rounding mode attribute to `truncf` (PR #86152)
Victor Perez
llvmlistbot at llvm.org
Fri Mar 29 03:36:13 PDT 2024
https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/86152
>From eeb49dea6ff0470cef0a44e341283b16cd6a5d89 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:38:19 +0000
Subject: [PATCH 1/6] [MLIR][Arith] Add rounding mode attribute to `truncf`
Add rounding mode attribute to `arith`. This attribute can be used in
different FP `arith` operations to control rounding mode. Rounding
modes correspond to IEEE 754-specified rounding modes.
As this is not supported in other dialects, conversion should fail for
now in case this attribute is present.
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
.../mlir/Dialect/Arith/IR/ArithBase.td | 25 ++++++++++++++
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 21 ++++++++++--
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 33 +++++++++++++++++++
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 3 ++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 28 +++++++++++++++-
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 +++++
.../Dialect/Arith/Transforms/ExpandOps.cpp | 6 ++++
mlir/test/Dialect/Arith/ops.mlir | 10 ++++++
8 files changed, 131 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c43c880b0..a9d976d9e4e28c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
let assemblyFormat = "`<` $value `>`";
}
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven // Round to nearest, ties to even
+ : I32EnumAttrCase<"tonearesteven", 0>;
+def Arith_RDownward // Round toward -inf
+ : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward // Round toward +inf
+ : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero // Round toward 0
+ : I32EnumAttrCase<"towardzero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+ : I32EnumAttrCase<"tonearestaway", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+ "RoundingMode", "Floating point rounding mode",
+ [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+ Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+ let cppNamespace = "::mlir::arith";
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d0395d1f..ead19c69a0831c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
// TruncFOp
//===----------------------------------------------------------------------===//
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+ Arith_Op<"truncf",
+ [Pure, SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
- If the value cannot be exactly represented, it is rounded using the default
- rounding mode. When operating on vectors, casts elementwise.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+ When operating on vectors, casts elementwise.
}];
+ let builders = [
+ OpBuilder<(ins "Type":$out, "Value":$in), [{
+ $_state.addOperands(in);
+ $_state.addTypes(out);
+ }]>
+ ];
let hasFolder = 1;
let hasVerifier = 1;
+ let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c32ef205..82d6c9ad6b03da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
];
}
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+ let description = [{
+ Access to op rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingModeAttr attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute for
+ the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >
+ ];
+}
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13ae362e92..0113a3df0b8e3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
}
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+ // Only supporting default rounding mode as of now.
+ if (op.getRoundingmodeAttr())
+ return failure();
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a75c5b3e..9d1961486303a2 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
namespace {
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
+/// will lower to; \tparam AttrConvert is the attribute conversion to convert
+/// the rounding mode attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+ template <typename, typename> typename AttrConvert =
+ AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+ using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::VectorConvertToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+ return failure();
+ return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::matchAndRewrite(op, adaptor,
+ rewriter);
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
- VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+ ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+ false>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using UIToFPOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..843d9c0afaadc1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
} else {
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
+ if (auto roundingModeOp =
+ dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+ if (arith::RoundingModeAttr roundingMode =
+ roundingModeOp.getRoundingModeAttr()) {
+ // TODO: Perform rounding mode attribute conversion and attach to new
+ // operation when defined in the dialect.
+ return failure();
+ }
+ }
}
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 71e14a153cfda9..0a57ad5ec7493d 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -253,6 +253,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573e324b5f..e4b23f073117e6 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
return %0 : vector<[8]xbf16>
}
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+ %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+ %3 = arith.truncf %arg0 towardzero : f64 to f32
+ %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+ return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: test_uitofp
func.func @test_uitofp(%arg0 : i32) -> f32 {
%0 = arith.uitofp %arg0 : i32 to f32
>From a7e82a7dd07a2e3b8acb4265a30ad4880a80ec37 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:49:59 +0000
Subject: [PATCH 2/6] Drop unused
---
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 0a57ad5ec7493d..dd04a599655894 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -258,7 +258,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
op, "only applicable to default rounding mode.");
}
- Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
>From f144e154d7bee6e8ce649f72fc9e5d07ef1ea8ff Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 28 Mar 2024 17:03:18 +0000
Subject: [PATCH 3/6] Extend
---
.../ArithCommon/AttrToLLVMConverter.h | 14 ++++++
.../ArithCommon/AttrToLLVMConverter.cpp | 31 +++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 45 +++++++++++++++----
mlir/test/Dialect/Arith/canonicalize.mlir | 45 +++++++++++++++++++
4 files changed, 127 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 32d7979c32dfb2..a5a95c3834e8a1 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -36,6 +36,20 @@ convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
LLVM::IntegerOverflowFlagsAttr
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
+/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
+/// mode enum value.
+LLVM::RoundingMode
+convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
+
+/// Creates an LLVM rounding mode attribute from a given arithmetic rounding
+/// mode attribute.
+LLVM::RoundingModeAttr
+convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
+
+/// Returns an attribute for the default LLVM FP exception behavior.
+LLVM::FPExceptionBehaviorAttr
+getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
+
// Attribute converter that populates a NamedAttrList by removing the fastmath
// attribute from the source operation attributes, and replacing it with an
// equivalent LLVM fastmath attribute.
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index dab064a3a954ec..595e4b9cd232cc 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -55,3 +55,34 @@ LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
}
+
+LLVM::RoundingMode
+mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
+ switch (roundingMode) {
+ case arith::RoundingMode::downward:
+ return LLVM::RoundingMode::TowardNegative;
+ case arith::RoundingMode::tonearestaway:
+ return LLVM::RoundingMode::NearestTiesToAway;
+ case arith::RoundingMode::tonearesteven:
+ return LLVM::RoundingMode::NearestTiesToEven;
+ case arith::RoundingMode::towardzero:
+ return LLVM::RoundingMode::TowardZero;
+ case arith::RoundingMode::upward:
+ return LLVM::RoundingMode::TowardPositive;
+ }
+ llvm_unreachable("Unhandled rounding mode");
+}
+
+LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
+ arith::RoundingModeAttr roundingModeAttr) {
+ assert(roundingModeAttr && "Expecting valid attribute");
+ return LLVM::RoundingModeAttr::get(
+ roundingModeAttr.getContext(),
+ convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
+}
+
+LLVM::FPExceptionBehaviorAttr
+mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
+ return LLVM::FPExceptionBehaviorAttr::get(&context,
+ LLVM::FPExceptionBehavior::Ignore);
+}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 2f32d9a26e7752..95c3d4795e8f10 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -91,6 +91,29 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
llvm_unreachable("unknown cmpi predicate kind");
}
+/// Equivalent to
+/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
+///
+/// Not possible to implement as chain of calls as this would introduce a
+/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
+/// on the LLVM dialect and on translation to LLVM.
+static llvm::RoundingMode
+convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
+ switch (roundingMode) {
+ case RoundingMode::downward:
+ return llvm::RoundingMode::TowardNegative;
+ case RoundingMode::tonearestaway:
+ return llvm::RoundingMode::NearestTiesToAway;
+ case RoundingMode::tonearesteven:
+ return llvm::RoundingMode::NearestTiesToEven;
+ case RoundingMode::towardzero:
+ return llvm::RoundingMode::TowardZero;
+ case RoundingMode::upward:
+ return llvm::RoundingMode::TowardPositive;
+ }
+ llvm_unreachable("Unhandled rounding mode");
+}
+
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
return arith::CmpIPredicateAttr::get(pred.getContext(),
invertPredicate(pred.getValue()));
@@ -1233,13 +1256,12 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
}
/// Attempts to convert `sourceValue` to an APFloat value with
-/// `targetSemantics`, without any information loss or rounding.
-static FailureOr<APFloat>
-convertFloatValue(APFloat sourceValue,
- const llvm::fltSemantics &targetSemantics) {
+/// `targetSemantics` and `roundingMode`, without any information loss.
+static FailureOr<APFloat> convertFloatValue(
+ APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
+ llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
bool losesInfo = false;
- auto status = sourceValue.convert(
- targetSemantics, llvm::RoundingMode::NearestTiesToEven, &losesInfo);
+ auto status = sourceValue.convert(targetSemantics, roundingMode, &losesInfo);
if (losesInfo || status != APFloat::opOK)
return failure();
@@ -1398,8 +1420,15 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
- [&targetSemantics](const APFloat &a, bool &castStatus) {
- FailureOr<APFloat> result = convertFloatValue(a, targetSemantics);
+ [this, &targetSemantics](const APFloat &a, bool &castStatus) {
+ FailureOr<APFloat> result;
+ if (std::optional<RoundingMode> roundingMode = getRoundingmode()) {
+ llvm::RoundingMode llvmRoundingMode =
+ convertArithRoundingModeToLLVMIR(*roundingMode);
+ result = convertFloatValue(a, targetSemantics, llvmRoundingMode);
+ } else {
+ result = convertFloatValue(a, targetSemantics);
+ }
if (failed(result)) {
castStatus = false;
return a;
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index bdc6c91d926775..e70d858cbaf33f 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -757,6 +757,51 @@ func.func @truncFPConstant() -> bf16 {
return %0 : bf16
}
+// CHECK-LABEL: @truncFPToNearestEvenConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPToNearestEvenConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst tonearesteven : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPDownwardConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPDownwardConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst downward : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPUpwardConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPUpwardConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst upward : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPTowardZeroConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPTowardZeroConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst towardzero : f32 to bf16
+ return %0 : bf16
+}
+
+// CHECK-LABEL: @truncFPToNearestAwayConstant
+// CHECK: %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
+// CHECK: return %[[cres]]
+func.func @truncFPToNearestAwayConstant() -> bf16 {
+ %cst = arith.constant 1.000000e+00 : f32
+ %0 = arith.truncf %cst tonearestaway : f32 to bf16
+ return %0 : bf16
+}
+
// CHECK-LABEL: @truncFPVectorConstant
// CHECK: %[[cres:.+]] = arith.constant dense<[0.000000e+00, 1.000000e+00]> : vector<2xbf16>
// CHECK: return %[[cres]]
>From 3b756e4555397cfcfd11b262613875dc0a82d9cd Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 28 Mar 2024 17:11:39 +0000
Subject: [PATCH 4/6] Extend
---
.../ArithCommon/AttrToLLVMConverter.h | 35 +++++++++++++++++++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 4 +++
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 15 ++++++++
3 files changed, 54 insertions(+)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index a5a95c3834e8a1..87643be7415e8d 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -103,6 +103,41 @@ class AttrConvertOverflowToLLVM {
private:
NamedAttrList convertedAttr;
};
+
+template <typename SourceOp, typename TargetOp>
+class AttrConverterConstrainedFPToLLVM {
+ static_assert(TargetOp::template hasTrait<
+ LLVM::FPExceptionBehaviorOpInterface::Trait>(),
+ "Target constrained FP operations must implement "
+ "LLVM::FPExceptionBehaviorOpInterface");
+
+public:
+ AttrConverterConstrainedFPToLLVM(
+ SourceOp srcOp) {
+ // Copy the source attributes.
+ convertedAttr = NamedAttrList{srcOp->getAttrs()};
+
+ if constexpr (TargetOp::template hasTrait<
+ LLVM::RoundingModeOpInterface::Trait>()) {
+ // Get the name of the rounding mode attribute.
+ StringRef arithAttrName = srcOp.getRoundingModeAttrName();
+ // Remove the source attribute.
+ auto arithAttr =
+ cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName));
+ // Set the target attribute.
+ convertedAttr.set(TargetOp::getRoundingModeAttrName(),
+ convertArithRoundingModeAttrToLLVM(arithAttr));
+ }
+ convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
+ getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
+ }
+
+ ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+
+private:
+ NamedAttrList convertedAttr;
+};
+
} // namespace arith
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 9d1961486303a2..6d00815cf9e0f0 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -139,6 +139,9 @@ using SubIOpLowering =
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
false>;
+using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
+ arith::AttrConverterConstrainedFPToLLVM>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using UIToFPOpLowering =
@@ -563,6 +566,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
SubFOpLowering,
SubIOpLowering,
TruncFOpLowering,
+ ConstrainedTruncFOpLowering,
TruncIOpLowering,
UIToFPOpLowering,
XOrIOpLowering
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 29268eef47e853..8eacc7a67181d0 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -289,6 +289,21 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
return
}
+// CHECK-LABEL: experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
+ %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
+ %3 = arith.truncf %arg0 towardzero : f64 to f32
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
+ %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+ return
+}
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
>From cc02465536d0d7841f5f9d123444303067432cf3 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Fri, 29 Mar 2024 10:20:27 +0000
Subject: [PATCH 5/6] Format
---
mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 87643be7415e8d..0891e2ba7be760 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -112,8 +112,7 @@ class AttrConverterConstrainedFPToLLVM {
"LLVM::FPExceptionBehaviorOpInterface");
public:
- AttrConverterConstrainedFPToLLVM(
- SourceOp srcOp) {
+ AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
// Copy the source attributes.
convertedAttr = NamedAttrList{srcOp->getAttrs()};
>From 1d54572182730cb5854c88f988a649a59f48683e Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Fri, 29 Mar 2024 10:35:55 +0000
Subject: [PATCH 6/6] Format and fix comment
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 95c3d4795e8f10..d79a00727033cc 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1413,8 +1413,7 @@ LogicalResult arith::TruncIOp::verify() {
//===----------------------------------------------------------------------===//
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
-/// can be represented without precision loss or rounding. This is because the
-/// semantics of `arith.truncf` do not assume a specific rounding mode.
+/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
@@ -1424,7 +1423,7 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
FailureOr<APFloat> result;
if (std::optional<RoundingMode> roundingMode = getRoundingmode()) {
llvm::RoundingMode llvmRoundingMode =
- convertArithRoundingModeToLLVMIR(*roundingMode);
+ convertArithRoundingModeToLLVMIR(*roundingMode);
result = convertFloatValue(a, targetSemantics, llvmRoundingMode);
} else {
result = convertFloatValue(a, targetSemantics);
More information about the Mlir-commits
mailing list