[Mlir-commits] [mlir] 8854b73 - [mlir][spirv] Convert memref.alloca to spv.Variable

Lei Zhang llvmlistbot at llvm.org
Thu Apr 28 05:17:40 PDT 2022


Author: Lei Zhang
Date: 2022-04-28T08:13:40-04:00
New Revision: 8854b736065c228270000df552bdd9dc7b152453

URL: https://github.com/llvm/llvm-project/commit/8854b736065c228270000df552bdd9dc7b152453
DIFF: https://github.com/llvm/llvm-project/commit/8854b736065c228270000df552bdd9dc7b152453.diff

LOG: [mlir][spirv] Convert memref.alloca to spv.Variable

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D124542

Added: 
    mlir/test/Conversion/MemRefToSPIRV/alloca.mlir

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
    mlir/test/Conversion/MemRefToSPIRV/alloc.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 7fb0f025e3826..a947d182865ad 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "llvm/Support/Debug.h"
@@ -85,15 +86,27 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
                                                    offset);
 }
 
-/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
-static bool isAllocationSupported(MemRefType t) {
-  // Currently only support workgroup local memory allocations with static
-  // shape and int or float or vector of int or float element type.
-  if (!(t.hasStaticShape() &&
-        SPIRVTypeConverter::getMemorySpaceForStorageClass(
-            spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
+/// Returns true if the allocations of memref `type` generated from `allocOp`
+/// can be lowered to SPIR-V.
+static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
+  if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
+    if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
+            spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt())
+      return false;
+  } else if (isa<memref::AllocaOp>(allocOp)) {
+    if (SPIRVTypeConverter::getMemorySpaceForStorageClass(
+            spirv::StorageClass::Function) != type.getMemorySpaceAsInt())
+      return false;
+  } else {
     return false;
-  Type elementType = t.getElementType();
+  }
+
+  // Currently only support static shape and int or float or vector of int or
+  // float element type.
+  if (!type.hasStaticShape())
+    return false;
+
+  Type elementType = type.getElementType();
   if (auto vecType = elementType.dyn_cast<VectorType>())
     elementType = vecType.getElementType();
   return elementType.isIntOrFloat();
@@ -102,10 +115,10 @@ static bool isAllocationSupported(MemRefType t) {
 /// Returns the scope to use for atomic operations use for emulating store
 /// operations of unsupported integer bitwidths, based on the memref
 /// type. Returns None on failure.
-static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
+static Optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
   Optional<spirv::StorageClass> storageClass =
       SPIRVTypeConverter::getStorageClassForMemorySpace(
-          t.getMemorySpaceAsInt());
+          type.getMemorySpaceAsInt());
   if (!storageClass)
     return {};
   switch (*storageClass) {
@@ -149,6 +162,16 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
 
 namespace {
 
+/// Converts memref.alloca to SPIR-V Function variables.
+class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
+public:
+  using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
 /// to Workgroup memory when the size is constant.  Note that this pattern needs
 /// to be applied in a pass that runs at least at spv.module scope since it wil
@@ -215,6 +238,25 @@ class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// AllocaOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
+                                 ConversionPatternRewriter &rewriter) const {
+  MemRefType allocType = allocaOp.getType();
+  if (!isAllocationSupported(allocaOp, allocType))
+    return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
+
+  // Get the SPIR-V type for the allocation.
+  Type spirvType = getTypeConverter()->convertType(allocType);
+  rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
+                                                 spirv::StorageClass::Function,
+                                                 /*initializer=*/nullptr);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // AllocOp
 //===----------------------------------------------------------------------===//
@@ -223,8 +265,8 @@ LogicalResult
 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {
   MemRefType allocType = operation.getType();
-  if (!isAllocationSupported(allocType))
-    return operation.emitError("unhandled allocation type");
+  if (!isAllocationSupported(operation, allocType))
+    return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
 
   // Get the SPIR-V type for the allocation.
   Type spirvType = getTypeConverter()->convertType(allocType);
@@ -262,8 +304,8 @@ DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
                                   OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
   MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
-  if (!isAllocationSupported(deallocType))
-    return operation.emitError("unhandled deallocation type");
+  if (!isAllocationSupported(operation, deallocType))
+    return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
   rewriter.eraseOp(operation);
   return success();
 }
@@ -505,8 +547,9 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
-               IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<AllocaOpPattern, AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
+           IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
+          typeConverter, patterns.getContext());
 }
 } // namespace mlir

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
index adaed0ea3a910..598e03fed55f9 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir
@@ -100,10 +100,12 @@ module attributes {
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
   }
 {
-  func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) {
-    // expected-error @+1 {{unhandled allocation type}}
+  // CHECK-LABEL: func @alloc_dynamic_size
+  func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
+    // CHECK: memref.alloc
     %0 = memref.alloc(%arg0) : memref<4x?xf32, 3>
-    return
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3>
+    return %1: f32
   }
 }
 
