[Mlir-commits] [mlir] 0eabb88 - [mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread
Thomas Raoux
llvmlistbot at llvm.org
Thu Feb 16 19:27:18 PST 2023
Author: Thomas Raoux
Date: 2023-02-17T03:25:15Z
New Revision: 0eabb884abebe748a5f71273f0e670ecdbabd400
URL: https://github.com/llvm/llvm-project/commit/0eabb884abebe748a5f71273f0e670ecdbabd400
DIFF: https://github.com/llvm/llvm-project/commit/0eabb884abebe748a5f71273f0e670ecdbabd400.diff
LOG: [mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D144219
Added:
Modified:
mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 97eb323ddb254..9b6485523c1c9 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -42,8 +42,11 @@ namespace gpu {
/// supported. Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
RewriterBase &rewriter, Operation *target,
- const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
- std::optional<TransformOpInterface> transformOp,
+ const SmallVectorImpl<int64_t> &blockDim,
+ function_ref<void(RewriterBase &, scf::ForeachThreadOp,
+ SmallVectorImpl<Value> &)>
+ threadIdGenerator,
+ bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index abb58369aa36f..90f04f3713cf3 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -502,8 +502,11 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
RewriterBase &rewriter, Operation *target,
- const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
- std::optional<TransformOpInterface> transformOp,
+ const SmallVectorImpl<int64_t> &blockDim,
+ function_ref<void(RewriterBase &, scf::ForeachThreadOp,
+ SmallVectorImpl<Value> &)>
+ threadIdGenerator,
+ bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -517,14 +520,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
foreachThreadOp.getMapping(), transformOp);
if (diag.succeeded()) {
rewriter.setInsertionPoint(foreachThreadOp);
- IndexType indexType = rewriter.getIndexType();
- SmallVector<Value> threadOps{
- rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
- Dimension::x),
- rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
- Dimension::y),
- rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
- Dimension::z)};
+ SmallVector<Value> threadOps;
+ threadIdGenerator(rewriter, foreachThreadOp, threadOps);
diag = rewriteOneForeachThreadToGpuThreads(
rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
transformOp, threadMappingAttributes);
@@ -562,10 +559,20 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
GPUThreadMappingAttr::get(ctx, Threads::DimX),
GPUThreadMappingAttr::get(ctx, Threads::DimY),
GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
-
+ auto threadIdGenerator = [](RewriterBase &rewriter,
+ scf::ForeachThreadOp foreachThreadOp,
+ SmallVectorImpl<Value> &threadIds) {
+ IndexType indexType = rewriter.getIndexType();
+ threadIds.assign({rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
+ indexType, Dimension::x),
+ rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
+ indexType, Dimension::y),
+ rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
+ indexType, Dimension::z)});
+ };
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
- rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
- threadMappingAttributes);
+ rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
+ transformOp, threadMappingAttributes);
if (diag.succeeded()) {
diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
More information about the Mlir-commits
mailing list