[Mlir-commits] [llvm] [mlir] [mlir][spirv] Support `memref` in `convert-to-spirv` pass (PR #102534)

Angel Zhang llvmlistbot at llvm.org
Fri Aug 9 06:20:21 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/102534

>From a90ccd1b9d969a706cbc3d164f5badac324bea46 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 7 Aug 2024 22:47:50 +0000
Subject: [PATCH 1/3] [mlir][spirv] Support MemRef in convert-to-spirv pass

---
 .../Conversion/ConvertToSPIRV/CMakeLists.txt  |  1 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 13 +++++++
 .../Conversion/ConvertToSPIRV/memref.mlir     | 36 +++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 4 files changed, 51 insertions(+)
 create mode 100644 mlir/test/Conversion/ConvertToSPIRV/memref.mlir

diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23fa..dde561e9dbf4dc 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
   MLIRFuncToSPIRV
   MLIRIndexToSPIRV
   MLIRIR
+  MLIRMemRefToSPIRV
   MLIRPass
   MLIRRewrite
   MLIRSCFToSPIRV
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94d..fbf80a8b510dff 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
@@ -62,12 +63,24 @@ struct ConvertToSPIRVPass final
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
+    // Map MemRef memory space to SPIR-V storage class.
+    spirv::TargetEnv targetEnv(targetAttr);
+    bool targetEnvSupportsKernelCapability =
+        targetEnv.allows(spirv::Capability::Kernel);
+    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+        targetEnvSupportsKernelCapability
+            ? spirv::mapMemorySpaceToOpenCLStorageClass
+            : spirv::mapMemorySpaceToVulkanStorageClass;
+    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+    spirv::convertMemRefTypesAndAttrs(op, converter);
+
     // Populate patterns for each dialect.
     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
     arith::populateArithToSPIRVPatterns(typeConverter, patterns);
     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
     populateFuncToSPIRVPatterns(typeConverter, patterns);
     index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+    populateMemRefToSPIRVPatterns(typeConverter, patterns);
     populateVectorToSPIRVPatterns(typeConverter, patterns);
     populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
     ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
new file mode 100644
index 00000000000000..338cb4c6feb934
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_store_float_rank_zero
+//  CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32
+//       CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : f32
+//       CHECK: spirv.Return
+func.func @load_store_float_rank_zero(%arg0: memref<f32>, %arg1: memref<f32>) {
+  %0 = memref.load %arg0[] : memref<f32>
+  memref.store %0, %arg1[] : memref<f32>
+  return
+}
+
+// CHECK-LABEL: @load_store_int_rank_one
+//  CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
+//       CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
+//       CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
+//       CHECK: spirv.Return
+func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2 : index) {
+  %0 = memref.load %arg0[%arg2] : memref<4xi32>
+  memref.store %0, %arg1[%arg2] : memref<4xi32>
+  return
+}
+
+} // end module
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 84938231140127..6373e53b16c975 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8316,6 +8316,7 @@ cc_library(
         ":FuncToSPIRV",
         ":IR",
         ":IndexToSPIRV",
+        ":MemRefToSPIRV",
         ":Pass",
         ":Rewrite",
         ":SCFToSPIRV",

>From 1c645ee65caa6ade3723578a78649b27dce23f7d Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Fri, 9 Aug 2024 08:35:20 -0400
Subject: [PATCH 2/3] Update mlir/test/Conversion/ConvertToSPIRV/memref.mlir

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/test/Conversion/ConvertToSPIRV/memref.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
index 338cb4c6feb934..8d670b5ee4a93a 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse %s | FileCheck %s
 
 module attributes {
   spirv.target_env = #spirv.target_env<

>From 25416d54606280652f6641b5cd65bee03c6ea0a7 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Fri, 9 Aug 2024 13:18:45 +0000
Subject: [PATCH 3/3] Add tests for larger memref and vector

---
 .../Conversion/ConvertToSPIRV/memref.mlir     | 29 +++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
index 8d670b5ee4a93a..5af8bfc842ea13 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -33,4 +33,33 @@ func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %
   return
 }
 
+// CHECK-LABEL: @load_store_larger_memref
+//  CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
+//       CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
+//       CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
+//       CHECK: spirv.Return
+func.func @load_store_larger_memref(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2 : index) {
+  %0 = memref.load %arg0[%arg2] : memref<8xi32>
+  memref.store %0, %arg1[%arg2] : memref<8xi32>
+  return
+}
+
+
+// CHECK-LABEL: @load_store_vector
+//  CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>
+//       CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+//       CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>, i32, i32
+//       CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : vector<4xi32>
+//       CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>, i32, i32
+//       CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : vector<4xi32>
+//       CHECK: spirv.Return
+func.func @load_store_vector(%arg0: memref<vector<4xi32>>, %arg1: memref<vector<4xi32>>) {
+  %0 = memref.load %arg0[] : memref<vector<4xi32>>
+  memref.store %0, %arg1[] : memref<vector<4xi32>>
+  return
+}
+
 } // end module



More information about the Mlir-commits mailing list