@@ -114,10 +116,12 @@ module attributes {
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
   }
 {
-  func.func @alloc_dealloc_mem() {
-    // expected-error @+1 {{unhandled allocation type}}
+  // CHECK-LABEL: func @alloc_unsupported_memory_space
+  func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
+    // CHECK: memref.alloc
     %0 = memref.alloc() : memref<4x5xf32>
-    return
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
+    return %1: f32
   }
 }
 
@@ -129,8 +133,9 @@ module attributes {
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
   }
 {
-  func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) {
-    // expected-error @+1 {{unhandled deallocation type}}
+  // CHECK-LABEL: func @dealloc_dynamic_size
+  func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) {
+    // CHECK: memref.dealloc
     memref.dealloc %arg0 : memref<4x?xf32, 3>
     return
   }
@@ -143,8 +148,9 @@ module attributes {
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
   }
 {
-  func.func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) {
-    // expected-error @+1 {{unhandled deallocation type}}
+  // CHECK-LABEL: func @dealloc_unsupported_memory_space
+  func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) {
+    // CHECK: memref.dealloc
     memref.dealloc %arg0 : memref<4x5xf32>
     return
   }

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
new file mode 100644
index 0000000000000..2aabeed2fd8c0
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+  func.func @alloc_function_variable(%arg0 : index, %arg1 : index) {
+    %0 = memref.alloca() : memref<4x5xf32, 6>
+    %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6>
+    memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6>
+    return
+  }
+}
+
+// CHECK-LABEL: func @alloc_function_variable
+//       CHECK:   %[[VAR:.+]] = spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
+//       CHECK:   %[[LOADPTR:.+]] = spv.AccessChain %[[VAR]]
+//       CHECK:   %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32
+//       CHECK:   %[[STOREPTR:.+]] = spv.AccessChain %[[VAR]]
+//       CHECK:   spv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32
+
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+  func.func @two_allocs() {
+    %0 = memref.alloca() : memref<4x5xf32, 6>
+    %1 = memref.alloca() : memref<2x3xi32, 6>
+    return
+  }
+}
+
+// CHECK-LABEL: func @two_allocs
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<6 x i32, stride=4>)>, Function>
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<20 x f32, stride=4>)>, Function>
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+  func.func @two_allocs_vector() {
+    %0 = memref.alloca() : memref<4xvector<4xf32>, 6>
+    %1 = memref.alloca() : memref<2xvector<2xi32>, 6>
+    return
+  }
+}
+
+// CHECK-LABEL: func @two_allocs_vector
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<2 x vector<2xi32>, stride=8>)>, Function>
+//   CHECK-DAG: spv.Variable : !spv.ptr<!spv.struct<(!spv.array<4 x vector<4xf32>, stride=16>)>, Function>
+
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+  // CHECK-LABEL: func @alloc_dynamic_size
+  func.func @alloc_dynamic_size(%arg0 : index) -> f32 {
+    // CHECK: memref.alloca
+    %0 = memref.alloca(%arg0) : memref<4x?xf32, 6>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6>
+    return %1: f32
+  }
+}
+
+// -----
+
+module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], []>, {}>} {
+  // CHECK-LABEL: func @alloc_unsupported_memory_space
+  func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 {
+    // CHECK: memref.alloca
+    %0 = memref.alloca() : memref<4x5xf32>
+    %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32>
+    return %1: f32
+  }
+}


        


More information about the Mlir-commits mailing list