[Mlir-commits] [llvm] [mlir] [mlir][spirv] Integrate `convert-to-spirv` into `mlir-vulkan-runner` (PR #106082)

Angel Zhang llvmlistbot at llvm.org
Mon Aug 26 13:05:19 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/106082

>From 410aab8c0e613e1f3dd4b13d8e61a2dc7255f3f4 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Sat, 24 Aug 2024 03:10:36 +0000
Subject: [PATCH 1/5] [mlir][spirv] Integrate convert-to-spirv into
 mlir-vulkan-runner

---
 mlir/include/mlir/Conversion/Passes.td        |  5 +-
 .../Conversion/ConvertToSPIRV/CMakeLists.txt  |  1 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 81 +++++++++++++------
 .../mlir-vulkan-runner/mlir-vulkan-runner.cpp |  5 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 5 files changed, 66 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7bde9e490e4f4e..244827ade66be6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -50,7 +50,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
     "Run function signature conversion to convert vector types">,
     Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
     /*default=*/"true",
-    "Run vector unrolling to convert vector types in function bodies">
+    "Run vector unrolling to convert vector types in function bodies">,
+    Option<"runOnGPUModules", "run-on-gpu-modules", "bool",
+    /*default=*/"false",
+    "Clone and convert only the GPU modules for integration testing">
   ];
 }
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index 863ef9603da385..124a4c453e75c5 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
   MLIRArithToSPIRV
   MLIRArithTransforms
   MLIRFuncToSPIRV
+  MLIRGPUDialect
   MLIRGPUToSPIRV
   MLIRIndexToSPIRV
   MLIRIR
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 9e57b923ea6894..f624999decac48 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -40,6 +41,35 @@ using namespace mlir;
 
 namespace {
 
+/// Map memRef memory space to SPIR-V storage class.
+void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
+  spirv::TargetEnv targetEnv(targetAttr);
+  bool targetEnvSupportsKernelCapability =
+      targetEnv.allows(spirv::Capability::Kernel);
+  spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+      targetEnvSupportsKernelCapability
+          ? spirv::mapMemorySpaceToOpenCLStorageClass
+          : spirv::mapMemorySpaceToVulkanStorageClass;
+  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+  spirv::convertMemRefTypesAndAttrs(op, converter);
+}
+
+/// Populate patterns for each dialect.
+void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                    ScfToSPIRVContext &scfToSPIRVContext,
+                                    RewritePatternSet &patterns) {
+  arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+  populateFuncToSPIRVPatterns(typeConverter, patterns);
+  populateGPUToSPIRVPatterns(typeConverter, patterns);
+  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+  populateMemRefToSPIRVPatterns(typeConverter, patterns);
+  populateVectorToSPIRVPatterns(typeConverter, patterns);
+  populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
+  ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+}
+
 /// A pass to perform the SPIR-V conversion.
 struct ConvertToSPIRVPass final
     : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
