[Mlir-commits] [mlir] 370a7ea - [mlir][spirv] Handle zero-sized memrefs
Jakub Kuderski
llvmlistbot at llvm.org
Fri Aug 4 22:10:57 PDT 2023
Author: Jakub Kuderski
Date: 2023-08-05T01:10:15-04:00
New Revision: 370a7eae352abb9646e8f86aae02930c38135e23
URL: https://github.com/llvm/llvm-project/commit/370a7eae352abb9646e8f86aae02930c38135e23
DIFF: https://github.com/llvm/llvm-project/commit/370a7eae352abb9646e8f86aae02930c38135e23.diff
LOG: [mlir][spirv] Handle zero-sized memrefs
Make sure to check type conversion results. Add missing tests.
Fix some typos in the surrounding code.
Fixes: https://github.com/llvm/llvm-project/issues/64409
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D157166
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 16bedc6b9858e1..1d85e64bdfbfc3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -304,6 +304,9 @@ AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
+ if (!spirvType)
+ return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
+
rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
spirv::StorageClass::Function,
/*initializer=*/nullptr);
@@ -323,6 +326,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
// Get the SPIR-V type for the allocation.
Type spirvType = getTypeConverter()->convertType(allocType);
+ if (!spirvType)
+ return rewriter.notifyMatchFailure(operation, "type conversion failed");
// Insert spirv.GlobalVariable for this allocation.
Operation *parent =
@@ -467,7 +472,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
int dstBits = dstType.getIntOrFloatBitWidth();
assert(dstBits % srcBits == 0);
- // If the rewrited load op has the same bit width, use the loading value
+ // If the rewritten load op has the same bit width, use the loading value
// directly.
if (srcBits == dstBits) {
Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
@@ -701,12 +706,16 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
Value result = adaptor.getSource();
Type resultPtrType = typeConverter.convertType(resultType);
+ if (!resultPtrType)
+ return rewriter.notifyMatchFailure(addrCastOp,
+ "failed to convert memref type");
+
Type genericPtrType = resultPtrType;
// SPIR-V doesn't have a general address space cast operation. Instead, it has
// conversions to and from generic pointers. To implement the general case,
// we use specific-to-generic conversions when the source class is not
// generic. Then when the result storage class is not generic, we convert the
- // generic pointer (either the input on ar intermediate result) to theat
+ // generic pointer (either the input on ar intermediate result) to that
// class. This also means that we'll need the intermediate generic pointer
// type if neither the source or destination have it.
if (sourceSc != spirv::StorageClass::Generic &&
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 51ae2c10087e50..7fea90a7dc8f42 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -563,6 +563,12 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
+ if (*memrefSize == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: zero-element memrefs are not supported\n");
+ return nullptr;
+ }
+
int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 1b9b715b6b831b..7037051573bd61 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -51,7 +51,6 @@ module attributes {
// CHECK: %{{.+}} = spirv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spirv.ptr<i32, Workgroup>
// CHECK: %{{.+}} = spirv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spirv.ptr<i32, Workgroup>
-
// -----
module attributes {
@@ -92,7 +91,6 @@ module attributes {
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<4xf32>>)>, Workgroup>
// CHECK-LABEL: func @two_allocs_vector()
-
// -----
module attributes {
@@ -179,3 +177,19 @@ module attributes {
// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
// CHECK: spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
// CHECK-NOT: memref.dealloc
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+ }
+{
+ func.func @zero_size() {
+ %0 = memref.alloc() : memref<0xf32, #spirv.storage_class<Workgroup>>
+ return
+ }
+}
+
+// Zero-sized allocations are not handled yet. Just make sure we do not crash.
+// CHECK-LABEL: func @zero_size()
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
index fb270235adf54d..58847d114df007 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
@@ -16,7 +16,6 @@ module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader
// CHECK: %[[STOREPTR:.+]] = spirv.AccessChain %[[VAR]]
// CHECK: spirv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32
-
// -----
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
@@ -69,3 +68,15 @@ module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader
return %1: f32
}
}
+
+// -----
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+ func.func @zero_size() {
+ %0 = memref.alloca() : memref<0xf32, #spirv.storage_class<Function>>
+ return
+ }
+}
+
+// Zero-sized allocations are not handled yet. Just make sure we do not crash.
+// CHECK-LABEL: func @zero_size
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 2b3678542e8db4..284aa698947213 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -608,4 +608,14 @@ func.func @cast_to_static(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgr
return %ret : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
}
+// TODO: Not supported yet
+// CHECK-LABEL: func.func @cast_to_static_zero_elems
+// CHECK-SAME: (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<0xf32, #spirv.storage_class<CrossWorkgroup>> {
+// CHECK: %[[MEM1:.*]] = memref.cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
+// CHECK: return %[[MEM1]]
+ %ret = memref.cast %arg : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
+ return %ret : memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
+}
+
}
More information about the Mlir-commits
mailing list