[Mlir-commits] [mlir] [mlir][spirv] Fix LowerABIAttributesPass to generate EntryPoints for SPV1.4 (PR #118994)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 16 08:14:51 PST 2024


https://github.com/fabrizio-indirli updated https://github.com/llvm/llvm-project/pull/118994

>From 533609e824fabafeaa7f647966152c6dcd4deb41 Mon Sep 17 00:00:00 2001
From: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
Date: Fri, 6 Dec 2024 15:31:02 +0000
Subject: [PATCH] [mlir][spirv] Fix LowerABIAttributesPass to generate
 EntryPoints for SPV1.4
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

- Extend the SPIRV::LowerABIAttributesPass to detect when the target env is
  using SPIR-V ver >= 1.4, and in this case add all the functions' interface storage
  variables to the spirv.EntryPoint calls, as required by the spec of OpEntryPoint:
  "Before version 1.4, the interface’s storage classes are limited to the Input and
   Output storage classes. Starting with version 1.4, the interface’s storage classes
   are all storage classes used in declaring all global variables referenced by
   the entry point’s call tree."
- Fix: generate the replacement ops (spirv.AddressOf and .AccessChain) in the
  order in which the associated variable appears in the function signature

Signed-off-by: Fabrizio Indirli <Fabrizio.Indirli at arm.com>
---
 .../Transforms/LowerABIAttributesPass.cpp     | 32 ++++++++++-------
 .../SPIRV/Transforms/abi-interface.mlir       | 29 ++++++++++++++-
 .../SPIRV/Transforms/abi-load-store.mlir      | 36 +++++++++----------
 3 files changed, 65 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 2024a2e5279ffc..15ddcf58a3d70d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -85,6 +85,9 @@ getInterfaceVariables(spirv::FuncOp funcOp,
   if (!module) {
     return failure();
   }
+  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
+  spirv::TargetEnv targetEnv(targetEnvAttr);
+
   SetVector<Operation *> interfaceVarSet;
 
   // TODO: This should in reality traverse the entry function
@@ -93,18 +96,18 @@ getInterfaceVariables(spirv::FuncOp funcOp,
   funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
     auto var =
         module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
-    // TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
+    // Per SPIR-V spec: "Before version 1.4, the interface's
     // storage classes are limited to the Input and Output storage classes.
-    // Starting with version 1.4, the interface’s storage classes are all
+    // Starting with version 1.4, the interface's storage classes are all
     // storage classes used in declaring all global variables referenced by the
-    // entry point’s call tree." We should consider the target environment here.
-    switch (cast<spirv::PointerType>(var.getType()).getStorageClass()) {
-    case spirv::StorageClass::Input:
-    case spirv::StorageClass::Output:
+    // entry point’s call tree."
+    const spirv::StorageClass storageClass =
+        cast<spirv::PointerType>(var.getType()).getStorageClass();
+    if ((targetEnvAttr && targetEnv.getVersion() >= spirv::Version::V_1_4) ||
+        (llvm::is_contained(
+            {spirv::StorageClass::Input, spirv::StorageClass::Output},
+            storageClass))) {
       interfaceVarSet.insert(var.getOperation());
-      break;
-    default:
-      break;
     }
   });
   for (auto &var : interfaceVarSet) {
@@ -124,6 +127,9 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
     return failure();
   }
 
+  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
+  spirv::TargetEnv targetEnv(targetEnvAttr);
+
   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
   auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
   builder.setInsertionPointToEnd(spirvModule.getBody());
@@ -135,8 +141,6 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
     return failure();
   }
 
-  spirv::TargetEnvAttr targetEnvAttr = spirv::lookupTargetEnv(funcOp);
-  spirv::TargetEnv targetEnv(targetEnvAttr);
   FailureOr<spirv::ExecutionModel> executionModel =
       spirv::getExecutionModel(targetEnvAttr);
   if (failed(executionModel))
@@ -234,6 +238,10 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
   auto indexType = typeConverter.getIndexType();
 
   auto attrName = spirv::getInterfaceVarABIAttrName();
+
+  OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
+  rewriter.setInsertionPointToStart(&funcOp.front());
+
   for (const auto &argType :
        llvm::enumerate(funcOp.getFunctionType().getInputs())) {
     auto abiInfo = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
@@ -250,8 +258,6 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
     if (!var)
       return failure();
 
-    OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
-    rewriter.setInsertionPointToStart(&funcOp.front());
     // Insert spirv::AddressOf and spirv::AccessChain operations.
     Value replacement =
         rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index 77e92da3504c62..bd51a07843652d 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -19,11 +19,11 @@ spirv.module Logical GLSL450 {
     %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
            {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
   attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
-    // CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
     // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
     // CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
     // CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
     // CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
+    // CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
     // CHECK: spirv.Return
     spirv.Return
   }
@@ -39,3 +39,30 @@ module {
 // expected-error at +1 {{'spirv.module' op missing SPIR-V target env attribute}}
 spirv.module Logical GLSL450 {}
 } // end module
+
+// -----
+
+// CHECK-LABEL: spirv.module
+// Test case with SPIRV version 1.4: all the interface's storage variables are passed to OpEntryPoint
+spirv.module Logical GLSL450 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>} {
+  //  CHECK-DAG:    spirv.GlobalVariable [[VAR0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>
+  //  CHECK-DAG:    spirv.GlobalVariable [[VAR1:@.*]] bind(0, 1) : !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32, stride=4> [0])>, StorageBuffer>
+  //      CHECK:    spirv.func [[FN:@.*]]()
+  // CHECK-SAME:      #spirv.entry_point_abi<subgroup_size = 64>
+  spirv.func @kernel(
+    %arg0: f32
+           {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0), StorageBuffer>},
+    %arg1: !spirv.ptr<!spirv.struct<(!spirv.array<12 x f32>)>, StorageBuffer>
+           {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) "None"
+  attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1], subgroup_size = 64>} {
+    // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
+    // CHECK: [[CONST0:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG0PTR:%.*]] = spirv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+    // CHECK: [[ARG0:%.*]] = spirv.Load "StorageBuffer" [[ARG0PTR]]
+    // CHECK: [[ARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
+    // CHECK: spirv.Return
+    spirv.Return
+  }
+  // CHECK: spirv.EntryPoint "GLCompute" [[FN]], [[VAR0]], [[VAR1]]
+  // CHECK: spirv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
+} // end spirv.module
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
index 4fdb6799c97fae..54e08ff3430075 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
@@ -39,28 +39,28 @@ spirv.module Logical GLSL450 {
     %arg6: i32
     {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 6), StorageBuffer>}) "None"
   attributes  {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
-    // CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
-    // CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
-    // CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
-    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]]
-    // CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
-    // CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
-    // CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
-    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
-    // CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
-    // CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
-    // CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
-    // CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
+    // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
+    // CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
+    // CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
+    // CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
+    // CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
+    // CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
     // CHECK: [[ADDRESSARG3:%.*]] = spirv.mlir.addressof [[VAR3]]
     // CHECK: [[CONST3:%.*]] = spirv.Constant 0 : i32
     // CHECK: [[ARG3PTR:%.*]] = spirv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
     // CHECK: [[ARG3:%.*]] = spirv.Load "StorageBuffer" [[ARG3PTR]]