@@ -64,31 +94,32 @@ struct ConvertToSPIRVPass final
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
-    // Map MemRef memory space to SPIR-V storage class.
-    spirv::TargetEnv targetEnv(targetAttr);
-    bool targetEnvSupportsKernelCapability =
-        targetEnv.allows(spirv::Capability::Kernel);
-    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
-        targetEnvSupportsKernelCapability
-            ? spirv::mapMemorySpaceToOpenCLStorageClass
-            : spirv::mapMemorySpaceToVulkanStorageClass;
-    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-    spirv::convertMemRefTypesAndAttrs(op, converter);
-
-    // Populate patterns for each dialect.
-    arith::populateCeilFloorDivExpandOpsPatterns(patterns);
-    arith::populateArithToSPIRVPatterns(typeConverter, patterns);
-    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
-    populateFuncToSPIRVPatterns(typeConverter, patterns);
-    populateGPUToSPIRVPatterns(typeConverter, patterns);
-    index::populateIndexToSPIRVPatterns(typeConverter, patterns);
-    populateMemRefToSPIRVPatterns(typeConverter, patterns);
-    populateVectorToSPIRVPatterns(typeConverter, patterns);
-    populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
-    ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
-
-    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
-      return signalPassFailure();
+    if (runOnGPUModules) {
+      SmallVector<Operation *, 1> gpuModules;
+      OpBuilder builder(context);
+      op->walk([&](gpu::GPUModuleOp gpuModule) {
+        builder.setInsertionPoint(gpuModule);
+        gpuModules.push_back(builder.clone(*gpuModule));
+      });
+      // Run conversion for each module independently as they can have
+      // different TargetEnv attributes.
+      for (Operation *gpuModule : gpuModules) {
+        spirv::TargetEnvAttr targetAttr =
+            spirv::lookupTargetEnvOrDefault(gpuModule);
+        mapToMemRef(gpuModule, targetAttr);
+        populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
+                                       patterns);
+        if (failed(
+                applyFullConversion(gpuModule, *target, std::move(patterns))))
+          return signalPassFailure();
+      }
+    } else {
+      mapToMemRef(op, targetAttr);
+      populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
+                                     patterns);
+      if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+        return signalPassFailure();
+    }
   }
 };
 
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 032f5760361f4b..93ca922c57084f 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -12,6 +12,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
@@ -64,7 +65,9 @@ static LogicalResult runMLIRPasses(Operation *op,
   passManager.addPass(createGpuKernelOutliningPass());
   passManager.addPass(memref::createFoldMemRefAliasOpsPass());
 
-  passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
+  ConvertToSPIRVPassOptions convertToSPIRVOptions{};
+  convertToSPIRVOptions.runOnGPUModules = true;
+  passManager.addPass(createConvertToSPIRVPass(convertToSPIRVOptions));
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());
   modulePM.addPass(spirv::createSPIRVUpdateVCEPass());
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ddb08f12f04976..2b70df3c9430cf 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8409,6 +8409,7 @@ cc_library(
         ":ArithTransforms",
         ":ConversionPassIncGen",
         ":FuncToSPIRV",
+        ":GPUDialect",
         ":GPUToSPIRV",
         ":IR",
         ":IndexToSPIRV",

>From 2359648932140551dc0bdd83a6d721664207353c Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Mon, 26 Aug 2024 11:12:23 -0400
Subject: [PATCH 2/5] Update mlir/include/mlir/Conversion/Passes.td

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/include/mlir/Conversion/Passes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 244827ade66be6..e6190c525bd6a0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -53,7 +53,7 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
     "Run vector unrolling to convert vector types in function bodies">,
     Option<"runOnGPUModules", "run-on-gpu-modules", "bool",
     /*default=*/"false",
-    "Clone and convert only the GPU modules for integration testing">
+    "Clone and convert GPU modules">
   ];
 }
 

>From 67f979960c04815e98cda78dbc3c0a39964f020c Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 26 Aug 2024 15:32:06 +0000
Subject: [PATCH 3/5] Code refactoring and comments

---
 mlir/include/mlir/Conversion/Passes.td        |  2 +-
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 42 ++++++++++---------
 .../mlir-vulkan-runner/mlir-vulkan-runner.cpp |  2 +-
 3 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index e6190c525bd6a0..f43d8aa08aadde 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -51,7 +51,7 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
     Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
     /*default=*/"true",
     "Run vector unrolling to convert vector types in function bodies">,
-    Option<"runOnGPUModules", "run-on-gpu-modules", "bool",
+    Option<"convertGPUModules", "convert-gpu-modules", "bool",
     /*default=*/"false",
     "Clone and convert GPU modules">
   ];
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index f624999decac48..9d6885182c5bd9 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -94,31 +94,33 @@ struct ConvertToSPIRVPass final
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
-    if (runOnGPUModules) {
-      SmallVector<Operation *, 1> gpuModules;
-      OpBuilder builder(context);
-      op->walk([&](gpu::GPUModuleOp gpuModule) {
-        builder.setInsertionPoint(gpuModule);
-        gpuModules.push_back(builder.clone(*gpuModule));
-      });
-      // Run conversion for each module independently as they can have
-      // different TargetEnv attributes.
-      for (Operation *gpuModule : gpuModules) {
-        spirv::TargetEnvAttr targetAttr =
-            spirv::lookupTargetEnvOrDefault(gpuModule);
-        mapToMemRef(gpuModule, targetAttr);
-        populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
-                                       patterns);
-        if (failed(
-                applyFullConversion(gpuModule, *target, std::move(patterns))))
-          return signalPassFailure();
-      }
-    } else {
+    if (!convertGPUModules) {
       mapToMemRef(op, targetAttr);
       populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
                                      patterns);
       if (failed(applyPartialConversion(op, *target, std::move(patterns))))
         return signalPassFailure();
+      return;
+    }
+
+    // Clone each GPU kernel module for conversion, given that the GPU
+    // launch op still needs the original GPU kernel module.
+    SmallVector<Operation *, 1> gpuModules;
+    OpBuilder builder(context);
+    op->walk([&](gpu::GPUModuleOp gpuModule) {
+      builder.setInsertionPoint(gpuModule);
+      gpuModules.push_back(builder.clone(*gpuModule));
+    });
+    // Run conversion for each module independently as they can have
+    // different TargetEnv attributes.
+    for (Operation *gpuModule : gpuModules) {
+      spirv::TargetEnvAttr targetAttr =
+          spirv::lookupTargetEnvOrDefault(gpuModule);
+      mapToMemRef(gpuModule, targetAttr);
+      populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
+                                     patterns);
+      if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
+        return signalPassFailure();
     }
   }
 };
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 93ca922c57084f..2dd539ef83481f 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -66,7 +66,7 @@ static LogicalResult runMLIRPasses(Operation *op,
   passManager.addPass(memref::createFoldMemRefAliasOpsPass());
 
   ConvertToSPIRVPassOptions convertToSPIRVOptions{};
-  convertToSPIRVOptions.runOnGPUModules = true;
+  convertToSPIRVOptions.convertGPUModules = true;
   passManager.addPass(createConvertToSPIRVPass(convertToSPIRVOptions));
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());

