[Mlir-commits] [mlir] [MLIR][Arith] Add rounding mode attribute to `truncf` (PR #86152)

Victor Perez llvmlistbot at llvm.org
Thu Mar 21 09:50:26 PDT 2024


https://github.com/victor-eds updated https://github.com/llvm/llvm-project/pull/86152

>From e05f6cb36140af2bece7f408795d9b5a50af3332 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:38:19 +0000
Subject: [PATCH 1/2] [MLIR][Arith] Add rounding mode attribute to `truncf`

Add rounding mode attribute to `arith`. This attribute can be used in
different FP `arith` operations to control rounding mode. Rounding
modes correspond to IEEE 754-specified rounding modes.

As this is not supported in other dialects, conversion should fail for
now in case this attribute is present.

Signed-off-by: Victor Perez <victor.perez at codeplay.com>
---
 .../mlir/Dialect/Arith/IR/ArithBase.td        | 25 ++++++++++++++
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 21 ++++++++++--
 .../Dialect/Arith/IR/ArithOpsInterfaces.td    | 33 +++++++++++++++++++
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  3 ++
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 28 +++++++++++++++-
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  |  9 +++++
 .../Dialect/Arith/Transforms/ExpandOps.cpp    |  6 ++++
 mlir/test/Dialect/Arith/ops.mlir              | 10 ++++++
 8 files changed, 131 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
index c8a42c43c880b0..a9d976d9e4e28c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td
@@ -156,4 +156,29 @@ def Arith_IntegerOverflowAttr :
   let assemblyFormat = "`<` $value `>`";
 }
 
+//===----------------------------------------------------------------------===//
+// Arith_RoundingMode
+//===----------------------------------------------------------------------===//
+
+// These correspond to LLVM's values defined in:
+// llvm/include/llvm/ADT/FloatingPointMode.h
+
+def Arith_RToNearestTiesToEven       // Round to nearest, ties to even
+    : I32EnumAttrCase<"tonearesteven", 0>;
+def Arith_RDownward                  // Round toward -inf
+    : I32EnumAttrCase<"downward", 1>;
+def Arith_RUpward                    // Round toward +inf
+    : I32EnumAttrCase<"upward", 2>;
+def Arith_RTowardZero                // Round toward 0
+    : I32EnumAttrCase<"towardzero", 3>;
+def Arith_RToNearestTiesAwayFromZero // Round to nearest, ties away from zero
+    : I32EnumAttrCase<"tonearestaway", 4>;
+
+def Arith_RoundingModeAttr : I32EnumAttr<
+    "RoundingMode", "Floating point rounding mode",
+    [Arith_RToNearestTiesToEven, Arith_RDownward, Arith_RUpward,
+     Arith_RTowardZero, Arith_RToNearestTiesAwayFromZero]> {
+  let cppNamespace = "::mlir::arith";
+}
+
 #endif // ARITH_BASE
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index c9df50d0395d1f..ead19c69a0831c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1227,17 +1227,32 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
 // TruncFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_TruncFOp : Arith_FToFCastOp<"truncf"> {
+def Arith_TruncFOp :
+    Arith_Op<"truncf",
+      [Pure, SameOperandsAndResultShape,
+       DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
+       DeclareOpInterfaceMethods<CastOpInterface>]>,
+    Arguments<(ins FloatLike:$in,
+                   OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+    Results<(outs FloatLike:$out)> {
   let summary = "cast from floating-point to narrower floating-point";
   let description = [{
     Truncate a floating-point value to a smaller floating-point-typed value.
     The destination type must be strictly narrower than the source type.
-    If the value cannot be exactly represented, it is rounded using the default
-    rounding mode. When operating on vectors, casts elementwise.
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+    When operating on vectors, casts elementwise.
   }];
+  let builders = [
+    OpBuilder<(ins "Type":$out, "Value":$in), [{
+      $_state.addOperands(in);
+      $_state.addTypes(out);
+    }]>
+  ];
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
index 73a5d9c32ef205..82d6c9ad6b03da 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
@@ -106,4 +106,37 @@ def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsI
   ];
 }
 
