[Mlir-commits] [mlir] 4d6f44f - [mlir][spirv] Lower allocation/deallocations of workgroup memory.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 27 09:53:38 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-27T09:53:16-07:00
New Revision: 4d6f44f5f0925f2d05431065d9f197644d07b1b5

URL: https://github.com/llvm/llvm-project/commit/4d6f44f5f0925f2d05431065d9f197644d07b1b5
DIFF: https://github.com/llvm/llvm-project/commit/4d6f44f5f0925f2d05431065d9f197644d07b1b5.diff

LOG: [mlir][spirv] Lower allocation/deallocations of workgroup memory.

This allocation of a workgroup memory is lowered to a
spv.globalVariable. Only static size allocation with element type
being int or float is handled. The lowering does account for the
element type that are not supported in the lowered spv.module based on
the extensions/capabilities and adjusts the number of elements to get
the same byte length.

Differential Revision: https://reviews.llvm.org/D80411

Added: 
    mlir/test/Conversion/StandardToSPIRV/alloc.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
    mlir/test/Conversion/GPUToSPIRV/load-store.mlir
    mlir/test/Conversion/GPUToSPIRV/loop.mlir
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index ba0b7ea0714c..1fa668d7ddc0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -41,6 +41,12 @@ class SPIRVTypeConverter : public TypeConverter {
 public:
   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);
 
+  /// Gets the number of bytes used for a type when converted to SPIR-V
+  /// type. Note that it doesnt account for whether the type is legal for a
+  /// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on
+  /// failure.
+  static Optional<int64_t> getConvertedTypeNumBytes(Type);
+
   /// Gets the SPIR-V correspondence for the standard index type.
   static Type getIndexType(MLIRContext *context);
 

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 560bc4acf436..facdbf7d096a 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -169,22 +169,51 @@ bool isUnsignedOp() {
     return true;                                                               \
   }
 
-CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp);
-CHECK_UNSIGNED_OP(spirv::AtomicUMinOp);
-CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp);
-CHECK_UNSIGNED_OP(spirv::ConvertUToFOp);
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp);
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp);
-CHECK_UNSIGNED_OP(spirv::UConvertOp);
-CHECK_UNSIGNED_OP(spirv::UDivOp);
-CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp);
-CHECK_UNSIGNED_OP(spirv::UGreaterThanOp);
-CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp);
-CHECK_UNSIGNED_OP(spirv::ULessThanOp);
-CHECK_UNSIGNED_OP(spirv::UModOp);
+CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
+CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
+CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
+CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
+CHECK_UNSIGNED_OP(spirv::UConvertOp)
+CHECK_UNSIGNED_OP(spirv::UDivOp)
+CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
+CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
+CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
+CHECK_UNSIGNED_OP(spirv::ULessThanOp)
+CHECK_UNSIGNED_OP(spirv::UModOp)
 
 #undef CHECK_UNSIGNED_OP
 
