[llvm] [mlir] [mlir][spirv] Support `memref` in `convert-to-spirv` pass (PR #102534)

Angel Zhang via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 8 13:55:51 PDT 2024


https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/102534

This PR adds conversion patterns for MemRef to the `convert-to-spirv` pass, introduced in #95942. Conversions from MemRef memory space to SPIR-V storage class were also included, and would run before the final dialect conversion phase.

**Future Plans**
- Add tests for ops other than `memref.load` and `memref.store`

>From a8a301a5b0ba414b5a240643f35463a9597115c8 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 7 Aug 2024 22:47:50 +0000
Subject: [PATCH] [mlir][spirv] Support MemRef in convert-to-spirv pass

---
 .../Conversion/ConvertToSPIRV/CMakeLists.txt  |  1 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 13 +++++++
 .../Conversion/ConvertToSPIRV/memref.mlir     | 36 +++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 4 files changed, 51 insertions(+)
 create mode 100644 mlir/test/Conversion/ConvertToSPIRV/memref.mlir

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23fa..dde561e9dbf4dc 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
   MLIRFuncToSPIRV
   MLIRIndexToSPIRV
   MLIRIR
+  MLIRMemRefToSPIRV
   MLIRPass
   MLIRRewrite
   MLIRSCFToSPIRV
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94d..fbf80a8b510dff 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
@@ -62,12 +63,24 @@ struct ConvertToSPIRVPass final
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
+    // Map MemRef memory space to SPIR-V storage class.
+    spirv::TargetEnv targetEnv(targetAttr);
+    bool targetEnvSupportsKernelCapability =
+        targetEnv.allows(spirv::Capability::Kernel);
+    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+        targetEnvSupportsKernelCapability
+            ? spirv::mapMemorySpaceToOpenCLStorageClass
+            : spirv::mapMemorySpaceToVulkanStorageClass;
+    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+    spirv::convertMemRefTypesAndAttrs(op, converter);
+
     // Populate patterns for each dialect.
     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
     populateFuncToSPIRVPatterns(typeConverter, patterns);
     index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+    populateMemRefToSPIRVPatterns(typeConverter, patterns);
     populateVectorToSPIRVPatterns(typeConverter, patterns);
     populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
     ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
new file mode 100644
index 00000000000000..b4ca98400e2029
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_store_float_rank_zero
+//  CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32
+//       CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : f32
+//       CHECK: spirv.Return
+func.func @load_store_float_rank_zero(%arg0: memref<f32>, %arg1: memref<f32>) {
+  %0 = memref.load %arg0[] : memref<f32>
+  memref.store %0, %arg1[] : memref<f32>
+  return
+}
+
+// CHECK-LABEL: @load_store_int_rank_one
+//  CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
+//       CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
+//       CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
+//       CHECK: spirv.Return
+func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2 : index) {
+  %0 = memref.load %arg0[%arg2] : memref<4xi32>
+  memref.store %0, %arg1[%arg2] : memref<4xi32>
+  return
+}
+
+} // end module
\ No newline at end of file
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 84938231140127..6373e53b16c975 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8316,6 +8316,7 @@ cc_library(
         ":FuncToSPIRV",
         ":IR",
         ":IndexToSPIRV",
+        ":MemRefToSPIRV",
         ":Pass",
         ":Rewrite",
         ":SCFToSPIRV",



More information about the llvm-commits mailing list