[Mlir-commits] [mlir] 115711c - [mlir][LinAlg][Transform][GPU] Add GPU memory hierarchy to the transform.promote op
Alex Zinenko
llvmlistbot at llvm.org
Mon Feb 27 07:34:06 PST 2023
Author: Amir Mohammad Tavakkoli
Date: 2023-02-27T16:33:58+01:00
New Revision: 115711c19cd287c098a872c63a00478ca635f642
URL: https://github.com/llvm/llvm-project/commit/115711c19cd287c098a872c63a00478ca635f642
DIFF: https://github.com/llvm/llvm-project/commit/115711c19cd287c098a872c63a00478ca635f642.diff
LOG: [mlir][LinAlg][Transform][GPU] Add GPU memory hierarchy to the transform.promote op
In this patch we are adding the support of copying a a `memref.subview` to the shared or private memory in GPU. The global to shared memory copy is adopted from codes implemented in IREE (https://github.com/iree-org/iree), but the private memory copy part has not been implemented in IREE. This patch enables transferring a subview from `global->shared`, `global->private`, and `shared->private`.
Our final aim is to provide a copy layout as an affine map to the `transform.promote` op to support transpose memory copy. This map is a permutation of the original affine index map. Although this has been implemented and user can copy data to arbitrary layout , this attempt is not included in this patch since we have still problem with `linalg.generic` operations to change their index map to the transformed index map. You can find more in following links ([[ https://github.com/tavakkoliamirmohammad/iree-llvm-fork/commit/4fd5f93355951ad0fb338858393ff409bd9c62f8 | Initial attempt to support layout map in promote op in transform dialect ]]) ([[ https://github.com/tavakkoliamirmohammad/iree-llvm-fork/commit/9062b5849f91d4defb84996392b71087dadf7a8c | Fix data transpose in shared memory ]])
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D144666
Added:
Modified:
mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/test/Dialect/Linalg/promote.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
index c03bdb651a6dc..3b261acdee83a 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td
@@ -85,4 +85,23 @@ def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [
}];
}
+
+def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
+ DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
+ let parameters = (ins
+ EnumParameter<GPU_AddressSpaceEnum>:$address_space
+ );
+ let assemblyFormat = "`<` params `>`";
+ let description = [{
+ An attribute that allows defining memory hierarchy for GPU devices.
+
+ GPU Memory has three memory space, global, workgroup, and private. The global memory
+ is visible to all workitems and workgroups, the workgroup memory is only available for workitems
+ within a workgroup, and private memory is only visible to a single workitem. This attribute indicates
+ that using memory hiearchy is desired. It can be consumed by lowering to
+ move data to a specific address space in GPU code.
+ }];
+}
+
+
#endif // GPU_DEVICE_MAPPING_ATTR
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c53497839ea9a..41c5daf6744d0 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -765,6 +765,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
// PromoteOp
//===----------------------------------------------------------------------===//
+
def PromoteOp : Op<Transform_Dialect, "structured.promote",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
@@ -791,6 +792,7 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
DefaultValuedAttr<BoolArrayAttr, "{}">:$use_full_tile_buffers,
UnitAttr:$use_full_tiles_by_default,
UnitAttr:$use_alloca,
+ OptionalAttr<DeviceMappingArrayAttr>:$mapping,
OptionalAttr<I64Attr>:$alignment);
let results = (outs PDL_Operation:$transformed);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8e77b54251aa7..ea645e8973c0b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -393,6 +393,32 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
const LinalgPromotionOptions &options);
+/// Allocate the subview in the GPU workgroup memory.
+Optional<Value> allocateWorkgroupMemory(OpBuilder &builder,
+ memref::SubViewOp subview,
+ ArrayRef<Value> sizeBounds,
+ DataLayout &);
+
+/// In case of GPU group memory there is no need to deallocate.
+LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/);
+
+/// Create Memref copy operations and add gpu barrier guards before and after
+/// the copy operation to ensure data integrity.
+LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst);
+
+/// Allocate the subview in the GPU private memory.
+Optional<Value> allocateGPUPrivateMemory(OpBuilder &builder,
+ memref::SubViewOp subview,
+ ArrayRef<Value> sizeBounds,
+ DataLayout &);
+
+/// Normal copy to between src and dst.
+LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
+
+/// In case of GPU private memory there is no need to deallocate since the
+/// memory is freed when going outside of the scope.
+LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
+
/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
/// are used to vectorize this operation. `inputVectorSizes` must match the rank
/// of the iteration space of the operation and the sizes must be smaller or
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 159892e7bd09a..0ec5877f80361 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -50,6 +50,10 @@ int64_t GPUThreadMappingAttr::getMappingId() const {
return static_cast<int64_t>(getThread());
}
+int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
+ return static_cast<int64_t>(getAddressSpace());
+}
+
//===----------------------------------------------------------------------===//
// MMAMatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 2d102383ddbe0..5a8f9816aefd1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1802,6 +1802,35 @@ transform::PromoteOp::applyToOne(LinalgOp target,
if (getAlignment().has_value())
promotionOptions = promotionOptions.setAlignment(*getAlignment());
+ if (getMapping().has_value()) {
+ // The mapping should only contain an element
+ auto mapping = *getMapping();
+ if (mapping.size() > 1)
+ return emitDefaultDefiniteFailure(target);
+
+ auto addressSpace = mapping[0].cast<gpu::GPUMemorySpaceMappingAttr>();
+
+ if (addressSpace.getAddressSpace() ==
+ gpu::GPUDialect::getWorkgroupAddressSpace()) {
+ promotionOptions =
+ promotionOptions
+ .setAllocationDeallocationFns(allocateWorkgroupMemory,
+ deallocateWorkgroupMemory)
+ .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
+ .setUseFullTileBuffers({false, false});
+ } else if (addressSpace.getAddressSpace() ==
+ gpu::GPUDialect::getPrivateAddressSpace()) {
+ promotionOptions =
+ promotionOptions
+ .setAllocationDeallocationFns(allocateGPUPrivateMemory,
+ deallocateGPUPrivateMemory)
+ .setCopyInOutFns(copyToGPUPrivateMemory, copyToGPUPrivateMemory)
+ .setUseFullTileBuffers({false, false});
+ } else {
+ return emitDefaultDefiniteFailure(target);
+ }
+ }
+
if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
return emitDefaultDefiniteFailure(target);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index fd41ed30ca9ab..abc5e0039b5b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -13,6 +13,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -397,3 +399,87 @@ mlir::linalg::promoteSubViews(OpBuilder &builder, LinalgOp linalgOp,
return failure();
return res;
}
+
+/// Allocate the given subview to a memory address space in GPU by creating a
+/// allocation operation and setting the memref type address space to desired
+/// address space.
+static Optional<Value> allocateSubviewGPUMemoryInAddressSpace(
+ OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
+ gpu::AddressSpace addressSpace) {
+ OpBuilder::InsertionGuard guard(builder);
+
+ func::FuncOp funcOp = subview->getParentOfType<func::FuncOp>();
+ if (!funcOp)
+ return std::nullopt;
+
+ // The subview size bounds are expected to be constant; they specify the shape
+ // of the allocation.
+ SmallVector<int64_t> shape;
+ for (Value bound : sizeBounds) {
+ APInt value;
+ if (!matchPattern(bound, m_ConstantInt(&value)))
+ return std::nullopt;
+ shape.push_back(value.getSExtValue());
+ }
+
+ builder.setInsertionPoint(&funcOp.front(), funcOp.front().begin());
+ auto type = MemRefType::get(
+ shape, subview.getType().getElementType(), MemRefLayoutAttrInterface{},
+ gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace));
+ Value buffer;
+ if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) {
+ buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type);
+ } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) {
+ buffer = builder.create<memref::AllocaOp>(funcOp.getLoc(), type);
+ } else {
+ return std::nullopt;
+ }
+ return buffer;
+}
+
+/// Allocate the subview in the GPU workgroup memory.
+Optional<Value> mlir::linalg::allocateWorkgroupMemory(
+ OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
+ DataLayout &) {
+ return allocateSubviewGPUMemoryInAddressSpace(
+ builder, subview, sizeBounds,
+ gpu::GPUDialect::getWorkgroupAddressSpace());
+}
+
+/// In case of GPU group memory there is no need to deallocate.
+LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &,
+ Value /*buffer*/) {
+ return success();
+}
+
+/// Create Memref copy operations and add gpu barrier guards before and after
+/// the copy operation to ensure data integrity.
+LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src,
+ Value dst) {
+ b.create<gpu::BarrierOp>(src.getLoc());
+ Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
+ b.create<gpu::BarrierOp>(copyOp->getLoc());
+ return success();
+}
+
+/// Allocate the subview in the GPU private memory.
+Optional<Value> mlir::linalg::allocateGPUPrivateMemory(
+ OpBuilder &builder, memref::SubViewOp subview, ArrayRef<Value> sizeBounds,
+ DataLayout &) {
+ return allocateSubviewGPUMemoryInAddressSpace(
+ builder, subview, sizeBounds, gpu::GPUDialect::getPrivateAddressSpace());
+}
+
+/// Normal copy to between src and dst.
+LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src,
+ Value dst) {
+ Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst);
+ return success();
+}
+
+/// In case of GPU private memory there is no need to deallocate since the
+/// memory is freed when going outside of the scope.
+LogicalResult mlir::linalg::deallocateGPUPrivateMemory(OpBuilder &,
+ Value /*buffer*/) {
+ return success();
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
index 085c1b7ae714b..b34a86ec901e7 100644
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -142,6 +142,94 @@ transform.sequence failures(propagate) {
%1 = transform.structured.promote %0
}
+// -----
+func.func @gemm_shared(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
+{
+ linalg.matmul ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
+ return
+}
+
+// CHECK: func @gemm_shared
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: %[[alloc_A:.*]] = memref.alloc() : memref<16x16xf32, #gpu.address_space<workgroup>>
+// CHECK: %[[alloc_B:.*]] = memref.alloc() : memref<16x16xf32, #gpu.address_space<workgroup>>
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: %[[subview_A:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: %[[subview_B:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: %[[subview_C:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+
+// CHECK: %[[shared_A:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+// CHECK: %[[shared_B:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<workgroup>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+
+// CHECK-NEXT: gpu.barrier
+// CHECK-NEXT: memref.copy %[[subview_A]], %[[shared_A]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+// CHECK-NEXT: gpu.barrier
+
+// CHECK-NEXT: gpu.barrier
+// CHECK-NEXT: memref.copy %[[subview_B]], %[[shared_B]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<workgroup>>
+// CHECK-NEXT: gpu.barrier
+
+// CHECK: linalg.matmul ins(%[[shared_A]], %[[shared_B]]{{.*}} outs(%[[subview_C]]
+
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+ %2 = transform.structured.promote %1 { operands_to_promote = [0, 1], mapping = [#gpu.memory_space<workgroup>] }
+}
+
+
+// -----
+
+func.func @gemm_private(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>)
+{
+ linalg.matmul ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
+ return
+}
+
+// CHECK: func @gemm_private
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: %[[alloc_A:.*]] = memref.alloca() : memref<16x16xf32, #gpu.address_space<private>>
+// CHECK: %[[alloc_B:.*]] = memref.alloca() : memref<16x16xf32, #gpu.address_space<private>>
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: %[[subview_A:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: %[[subview_B:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+// CHECK: %[[subview_C:.*]] = memref.subview {{.*}} : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+
+// CHECK: %[[private_A:.*]] = memref.subview %[[alloc_B]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<private>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+// CHECK: %[[private_B:.*]] = memref.subview %[[alloc_A]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x16xf32, #gpu.address_space<private>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+
+// CHECK-NEXT: memref.copy %[[subview_A]], %[[private_A]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+// CHECK-NEXT: memref.copy %[[subview_B]], %[[private_B]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[16, 1]>, #gpu.address_space<private>>
+
+// CHECK: linalg.matmul ins(%[[private_A]], %[[private_B]]{{.*}} outs(%[[subview_C]]
+
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ %1, %loops:3 = transform.structured.tile %0 [16, 16, 16] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation)
+ %2 = transform.structured.promote %1 { operands_to_promote = [0, 1], mapping = [#gpu.memory_space<private>] }
+}
+
+
// -----
#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
More information about the Mlir-commits
mailing list