[Mlir-commits] [mlir] [MLIR][Arith] Add rounding mode attribute to `truncf` (PR #86152)
Victor Perez
llvmlistbot at llvm.org
Thu Mar 21 09:50:26 PDT 2024
https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/86152
>From e05f6cb36140af2bece7f408795d9b5a50af3332 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:38:19 +0000
Subject: [PATCH 1/2] [MLIR][Arith] Add rounding mode attribute to `truncf`
Add rounding mode attribute to `arith`. This attribute can be used in
different FP `arith` operations to control rounding mode. Rounding
modes correspond to IEEE 754-specified rounding modes.
As this is not supported in other dialects, conversion should fail for
now in case this attribute is present.
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
.../mlir/Dialect/Arith/IR/ArithBase.td | 25 ++++++++++++++
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 21 ++++++++++--
.../Dialect/Arith/IR/ArithOpsInterfaces.td | 33 +++++++++++++++++++
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 3 ++
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 28 +++++++++++++++-
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 +++++
.../Dialect/Arith/Transforms/ExpandOps.cpp | 6 ++++
mlir/test/Dialect/Arith/ops.mlir | 10 ++++++
8 files changed, 131 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c43c880b0..a9d976d9e4e28c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
let assemblyFormat = "`<` $value `>`";
}
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven // Round to nearest, ties to even
+ : I32EnumAttrCase<"tonearesteven", 0>;
+def Arith_RDownward // Round toward -inf
+ : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward // Round toward +inf
+ : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero // Round toward 0
+ : I32EnumAttrCase<"towardzero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+ : I32EnumAttrCase<"tonearestaway", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+ "RoundingMode", "Floating point rounding mode",
+ [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+ Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+ let cppNamespace = "::mlir::arith";
+}
+
#endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d0395d1f..ead19c69a0831c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
// TruncFOp
//===----------------------------------------------------------------------===//
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+ Arith_Op<"truncf",
+ [Pure, SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+ DeclareOpInterfaceMethods<CastOpInterface>]>,
+ Arguments<(ins FloatLike:$in,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ Results<(outs FloatLike:$out)> {
let summary = "cast from floating-point to narrower floating-point";
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
- If the value cannot be exactly represented, it is rounded using the default
- rounding mode. When operating on vectors, casts elementwise.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+ When operating on vectors, casts elementwise.
}];
+ let builders = [
+ OpBuilder<(ins "Type":$out, "Value":$in), [{
+ $_state.addOperands(in);
+ $_state.addTypes(out);
+ }]>
+ ];
let hasFolder = 1;
let hasVerifier = 1;
+ let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c32ef205..82d6c9ad6b03da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
];
}
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+ let description = [{
+ Access to op rounding mode.
+ }];
+
+ let cppNamespace = "::mlir::arith";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a RoundingModeAttr attribute for the operation",
+ /*returnType=*/ "RoundingModeAttr",
+ /*methodName=*/ "getRoundingModeAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getRoundingmodeAttr();
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the RoundingModeAttr attribute for
+ the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getRoundingModeAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "roundingmode";
+ }]
+ >
+ ];
+}
+
#endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13ae362e92..0113a3df0b8e3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
}
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+ // Only supporting default rounding mode as of now.
+ if (op.getRoundingmodeAttr())
+ return failure();
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a75c5b3e..9d1961486303a2 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
namespace {
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
+/// will lower to; \tparam AttrConvert is the attribute conversion to convert
+/// the rounding mode attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+ template <typename, typename> typename AttrConvert =
+ AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+ using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::VectorConvertToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+ return failure();
+ return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+ AttrConvert>::matchAndRewrite(op, adaptor,
+ rewriter);
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
- VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+ ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+ false>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using UIToFPOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..843d9c0afaadc1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
} else {
rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
adaptor.getOperands());
+ if (auto roundingModeOp =
+ dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+ if (arith::RoundingModeAttr roundingMode =
+ roundingModeOp.getRoundingModeAttr()) {
+ // TODO: Perform rounding mode attribute conversion and attach to new
+ // operation when defined in the dialect.
+ return failure();
+ }
+ }
}
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7f246daf99ff3c..b45be8b6bd8a4c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,6 +261,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}
+ if (op.getRoundingmodeAttr()) {
+ return rewriter.notifyMatchFailure(
+ op, "only applicable to default rounding mode.");
+ }
+
+ Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573e324b5f..e4b23f073117e6 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
return %0 : vector<[8]xbf16>
}
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+ %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+ %1 = arith.truncf %arg0 downward : f64 to f32
+ %2 = arith.truncf %arg0 upward : f64 to f32
+ %3 = arith.truncf %arg0 towardzero : f64 to f32
+ %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+ return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
// CHECK-LABEL: test_uitofp
func.func @test_uitofp(%arg0 : i32) -> f32 {
%0 = arith.uitofp %arg0 : i32 to f32
>From 485c36096bb7cb8a44607bcfd78ad5bc4bb822a9 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:49:59 +0000
Subject: [PATCH 2/2] Drop unused
---
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index b45be8b6bd8a4c..f72be9754d8869 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -266,7 +266,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
op, "only applicable to default rounding mode.");
}
- Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
More information about the Mlir-commits
mailing list