[Mlir-commits] [mlir] 288ae0b - [mlir][gpu] NFC change to pass threadID ops to rewriteOneForeachThreadToGpuThreads
Thomas Raoux
llvmlistbot at llvm.org
Mon Feb 13 17:28:21 PST 2023
Author: Thomas Raoux
Date: 2023-02-14T01:28:11Z
New Revision: 288ae0b92f57cc6fcd77a6e5220e67fba7768ceb
URL: https://github.com/llvm/llvm-project/commit/288ae0b92f57cc6fcd77a6e5220e67fba7768ceb
DIFF: https://github.com/llvm/llvm-project/commit/288ae0b92f57cc6fcd77a6e5220e67fba7768ceb.diff
LOG: [mlir][gpu] NFC change to pass threadID ops to rewriteOneForeachThreadToGpuThreads
This allows user to give both the thread ids and dimension of the threads we want to distribute on.
This means we can use it to distribute on warps as well.
Reviewed By: harsh
Differential Revision: https://reviews.llvm.org/D143950
Added:
Modified:
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index ecad3aa48bd8..742189c2ae49 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -366,7 +366,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
/// not supported. Dynamic block dim sizes are currently not supported.
static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
- const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
+ const SmallVectorImpl<int64_t> &globalBlockDims,
+ const SmallVectorImpl<Value> &threadOps, bool syncAfterDistribute,
std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
// Step 0. Target-specific verifications. There is no good place to anchor
@@ -427,28 +428,26 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
// Step 3. Create the gpu.thread ops and map the induction variables to the
// newly created ops.
IndexType indexType = rewriter.getIndexType();
- SmallVector<Value> threadOps{
- rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
- rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
- rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
// Replace ids of dimension size 1 by zero to simplify the IR.
+ SmallVector<Value> threadOpsUpdated(threadOps.begin(), threadOps.end());
+ assert(threadOps.size() == globalBlockDims.size());
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) {
if (globalBlockDims[i] == 1)
- threadOps[i] = zero;
+ threadOpsUpdated[i] = zero;
}
IRMapping bvm;
for (auto [blockIdx, blockDim] :
llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
- bvm.map(
- blockIdx,
- threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
+ bvm.map(blockIdx,
+ threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
+ .getMappingId()]);
}
// Step 4. Maybe create conditionals to predicate the region.
Value predicate;
for (auto [threadId, blockDim, globalBlockDim] :
- llvm::zip(threadOps, blockDims, globalBlockDims)) {
+ llvm::zip(threadOpsUpdated, blockDims, globalBlockDims)) {
if (blockDim > globalBlockDim) {
return failureHelper(
"The requested GPU threads are fewer than the number of loop trip "
@@ -519,9 +518,17 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
foreachThreadOp.getMapping(), transformOp);
if (diag.succeeded()) {
rewriter.setInsertionPoint(foreachThreadOp);
+ IndexType indexType = rewriter.getIndexType();
+ SmallVector<Value> threadOps{
+ rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
+ Dimension::x),
+ rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
+ Dimension::y),
+ rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
+ Dimension::z)};
diag = rewriteOneForeachThreadToGpuThreads(
- rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
- threadMappingAttributes);
+ rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
+ transformOp, threadMappingAttributes);
}
return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
});
More information about the Mlir-commits
mailing list