[Mlir-commits] [mlir] 8854b73 - [mlir][spirv] Convert memref.alloca to spv.Variable
Lei Zhang
llvmlistbot at llvm.org
Thu Apr 28 05:17:40 PDT 2022
Author: Lei Zhang
Date: 2022-04-28T08:13:40-04:00
New Revision: 8854b736065c228270000df552bdd9dc7b152453
URL: https://github.com/llvm/llvm-project/commit/8854b736065c228270000df552bdd9dc7b152453
DIFF: https://github.com/llvm/llvm-project/commit/8854b736065c228270000df552bdd9dc7b152453.diff
LOG: [mlir][spirv] Convert memref.alloca to spv.Variable
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D124542
Added:
mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7fb0f025e3826..a947d182865ad 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "llvm/Support/Debug.h"
@@ -85,15 +86,27 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
offset);
}
-/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
-static bool isAllocationSupported(MemRefType t) {
- // Currently only support workgroup local memory allocations with static
- // shape and int or float or vector of int or float element type.
- if (!(t.hasStaticShape() &&
- SPIRVTypeConverter::getMemorySpaceForStorageClass(
- spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
+/// Returns true if the allocations of memref `type` generated from `allocOp`
+/// can be lowered to SPIR-V.
+static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
+ if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
+ if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
+ spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
+ return false;
+ } else if (isa<memref::AllocaOp>(allocOp)) {
+ if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
+ spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
+ return false;
+ } else {
return false;
- Type elementType = t.getElementType();
+ }
+
+ // Currently only support static shape and int or float or vector of int or
+ // float element type.
+ if (!type.hasStaticShape())
+ return false;
+
+ Type elementType = type.getElementType();
if (auto vecType = elementType.dyn_cast<VectorType>())
elementType = vecType.getElementType();
return elementType.isIntOrFloat();
@@ -102,10 +115,10 @@ static bool isAllocationSupported(MemRefType t) {
/// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns None on failure.
-static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
+static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(
- t.getMemorySpaceAsInt());
+ type.getMemorySpaceAsInt());
if (!storageClass)
return {};
switch (*storageClass) {
@@ -149,6 +162,16 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
namespace {
+/// Converts memref.alloca to SPIR-V Function variables.
+class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
+public:
+ using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spv.module scope since it wil
@@ -215,6 +238,25 @@ class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
} // namespace
+//===----------------------------------------------------------------------===//
+// AllocaOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ MemRefType allocType = allocaOp.getType();
+ if (!isAllocationSupported(allocaOp, allocType))
+ return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
+
+ // Get the SPIR-V type for the allocation.
+ Type spirvType = getTypeConverter()->convertType(allocType);
+ rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
+ spirv::StorageClass::Function,
+ /*initializer=*/nullptr);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//
@@ -223,8 +265,8 @@ LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType allocType = operation.getType();
- if (!isAllocationSupported(allocType))
- return operation.emitError("unhandled allocation type");
+ if (!isAllocationSupported(operation, allocType))
+ return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
@@ -262,8 +304,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
- if (!isAllocationSupported(deallocType))
- return operation.emitError("unhandled deallocation type");
+ if (!isAllocationSupported(operation, deallocType))
+ return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
rewriter.eraseOp(operation);
return success();
}
@@ -505,8 +547,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
- IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
+ IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
+ typeConverter, patterns.getContext());
}
} // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index adaed0ea3a910..598e03fed55f9 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -100,10 +100,12 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
}
{
- func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) {
- // expected-error @+1 {{unhandled allocation type}}
+ // CHECK-LABEL: func @alloc_dynamic_size
+ func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
+ // CHECK: memref.alloc
%0 = memref.alloc(%arg0) : memref<4x?xf32, 3>
- return
+ %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3>
+ return %1: f32
}
}
@@ -114,10 +116,12 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
}
{
- func.func @alloc_dealloc_mem() {
- // expected-error @+1 {{unhandled allocation type}}
+ // CHECK-LABEL: func @alloc_unsupported_memory_space
+ func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
+ // CHECK: memref.alloc
%0 = memref.alloc() : memref<4x5xf32>
- return
+ %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
+ return %1: f32
}
}
@@ -129,8 +133,9 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
}
{
- func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) {
- // expected-error @+1 {{unhandled deallocation type}}
+ // CHECK-LABEL: func @dealloc_dynamic_size
+ func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) {
+ // CHECK: memref.dealloc
memref.dealloc %arg0 : memref<4x?xf32, 3>
return
}
@@ -143,8 +148,9 @@ module attributes {
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
}
{
- func.func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) {
- // expected-error @+1 {{unhandled deallocation type}}
+ // CHECK-LABEL: func @dealloc_unsupported_memory_space
+ func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) {
+ // CHECK: memref.dealloc
memref.dealloc %arg0 : memref<4x5xf32>
return
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
new file mode 100644
index 0000000000000..2aabeed2fd8c0
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ func.func @alloc_function_variable(%arg0 : index, %arg1 : index) {
+ %0 = memref.alloca() : memref<4x5xf32, 6>
+ %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6>
+ memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6>
+ return
+ }
+}
+
+// CHECK-LABEL: func @alloc_function_variable
+// CHECK: %[[VAR:.+]] = spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
+// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[VAR]]
+// CHECK: %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32
+// CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[VAR]]
+// CHECK: spv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32
+
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ func.func @two_allocs() {
+ %0 = memref.alloca() : memref<4x5xf32, 6>
+ %1 = memref.alloca() : memref<2x3xi32, 6>
+ return
+ }
+}
+
+// CHECK-LABEL: func @two_allocs
+// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4>)>, Function>
+// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ func.func @two_allocs_vector() {
+ %0 = memref.alloca() : memref<4xvector<4xf32>, 6>
+ %1 = memref.alloca() : memref<2xvector<2xi32>, 6>
+ return
+ }
+}
+
+// CHECK-LABEL: func @two_allocs_vector
+// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>, stride=8>)>, Function>
+// CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16>)>, Function>
+
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ // CHECK-LABEL: func @alloc_dynamic_size
+ func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
+ // CHECK: memref.alloca
+ %0 = memref.alloca(%arg0) : memref<4x?xf32, 6>
+ %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6>
+ return %1: f32
+ }
+}
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+ // CHECK-LABEL: func @alloc_unsupported_memory_space
+ func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
+ // CHECK: memref.alloca
+ %0 = memref.alloca() : memref<4x5xf32>
+ %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
+ return %1: f32
+ }
+}
More information about the Mlir-commits
mailing list