[Mlir-commits] [mlir] 0eabb88 - [mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread

Thomas Raoux llvmlistbot at llvm.org
Thu Feb 16 19:27:18 PST 2023


Author: Thomas Raoux
Date: 2023-02-17T03:25:15Z
New Revision: 0eabb884abebe748a5f71273f0e670ecdbabd400

URL: https://github.com/llvm/llvm-project/commit/0eabb884abebe748a5f71273f0e670ecdbabd400
DIFF: https://github.com/llvm/llvm-project/commit/0eabb884abebe748a5f71273f0e670ecdbabd400.diff

LOG: [mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D144219

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 97eb323ddb254..9b6485523c1c9 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -42,8 +42,11 @@ namespace gpu {
 /// supported. Dynamic block dim sizes are currently not supported.
 DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
-    const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
-    std::optional<TransformOpInterface> transformOp,
+    const SmallVectorImpl<int64_t> &blockDim,
+    function_ref<void(RewriterBase &, scf::ForeachThreadOp,
+                      SmallVectorImpl<Value> &)>
+        threadIdGenerator,
+    bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
 
 /// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index abb58369aa36f..90f04f3713cf3 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -502,8 +502,11 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
 
 DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
-    const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
-    std::optional<TransformOpInterface> transformOp,
+    const SmallVectorImpl<int64_t> &blockDim,
+    function_ref<void(RewriterBase &, scf::ForeachThreadOp,
+                      SmallVectorImpl<Value> &)>
+        threadIdGenerator,
+    bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
   DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
   target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -517,14 +520,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
                               foreachThreadOp.getMapping(), transformOp);
     if (diag.succeeded()) {
       rewriter.setInsertionPoint(foreachThreadOp);
-      IndexType indexType = rewriter.getIndexType();
-      SmallVector<Value> threadOps{
-          rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
-                                      Dimension::x),
-          rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
-                                      Dimension::y),
-          rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
-                                      Dimension::z)};
+      SmallVector<Value> threadOps;
+      threadIdGenerator(rewriter, foreachThreadOp, threadOps);
       diag = rewriteOneForeachThreadToGpuThreads(
           rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
           transformOp, threadMappingAttributes);
@@ -562,10 +559,20 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
       GPUThreadMappingAttr::get(ctx, Threads::DimX),
       GPUThreadMappingAttr::get(ctx, Threads::DimY),
       GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
-
+  auto threadIdGenerator = [](RewriterBase &rewriter,
+                              scf::ForeachThreadOp foreachThreadOp,
+                              SmallVectorImpl<Value> &threadIds) {
+    IndexType indexType = rewriter.getIndexType();
+    threadIds.assign({rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
+                                                  indexType, Dimension::x),
+                      rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
+                                                  indexType, Dimension::y),
+                      rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
+                                                  indexType, Dimension::z)});
+  };
   diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
-      rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
-      threadMappingAttributes);
+      rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
+      transformOp, threadMappingAttributes);
 
   if (diag.succeeded()) {
     diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,


        


More information about the Mlir-commits mailing list