[Mlir-commits] [mlir] ba49096 - [mlir][spirv] Lower memref with dynamic dimensions to runtime arrays

Lei Zhang llvmlistbot at llvm.org
Mon Apr 20 08:49:02 PDT 2020


Author: Lei Zhang
Date: 2020-04-20T11:48:47-04:00
New Revision: ba49096817b77932ed1534ab1fb323b46944293c

URL: https://github.com/llvm/llvm-project/commit/ba49096817b77932ed1534ab1fb323b46944293c
DIFF: https://github.com/llvm/llvm-project/commit/ba49096817b77932ed1534ab1fb323b46944293c.diff

LOG: [mlir][spirv] Lower memref with dynamic dimensions to runtime arrays

memref types with dynamic dimensions do not have a compile-time
known size. They should be mapped to SPIR-V runtime array types.

Differential Revision: https://reviews.llvm.org/D78197

Added: 
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Modified: 
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Removed: 
    mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index 0f83412f1962..c6e15d5db485 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -331,10 +331,11 @@ static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
 
 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
                                         MemRefType type) {
-  // TODO(ravishankarm) : Handle dynamic shapes.
-  if (!type.hasStaticShape()) {
+  Optional<spirv::StorageClass> storageClass =
+      SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
+  if (!storageClass) {
     LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: dynamic shape unimplemented\n");
+               << type << " illegal: cannot convert memory space\n");
     return llvm::None;
   }
 
@@ -345,27 +346,33 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
     return llvm::None;
   }
 
+  auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
+  if (!arrayElemType)
+    return llvm::None;
+
   Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
-  Optional<int64_t> memrefSize = getTypeNumBytes(type);
-  if (!scalarSize || !memrefSize) {
+  if (!scalarSize) {
     LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot deduce element count\n");
+               << type << " illegal: cannot deduce element size\n");
     return llvm::None;
   }
 
-  auto arrayElemCount = *memrefSize / *scalarSize;
+  if (!type.hasStaticShape()) {
+    auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *scalarSize);
+    // Wrap in a struct to satisfy Vulkan interface requirements.
+    auto structType = spirv::StructType::get(arrayType, 0);
+    return spirv::PointerType::get(structType, *storageClass);
+  }
 
-  auto storageClass =
-      SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
-  if (!storageClass) {
+  Optional<int64_t> memrefSize = getTypeNumBytes(type);
+  if (!memrefSize) {
     LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: cannot convert memory space\n");
+               << type << " illegal: cannot deduce element count\n");
     return llvm::None;
   }
 
-  auto arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
-  if (!arrayElemType)
-    return llvm::None;
+  auto arrayElemCount = *memrefSize / *scalarSize;
+
   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
   if (!arrayElemSize) {
     LLVM_DEBUG(llvm::dbgs()

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
similarity index 100%
rename from mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir
rename to mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 89d1fe2eb1e3..b98a20a56c6a 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -486,7 +486,7 @@ func @memref_offset_strides(
 
 // -----
 
-// Check that dynamic shapes are not supported.
+// Dynamic shapes
 module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.0, [], []>,
@@ -494,13 +494,17 @@ module attributes {
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
 } {
 
+// Check that unranked shapes are not supported.
 // CHECK-LABEL: func @unranked_memref
 // CHECK-SAME: memref<*xi32>
 func @unranked_memref(%arg0: memref<*xi32>) { return }
 
 // CHECK-LABEL: func @dynamic_dim_memref
-// CHECK-SAME: memref<8x?xi32>
-func @dynamic_dim_memref(%arg0: memref<8x?xi32>) { return }
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<i32, stride=4> [0]>, StorageBuffer>
+// CHECK-SAME: !spv.ptr<!spv.struct<!spv.rtarray<f32, stride=4> [0]>, StorageBuffer>
+func @dynamic_dim_memref(%arg0: memref<8x?xi32>,
+                         %arg1: memref<?x?xf32>)
+{ return }
 
 } // end module
 


        


More information about the Mlir-commits mailing list