+/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
+static bool isAllocationSupported(MemRefType t) {
+  // Currently only support workgroup local memory allocations with static
+  // shape and int or float element type.
+  return t.hasStaticShape() &&
+         SPIRVTypeConverter::getMemorySpaceForStorageClass(
+             spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
+         t.getElementType().isIntOrFloat();
+}
+
+/// Returns the scope to use for atomic operations use for emulating store
+/// operations of unsupported integer bitwidths, based on the memref
+/// type. Returns None on failure.
+static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
+  Optional<spirv::StorageClass> storageClass =
+      SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
+  if (!storageClass)
+    return {};
+  switch (*storageClass) {
+  case spirv::StorageClass::StorageBuffer:
+    return spirv::Scope::Device;
+  case spirv::StorageClass::Workgroup:
+    return spirv::Scope::Workgroup;
+  default: {
+  }
+  }
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -195,6 +224,67 @@ CHECK_UNSIGNED_OP(spirv::UModOp);
 
 namespace {
 
+/// Converts an allocation operation to SPIR-V. Currently only supports lowering
+/// to Workgroup memory when the size is constant.  Note that this pattern needs
+/// to be applied in a pass that runs at least at spv.module scope since it wil
+/// ladd global variables into the spv.module.
+class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
+public:
+  using SPIRVOpLowering<AllocOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType allocType = operation.getType();
+    if (!isAllocationSupported(allocType))
+      return operation.emitError("unhandled allocation type");
+
+    // Get the SPIR-V type for the allocation.
+    Type spirvType = typeConverter.convertType(allocType);
+
+    // Insert spv.globalVariable for this allocation.
+    Operation *parent =
+        SymbolTable::getNearestSymbolTable(operation.getParentOp());
+    if (!parent)
+      return failure();
+    Location loc = operation.getLoc();
+    spirv::GlobalVariableOp varOp;
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      Block &entryBlock = *parent->getRegion(0).begin();
+      rewriter.setInsertionPointToStart(&entryBlock);
+      auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
+      std::string varName =
+          std::string("__workgroup_mem__") +
+          std::to_string(std::distance(varOps.begin(), varOps.end()));
+      varOp = rewriter.create<spirv::GlobalVariableOp>(
+          loc, TypeAttr::get(spirvType), varName,
+          /*initializer = */ nullptr);
+    }
+
+    // Get pointer to global variable at the current scope.
+    rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
+    return success();
+  }
+};
+
+/// Removed a deallocation if it is a supported allocation. Currently only
+/// removes deallocation if the memory space is workgroup memory.
+class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
+public:
+  using SPIRVOpLowering<DeallocOp>::SPIRVOpLowering;
+
+  LogicalResult
+  matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
+    if (!isAllocationSupported(deallocType))
+      return operation.emitError("unhandled deallocation type");
+    rewriter.eraseOp(operation);
+    return success();
+  }
+};
+
 /// Converts unary and binary standard operations to SPIR-V operations.
 template <typename StdOp, typename SPIRVOp>
 class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
@@ -823,12 +913,15 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
       shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
+  Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
+  if (!scope)
+    return failure();
   Value result = rewriter.create<spirv::AtomicAndOp>(
-      loc, dstType, adjustedPtr, spirv::Scope::Device,
-      spirv::MemorySemantics::AcquireRelease, clearBitsMask);
+      loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
+      clearBitsMask);
   result = rewriter.create<spirv::AtomicOrOp>(
-      loc, dstType, adjustedPtr, spirv::Scope::Device,
-      spirv::MemorySemantics::AcquireRelease, storeVal);
+      loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
+      storeVal);
 
   // The AtomicOrOp has no side effect. Since it is already inserted, we can
   // just remove the original StoreOp. Note that rewriter.replaceOp()
@@ -913,6 +1006,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
       UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
       UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
       UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
+      AllocOpPattern, DeallocOpPattern,
       BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
       BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
index c9f2983e232b..dfc2728ef710 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
@@ -218,6 +218,10 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
   return llvm::None;
 }
 
+Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
+  return getTypeNumBytes(t);
+}
+
 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
 static Optional<Type>
 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
@@ -383,8 +387,11 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
   auto arrayType =
       spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
 
-  // Wrap in a struct to satisfy Vulkan interface requirements.
-  auto structType = spirv::StructType::get(arrayType, 0);
+  // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
+  // workgroup storage class do not need the struct to be laid out explicitly.
+  auto structType = *storageClass == spirv::StorageClass::Workgroup
+                        ? spirv::StructType::get(arrayType)
+                        : spirv::StructType::get(arrayType, 0);
   return spirv::PointerType::get(structType, *storageClass);
 }
 
@@ -574,35 +581,40 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
     ArrayRef<Value> indices, Location loc, OpBuilder &builder) {
   // Get base and offset of the MemRefType and verify they are static.
+
   int64_t offset;
   SmallVector<int64_t, 4> strides;
   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
-      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
+      offset == MemRefType::getDynamicStrideOrOffset()) {
     return nullptr;
   }
 
   auto indexType = typeConverter.getIndexType(builder.getContext());
-
-  Value ptrLoc = nullptr;
-  assert(indices.size() == strides.size() &&
-         "must provide indices for all dimensions");
-  for (auto index : enumerate(indices)) {
-    Value strideVal = builder.create<spirv::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
-    Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
-    ptrLoc =
-        (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
-                : update);
-  }
   SmallVector<Value, 2> linearizedIndices;
   // Add a '0' at the start to index into the struct.
   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
   linearizedIndices.push_back(zero);
