[Mlir-commits] [mlir] [mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V (PR #101547)

Andrea Faulds llvmlistbot at llvm.org
Fri Aug 2 05:40:08 PDT 2024


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/101547

>From b6c61487097a52220c7734b0eab6223287cd73e0 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Fri, 2 Aug 2024 13:57:02 +0200
Subject: [PATCH] [mlir][arith][spirv] Convert arith.truncf rounding mode to
 SPIR-V

Resolves #87050.
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 36 ++++++++++++++++---
 .../arith-to-spirv-unsupported.mlir           | 17 +++++++++
 .../ArithToSPIRV/arith-to-spirv.mlir          | 19 ++++++++--
 3 files changed, 65 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 4c3237b24b786..e6c01f063e8b8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
 // TypeCastingOp
 //===----------------------------------------------------------------------===//
 
+static std::optional<spirv::FPRoundingMode>
+convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
+  switch (roundingMode) {
+  case arith::RoundingMode::downward:
+    return spirv::FPRoundingMode::RTN;
+  case arith::RoundingMode::to_nearest_even:
+    return spirv::FPRoundingMode::RTE;
+  case arith::RoundingMode::toward_zero:
+    return spirv::FPRoundingMode::RTZ;
+  case arith::RoundingMode::upward:
+    return spirv::FPRoundingMode::RTP;
+  case arith::RoundingMode::to_nearest_away:
+    // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
+    // (as of SPIR-V 1.6)
+    return std::nullopt;
+  }
+  llvm_unreachable("Unhandled rounding mode");
+}
+
 /// Converts type-casting standard operations to SPIR-V operations.
 template <typename Op, typename SPIRVOp>
 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -829,15 +848,22 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
       // Then we can just erase this operation by forwarding its operand.
       rewriter.replaceOp(op, adaptor.getOperands().front());
     } else {
-      rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
-                                                    adaptor.getOperands());
+      auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+          op, dstType, adaptor.getOperands());
       if (auto roundingModeOp =
               dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
         if (arith::RoundingModeAttr roundingMode =
                 roundingModeOp.getRoundingModeAttr()) {
-          // TODO: Perform rounding mode attribute conversion and attach to new
-          // operation when defined in the dialect.
-          return failure();
+          if (auto rm =
+                  convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
+            newOp->setAttr(
+                getDecorationString(spirv::Decoration::FPRoundingMode),
+                spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+          } else {
+            return rewriter.notifyMatchFailure(
+                op->getLoc(),
+                llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
+          }
         }
       }
     }
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 2512254b443db..24a0bab352c34 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -1,5 +1,22 @@
 // RUN: mlir-opt -split-input-file -convert-arith-to-spirv -verify-diagnostics %s
 
+///===----------------------------------------------------------------------===//
+// Cast ops
+//===----------------------------------------------------------------------===//
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Float16, Kernel], []>, #spirv.resource_limits<>>
+} {
+
+func.func @experimental_constrained_fptrunc(%arg0 : f32) {
+  // expected-error at +1 {{failed to legalize operation 'arith.truncf'}}
+  %3 = arith.truncf %arg0 to_nearest_away : f32 to f16
+  return
+}
+
+} // end module
+
 ///===----------------------------------------------------------------------===//
 // Binary ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index beb2c8d2d242c..1abe0fd2ec468 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -221,7 +221,7 @@ func.func @one_elem_vector(%arg0: vector<1xi32>) {
 // -----
 
 //===----------------------------------------------------------------------===//
-// std bit ops
+// Bit ops
 //===----------------------------------------------------------------------===//
 
 module attributes {
@@ -653,7 +653,7 @@ func.func @corner_cases() {
 // -----
 
 //===----------------------------------------------------------------------===//
-// std cast ops
+// Cast ops
 //===----------------------------------------------------------------------===//
 
 module attributes {
@@ -754,6 +754,21 @@ func.func @fptrunc2(%arg0: f32) -> f16 {
   return %0 : f16
 }
 
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTE>} : f64 to f32
+  %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTP>} : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTZ>} : f64 to f32
+  %3 = arith.truncf %arg0 toward_zero : f64 to f32
+  return
+}
+
+
 // CHECK-LABEL: @sitofp1
 func.func @sitofp1(%arg0 : i32) -> f32 {
   // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32



More information about the Mlir-commits mailing list