[Mlir-commits] [mlir] [mlir][spirv] Add target width to SPIR-V ABI (PR #88555)

Hsiangkai Wang llvmlistbot at llvm.org
Fri Apr 12 11:43:43 PDT 2024


https://github.com/Hsiangkai created https://github.com/llvm/llvm-project/pull/88555

There are execution modes need target width as their extra operands. SignedZeroInfNanPreserve is one of them. This patch adds `target width` as one of SPIR-V ABI attributes.

>From c2be7ee825404ba3f7536fa4680d98e7e86d476f Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 12 Apr 2024 15:01:20 +0100
Subject: [PATCH] [mlir][spirv] Add target width to SPIR-V ABI

There are execution modes need target width as their extra operands.
SignedZeroInfNanPreserve is one of them. This patch adds `target width`
as one of SPIR-V ABI attributes.
---
 .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.td  |  3 ++-
 .../mlir/Dialect/SPIRV/IR/TargetAndABI.h      |  3 ++-
 mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp    | 11 +++++-----
 .../Transforms/LowerABIAttributesPass.cpp     | 20 ++++++++++++++++---
 .../Conversion/GPUToSPIRV/entry-point.mlir    |  4 ++++
 .../lib/Dialect/SPIRV/TestEntryPointAbi.cpp   | 10 +++++++++-
 6 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index 74d36445e31138..89c2c4b447f926 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -35,7 +35,8 @@ class SPIRV_Attr<string attrName, string attrMnemonic>
 def SPIRV_EntryPointABIAttr : SPIRV_Attr<"EntryPointABI", "entry_point_abi"> {
   let parameters = (ins
     OptionalParameter<"DenseI32ArrayAttr">:$workgroup_size,
-    OptionalParameter<"std::optional<int>">:$subgroup_size
+    OptionalParameter<"std::optional<int>">:$subgroup_size,
+    OptionalParameter<"std::optional<int>">:$target_width
   );
   let assemblyFormat = "`<` struct(params) `>`";
 }
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
index c35a8c26c2bc9b..d651549574b989 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/TargetAndABI.h
@@ -89,7 +89,8 @@ StringRef getEntryPointABIAttrName();
 /// Gets the EntryPointABIAttr given its fields.
 EntryPointABIAttr getEntryPointABIAttr(MLIRContext *context,
                                        ArrayRef<int32_t> workgroupSize = {},
-                                       std::optional<int> subgroupSize = {});
+                                       std::optional<int> subgroupSize = {},
+                                       std::optional<int> targetWidth = {});
 
 /// Queries the entry point ABI on the nearest function-like op containing the
 /// given `op`. Returns null attribute if not found.
diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 5b7c0a59ba4200..bbc318e17300ad 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -120,17 +120,16 @@ bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
 
 StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
 
-spirv::EntryPointABIAttr
-spirv::getEntryPointABIAttr(MLIRContext *context,
-                            ArrayRef<int32_t> workgroupSize,
-                            std::optional<int> subgroupSize) {
+spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
+    MLIRContext *context, ArrayRef<int32_t> workgroupSize,
+    std::optional<int> subgroupSize, std::optional<int> targetWidth) {
   DenseI32ArrayAttr workgroupSizeAttr;
   if (!workgroupSize.empty()) {
     assert(workgroupSize.size() == 3);
     workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
   }
-  return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr,
-                                       subgroupSize);
+  return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
+                                       targetWidth);
 }
 
 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 6150b5ee17851d..2024a2e5279ffc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
       // Erase workgroup size.
       entryPointAttr = spirv::EntryPointABIAttr::get(
           entryPointAttr.getContext(), DenseI32ArrayAttr(),
-          entryPointAttr.getSubgroupSize());
+          entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
     }
   }
   if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
@@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
       // Erase subgroup size.
       entryPointAttr = spirv::EntryPointABIAttr::get(
           entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
-          std::nullopt);
+          std::nullopt, entryPointAttr.getTargetWidth());
     }
   }
-  if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize())
+  if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
+    std::optional<ArrayRef<spirv::Capability>> caps =
+        spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
+    if (!caps || targetEnv.allows(*caps)) {
+      builder.create<spirv::ExecutionModeOp>(
+          funcOp.getLoc(), funcOp,
+          spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
+      // Erase target width.
+      entryPointAttr = spirv::EntryPointABIAttr::get(
+          entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
+          entryPointAttr.getSubgroupSize(), std::nullopt);
+    }
+  }
+  if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
+      entryPointAttr.getTargetWidth())
     funcOp->setAttr(entryPointAttrName, entryPointAttr);
   else
     funcOp->removeAttr(entryPointAttrName);
diff --git a/mlir/test/Conversion/GPUToSPIRV/entry-point.mlir b/mlir/test/Conversion/GPUToSPIRV/entry-point.mlir
index 99369d11a4ba39..ff2c546faf66e8 100644
--- a/mlir/test/Conversion/GPUToSPIRV/entry-point.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/entry-point.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt -test-spirv-entry-point-abi %s | FileCheck %s -check-prefix=DEFAULT
 // RUN: mlir-opt -test-spirv-entry-point-abi="workgroup-size=32" %s | FileCheck %s -check-prefix=WG32
+// RUN: mlir-opt -test-spirv-entry-point-abi="target-width=32" %s | FileCheck %s -check-prefix=TW32
 
 //      DEFAULT: gpu.func @foo()
 // DEFAULT-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>
@@ -7,6 +8,9 @@
 //      WG32: gpu.func @foo()
 // WG32-SAME:  spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>
 
+//      TW32: gpu.func @foo()
+// TW32-SAME:  spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1], target_width = 32>
+
 gpu.module @kernels {
   gpu.func @foo() kernel {
     gpu.return
diff --git a/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp b/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
index 47b06e4531152d..b16dd86cc306b3 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestEntryPointAbi.cpp
@@ -45,6 +45,11 @@ struct TestSpirvEntryPointABIPass
           "Workgroup size to use for all gpu.func kernels in the module, "
           "specified with x-dimension first, y-dimension next and z-dimension "
           "last. Unspecified dimensions will be set to 1")};
+  Pass::Option<int> targetWidth{
+      *this, "target-width",
+      llvm::cl::desc(
+          "Specify the component width of floating-point instructions"),
+      llvm::cl::init(0)};
 };
 } // namespace
 
@@ -60,7 +65,10 @@ void TestSpirvEntryPointABIPass::runOnOperation() {
                                              workgroupSize.end());
     workgroupSizeVec.resize(3, 1);
     gpuFunc->setAttr(attrName,
-                     spirv::getEntryPointABIAttr(context, workgroupSizeVec));
+                     spirv::getEntryPointABIAttr(
+                         context, workgroupSizeVec, {},
+                         (targetWidth == 0) ? std::nullopt
+                                            : std::optional<int>(targetWidth)));
   }
 }
 



More information about the Mlir-commits mailing list