[Mlir-commits] [mlir] 370a7ea - [mlir][spirv] Handle zero-sized memrefs

Jakub Kuderski llvmlistbot at llvm.org
Fri Aug 4 22:10:57 PDT 2023


Author: Jakub Kuderski
Date: 2023-08-05T01:10:15-04:00
New Revision: 370a7eae352abb9646e8f86aae02930c38135e23

URL: https://github.com/llvm/llvm-project/commit/370a7eae352abb9646e8f86aae02930c38135e23
DIFF: https://github.com/llvm/llvm-project/commit/370a7eae352abb9646e8f86aae02930c38135e23.diff

LOG: [mlir][spirv] Handle zero-sized memrefs

Make sure to check type conversion results. Add missing tests.

Fix some typos in the surrounding code.

Fixes: https://github.com/llvm/llvm-project/issues/64409

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D157166

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 16bedc6b9858e1..1d85e64bdfbfc3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -304,6 +304,9 @@ AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
 
   // Get the SPIR-V type for the allocation.
   Type spirvType = getTypeConverter()->convertType(allocType);
+  if (!spirvType)
+    return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
+
   rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
                                                  spirv::StorageClass::Function,
                                                  /*initializer=*/nullptr);
@@ -323,6 +326,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
 
   // Get the SPIR-V type for the allocation.
   Type spirvType = getTypeConverter()->convertType(allocType);
+  if (!spirvType)
+    return rewriter.notifyMatchFailure(operation, "type conversion failed");
 
   // Insert spirv.GlobalVariable for this allocation.
   Operation *parent =
@@ -467,7 +472,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
   int dstBits = dstType.getIntOrFloatBitWidth();
   assert(dstBits % srcBits == 0);
 
-  // If the rewrited load op has the same bit width, use the loading value
+  // If the rewritten load op has the same bit width, use the loading value
   // directly.
   if (srcBits == dstBits) {
     Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain);
@@ -701,12 +706,16 @@ LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
 
   Value result = adaptor.getSource();
   Type resultPtrType = typeConverter.convertType(resultType);
+  if (!resultPtrType)
+    return rewriter.notifyMatchFailure(addrCastOp,
+                                       "failed to convert memref type");
+
   Type genericPtrType = resultPtrType;
   // SPIR-V doesn't have a general address space cast operation. Instead, it has
   // conversions to and from generic pointers. To implement the general case,
   // we use specific-to-generic conversions when the source class is not
   // generic. Then when the result storage class is not generic, we convert the
-  // generic pointer (either the input on ar intermediate result) to theat
+  // generic pointer (either the input on ar intermediate result) to that
   // class. This also means that we'll need the intermediate generic pointer
   // type if neither the source or destination have it.
   if (sourceSc != spirv::StorageClass::Generic &&

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 51ae2c10087e50..7fea90a7dc8f42 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -563,6 +563,12 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
+  if (*memrefSize == 0) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: zero-element memrefs are not supported\n");
+    return nullptr;
+  }
+
   int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index 1b9b715b6b831b..7037051573bd61 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -51,7 +51,6 @@ module attributes {
 //       CHECK:   %{{.+}} = spirv.AtomicAnd "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spirv.ptr<i32, Workgroup>
 //       CHECK:   %{{.+}} = spirv.AtomicOr "Workgroup" "AcquireRelease" %[[PTR]], %{{.+}} : !spirv.ptr<i32, Workgroup>
 
-
 // -----
 
 module attributes {
@@ -92,7 +91,6 @@ module attributes {
 //  CHECK-SAME:   !spirv.ptr<!spirv.struct<(!spirv.array<4 x vector<4xf32>>)>, Workgroup>
 // CHECK-LABEL: func @two_allocs_vector()
 
-
 // -----
 
 module attributes {
@@ -179,3 +177,19 @@ module attributes {
 //       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
 //       CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
 //   CHECK-NOT:   memref.dealloc
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  }
+{
+  func.func @zero_size() {
+    %0 = memref.alloc() : memref<0xf32, #spirv.storage_class<Workgroup>>
+    return
+  }
+}
+
+// Zero-sized allocations are not handled yet. Just make sure we do not crash.
+// CHECK-LABEL: func @zero_size()

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
index fb270235adf54d..58847d114df007 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
@@ -16,7 +16,6 @@ module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader
 //       CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[VAR]]
 //       CHECK:   spirv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32
 
-
 // -----
 
 module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
@@ -69,3 +68,15 @@ module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader
     return %1: f32
   }
 }
+
+// -----
+
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader], []>, #spirv.resource_limits<>>} {
+  func.func @zero_size() {
+    %0 = memref.alloca() : memref<0xf32, #spirv.storage_class<Function>>
+    return
+  }
+}
+
+// Zero-sized allocations are not handled yet. Just make sure we do not crash.
+// CHECK-LABEL: func @zero_size

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 2b3678542e8db4..284aa698947213 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -608,4 +608,14 @@ func.func @cast_to_static(%arg: memref<4x?xf32, #spirv.storage_class<CrossWorkgr
   return %ret : memref<4x4xf32, #spirv.storage_class<CrossWorkgroup>>
 }
 
+// TODO: Not supported yet
+// CHECK-LABEL: func.func @cast_to_static_zero_elems
+//  CHECK-SAME:  (%[[MEM:.*]]: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>)
+func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> memref<0xf32, #spirv.storage_class<CrossWorkgroup>> {
+//       CHECK:  %[[MEM1:.*]] =  memref.cast %[[MEM]] : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
+//       CHECK:  return %[[MEM1]]
+  %ret = memref.cast %arg : memref<?xf32, #spirv.storage_class<CrossWorkgroup>> to memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
+  return %ret : memref<0xf32, #spirv.storage_class<CrossWorkgroup>>
+}
+
 }


        


More information about the Mlir-commits mailing list