[Mlir-commits] [mlir] [mlir][spirv] Add some op decorations (PR #72809)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 19 14:55:27 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

NoSignedWrap, NoUnsignedWrap, FPFastMathMode.

---
Full diff: https://github.com/llvm/llvm-project/pull/72809.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+24) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+12-1) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+20-3) 
- (modified) mlir/test/Target/SPIRV/decorations.mlir (+21) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..8eaf2a98a58560e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4742,4 +4742,28 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
+def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
+def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
+def SPIRV_FPFMM_NSZ          : I32BitEnumAttrCaseBit<"NSZ", 2>;
+def SPIRV_FPFMM_AllowRecip   : I32BitEnumAttrCaseBit<"AllowRecip", 3>;
+def SPIRV_FPFMM_Fast         : I32BitEnumAttrCaseBit<"Fast", 4>;
+def SPIRV_FPFMM_AllowContractFastINTEL : I32BitEnumAttrCaseBit<"AllowContractFastINTEL", 16> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_FPFastMathModeINTEL]>
+  ];
+}
+def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 17> {
+  list<Availability> availability = [
+    Capability<[SPIRV_C_FPFastMathModeINTEL]>
+  ];
+}
+
+def SPIRV_FPFastMathModeAttr :
+    SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
+      SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
+      SPIRV_FPFMM_AllowRecip, SPIRV_FPFMM_Fast, SPIRV_FPFMM_AllowContractFastINTEL,
+      SPIRV_FPFMM_AllowReassocINTEL
+    ]>;
+
 #endif // MLIR_DIALECT_SPIRV_IR_BASE
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index ce8b3ab3894606c..89e2e7ad52fa7d1 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -242,6 +242,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
   auto symbol = opBuilder.getStringAttr(attrName);
   switch (static_cast<spirv::Decoration>(words[1])) {
+  case spirv::Decoration::FPFastMathMode:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
+                                        static_cast<FPFastMathMode>(words[2])));
+    break;
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Binding:
     if (words.size() != 3) {
@@ -295,8 +304,10 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
-  case spirv::Decoration::Restrict:
+  case spirv::Decoration::NoSignedWrap:
+  case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
+  case spirv::Decoration::Restrict:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 22fcc4939317be9..9e9a16456cc1022 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -206,11 +206,19 @@ void Serializer::processMemoryModel() {
   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
 }
 
+static std::string getDecorationName(StringRef attrName) {
+  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
+  // expected FPFastMathMode.
+  if (attrName == "fp_fast_math_mode")
+    return "FPFastMathMode";
+
+  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+}
+
 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
                                             NamedAttribute attr) {
   auto attrName = attr.getName().strref();
-  auto decorationName =
-      llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
+  auto decorationName = getDecorationName(attrName);
   auto decoration = spirv::symbolizeDecoration(decorationName);
   if (!decoration) {
     return emitError(
@@ -232,6 +240,13 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
     args.push_back(static_cast<uint32_t>(linkageType));
     break;
   }
+  case spirv::Decoration::FPFastMathMode:
+    if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr.getValue())) {
+      args.push_back(static_cast<uint32_t>(intAttr.getValue()));
+      break;
+    }
+    return emitError(loc, "expected FPFastMathModeAttr attribute for ")
+           << attrName;
   case spirv::Decoration::Binding:
   case spirv::Decoration::DescriptorSet:
   case spirv::Decoration::Location:
@@ -256,8 +271,10 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
   case spirv::Decoration::NonReadable:
   case spirv::Decoration::NonWritable:
   case spirv::Decoration::NoPerspective:
-  case spirv::Decoration::Restrict:
+  case spirv::Decoration::NoSignedWrap:
+  case spirv::Decoration::NoUnsignedWrap:
   case spirv::Decoration::RelaxedPrecision:
+  case spirv::Decoration::Restrict:
     // For unit attributes, the args list has no values so we do nothing
     if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
       break;
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index aadf64c340b3445..04cb059f931863d 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -55,6 +55,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK: relaxed_precision
   spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
 }
+
 // -----
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
@@ -66,3 +67,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     >
   } : !spirv.ptr<f32, Private>
 }
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: i32) -> i32 "None" {
+  // CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
+  %0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
+  spirv.ReturnValue %0 : i32
+}
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.func @iadd_scalar(%arg: f32) -> f32 "None" {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
+  %0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
+  spirv.ReturnValue %0 : f32
+}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/72809


More information about the Mlir-commits mailing list