[Mlir-commits] [mlir] f7907bc - [mlir] Add map_nested_foreach_thread_to_gpu_blocks op to transform dialect
Guray Ozen
llvmlistbot at llvm.org
Fri Sep 23 07:27:23 PDT 2022
Author: Guray Ozen
Date: 2022-09-23T16:27:10+02:00
New Revision: f7907bc536892e7ab1d5656a49ec708750d790f9
URL: https://github.com/llvm/llvm-project/commit/f7907bc536892e7ab1d5656a49ec708750d790f9
DIFF: https://github.com/llvm/llvm-project/commit/f7907bc536892e7ab1d5656a49ec708750d790f9.diff
LOG: [mlir] Add map_nested_foreach_thread_to_gpu_blocks op to transform dialect
This revision adds a new op `map_nested_foreach_thread_to_gpu_blocks` to transform dialect.
If `generate_gpu_launch` argument is given, the op first generates `gpu_launch`. Otherwise, `target` must be `gpu_launch`. The op searches top level `scf.foreach_threads` inside the `gpu_launch` and distributes them with gpu.block_id attribute.
Loop mapping is explicit and given by the map_nested_foreach_thread_to_gpu_blocks op. Mapping is done one-to-one, therefore the loops disappear.
It also adds `gpu dialect` as dependent since the new op can create `gpu::LaunchOp` for given `scf::ForeachThreadOp`.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D134190
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-gpu.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index d5fafcabb3af..e60131b893a3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -846,6 +846,63 @@ def MapNestedForeachThreadToGpuThreads :
}];
}
+def MapNestedForeachThreadToGpuBlocks : Op<Transform_Dialect,
+ "structured.map_nested_foreach_thread_to_gpu_blocks",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
+ let description = [{
+ Target the gpu_launch op and rewrite the top level `scf.foreach_thread`
+ to distributed gpu.block_id attribute. If `generate_gpu_launch` attribute
+ is set, then first generates `gpu_launch` and moves the top level
+ `scf.foreach_thread` inside.
+
+ The operation searches top level `scf.foreach_thread` ops under
+ `gpu_launch` and maps each such op to GPU blocks. Mapping is
+ one-to-one and the induction variables of `scf.foreach_thread` are
+ rewritten to gpu.block_id according to the `thread_dim_apping` attribute.
+
+ Dynamic, `scf.foreach_thread` trip counts are currently not supported.
+ Dynamic block dim sizes are currently not supported.
+
+ Only **bufferized** scf.foreach_thread are currently supported.
+ Only scf.foreach_thread distributed to **at most 3 dimensions** are
+ currently supported.
+
+ The operation alters the block size of the given gpu_launch using
+ gridDim argument.
+
+ #### Return modes:
+
+ This operation ignores non-gpu_launch ops and drops them in the return.
+
+ If any scf.foreach_thread with tensors is found, the transform definitely
+ fails.
+
+ If all the scf.foreach_thread operations contained within the LaunchOp
+ referred to by the `target` PDLOperation lower to GPU properly, the
+ transform succeeds. Otherwise the transform definitely fails.
+
+ The returned handle points to the same LaunchOp operand, consuming it and
+ producing a new SSA value to satisfy chaining and linearity of the IR
+ properties.
+ }];
+
+ let arguments = (ins PDL_Operation:$target,
+ DefaultValuedAttr<I64ArrayAttr, "{}">:$gridDim,
+ UnitAttr:$generate_gpu_launch);
+ let results = (outs PDL_Operation:$result);
+
+ let assemblyFormat = "$target attr-dict";
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::llvm::SmallVectorImpl<::mlir::Operation *> &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b6586718ad48..11a34cc573d8 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -125,6 +125,21 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
OpOperand *fusedOperand);
+/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
+/// one-to-one and the induction variables of `scf.foreach_thread` are rewritten
+/// to gpu.block_id according to the thread_dim_apping attribute. Dynamic,
+/// `scf.foreach_thread` trip counts are currently not supported. Dynamic block
+/// dim sizes are currently not supported.
+LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(
+ RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
+ function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+ SmallVector<Value> &)>
+ blockIdGenerator,
+ SmallVector<int64_t> &gridDims);
+
+/// Finds the top level scf::ForeachThreadOp of given target.
+FailureOr<scf::ForeachThreadOp> findTopLevelForeachThreadOp(Operation *target);
+
/// Searches `scf.foreach_thread` ops nested under `target` and maps each such
/// op to GPU threads. Mapping is one-to-one and the induction variables of
/// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ca3c932c77cc..18df9ed49267 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1285,25 +1285,56 @@ mlir::WalkResult mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
return walkResult;
}
-// Alter blockDim of the given kernel
-static LogicalResult alterGpuLaunchBlockDim(SimpleRewriter &rewriter,
- gpu::LaunchOp gpuLaunch,
- SmallVector<int64_t> blockDim) {
- gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
- if (blockDim[0] < 1 || blockDim[1] < 1 || blockDim[2] < 1) {
- gpuLaunch->emitError() << "Given blockDim(" << blockDim[0] << ","
- << blockDim[1] << "," << blockDim[2]
- << ") is invalid";
+static LogicalResult
+checkGpuLimits(Optional<int64_t> gridDimX, Optional<int64_t> gridDimY,
+ Optional<int64_t> gridDimZ, Optional<int64_t> blockDimX,
+ Optional<int64_t> blockDimY, Optional<int64_t> blockDimZ) {
+ // TODO The limits should live in the gpu dialect, but it's not like that
+ // right now. Read them in the common gpu dialect
+ if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
+ 1024 ||
+ gridDimY.value_or(1) > 65535 || gridDimZ.value_or(1) > 65535 ||
+ gridDimX.value_or(1) > 2147483647)
+ return failure();
+ return success();
+}
+
+/// Alter grid or block dimensions of the given kernel
+static LogicalResult alterGpuLaunch(SimpleRewriter &rewriter,
+ gpu::LaunchOp gpuLaunch,
+ Optional<int64_t> gridDimX = llvm::None,
+ Optional<int64_t> gridDimY = llvm::None,
+ Optional<int64_t> gridDimZ = llvm::None,
+ Optional<int64_t> blockDimX = llvm::None,
+ Optional<int64_t> blockDimY = llvm::None,
+ Optional<int64_t> blockDimZ = llvm::None) {
+ if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
+ blockDimZ))) {
+ gpuLaunch->emitError(
+ "Requested kernel thread configuration is larger than the limits");
return failure();
}
+
+ gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
+ OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(currentBlockdim.x);
- auto createBlockDimValue = [&](int64_t dim) {
+ auto createConstValue = [&](int dim) {
return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
dim);
};
- gpuLaunch.blockSizeXMutable().assign(createBlockDimValue(blockDim[0]));
- gpuLaunch.blockSizeYMutable().assign(createBlockDimValue(blockDim[1]));
- gpuLaunch.blockSizeZMutable().assign(createBlockDimValue(blockDim[2]));
+
+ if (gridDimX.has_value())
+ gpuLaunch.gridSizeXMutable().assign(createConstValue(gridDimX.value()));
+ if (gridDimY.has_value())
+ gpuLaunch.gridSizeYMutable().assign(createConstValue(gridDimY.value()));
+ if (gridDimZ.has_value())
+ gpuLaunch.gridSizeZMutable().assign(createConstValue(gridDimZ.value()));
+ if (blockDimX.has_value())
+ gpuLaunch.blockSizeXMutable().assign(createConstValue(blockDimX.value()));
+ if (blockDimY.has_value())
+ gpuLaunch.blockSizeYMutable().assign(createConstValue(blockDimY.value()));
+ if (blockDimZ.has_value())
+ gpuLaunch.blockSizeZMutable().assign(createConstValue(blockDimZ.value()));
return success();
}
@@ -1327,7 +1358,9 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
if (walkResult.wasInterrupted())
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- LogicalResult result = alterGpuLaunchBlockDim(rewriter, gpuLaunch, blockDim);
+ LogicalResult result =
+ alterGpuLaunch(rewriter, gpuLaunch, llvm::None, llvm::None, llvm::None,
+ blockDim[0], blockDim[1], blockDim[2]);
if (failed(result))
return DiagnosedSilenceableFailure::definiteFailure();
@@ -1335,6 +1368,184 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
return DiagnosedSilenceableFailure(success());
}
+//===----------------------------------------------------------------------===//
+// MapNestedForeachThreadToGpuBlocks
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
+ RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
+ function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+ SmallVector<Value> &)>
+ blockIdGenerator,
+ SmallVector<int64_t> &gridDims) {
+ if (foreachThreadOp.getNumResults() > 0)
+ return foreachThreadOp->emitError(
+ "only bufferized scf.foreach_thread lowers to gpu.block_id");
+ if (foreachThreadOp.getNumThreads().size() > 3)
+ return foreachThreadOp->emitError(
+ "scf.foreach_thread with rank > 3 does not lower to gpu.block_id");
+
+ // Step 0. Outline the compute workload region and set up the workload
+ // operands.
+ auto potentialGridDim = foreachThreadOp.getPermutedNumThreads(rewriter);
+ if (failed(potentialGridDim) ||
+ llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
+ return !getConstantIntValue(ofr).has_value();
+ }))
+ return foreachThreadOp->emitError("unsupported dynamic gridDim");
+
+ for (OpFoldResult ofr : *potentialGridDim)
+ gridDims.push_back(getConstantIntValue(ofr).value());
+
+ IndexType indexType = rewriter.getIndexType();
+ SmallVector<Value> blockOps;
+ blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps);
+
+ // Step 1. Move the body of foreachThreadOp.
+ // Erase the terminator first, it will not be used since we are on buffers.
+ rewriter.eraseOp(foreachThreadOp.getTerminator());
+ Block *targetBlock = foreachThreadOp->getBlock();
+ Block::iterator insertionPoint = Block::iterator(foreachThreadOp);
+ Block &sourceBlock = foreachThreadOp.getRegion().front();
+ targetBlock->getOperations().splice(insertionPoint,
+ sourceBlock.getOperations());
+
+ // Step 2. RAUW thread indices to thread ops.
+ SmallVector<Value> threadIndices =
+ *foreachThreadOp.getPermutedThreadIndices();
+ assert(blockOps.size() == 3 && "3 block id ops are required");
+ for (auto it : llvm::zip(threadIndices, blockOps)) {
+ Value val = std::get<0>(it);
+ if (!val)
+ continue;
+ for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
+ rewriter.updateRootInPlace(
+ user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); });
+ }
+ }
+
+ // Step 3. Erase old op.
+ rewriter.eraseOp(foreachThreadOp);
+
+ return success();
+}
+
+FailureOr<scf::ForeachThreadOp>
+mlir::linalg::findTopLevelForeachThreadOp(Operation *target) {
+ scf::ForeachThreadOp topLevelForeachThreadOp;
+ auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
+ if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
+ return WalkResult::advance();
+ if (topLevelForeachThreadOp)
+ // TODO Handle multiple foreach if there is no dependences between them
+ return WalkResult::interrupt();
+ topLevelForeachThreadOp = foreachThreadOp;
+ return WalkResult::advance();
+ });
+
+ if (walkResult.wasInterrupted())
+ return target->emitError(
+ "could not find a unique topLevel scf.foreach_thread");
+
+ return topLevelForeachThreadOp;
+}
+
+/// Create gpuLauncOp with given kernel configurations
+static FailureOr<gpu::LaunchOp>
+createGpuLaunch(RewriterBase &rewriter, Location loc,
+ Optional<int64_t> gridDimX = llvm::None,
+ Optional<int64_t> gridDimY = llvm::None,
+ Optional<int64_t> gridDimZ = llvm::None,
+ Optional<int64_t> blockDimX = llvm::None,
+ Optional<int64_t> blockDimY = llvm::None,
+ Optional<int64_t> blockDimZ = llvm::None) {
+ if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
+ blockDimZ)))
+ return failure();
+ auto createConstant = [&](int dim) {
+ return rewriter.create<arith::ConstantIndexOp>(loc, dim);
+ };
+ Value one = createConstant(1);
+ Value gridSizeX =
+ gridDimX.has_value() ? createConstant(gridDimX.value()) : one;
+ Value gridSizeY =
+ gridDimY.has_value() ? createConstant(gridDimY.value()) : one;
+ Value gridSizeZ =
+ gridDimZ.has_value() ? createConstant(gridDimZ.value()) : one;
+ Value blockSizeX =
+ blockDimX.has_value() ? createConstant(blockDimX.value()) : one;
+ Value blockSizeY =
+ blockDimY.has_value() ? createConstant(blockDimY.value()) : one;
+ Value blockSizeZ =
+ blockDimZ.has_value() ? createConstant(blockDimZ.value()) : one;
+ auto launchOp = rewriter.create<gpu::LaunchOp>(
+ loc, gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ);
+ rewriter.setInsertionPointToEnd(&launchOp.body().front());
+ rewriter.create<gpu::TerminatorOp>(loc);
+ return launchOp;
+}
+
+DiagnosedSilenceableFailure
+transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
+ Operation *target, SmallVectorImpl<Operation *> &results,
+ transform::TransformState &state) {
+ gpu::LaunchOp gpuLaunch = dyn_cast<gpu::LaunchOp>(target);
+ SimpleRewriter rewriter(getContext());
+
+ if (!getGenerateGpuLaunch() && !gpuLaunch) {
+ target->emitError("Given target is not gpu.launch, set "
+ "`generate_gpu_launch` attribute");
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ auto res = mlir::linalg::findTopLevelForeachThreadOp(target);
+ if (failed(res))
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+ scf::ForeachThreadOp topLevelForeachThreadOp = *res;
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(topLevelForeachThreadOp);
+
+ // Generate gpu launch here and move the foreach_thread inside
+ if (getGenerateGpuLaunch()) {
+ FailureOr<gpu::LaunchOp> maybeGpuLaunch =
+ createGpuLaunch(rewriter, target->getLoc());
+ if (failed(maybeGpuLaunch))
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+ gpuLaunch = *maybeGpuLaunch;
+ rewriter.setInsertionPointToStart(&gpuLaunch.body().front());
+ Operation *newForeachThreadOp = rewriter.clone(*topLevelForeachThreadOp);
+ rewriter.eraseOp(topLevelForeachThreadOp);
+ topLevelForeachThreadOp =
+ dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
+ }
+
+ auto generateBlocks = [&](Operation *op, const SmallVector<int64_t> &gridDims,
+ IndexType indexType, SmallVector<Value> &blockOps) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
+ gpu::Dimension::z};
+ for (int64_t idx : llvm::seq<int64_t>(0, gridDims.size())) {
+ blockOps.push_back(
+ rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
+ }
+ };
+
+ SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
+ if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
+ rewriter, topLevelForeachThreadOp, generateBlocks, gridDim)))
+ return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
+
+ if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],
+ gridDim[2])))
+ return DiagnosedSilenceableFailure::definiteFailure();
+
+ results.assign({gpuLaunch});
+ return DiagnosedSilenceableFailure(success());
+}
+
//===----------------------------------------------------------------------===//
// TileToForeachThreadOp
//===----------------------------------------------------------------------===//
@@ -1562,6 +1773,7 @@ class LinalgTransformDialectExtension
declareGeneratedDialect<arith::ArithmeticDialect>();
declareGeneratedDialect<scf::SCFDialect>();
declareGeneratedDialect<vector::VectorDialect>();
+ declareGeneratedDialect<gpu::GPUDialect>();
registerTransformOps<
#define GET_OP_LIST
diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir
index 00b750eb7927..fbd7bcb6ccf8 100644
--- a/mlir/test/Dialect/Linalg/transform-gpu.mlir
+++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir
@@ -1,4 +1,45 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file -canonicalize -cse %s | FileCheck %s
+
+!type = memref<2 x 32 x f32>
+!type1d = memref<32 x f32>
+
+// CHECK-LABEL: func.func @saxpy2dblock(
+// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>
+// CHECK-SAME: %[[ARGT:[0-9a-z]+]]: memref<32xf32>
+func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type {
+ %c9 = arith.constant 9 : index
+ %c7 = arith.constant 7 : index
+ %one = arith.constant 1 : index
+// CHECK: gpu.launch
+// CHECK: %[[BLKX:.*]] = gpu.block_id x
+// CHECK: %[[BLKY:.*]] = gpu.block_id y
+// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]]]
+// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]]]
+ %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
+ threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
+ {
+ scf.foreach_thread (%i, %j) in (%c7, %c9) {
+ %4 = memref.load %x[%i, %j] : !type
+ %5 = memref.load %y[%i, %j] : !type
+ %6 = math.fma %alpha, %4, %5 : f32
+ memref.store %6, %y[%i, %j] : !type
+ } {thread_dim_mapping = [0, 1, 2]}
+ gpu.terminator
+ }
+ return %y : !type
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0
+ transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { blockDim = [12, 9, 1]}
+ }
+}
+
+// -----
!type = memref<2 x 32 x f32>
!type1d = memref<32 x f32>
@@ -12,21 +53,20 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
%c12 = arith.constant 12 : index
%c9 = arith.constant 9 : index
%c7 = arith.constant 7 : index
-// CHECK: gpu.launch
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C12:.*]] = arith.constant 12 : index
+// CHECK: %[[C9:.*]] = arith.constant 9 : index
+// CHECK: %[[C7:.*]] = arith.constant 7 : index
+// CHECK: gpu.launch async [%{{.*}}] blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C1]], %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C12]], %{{.*}} = %[[C9]], %{{.*}} = %[[C1]])
// CHECK: %[[TIDX:.*]] = gpu.thread_id x
// CHECK: %[[TIDY:.*]] = gpu.thread_id y
-// CHECK: %[[C9:.*]] = arith.constant 9 : index
// CHECK: arith.cmpi ult, %[[TIDX]], %[[C9]] : index
-// CHECK: %[[C7:.*]] = arith.constant 7 : index
// CHECK: arith.cmpi ult, %[[TIDY]], %[[C7]] : index
// CHECK: memref.load %[[ARGX]][%[[TIDY]], %[[TIDX]]]
// CHECK: memref.load %[[ARGY]][%[[TIDY]], %[[TIDX]]]
// CHECK: gpu.barrier
-// CHECK: %[[TIDX2:.*]] = gpu.thread_id x
-// CHECK: %[[TIDY2:.*]] = gpu.thread_id y
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: arith.cmpi ult, %[[TIDY2]], %[[C1]] : index
-// CHECK: memref.load %[[ARGT]][%[[TIDX2]]]
+// CHECK: arith.cmpi ult, %[[TIDY]], %[[C1]] : index
+// CHECK: memref.load %[[ARGT]][%[[TIDX]]]
// CHECK: gpu.barrier
%name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
@@ -56,3 +96,45 @@ transform.with_pdl_patterns {
}
}
+// -----
+
+!type4d = memref<32x64x4x32xf32>
+
+// CHECK-LABEL: func.func @saxpy4d(
+// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<32x64x4x32xf32>
+// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<32x64x4x32xf32>
+func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d {
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ %c4 = arith.constant 4 : index
+// CHECK: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: %[[C64:.*]] = arith.constant 64 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C32]], %{{.*}} = %[[C64]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C32]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1]])
+// CHECK: %[[BLKX:.*]] = gpu.block_id x
+// CHECK: %[[BLKY:.*]] = gpu.block_id y
+// CHECK: %[[TIDX:.*]] = gpu.thread_id x
+// CHECK: %[[TIDY:.*]] = gpu.thread_id y
+// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]]
+// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]]
+ scf.foreach_thread (%i, %j) in (%c32, %c64) {
+ scf.foreach_thread (%k, %l) in (%c4, %c32) {
+ %4 = memref.load %x[%i, %j, %k, %l] : !type4d
+ %5 = memref.load %y[%i, %j, %k, %l] : !type4d
+ %6 = math.fma %alpha, %4, %5 : f32
+ memref.store %6, %y[%i, %j, %k, %l] : !type4d
+ } {thread_dim_mapping = [1, 0, 2]}
+ } {thread_dim_mapping = [0, 1, 2]}
+ return %y : !type4d
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ transform.sequence %arg0 failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %funcop = transform.structured.match ops{["func.func"]} in %arg0
+ %gpuLaunch = transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { generate_gpu_launch }
+ transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [32, 4, 1] }
+ }
+}
More information about the Mlir-commits
mailing list