[Mlir-commits] [mlir] f8ad6ea - [mlir] Refactor transform dialect's gpu block func
Guray Ozen
llvmlistbot at llvm.org
Tue Sep 27 03:27:26 PDT 2022
Author: Guray Ozen
Date: 2022-09-27T12:27:17+02:00
New Revision: f8ad6eaf92ac241420ef99db855d003d5b6d274e
URL: https://github.com/llvm/llvm-project/commit/f8ad6eaf92ac241420ef99db855d003d5b6d274e
DIFF: https://github.com/llvm/llvm-project/commit/f8ad6eaf92ac241420ef99db855d003d5b6d274e.diff
LOG: [mlir] Refactor transform dialect's gpu block func
This revision refactors gpu block id generator lambda that is used in the transform dialect. It removes the lambda and instead uses a static function that's name generateGpuBlockIds.
It also simplifies arguments that the function takes.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D134724
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 11a34cc573d84..2e7b247d60ca2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -132,7 +132,7 @@ FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
/// dim sizes are currently not supported.
LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
- function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+ function_ref<void(RewriterBase &, scf::ForeachThreadOp,
SmallVector<Value> &)>
blockIdGenerator,
SmallVector<int64_t> &gridDims);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f6a6b10b29f5..a5331c1e40bd1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1374,7 +1374,7 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
- function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+ function_ref<void(RewriterBase &, scf::ForeachThreadOp,
SmallVector<Value> &)>
blockIdGenerator,
SmallVector<int64_t> &gridDims) {
@@ -1397,9 +1397,8 @@ LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
for (OpFoldResult ofr : *potentialGridDim)
gridDims.push_back(getConstantIntValue(ofr).value());
- IndexType indexType = rewriter.getIndexType();
SmallVector<Value> blockOps;
- blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps);
+ blockIdGenerator(rewriter, foreachThreadOp, blockOps);
// Step 1. Move the body of foreachThreadOp.
// Erase the terminator first, it will not be used since we are on buffers.
@@ -1485,6 +1484,23 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
return launchOp;
}
+/// This is an helper that is only used in
+/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id
+static void generateGpuBlockIds(RewriterBase &rewriter,
+ scf::ForeachThreadOp foreachOp,
+ SmallVector<Value> &blockOps) {
+ Location loc = foreachOp->getLoc();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(foreachOp);
+ IndexType indexType = rewriter.getIndexType();
+ SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
+ gpu::Dimension::z};
+ for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
+ blockOps.push_back(
+ rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
+ }
+}
+
DiagnosedSilenceableFailure
transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
Operation *target, SmallVectorImpl<Operation *> &results,
@@ -1520,22 +1536,9 @@ transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
}
- auto generateBlocks = [&](Operation *op, const SmallVector<int64_t> &gridDims,
- IndexType indexType, SmallVector<Value> &blockOps) {
- Location loc = op->getLoc();
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(op);
- SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
- gpu::Dimension::z};
- for (int64_t idx : llvm::seq<int64_t>(0, gridDims.size())) {
- blockOps.push_back(
- rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
- }
- };
-
SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
- rewriter, topLevelForeachThreadOp, generateBlocks, gridDim)))
+ rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim)))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],
More information about the Mlir-commits
mailing list