[Mlir-commits] [mlir] f8ad6ea - [mlir] Refactor transform dialect's gpu block func

Guray Ozen llvmlistbot at llvm.org
Tue Sep 27 03:27:26 PDT 2022


Author: Guray Ozen
Date: 2022-09-27T12:27:17+02:00
New Revision: f8ad6eaf92ac241420ef99db855d003d5b6d274e

URL: https://github.com/llvm/llvm-project/commit/f8ad6eaf92ac241420ef99db855d003d5b6d274e
DIFF: https://github.com/llvm/llvm-project/commit/f8ad6eaf92ac241420ef99db855d003d5b6d274e.diff

LOG: [mlir] Refactor transform dialect's gpu block func

This revision refactors gpu block id generator lambda that is used in the transform dialect. It removes the lambda  and instead uses a static function that's name generateGpuBlockIds.

It also simplifies arguments that the function takes.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D134724

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 11a34cc573d84..2e7b247d60ca2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -132,7 +132,7 @@ FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
 /// dim sizes are currently not supported.
 LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(
     RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
-    function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+    function_ref<void(RewriterBase &, scf::ForeachThreadOp,
                       SmallVector<Value> &)>
         blockIdGenerator,
     SmallVector<int64_t> &gridDims);

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8f6a6b10b29f5..a5331c1e40bd1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1374,7 +1374,7 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
 
 LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
     RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
-    function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
+    function_ref<void(RewriterBase &, scf::ForeachThreadOp,
                       SmallVector<Value> &)>
         blockIdGenerator,
     SmallVector<int64_t> &gridDims) {
@@ -1397,9 +1397,8 @@ LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
   for (OpFoldResult ofr : *potentialGridDim)
     gridDims.push_back(getConstantIntValue(ofr).value());
 
-  IndexType indexType = rewriter.getIndexType();
   SmallVector<Value> blockOps;
-  blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps);
+  blockIdGenerator(rewriter, foreachThreadOp, blockOps);
 
   // Step 1. Move the body of foreachThreadOp.
   // Erase the terminator first, it will not be used since we are on buffers.
@@ -1485,6 +1484,23 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
   return launchOp;
 }
 
+/// This is an helper that is only used in
+/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id
+static void generateGpuBlockIds(RewriterBase &rewriter,
+                                scf::ForeachThreadOp foreachOp,
+                                SmallVector<Value> &blockOps) {
+  Location loc = foreachOp->getLoc();
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(foreachOp);
+  IndexType indexType = rewriter.getIndexType();
+  SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
+                                      gpu::Dimension::z};
+  for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
+    blockOps.push_back(
+        rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
+  }
+}
+
 DiagnosedSilenceableFailure
 transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
     Operation *target, SmallVectorImpl<Operation *> &results,
@@ -1520,22 +1536,9 @@ transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
         dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
   }
 
-  auto generateBlocks = [&](Operation *op, const SmallVector<int64_t> &gridDims,
-                            IndexType indexType, SmallVector<Value> &blockOps) {
-    Location loc = op->getLoc();
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPoint(op);
-    SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
-                                        gpu::Dimension::z};
-    for (int64_t idx : llvm::seq<int64_t>(0, gridDims.size())) {
-      blockOps.push_back(
-          rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
-    }
-  };
-
   SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
   if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
-          rewriter, topLevelForeachThreadOp, generateBlocks, gridDim)))
+          rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim)))
     return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
 
   if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],


        


More information about the Mlir-commits mailing list