[Mlir-commits] [mlir] 93eda08 - [mlir][spirv] Support `gpu` in `convert-to-spirv` pass (#105010)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Aug 20 10:17:20 PDT 2024


Author: Angel Zhang
Date: 2024-08-20T13:17:17-04:00
New Revision: 93eda08babe95188ee41400035abaade79cda7d1

URL: https://github.com/llvm/llvm-project/commit/93eda08babe95188ee41400035abaade79cda7d1
DIFF: https://github.com/llvm/llvm-project/commit/93eda08babe95188ee41400035abaade79cda7d1.diff

LOG: [mlir][spirv] Support `gpu` in `convert-to-spirv` pass (#105010)

This PR adds conversion patterns for GPU to the `convert-to-spirv` pass,
introduced in #95942. Now the pass is able to convert each `gpu.module`
and its ops within a `builtin.module` into a `spirv.module`.

**Future Plans**
- Use `gpu.launch_func` to invoke kernel from host functions
- Potentially integrate into the `mlir-vulkan-runner` for e2e testing

Added: 
    mlir/test/Conversion/ConvertToSPIRV/argmax-kernel.mlir
    mlir/test/Conversion/ConvertToSPIRV/gpu.mlir

Modified: 
    mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
    mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index dde561e9dbf4dc..863ef9603da385 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
   MLIRArithToSPIRV
   MLIRArithTransforms
   MLIRFuncToSPIRV
+  MLIRGPUToSPIRV
   MLIRIndexToSPIRV
   MLIRIR
   MLIRMemRefToSPIRV

diff  --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index fbf80a8b510dff..9e57b923ea6894 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
@@ -79,6 +80,7 @@ struct ConvertToSPIRVPass final
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
     populateFuncToSPIRVPatterns(typeConverter, patterns);
+    populateGPUToSPIRVPatterns(typeConverter, patterns);
     index::populateIndexToSPIRVPatterns(typeConverter, patterns);
     populateMemRefToSPIRVPatterns(typeConverter, patterns);
     populateVectorToSPIRVPatterns(typeConverter, patterns);

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/argmax-kernel.mlir b/mlir/test/Conversion/ConvertToSPIRV/argmax-kernel.mlir
new file mode 100644
index 00000000000000..5cd1fead2527b1
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/argmax-kernel.mlir
@@ -0,0 +1,99 @@
+// RUN: mlir-opt -convert-to-spirv -cse %s | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Groups, GroupNonUniformArithmetic, GroupNonUniformBallot], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+  // CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
+  // CHECK-DAG: spirv.GlobalVariable @[[$LOCALINVOCATIONIDVAR:.*]] built_in("LocalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
+  // CHECK-LABEL: spirv.func @argmax
+  // CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+  // CHECK-SAME: %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>
+  gpu.module @kernels {
+    gpu.func @argmax(%input : memref<4xf32>, %output : memref<i32>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} {
+      // CHECK: %[[C0:.*]] = spirv.Constant 0 : i32
+      // CHECK: %[[C1:.*]] = spirv.Constant 1 : i32
+      // CHECK: %[[C32:.*]] = spirv.Constant 32 : i32
+      // CHECK: %[[ADDRESSLOCALINVOCATIONID:.*]] = spirv.mlir.addressof @[[$LOCALINVOCATIONIDVAR]]
+      // CHECK: %[[LOCALINVOCATIONID:.*]] = spirv.Load "Input" %[[ADDRESSLOCALINVOCATIONID]]
+      // CHECK: %[[LOCALINVOCATIONIDX:.*]] = spirv.CompositeExtract %[[LOCALINVOCATIONID]]{{\[}}0 : i32{{\]}}
+      // CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[C0]], %[[LOCALINVOCATIONIDX]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+      // CHECK: %[[LOAD0:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32
+      // CHECK: %[[FUNC0:.*]] = spirv.Variable : !spirv.ptr<i32, Function>
+      // CHECK: %[[FUNC1:.*]] = spirv.Variable : !spirv.ptr<f32, Function>
+      %cst_0_idx = arith.constant 0 : index
+      %cst_1_i32 = arith.constant 1 : i32
+      %cst_1_idx = arith.constant 1 : index
+      %cst_32 = arith.constant 32 : i32
+      %num_batches = arith.divui %cst_1_i32, %cst_32 : i32
+      %tx = gpu.thread_id x
+      %tx_i32 = index.castu %tx : index to i32
+      %ub = index.castu %num_batches : i32 to index
+      %lane_res_init = arith.constant 0 : i32
+      %lane_max_init = memref.load %input[%tx] : memref<4xf32>
+
+      // CHECK: spirv.mlir.loop {
+      // CHECK:   spirv.Branch ^[[HEADER:.*]](%[[C1]], %[[C0]], %[[LOAD0]] : i32, i32, f32)
+      // CHECK: ^[[HEADER]](%[[INDVAR0:.*]]: i32, %[[INDVAR1:.*]]: i32, %[[INDVAR2:.*]]: f32):
+      // CHECK:   %[[SLT:.*]] = spirv.SLessThan %[[INDVAR0]], %[[C0]] : i32
+      // CHECK:   spirv.BranchConditional %[[SLT]], ^[[BODY:.*]], ^[[MERGE:.*]]
+      // CHECK: ^[[BODY]]:
+      // CHECK:   %[[MUL:.*]] = spirv.IMul %[[INDVAR0]], %[[C32]] : i32
+      // CHECK:   %[[ADD:.*]] = spirv.IAdd %[[MUL]], %[[LOCALINVOCATIONIDX]] : i32
+      // CHECK:   %[[AC1:.*]] = spirv.AccessChain %[[ARG0]][%[[C0]], %[[ADD]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+      // CHECK:   %[[LOAD1:.*]] = spirv.Load "StorageBuffer" %[[AC1]] : f32
+      // CHECK:   %[[OGT:.*]] = spirv.FOrdGreaterThan %[[LOAD1]], %[[INDVAR2]] : f32
+      // CHECK:   %[[SELECT0:.*]] = spirv.Select %[[OGT]], %[[ADD]], %[[INDVAR1]] : i1, i32
+      // CHECK:   %[[SELECT1:.*]] = spirv.Select %[[OGT]], %[[LOAD1]], %[[INDVAR2]] : i1, f32
+      // CHECK:   spirv.Store "Function" %[[FUNC0]], %[[SELECT0]] : i32
+      // CHECK:   spirv.Store "Function" %[[FUNC1]], %[[SELECT1]] : f32
+      // CHECK:   %[[ADD1:.*]] = spirv.IAdd %[[INDVAR0]], %[[C1]] : i32
+      // CHECK:   spirv.Branch ^[[HEADER]](%[[ADD1]], %[[SELECT0]], %[[SELECT1]] : i32, i32, f32)
+      // CHECK: ^[[MERGE]]:
+      // CHECK:   spirv.mlir.merge
+      // CHECK: }
+      // CHECK-DAG: %[[LANE_RES:.*]] = spirv.Load "Function" %[[FUNC0]] : i32
+      // CHECK-DAG: %[[LANE_MAX:.*]] = spirv.Load "Function" %[[FUNC1]] : f32
+      %lane_res, %lane_max = scf.for %iter = %cst_1_idx to %ub step %cst_1_idx
+      iter_args(%lane_res_iter = %lane_res_init, %lane_max_iter = %lane_max_init) -> (i32, f32) {
+        %iter_i32 = index.castu %iter : index to i32
+        %mul = arith.muli %cst_32, %iter_i32 : i32
+        %idx_i32 = arith.addi %mul, %tx_i32 : i32
+        %idx = index.castu %idx_i32 : i32 to index
+        %elem = memref.load %input[%idx] : memref<4xf32>
+        %gt = arith.cmpf ogt, %elem, %lane_max_iter : f32
+        %lane_res_next = arith.select %gt, %idx_i32, %lane_res_iter : i32
+        %lane_max_next = arith.select %gt, %elem, %lane_max_iter : f32
+        scf.yield %lane_res_next, %lane_max_next : i32, f32
+      }
+
+      // CHECK: %[[SUBGROUP_MAX:.*]] = spirv.GroupNonUniformFMax "Subgroup" "Reduce" %[[LANE_MAX]] : f32
+      // CHECK: %[[OEQ:.*]] = spirv.FOrdEqual %[[LANE_MAX]], %[[SUBGROUP_MAX]] : f32
+      // CHECK: %[[BALLOT:.*]] = spirv.GroupNonUniformBallot <Subgroup> %[[OEQ]] : vector<4xi32>
+      // CHECK: %[[BALLOTLSB:.*]] = spirv.GroupNonUniformBallotFindLSB <Subgroup> %[[BALLOT]] : vector<4xi32>, i32
+      // CHECK: %[[EQ:.*]] = spirv.IEqual %[[LOCALINVOCATIONIDX]], %[[C1]] : i32
+      %subgroup_max = gpu.subgroup_reduce maximumf %lane_max : (f32) -> (f32)
+      %eq = arith.cmpf oeq, %lane_max, %subgroup_max : f32
+      %ballot = spirv.GroupNonUniformBallot <Subgroup> %eq : vector<4xi32>
+      %lsb = spirv.GroupNonUniformBallotFindLSB <Subgroup> %ballot : vector<4xi32>, i32
+      %cond = arith.cmpi eq, %cst_1_i32, %tx_i32 : i32
+
+      // CHECK: spirv.mlir.selection {
+      // CHECK:   spirv.BranchConditional %[[EQ]], ^[[TRUE:.*]], ^[[FALSE:.*]]
+      // CHECK: ^[[TRUE]]:
+      // CHECK:   %[[AC2:.*]] = spirv.AccessChain %[[ARG1]][%[[C0]], %[[C0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+      // CHECK:   spirv.Store "StorageBuffer" %[[AC2]], %[[LANE_RES]] : i32
+      // CHECK:   spirv.Branch ^[[FALSE]]
+      // CHECK: ^[[FALSE]]:
+      // CHECK:   spirv.mlir.merge
+      // CHECK: }
+      scf.if %cond {
+        memref.store %lane_res, %output[] : memref<i32>
+      }
+
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}

diff  --git a/mlir/test/Conversion/ConvertToSPIRV/gpu.mlir b/mlir/test/Conversion/ConvertToSPIRV/gpu.mlir
new file mode 100644
index 00000000000000..f33a66bdf5effc
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/gpu.mlir
@@ -0,0 +1,85 @@
+// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL: spirv.func @all_reduce
+  // CHECK-SAME: (%[[ARG0:.*]]: f32)
+  // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd "Workgroup" "Reduce" %[[ARG0]] : f32
+  gpu.func @all_reduce(%arg0 : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %reduced = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
+} {
+
+gpu.module @kernels {
+  // CHECK-LABEL: spirv.func @subgroup_reduce
+  // CHECK-SAME: (%[[ARG0:.*]]: f32)
+  // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd "Subgroup" "Reduce" %[[ARG0]] : f32
+  gpu.func @subgroup_reduce(%arg0 : f32) kernel
+    attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+    %reduced = gpu.subgroup_reduce add %arg0 {} : (f32) -> (f32)
+    gpu.return
+  }
+}
+
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+  // CHECK-LABEL: spirv.module @{{.*}} Logical GLSL450
+  // CHECK-LABEL: spirv.func @load_store
+  // CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+  // CHECK-SAME: %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+  // CHECK-SAME: %[[ARG2:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<48 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>}
+  gpu.module @kernels {
+    gpu.func @load_store(%arg0: memref<12x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1: memref<12x4xf32, #spirv.storage_class<StorageBuffer>>, %arg2: memref<12x4xf32, #spirv.storage_class<StorageBuffer>>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
+      // CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]
+      // CHECK-NEXT: spirv.Load "StorageBuffer" %[[PTR1]]
+      // CHECK: %[[PTR2:.*]] = spirv.AccessChain %[[ARG1]]
+      // CHECK-NEXT: spirv.Load "StorageBuffer" %[[PTR2]]
+      // CHECK: spirv.FAdd
+      // CHECK: %[[PTR3:.*]] = spirv.AccessChain %[[ARG2]]
+      // CHECK-NEXT: spirv.Store "StorageBuffer" %[[PTR3]]
+      %0 = gpu.block_id x
+      %1 = gpu.block_id y
+      %2 = gpu.block_id z
+      %3 = gpu.thread_id x
+      %4 = gpu.thread_id y
+      %5 = gpu.thread_id z
+      %6 = gpu.grid_dim x
+      %7 = gpu.grid_dim y
+      %8 = gpu.grid_dim z
+      %9 = gpu.block_dim x
+      %10 = gpu.block_dim y
+      %11 = gpu.block_dim z
+      %12 = arith.addi %arg3, %0 : index
+      %13 = arith.addi %arg4, %3 : index
+      %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
+      %15 = memref.load %arg1[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
+      %16 = arith.addf %14, %15 : f32
+      memref.store %16, %arg2[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
+      gpu.return
+    }
+  }
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4b95221e8587da..de069daf603f1e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8401,6 +8401,7 @@ cc_library(
         ":ArithTransforms",
         ":ConversionPassIncGen",
         ":FuncToSPIRV",
+        ":GPUToSPIRV",
         ":IR",
         ":IndexToSPIRV",
         ":MemRefToSPIRV",


        


More information about the Mlir-commits mailing list