>From ec0f3f314637417673a27afd8a9c47ab3b18f9be Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 26 Aug 2024 20:04:29 +0000
Subject: [PATCH 4/5] Fix bug

---
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 22 +++++++++++--------
 1 file changed, 13 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 9d6885182c5bd9..b04973e989fbfa 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -87,14 +87,14 @@ struct ConvertToSPIRVPass final
     if (runVectorUnrolling && failed(spirv::unrollVectorsInFuncBodies(op)))
       return signalPassFailure();
 
-    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
-    std::unique_ptr<ConversionTarget> target =
-        SPIRVConversionTarget::get(targetAttr);
-    SPIRVTypeConverter typeConverter(targetAttr);
-    RewritePatternSet patterns(context);
-    ScfToSPIRVContext scfToSPIRVContext;
-
+    // Generic conversion.
     if (!convertGPUModules) {
+      spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
+      std::unique_ptr<ConversionTarget> target =
+          SPIRVConversionTarget::get(targetAttr);
+      SPIRVTypeConverter typeConverter(targetAttr);
+      RewritePatternSet patterns(context);
+      ScfToSPIRVContext scfToSPIRVContext;
       mapToMemRef(op, targetAttr);
       populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
                                      patterns);
@@ -114,8 +114,12 @@ struct ConvertToSPIRVPass final
     // Run conversion for each module independently as they can have
     // different TargetEnv attributes.
     for (Operation *gpuModule : gpuModules) {
-      spirv::TargetEnvAttr targetAttr =
-          spirv::lookupTargetEnvOrDefault(gpuModule);
+      spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
+      std::unique_ptr<ConversionTarget> target =
+          SPIRVConversionTarget::get(targetAttr);
+      SPIRVTypeConverter typeConverter(targetAttr);
+      RewritePatternSet patterns(context);
+      ScfToSPIRVContext scfToSPIRVContext;
       mapToMemRef(gpuModule, targetAttr);
       populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
                                      patterns);

>From e1c87f6b975b5f948e618116d4f71e5638823a8e Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 26 Aug 2024 20:05:05 +0000
Subject: [PATCH 5/5] Add LIT tests

---
 .../ConvertToSPIRV/convert-gpu-modules.mlir   | 110 ++++++++++++++++++
 1 file changed, 110 insertions(+)
 create mode 100644 mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir

