[Mlir-commits] [mlir] [mlir][spirv] Handle all zero-element memref types (PR #73351)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Nov 24 09:02:02 PST 2023
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/73351
Bail out of type conversion instead of crashing.
>From 37d23b4868486b0442ab088d3729b4330234e173 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Fri, 24 Nov 2023 12:00:05 -0500
Subject: [PATCH] [mlir][spirv] Handle all zero-element memref types
Bail out of type conversion instead of crashing.
---
.../lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 12 ++++++++++++
mlir/test/Conversion/MemRefToSPIRV/alloc.mlir | 2 ++
2 files changed, 14 insertions(+)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..2b79c8022b8e85b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -469,6 +469,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
+ if (type.getNumElements() == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: zero-element memrefs are not supported\n");
+ return nullptr;
+ }
+
int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
@@ -500,6 +506,12 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
+ if (type.getNumElements() == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << type << " illegal: zero-element memrefs are not supported\n");
+ return nullptr;
+ }
+
int64_t memrefSize =
llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 7037051573bd610..2a5f81544f20a86 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -187,6 +187,8 @@ module attributes {
{
func.func @zero_size() {
%0 = memref.alloc() : memref<0xf32, #spirv.storage_class<Workgroup>>
+ %1 = memref.alloc() : memref<0xi1, #spirv.storage_class<Workgroup>>
+ %2 = memref.alloc() : memref<0xi4, #spirv.storage_class<Workgroup>>
return
}
}
More information about the Mlir-commits
mailing list