-  // If it is a zero-rank memref type, extract the element directly.
-  if (!ptrLoc) {
-    ptrLoc = zero;
+
+  if (baseType.getRank() == 0) {
+    linearizedIndices.push_back(zero);
+  } else {
+    // TODO: Instead of this logic, use affine.apply and add patterns for
+    // lowering affine.apply to standard ops. These will get lowered to SPIR-V
+    // ops by the DialectConversion framework.
+    Value ptrLoc = builder.create<spirv::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, offset));
+    assert(indices.size() == strides.size() &&
+           "must provide indices for all dimensions");
+    for (auto index : enumerate(indices)) {
+      Value strideVal = builder.create<spirv::ConstantOp>(
+          loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+      Value update =
+          builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+      ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
+    }
+    linearizedIndices.push_back(ptrLoc);
   }
-  linearizedIndices.push_back(ptrLoc);
   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
 }
 

diff  --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
index 543f364c93f0..12a5d9df61a8 100644
--- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir
@@ -58,13 +58,15 @@ module attributes {
       %12 = addi %arg3, %0 : index
       // CHECK: %[[INDEX2:.*]] = spv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
       %13 = addi %arg4, %3 : index
+      // CHECK: %[[ZERO:.*]] = spv.constant 0 : i32
+      // CHECK: %[[OFFSET1_0:.*]] = spv.constant 0 : i32
       // CHECK: %[[STRIDE1_1:.*]] = spv.constant 4 : i32
-      // CHECK: %[[OFFSET1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
+      // CHECK: %[[UPDATE1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
+      // CHECK: %[[OFFSET1_1:.*]] = spv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
       // CHECK: %[[STRIDE1_2:.*]] = spv.constant 1 : i32
       // CHECK: %[[UPDATE1_2:.*]] = spv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
       // CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
-      // CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32
-      // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO1]], %[[OFFSET1_2]]{{\]}}
+      // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
       // CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]]
       %14 = load %arg0[%12, %13] : memref<12x4xf32>
       // CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}

diff  --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
index 5bc44cf0ba05..7c5df798438f 100644
--- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir
@@ -28,13 +28,17 @@ module attributes {
       // CHECK:        %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
       // CHECK:        spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
       // CHECK:      ^[[BODY]]:
-      // CHECK:        %[[STRIDE1:.*]] = spv.constant 1 : i32
-      // CHECK:        %[[INDEX1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
       // CHECK:        %[[ZERO1:.*]] = spv.constant 0 : i32
+      // CHECK:        %[[OFFSET1:.*]] = spv.constant 0 : i32
+      // CHECK:        %[[STRIDE1:.*]] = spv.constant 1 : i32
+      // CHECK:        %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
+      // CHECK:        %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32
       // CHECK:        spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
-      // CHECK:        %[[STRIDE2:.*]] = spv.constant 1 : i32
-      // CHECK:        %[[INDEX2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
       // CHECK:        %[[ZERO2:.*]] = spv.constant 0 : i32
+      // CHECK:        %[[OFFSET2:.*]] = spv.constant 0 : i32
+      // CHECK:        %[[STRIDE2:.*]] = spv.constant 1 : i32
+      // CHECK:        %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
+      // CHECK:        %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32
       // CHECK:        spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
       // CHECK:        %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
       // CHECK:        spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)

diff  --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
new file mode 100644
index 000000000000..3cbeda1cafb0
--- /dev/null
+++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir
@@ -0,0 +1,144 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// std allocation/deallocation ops
+//===----------------------------------------------------------------------===//
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
+    %0 = alloc() : memref<4x5xf32, 3>
+    %1 = load %0[%arg0, %arg1] : memref<4x5xf32, 3>
+    store %1, %0[%arg0, %arg1] : memref<4x5xf32, 3>
+    dealloc %0 : memref<4x5xf32, 3>
+    return
+  }
+}
+//     CHECK: spv.globalVariable @[[VAR:.+]] : !spv.ptr<!spv.struct<!spv.array<20 x f32, stride=4>>, Workgroup>
+//     CHECK: func @alloc_dealloc_workgroup_mem
+// CHECK-NOT:   alloc
+//     CHECK:   %[[PTR:.+]] = spv._address_of @[[VAR]]
+//     CHECK:   %[[LOADPTR:.+]] = spv.AccessChain %[[PTR]]
+//     CHECK:   %[[VAL:.+]] = spv.Load "Workgroup" %[[LOADPTR]] : f32
+//     CHECK:   %[[STOREPTR:.+]] = spv.AccessChain %[[PTR]]
+//     CHECK:   spv.Store "Workgroup" %[[STOREPTR]], %[[VAL]] : f32
+// CHECK-NOT:   dealloc
+//     CHECK:   spv.Return
+
+// -----
+
+// TODO: Uncomment this test when the extension handling correctly
+// converts an i16 type to i32 type and handles the load/stores
+// correctly.
+
+// module attributes {
+//   spv.target_env = #spv.target_env<
+//     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+//     {max_compute_workgroup_invocations = 128 : i32,
+//      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+//   }
+// {
+//   func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) {
+//     %0 = alloc() : memref<4x5xi16, 3>
+//     %1 = load %0[%arg0, %arg1] : memref<4x5xi16, 3>
+//     store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3>
+//     dealloc %0 : memref<4x5xi16, 3>
+//     return
+//   }
+// }
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @two_allocs() {
+    %0 = alloc() : memref<4x5xf32, 3>
+    %1 = alloc() : memref<2x3xi32, 3>
+    return
+  }
+}
+
+//  CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME:   !spv.ptr<!spv.struct<!spv.array<6 x i32, stride=4>>, Workgroup>
+//  CHECK-DAG: spv.globalVariable @__workgroup_mem__{{[0-9]+}}
+// CHECK-SAME:   !spv.ptr<!spv.struct<!spv.array<20 x f32, stride=4>>, Workgroup>
+//      CHECK: spv.func @two_allocs()
+//      CHECK: spv.Return
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) {
+    // expected-error @+2 {{unhandled allocation type}}
+    // expected-error @+1 {{'std.alloc' op operand #0 must be index}}
+    %0 = alloc(%arg0) : memref<4x?xf32, 3>
+    return
+  }
+}
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_mem() {
+    // expected-error @+1 {{unhandled allocation type}}
+    %0 = alloc() : memref<4x5xf32>
+    return
+  }
+}
+
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) {
+    // expected-error @+2 {{unhandled deallocation type}}
+    // expected-error @+1 {{'std.dealloc' op operand #0 must be memref of any type values}}
+    dealloc %arg0 : memref<4x?xf32, 3>
+    return
+  }
+}
+
+// -----
+
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+  }
+{
+  func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) {
+    // expected-error @+2 {{unhandled deallocation type}}
+    // expected-error @+1 {{op operand #0 must be memref of any type values}}
+    dealloc %arg0 : memref<4x5xf32>
+    return
+  }
+}

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index bf54dbaadb18..3fe24d05dd2e 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -747,9 +747,11 @@ func @load_i8(%arg0: memref<i8>) {
 // CHECK-LABEL: @load_i16
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
 func @load_i16(%arg0: memref<10xi16>, %index : index) {
-  //     CHECK: %[[ONE:.+]] = spv.constant 1 : i32
-  //     CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
   //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[OFFSET:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+  //     CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
+  //     CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO1:.+]] = spv.constant 2 : i32
   //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32
   //     CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
@@ -811,9 +813,11 @@ func @store_i8(%arg0: memref<i8>, %value: i8) {
 // CHECK-LABEL: @store_i16
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
 func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) {
-  //     CHECK: %[[ONE:.+]] = spv.constant 1 : i32
-  //     CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
   //     CHECK: %[[ZERO:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[OFFSET:.+]] = spv.constant 0 : i32
+  //     CHECK: %[[ONE:.+]] = spv.constant 1 : i32
+  //     CHECK: %[[UPDATE:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32
+  //     CHECK: %[[FLAT_IDX:.+]] = spv.IAdd %[[OFFSET]], %[[UPDATE]] : i32
   //     CHECK: %[[TWO:.+]] = spv.constant 2 : i32
   //     CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32
   //     CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32


        


More information about the Mlir-commits mailing list