[llvm] [mlir] [mlir][spirv] Support `memref` in `convert-to-spirv` pass (PR #102534)
Angel Zhang via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 8 13:55:51 PDT 2024
https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/102534
This PR adds conversion patterns for MemRef to the `convert-to-spirv` pass, introduced in #95942. Conversions from MemRef memory space to SPIR-V storage class were also included, and would run before the final dialect conversion phase.
**Future Plans**
- Add tests for ops other than `memref.load` and `memref.store`
>From a8a301a5b0ba414b5a240643f35463a9597115c8 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 7 Aug 2024 22:47:50 +0000
Subject: [PATCH] [mlir][spirv] Support MemRef in convert-to-spirv pass
---
.../Conversion/ConvertToSPIRV/CMakeLists.txt | 1 +
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 13 +++++++
.../Conversion/ConvertToSPIRV/memref.mlir | 36 +++++++++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
4 files changed, 51 insertions(+)
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/memref.mlir
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23fa..dde561e9dbf4dc 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRFuncToSPIRV
MLIRIndexToSPIRV
MLIRIR
+ MLIRMemRefToSPIRV
MLIRPass
MLIRRewrite
MLIRSCFToSPIRV
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94d..fbf80a8b510dff 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
@@ -62,12 +63,24 @@ struct ConvertToSPIRVPass final
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
+ // Map MemRef memory space to SPIR-V storage class.
+ spirv::TargetEnv targetEnv(targetAttr);
+ bool targetEnvSupportsKernelCapability =
+ targetEnv.allows(spirv::Capability::Kernel);
+ spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+ targetEnvSupportsKernelCapability
+ ? spirv::mapMemorySpaceToOpenCLStorageClass
+ : spirv::mapMemorySpaceToVulkanStorageClass;
+ spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+ spirv::convertMemRefTypesAndAttrs(op, converter);
+
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+ populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
new file mode 100644
index 00000000000000..b4ca98400e2029
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_store_float_rank_zero
+// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32
+// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : f32
+// CHECK: spirv.Return
+func.func @load_store_float_rank_zero(%arg0: memref<f32>, %arg1: memref<f32>) {
+ %0 = memref.load %arg0[] : memref<f32>
+ memref.store %0, %arg1[] : memref<f32>
+ return
+}
+
+// CHECK-LABEL: @load_store_int_rank_one
+// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
+// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
+// CHECK: spirv.Return
+func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2 : index) {
+ %0 = memref.load %arg0[%arg2] : memref<4xi32>
+ memref.store %0, %arg1[%arg2] : memref<4xi32>
+ return
+}
+
+} // end module
\ No newline at end of file
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 84938231140127..6373e53b16c975 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8316,6 +8316,7 @@ cc_library(
":FuncToSPIRV",
":IR",
":IndexToSPIRV",
+ ":MemRefToSPIRV",
":Pass",
":Rewrite",
":SCFToSPIRV",
More information about the llvm-commits
mailing list