[Mlir-commits] [mlir] 6ba6039 - [mlir][spirv] Handle all zero-element memref types (#73351)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 24 11:13:39 PST 2023


Author: Jakub Kuderski
Date: 2023-11-24T14:13:34-05:00
New Revision: 6ba60390cc4b6a8f7f0815df2a64d7b912f25e83

URL: https://github.com/llvm/llvm-project/commit/6ba60390cc4b6a8f7f0815df2a64d7b912f25e83
DIFF: https://github.com/llvm/llvm-project/commit/6ba60390cc4b6a8f7f0815df2a64d7b912f25e83.diff

LOG: [mlir][spirv] Handle all zero-element memref types (#73351)

Bail out of type conversion instead of crashing.

Fixes: https://github.com/llvm/llvm-project/issues/73289

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..2b79c8022b8e85b 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -469,6 +469,12 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
+  if (type.getNumElements() == 0) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: zero-element memrefs are not supported\n");
+    return nullptr;
+  }
+
   int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
   int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
@@ -500,6 +506,12 @@ static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
+  if (type.getNumElements() == 0) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: zero-element memrefs are not supported\n");
+    return nullptr;
+  }
+
   int64_t memrefSize =
       llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
   int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 7037051573bd610..2a5f81544f20a86 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -187,6 +187,8 @@ module attributes {
 {
   func.func @zero_size() {
     %0 = memref.alloc() : memref<0xf32, #spirv.storage_class<Workgroup>>
+    %1 = memref.alloc() : memref<0xi1, #spirv.storage_class<Workgroup>>
+    %2 = memref.alloc() : memref<0xi4, #spirv.storage_class<Workgroup>>
     return
   }
 }


        


More information about the Mlir-commits mailing list