[Mlir-commits] [mlir] [mlir][arith][spirv] Convert arith.truncf rounding mode to SPIR-V (PR #101547)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 1 12:06:23 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrea Faulds (andfau-amd)

<details>
<summary>Changes</summary>

Resolves #<!-- -->87050.

---
Full diff: https://github.com/llvm/llvm-project/pull/101547.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+13) 
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+29-5) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+9) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+10) 
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+15) 
- (modified) mlir/test/Target/SPIRV/decorations.mlir (+10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6ec97e17c5dcc..b38978272c5bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
   ];
 }
 
+def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
+def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
+def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
+def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
+
+// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
+//       FPRoundingMode on a value stored to certain storage classes?
+//       (The OpenCL environment also has FPRoundingMode rules, but different.)
+def SPIRV_FPRoundingModeAttr :
+    SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
+      SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
+    ]>;
+
 def SPIRV_FunctionControlAttr :
     SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
       SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 4c3237b24b786..f2b9a18f60eca 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -807,6 +807,25 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
 // TypeCastingOp
 //===----------------------------------------------------------------------===//
 
+static std::optional<spirv::FPRoundingMode>
+convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
+  switch (roundingMode) {
+  case arith::RoundingMode::downward:
+    return spirv::FPRoundingMode::RTN;
+  case arith::RoundingMode::to_nearest_even:
+    return spirv::FPRoundingMode::RTE;
+  case arith::RoundingMode::toward_zero:
+    return spirv::FPRoundingMode::RTZ;
+  case arith::RoundingMode::upward:
+    return spirv::FPRoundingMode::RTP;
+  case arith::RoundingMode::to_nearest_away:
+    // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
+    // (as of SPIR-V 1.6)
+    return {};
+  }
+  llvm_unreachable("Unhandled rounding mode");
+}
+
 /// Converts type-casting standard operations to SPIR-V operations.
 template <typename Op, typename SPIRVOp>
 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -829,15 +848,20 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
       // Then we can just erase this operation by forwarding its operand.
       rewriter.replaceOp(op, adaptor.getOperands().front());
     } else {
-      rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
-                                                    adaptor.getOperands());
+      auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+          op, dstType, adaptor.getOperands());
       if (auto roundingModeOp =
               dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
         if (arith::RoundingModeAttr roundingMode =
                 roundingModeOp.getRoundingModeAttr()) {
-          // TODO: Perform rounding mode attribute conversion and attach to new
-          // operation when defined in the dialect.
-          return failure();
+          if (auto rm =
+                  convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
+            newOp->setAttr(
+                getDecorationString(spirv::Decoration::FPRoundingMode),
+                spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+          } else {
+            return failure(); // unsupported rounding mode
+          }
         }
       }
     }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d7a308548cf4d..12980879b20ab 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
         symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
                                         static_cast<FPFastMathMode>(words[2])));
     break;
+  case spirv::Decoration::FPRoundingMode:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
+                                        static_cast<FPRoundingMode>(words[2])));
+    break;
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Binding:
     if (words.size() != 3) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4c4fef177317e..714a3edfb5657 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
   // expected FPFastMathMode.
   if (attrName == "fp_fast_math_mode")
     return "FPFastMathMode";
+  // similar here
+  if (attrName == "fp_rounding_mode")
+    return "FPRoundingMode";
 
   return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
 }
@@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
     }
     return emitError(loc, "expected FPFastMathModeAttr attribute for ")
            << stringifyDecoration(decoration);
+  case spirv::Decoration::FPRoundingMode:
+    if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
+      args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+      break;
+    }
+    return emitError(loc, "expected FPRoundingModeAttr attribute for ")
+           << stringifyDecoration(decoration);
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index beb2c8d2d242c..4c5b7664bb1aa 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -754,6 +754,21 @@ func.func @fptrunc2(%arg0: f32) -> f16 {
   return %0 : f16
 }
 
+
+// CHECK-LABEL: @experimental_constrained_fptrunc
+func.func @experimental_constrained_fptrunc(%arg0 : f64) {
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTE>} : f64 to f32
+  %0 = arith.truncf %arg0 to_nearest_even : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f64 to f32
+  %1 = arith.truncf %arg0 downward : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTP>} : f64 to f32
+  %2 = arith.truncf %arg0 upward : f64 to f32
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTZ>} : f64 to f32
+  %3 = arith.truncf %arg0 toward_zero : f64 to f32
+  return
+}
+
+
 // CHECK-LABEL: @sitofp1
 func.func @sitofp1(%arg0 : i32) -> f32 {
   // CHECK: spirv.ConvertSToF %{{.*}} : i32 to f32
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 195773735431e..0a29290b6a6fa 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
   spirv.ReturnValue %0 : f32
 }
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
+  // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+  %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+  spirv.ReturnValue %0 : f16
+}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/101547


More information about the Mlir-commits mailing list