-    // CHECK: [[ADDRESSARG2:%.*]] = spirv.mlir.addressof [[VAR2]]
-    // CHECK: [[ARG2:%.*]] = spirv.Bitcast [[ADDRESSARG2]]
-    // CHECK: [[ADDRESSARG1:%.*]] = spirv.mlir.addressof [[VAR1]]
-    // CHECK: [[ARG1:%.*]] = spirv.Bitcast [[ADDRESSARG1]]
-    // CHECK: [[ADDRESSARG0:%.*]] = spirv.mlir.addressof [[VAR0]]
-    // CHECK: [[ARG0:%.*]] = spirv.Bitcast [[ADDRESSARG0]]
+    // CHECK: [[ADDRESSARG4:%.*]] = spirv.mlir.addressof [[VAR4]]
+    // CHECK: [[CONST4:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG4PTR:%.*]] = spirv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
+    // CHECK: [[ARG4:%.*]] = spirv.Load "StorageBuffer" [[ARG4PTR]]
+    // CHECK: [[ADDRESSARG5:%.*]] = spirv.mlir.addressof [[VAR5]]
+    // CHECK: [[CONST5:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG5PTR:%.*]] = spirv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
+    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG5PTR]]
+    // CHECK: [[ADDRESSARG6:%.*]] = spirv.mlir.addressof [[VAR6]]
+    // CHECK: [[CONST6:%.*]] = spirv.Constant 0 : i32
+    // CHECK: [[ARG6PTR:%.*]] = spirv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
+    // CHECK: {{%.*}} = spirv.Load "StorageBuffer" [[ARG6PTR]] 
     %0 = spirv.mlir.addressof @__builtin_var_WorkgroupId__ : !spirv.ptr<vector<3xi32>, Input>
     %1 = spirv.Load "Input" %0 : vector<3xi32>
     %2 = spirv.CompositeExtract %1[0 : i32] : vector<3xi32>



More information about the Mlir-commits mailing list