+def ArithRoundingModeInterface : OpInterface<"ArithRoundingModeInterface"> {
+  let description = [{
+    Access to op rounding mode.
+  }];
+
+  let cppNamespace = "::mlir::arith";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a RoundingModeAttr attribute for the operation",
+      /*returnType=*/  "RoundingModeAttr",
+      /*methodName=*/  "getRoundingModeAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRoundingmodeAttr();
+      }]
+    >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the RoundingModeAttr attribute for
+                         the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getRoundingModeAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "roundingmode";
+      }]
+    >
+  ];
+}
+
 #endif // ARITH_OPS_INTERFACES
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index b51a13ae362e92..0113a3df0b8e3d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -175,6 +175,9 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
 }
 
 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
+  // Only supporting default rounding mode as of now.
+  if (op.getRoundingmodeAttr())
+    return failure();
   Type outType = op.getOut().getType();
   if (auto outVecType = outType.dyn_cast<VectorType>()) {
     if (outVecType.isScalable())
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 1f01f4a75c5b3e..9d1961486303a2 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -28,6 +28,31 @@ using namespace mlir;
 
 namespace {
 
+/// Operations whose conversion will depend on whether they are passed a
+/// rounding mode attribute or not.
+///
+/// \tparam SourceOp is the source operation; \tparam TargetOp, the operation it
+/// will lower to; \tparam AttrConvert is the attribute conversion to convert
+/// the rounding mode attribute.
+template <typename SourceOp, typename TargetOp, bool Constrained,
+          template <typename, typename> typename AttrConvert =
+              AttrConvertPassThrough>
+struct ConstrainedVectorConvertToLLVMPattern
+    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
+  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                   AttrConvert>::VectorConvertToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+      return failure();
+    return VectorConvertToLLVMPattern<SourceOp, TargetOp,
+                                      AttrConvert>::matchAndRewrite(op, adaptor,
+                                                                    rewriter);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -112,7 +137,8 @@ using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
-    VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
+    ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
+                                          false>;
 using TruncIOpLowering =
     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
 using UIToFPOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f396..843d9c0afaadc1 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -805,6 +805,15 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
     } else {
       rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
                                                     adaptor.getOperands());
+      if (auto roundingModeOp =
+              dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+        if (arith::RoundingModeAttr roundingMode =
+                roundingModeOp.getRoundingModeAttr()) {
+          // TODO: Perform rounding mode attribute conversion and attach to new
+          // operation when defined in the dialect.
+          return failure();
+        }
+      }
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7f246daf99ff3c..b45be8b6bd8a4c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,6 +261,12 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
       return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
     }
 
+    if (op.getRoundingmodeAttr()) {
+      return rewriter.notifyMatchFailure(
+          op, "only applicable to default rounding mode.");
+    }
+
+    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index e499573e324b5f..e4b23f073117e6 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -703,6 +703,16 @@ func.func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf
   return %0 : vector<[8]xbf16>
 }
 
+// CHECK-LABEL: test_truncf_rounding_mode
+func.func @test_truncf_rounding_mode(%arg0 : f64) -> (f32, f32, f32, f32, f32) {
+  %0 = arith.truncf %arg0 tonearesteven : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+  %3 = arith.truncf %arg0 towardzero : f64 to f32
+  %4 = arith.truncf %arg0 tonearestaway : f64 to f32
+  return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
+}
+
 // CHECK-LABEL: test_uitofp
 func.func @test_uitofp(%arg0 : i32) -> f32 {
   %0 = arith.uitofp %arg0 : i32 to f32

>From 485c36096bb7cb8a44607bcfd78ad5bc4bb822a9 Mon Sep 17 00:00:00 2001
From: Victor Perez <victor.perez at codeplay.com>
Date: Thu, 21 Mar 2024 16:49:59 +0000
Subject: [PATCH 2/2] Drop unused

---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index b45be8b6bd8a4c..f72be9754d8869 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -266,7 +266,6 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
           op, "only applicable to default rounding mode.");
     }
 
-    Type i1Ty = b.getI1Type();
     Type i16Ty = b.getI16Type();
     Type i32Ty = b.getI32Type();
     Type f32Ty = b.getF32Type();



More information about the Mlir-commits mailing list