diff --git a/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir
new file mode 100644
index 00000000000000..1fde6c34418fc5
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/convert-gpu-modules.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt -convert-to-spirv="convert-gpu-modules=true run-signature-conversion=false run-vector-unrolling=false" -split-input-file %s | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: func.func @main
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  // CHECK:       gpu.launch_func  @[[$KERNELS_1:.*]]::@[[$BUILTIN_WG_ID_X:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
+  // CHECK:       gpu.launch_func  @[[$KERNELS_2:.*]]::@[[$BUILTIN_WG_ID_Y:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]])
+  func.func @main() {
+    %c1 = arith.constant 1 : index
+    gpu.launch_func @kernels_1::@builtin_workgroup_id_x
+        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+    gpu.launch_func @KERNELS_2::@builtin_workgroup_id_y
+        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+    return
+  }
+
+  // CHECK-LABEL:  spirv.module @{{.*}} Logical GLSL450
+  // CHECK:        spirv.func @[[$BUILTIN_WG_ID_X]]
+  // CHECK:        spirv.mlir.addressof
+  // CHECK:        spirv.Load "Input"
+  // CHECK:        spirv.CompositeExtract
+  gpu.module @kernels_1 {
+    gpu.func @builtin_workgroup_id_x() kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+      %0 = gpu.block_id x
+      gpu.return
+    }
+  }
+  // CHECK:  gpu.module @[[$KERNELS_1]]
+  // CHECK:  gpu.func @[[$BUILTIN_WG_ID_X]]
+  // CHECK   gpu.block_id x
+  // CHECK:  gpu.return
+
+  // CHECK-LABEL:  spirv.module @{{.*}} Logical GLSL450
+  // CHECK:        spirv.func @[[$BUILTIN_WG_ID_Y]]
+  // CHECK:        spirv.mlir.addressof
+  // CHECK:        spirv.Load "Input"
+  // CHECK:        spirv.CompositeExtract
+  gpu.module @KERNELS_2 {
+    gpu.func @builtin_workgroup_id_y() kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+      %0 = gpu.block_id y
+      gpu.return
+    }
+  }
+  // CHECK:  gpu.module @[[$KERNELS_2]]
+  // CHECK:  gpu.func @[[$BUILTIN_WG_ID_Y]]
+  // CHECK   gpu.block_id y
+  // CHECK:  gpu.return
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: func.func @main
+  // CHECK-SAME:  %[[ARG0:.*]]: memref<2xi32>, %[[ARG1:.*]]: memref<4xi32>
+  // CHECK:       %[[C1:.*]] = arith.constant 1 : index
+  // CHECK:       gpu.launch_func  @[[$KERNEL_MODULE:.*]]::@[[$KERNEL_FUNC:.*]] blocks in (%[[C1]], %[[C1]], %[[C1]]) threads in (%[[C1]], %[[C1]], %[[C1]]) args(%[[ARG0]] : memref<2xi32>, %[[ARG1]] : memref<4xi32>)
+  func.func @main(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>) {
+    %c1 = arith.constant 1 : index
+    gpu.launch_func @kernels::@kernel_foo
+        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+        args(%arg0 : memref<2xi32>, %arg2 : memref<4xi32>)
+    return
+  }
+
+  // CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
+  // CHECK:       spirv.func @[[$KERNEL_FUNC]]
+  // CHECK-SAME:  %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<2 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+  // CHECK-SAME:  %{{.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+  gpu.module @kernels {
+    gpu.func @kernel_foo(%arg0 : memref<2xi32>, %arg1 : memref<4xi32>)
+      kernel attributes { spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [1, 1, 1]>} {
+      // CHECK: spirv.Constant
+      // CHECK: spirv.Constant dense<0>
+      %idx0 = arith.constant 0 : index
+      %vec0 = arith.constant dense<[0, 0]> : vector<2xi32>
+      // CHECK: spirv.AccessChain
+      // CHECK: spirv.Load "StorageBuffer"
+      %val = memref.load %arg0[%idx0] : memref<2xi32>
+      // CHECK: spirv.CompositeInsert
+      %vec = vector.insertelement %val, %vec0[%idx0 : index] : vector<2xi32>
+      // CHECK: spirv.VectorShuffle
+      %shuffle = vector.shuffle %vec, %vec[3, 2, 1, 0] : vector<2xi32>, vector<2xi32>
+      // CHECK: spirv.CompositeExtract
+      %res = vector.extractelement %shuffle[%idx0 : index] : vector<4xi32>
+      // CHECK: spirv.AccessChain
+      // CHECK: spirv.Store "StorageBuffer"
+      memref.store %res, %arg1[%idx0]: memref<4xi32>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+  // CHECK:      gpu.module @[[$KERNEL_MODULE]]
+  // CHECK:      gpu.func @[[$KERNEL_FUNC]]
+  // CHECK-SAME: %{{.*}}: memref<2xi32>, %{{.*}}: memref<4xi32>
+  // CHECK:      arith.constant
+  // CHECK:      memref.load
+  // CHECK:      vector.insertelement
+  // CHECK:      vector.shuffle
+  // CHECK:      vector.extractelement
+  // CHECK:      memref.store
+  // CHECK:      gpu.return
+}



More information about the Mlir-commits mailing list