[Mlir-commits] [mlir] d6de6dd - [mlir][spirv] Handle dynamic/static cases differntly for kernel capability

Nirvedh Meshram llvmlistbot at llvm.org
Thu Sep 29 19:35:30 PDT 2022


Author: Nirvedh Meshram
Date: 2022-09-29T19:34:42-07:00
New Revision: d6de6dde824596f5daa223f790d8e2206f2e5dce

URL: https://github.com/llvm/llvm-project/commit/d6de6dde824596f5daa223f790d8e2206f2e5dce
DIFF: https://github.com/llvm/llvm-project/commit/d6de6dde824596f5daa223f790d8e2206f2e5dce.diff

LOG: [mlir][spirv] Handle dynamic/static cases differntly for kernel capability

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D134908

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7f31488ad2b5..cfb0e5532687 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -335,8 +335,10 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                          .getPointeeType();
   Type dstType;
   if (typeConverter.allows(spirv::Capability::Kernel)) {
-    // For OpenCL Kernel, pointer will be directly pointing to the element.
-    dstType = pointeeType;
+    if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
+      dstType = arrayType.getElementType();
+    else
+      dstType = pointeeType;
   } else {
     // For Vulkan we need to extract element from wrapping struct and array.
     Type structElemType =
@@ -464,8 +466,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                          .getPointeeType();
   Type dstType;
   if (typeConverter.allows(spirv::Capability::Kernel)) {
-    // For OpenCL Kernel, pointer will be directly pointing to the element.
-    dstType = pointeeType;
+    if (auto arrayType = pointeeType.dyn_cast<spirv::ArrayType>())
+      dstType = arrayType.getElementType();
+    else
+      dstType = pointeeType;
   } else {
     // For Vulkan we need to extract element from wrapping struct and array.
     Type structElemType =

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 3e2d0e9db575..083b997aeaef 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -338,15 +338,16 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  // For OpenCL Kernel we can just emit a pointer pointing to the element.
-  if (targetEnv.allows(spirv::Capability::Kernel))
-    return spirv::PointerType::get(arrayElemType, storageClass);
 
-  // For Vulkan we need extra wrapping struct and array to satisfy interface
-  // needs.
   if (!type.hasStaticShape()) {
+    // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
+    // to the element.
+    if (targetEnv.allows(spirv::Capability::Kernel))
+      return spirv::PointerType::get(arrayElemType, storageClass);
     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
+    // For Vulkan we need extra wrapping struct and array to satisfy interface
+    // needs.
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
@@ -354,7 +355,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
-
+  if (targetEnv.allows(spirv::Capability::Kernel))
+    return spirv::PointerType::get(arrayType, storageClass);
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
@@ -403,15 +405,16 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  // For OpenCL Kernel we can just emit a pointer pointing to the element.
-  if (targetEnv.allows(spirv::Capability::Kernel))
-    return spirv::PointerType::get(arrayElemType, storageClass);
 
-  // For Vulkan we need extra wrapping struct and array to satisfy interface
-  // needs.
   if (!type.hasStaticShape()) {
+    // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
+    // to the element.
+    if (targetEnv.allows(spirv::Capability::Kernel))
+      return spirv::PointerType::get(arrayElemType, storageClass);
     int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
+    // For Vulkan we need extra wrapping struct and array to satisfy interface
+    // needs.
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
@@ -425,7 +428,8 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
-
+  if (targetEnv.allows(spirv::Capability::Kernel))
+    return spirv::PointerType::get(arrayType, storageClass);
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
@@ -776,15 +780,20 @@ Value mlir::spirv::getOpenCLElementPtr(SPIRVTypeConverter &typeConverter,
   auto indexType = typeConverter.getIndexType();
 
   SmallVector<Value, 2> linearizedIndices;
-  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
-
   Value linearIndex;
   if (baseType.getRank() == 0) {
-    linearIndex = zero;
+    linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
   } else {
     linearIndex =
         linearizeIndex(indices, strides, offset, indexType, loc, builder);
   }
+  Type pointeeType =
+      basePtr.getType().cast<spirv::PointerType>().getPointeeType();
+  if (pointeeType.isa<spirv::ArrayType>()) {
+    linearizedIndices.push_back(linearIndex);
+    return builder.create<spirv::AccessChainOp>(loc, basePtr,
+                                                linearizedIndices);
+  }
   return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
                                                  linearizedIndices);
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
index 35cd5ea26b26..e442ac6c8803 100644
--- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir
@@ -9,7 +9,7 @@ module attributes {
     //       CHECK:   spirv.func
     //  CHECK-SAME:     {{%.*}}: f32
     //   CHECK-NOT:     spirv.interface_var_abi
-    //  CHECK-SAME:     {{%.*}}: !spirv.ptr<f32, CrossWorkgroup>
+    //  CHECK-SAME:     {{%.*}}: !spirv.ptr<!spirv.array<12 x f32>, CrossWorkgroup>
     //   CHECK-NOT:     spirv.interface_var_abi
     //  CHECK-SAME:     spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
     gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class<CrossWorkgroup>>) kernel

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index d1979e21a436..0e39ee15ea98 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -155,3 +155,27 @@ module attributes {
     return
   }
 }
+
+// -----
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Kernel], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+  }
+{
+  func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
+    %0 = memref.alloc() : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+    memref.dealloc %0 : memref<4x5xf32, #spirv.storage_class<Workgroup>>
+    return
+  }
+}
+//     CHECK: spirv.GlobalVariable @[[VAR:.+]] : !spirv.ptr<!spirv.array<20 x f32>, Workgroup>
+//     CHECK: func @alloc_dealloc_workgroup_mem
+// CHECK-NOT:   memref.alloc
+//     CHECK:   %[[PTR:.+]] = spirv.mlir.addressof @[[VAR]]
+//     CHECK:   %[[LOADPTR:.+]] = spirv.AccessChain %[[PTR]]
+//     CHECK:   %[[VAL:.+]] = spirv.Load "Workgroup" %[[LOADPTR]] : f32
+//     CHECK:   %[[STOREPTR:.+]] = spirv.AccessChain %[[PTR]]
+//     CHECK:   spirv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
+// CHECK-NOT:   memref.dealloc

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e3ddb7398055..67556f00c2f2 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -121,16 +121,16 @@ module attributes {
 
 // CHECK-LABEL: @load_store_zero_rank_float
 func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<f32, #spirv.storage_class<CrossWorkgroup>>) {
-  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
-  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<f32, CrossWorkgroup>
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<f32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x f32>, CrossWorkgroup>
   //      CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
-  //      CHECK: spirv.PtrAccessChain [[ARG0]][
+  //      CHECK: spirv.AccessChain [[ARG0]][
   // CHECK-SAME: [[ZERO1]]
   // CHECK-SAME: ] :
   //      CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : f32
   %0 = memref.load %arg0[] : memref<f32, #spirv.storage_class<CrossWorkgroup>>
   //      CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
-  //      CHECK: spirv.PtrAccessChain [[ARG1]][
+  //      CHECK: spirv.AccessChain [[ARG1]][
   // CHECK-SAME: [[ZERO2]]
   // CHECK-SAME: ] :
   //      CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : f32
@@ -140,16 +140,16 @@ func.func @load_store_zero_rank_float(%arg0: memref<f32, #spirv.storage_class<Cr
 
 // CHECK-LABEL: @load_store_zero_rank_int
 func.func @load_store_zero_rank_int(%arg0: memref<i32, #spirv.storage_class<CrossWorkgroup>>, %arg1: memref<i32, #spirv.storage_class<CrossWorkgroup>>) {
-  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to  !spirv.ptr<i32, CrossWorkgroup>
-  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to  !spirv.ptr<i32, CrossWorkgroup>
+  //      CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
+  //      CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref<i32, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<1 x i32>, CrossWorkgroup>
   //      CHECK: [[ZERO1:%.*]] = spirv.Constant 0 : i32
-  //      CHECK: spirv.PtrAccessChain [[ARG0]][
+  //      CHECK: spirv.AccessChain [[ARG0]][
   // CHECK-SAME: [[ZERO1]]
   // CHECK-SAME: ] :
   //      CHECK: spirv.Load "CrossWorkgroup" %{{.*}} : i32
   %0 = memref.load %arg0[] : memref<i32, #spirv.storage_class<CrossWorkgroup>>
   //      CHECK: [[ZERO2:%.*]] = spirv.Constant 0 : i32
-  //      CHECK: spirv.PtrAccessChain [[ARG1]][
+  //      CHECK: spirv.AccessChain [[ARG1]][
   // CHECK-SAME: [[ZERO2]]
   // CHECK-SAME: ] :
   //      CHECK: spirv.Store "CrossWorkgroup" %{{.*}} : i32
@@ -173,14 +173,13 @@ func.func @load_store_unknown_dim(%i: index, %source: memref<?xi32, #spirv.stora
 // CHECK-LABEL: func @load_i1
 //  CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %[[IDX:.+]]: index)
 func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i : index) -> i1 {
-  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
+  // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
-  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]]
+  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ADD]]]
   // CHECK: %[[VAL:.+]] = spirv.Load "CrossWorkgroup" %[[ADDR]] : i8
   // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
   // CHECK: %[[BOOL:.+]] = spirv.IEqual %[[VAL]], %[[ONE_I8]] : i8
@@ -194,14 +193,13 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i
 //  CHECK-SAME: %[[IDX:.+]]: index
 func.func @store_i1(%dst: memref<4xi1, #spirv.storage_class<CrossWorkgroup>>, %i: index) {
   %true = arith.constant true
-  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<i8, CrossWorkgroup>
+  // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spirv.storage_class<CrossWorkgroup>> to !spirv.ptr<!spirv.array<4 x i8>, CrossWorkgroup>
   // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]]
   // CHECK: %[[ZERO_0:.+]] = spirv.Constant 0 : i32
-  // CHECK: %[[ZERO_1:.+]] = spirv.Constant 0 : i32
   // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
   // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[IDX_CAST]] : i32
-  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_1]], %[[MUL]] : i32
-  // CHECK: %[[ADDR:.+]] = spirv.PtrAccessChain %[[DST_CAST]][%[[ADD]]]
+  // CHECK: %[[ADD:.+]] = spirv.IAdd %[[ZERO_0]], %[[MUL]] : i32
+  // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[DST_CAST]][%[[ADD]]]
   // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8
   // CHECK: %[[ONE_I8:.+]] = spirv.Constant 1 : i8
   // CHECK: %[[RES:.+]] = spirv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8


        


More information about the Mlir-commits mailing list