[Mlir-commits] [llvm] [mlir] [mlir][spirv] Support `memref` in `convert-to-spirv` pass (PR #102534)
Angel Zhang
llvmlistbot at llvm.org
Fri Aug 9 06:20:21 PDT 2024
https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/102534
>From a90ccd1b9d969a706cbc3d164f5badac324bea46 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Wed, 7 Aug 2024 22:47:50 +0000
Subject: [PATCH 1/3] [mlir][spirv] Support MemRef in convert-to-spirv pass
---
.../Conversion/ConvertToSPIRV/CMakeLists.txt | 1 +
.../ConvertToSPIRV/ConvertToSPIRVPass.cpp | 13 +++++++
.../Conversion/ConvertToSPIRV/memref.mlir | 36 +++++++++++++++++++
.../llvm-project-overlay/mlir/BUILD.bazel | 1 +
4 files changed, 51 insertions(+)
create mode 100644 mlir/test/Conversion/ConvertToSPIRV/memref.mlir
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index c9d962d2de23fa..dde561e9dbf4dc 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
MLIRFuncToSPIRV
MLIRIndexToSPIRV
MLIRIR
+ MLIRMemRefToSPIRV
MLIRPass
MLIRRewrite
MLIRSCFToSPIRV
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 4694a147e1e94d..fbf80a8b510dff 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
@@ -62,12 +63,24 @@ struct ConvertToSPIRVPass final
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
+ // Map MemRef memory space to SPIR-V storage class.
+ spirv::TargetEnv targetEnv(targetAttr);
+ bool targetEnvSupportsKernelCapability =
+ targetEnv.allows(spirv::Capability::Kernel);
+ spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+ targetEnvSupportsKernelCapability
+ ? spirv::mapMemorySpaceToOpenCLStorageClass
+ : spirv::mapMemorySpaceToVulkanStorageClass;
+ spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+ spirv::convertMemRefTypesAndAttrs(op, converter);
+
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+ populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
new file mode 100644
index 00000000000000..338cb4c6feb934
--- /dev/null
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_store_float_rank_zero
+// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : f32
+// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : f32
+// CHECK: spirv.Return
+func.func @load_store_float_rank_zero(%arg0: memref<f32>, %arg1: memref<f32>) {
+ %0 = memref.load %arg0[] : memref<f32>
+ memref.store %0, %arg1[] : memref<f32>
+ return
+}
+
+// CHECK-LABEL: @load_store_int_rank_one
+// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
+// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
+// CHECK: spirv.Return
+func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %arg2 : index) {
+ %0 = memref.load %arg0[%arg2] : memref<4xi32>
+ memref.store %0, %arg1[%arg2] : memref<4xi32>
+ return
+}
+
+} // end module
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 84938231140127..6373e53b16c975 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8316,6 +8316,7 @@ cc_library(
":FuncToSPIRV",
":IR",
":IndexToSPIRV",
+ ":MemRefToSPIRV",
":Pass",
":Rewrite",
":SCFToSPIRV",
>From 1c645ee65caa6ade3723578a78649b27dce23f7d Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Fri, 9 Aug 2024 08:35:20 -0400
Subject: [PATCH 2/3] Update mlir/test/Conversion/ConvertToSPIRV/memref.mlir
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
mlir/test/Conversion/ConvertToSPIRV/memref.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
index 338cb4c6feb934..8d670b5ee4a93a 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-to-spirv="run-signature-conversion=false run-vector-unrolling=false" -cse %s | FileCheck %s
module attributes {
spirv.target_env = #spirv.target_env<
>From 25416d54606280652f6641b5cd65bee03c6ea0a7 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Fri, 9 Aug 2024 13:18:45 +0000
Subject: [PATCH 3/3] Add tests for larger memref and vector
---
.../Conversion/ConvertToSPIRV/memref.mlir | 29 +++++++++++++++++++
1 file changed, 29 insertions(+)
diff --git a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
index 8d670b5ee4a93a..5af8bfc842ea13 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/memref.mlir
@@ -33,4 +33,33 @@ func.func @load_store_int_rank_one(%arg0: memref<4xi32>, %arg1: memref<4xi32>, %
return
}
+// CHECK-LABEL: @load_store_larger_memref
+// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, %[[ARG2:.*]]: i32
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : i32
+// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[ARG2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<8 x i32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : i32
+// CHECK: spirv.Return
+func.func @load_store_larger_memref(%arg0: memref<8xi32>, %arg1: memref<8xi32>, %arg2 : index) {
+ %0 = memref.load %arg0[%arg2] : memref<8xi32>
+ memref.store %0, %arg1[%arg2] : memref<8xi32>
+ return
+}
+
+
+// CHECK-LABEL: @load_store_vector
+// CHECK-SAME: %[[ARG0:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>, %[[ARG1:.*]]: !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>
+// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
+// CHECK: %[[AC0:.*]] = spirv.AccessChain %[[ARG0]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[LOAD:.*]] = spirv.Load "StorageBuffer" %[[AC0]] : vector<4xi32>
+// CHECK: %[[AC1:.*]] = spirv.AccessChain %[[ARG1]][%[[CST0]], %[[CST0]]] : !spirv.ptr<!spirv.struct<(!spirv.array<1 x vector<4xi32>, stride=16> [0])>, StorageBuffer>, i32, i32
+// CHECK: spirv.Store "StorageBuffer" %[[AC1]], %[[LOAD]] : vector<4xi32>
+// CHECK: spirv.Return
+func.func @load_store_vector(%arg0: memref<vector<4xi32>>, %arg1: memref<vector<4xi32>>) {
+ %0 = memref.load %arg0[] : memref<vector<4xi32>>
+ memref.store %0, %arg1[] : memref<vector<4xi32>>
+ return
+}
+
} // end module
More information about the Mlir-commits
mailing list