[Mlir-commits] [mlir] [mlir][spirv] Add definitions and (de)serialization for FPRoundingMode (PR #101546)
Andrea Faulds
llvmlistbot at llvm.org
Thu Aug 1 12:04:13 PDT 2024
https://github.com/andfau-amd created https://github.com/llvm/llvm-project/pull/101546
None
>From 47b94197a7241de785c4aa80ecac236819fa1740 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Thu, 1 Aug 2024 21:00:14 +0200
Subject: [PATCH] [mlir][spirv] Add definitions and (de)serialization for
FPRoundingMode
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 13 +++++++++++++
.../Target/SPIRV/Deserialization/Deserializer.cpp | 9 +++++++++
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp | 10 ++++++++++
mlir/test/Target/SPIRV/decorations.mlir | 10 ++++++++++
4 files changed, 42 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 6ec97e17c5dcc..b38978272c5bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
];
}
+def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
+def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
+def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
+def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
+
+// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
+// FPRoundingMode on a value stored to certain storage classes?
+// (The OpenCL environment also has FPRoundingMode rules, but different.)
+def SPIRV_FPRoundingModeAttr :
+ SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
+ SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
+ ]>;
+
def SPIRV_FunctionControlAttr :
SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d7a308548cf4d..12980879b20ab 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
static_cast<FPFastMathMode>(words[2])));
break;
+ case spirv::Decoration::FPRoundingMode:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
+ static_cast<FPRoundingMode>(words[2])));
+ break;
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Binding:
if (words.size() != 3) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4c4fef177317e..714a3edfb5657 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
// expected FPFastMathMode.
if (attrName == "fp_fast_math_mode")
return "FPFastMathMode";
+ // similar here
+ if (attrName == "fp_rounding_mode")
+ return "FPRoundingMode";
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
@@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
<< stringifyDecoration(decoration);
+ case spirv::Decoration::FPRoundingMode:
+ if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
+ args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+ break;
+ }
+ return emitError(loc, "expected FPRoundingModeAttr attribute for ")
+ << stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 195773735431e..0a29290b6a6fa 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
spirv.ReturnValue %0 : f32
}
}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
+ // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+ %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
+ spirv.ReturnValue %0 : f16
+}
+}
More information about the Mlir-commits
mailing list