[Mlir-commits] [mlir] 288ae0b - [mlir][gpu] NFC change to pass threadID ops to rewriteOneForeachThreadToGpuThreads

Thomas Raoux llvmlistbot at llvm.org
Mon Feb 13 17:28:21 PST 2023


Author: Thomas Raoux
Date: 2023-02-14T01:28:11Z
New Revision: 288ae0b92f57cc6fcd77a6e5220e67fba7768ceb

URL: https://github.com/llvm/llvm-project/commit/288ae0b92f57cc6fcd77a6e5220e67fba7768ceb
DIFF: https://github.com/llvm/llvm-project/commit/288ae0b92f57cc6fcd77a6e5220e67fba7768ceb.diff

LOG: [mlir][gpu] NFC change to pass threadID ops to rewriteOneForeachThreadToGpuThreads

This allows user to give both the thread ids and dimension of the threads we want to distribute on.
This means we can use it to distribute on warps as well.

Reviewed By: harsh

Differential Revision: https://reviews.llvm.org/D143950

Added: 
    

Modified: 
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index ecad3aa48bd8..742189c2ae49 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -366,7 +366,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
 /// not supported. Dynamic block dim sizes are currently not supported.
 static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
     RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
-    const SmallVectorImpl<int64_t> &globalBlockDims, bool syncAfterDistribute,
+    const SmallVectorImpl<int64_t> &globalBlockDims,
+    const SmallVectorImpl<Value> &threadOps, bool syncAfterDistribute,
     std::optional<TransformOpInterface> transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
   // Step 0. Target-specific verifications. There is no good place to anchor
@@ -427,28 +428,26 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
   // Step 3. Create the gpu.thread ops and map the induction variables to the
   // newly created ops.
   IndexType indexType = rewriter.getIndexType();
-  SmallVector<Value> threadOps{
-      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::x),
-      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::y),
-      rewriter.create<ThreadIdOp>(loc, indexType, Dimension::z)};
   // Replace ids of dimension size 1 by zero to simplify the IR.
+  SmallVector<Value> threadOpsUpdated(threadOps.begin(), threadOps.end());
+  assert(threadOps.size() == globalBlockDims.size());
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) {
     if (globalBlockDims[i] == 1)
-      threadOps[i] = zero;
+      threadOpsUpdated[i] = zero;
   }
   IRMapping bvm;
   for (auto [blockIdx, blockDim] :
        llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) {
-    bvm.map(
-        blockIdx,
-        threadOps[blockDim.cast<DeviceMappingAttrInterface>().getMappingId()]);
+    bvm.map(blockIdx,
+            threadOpsUpdated[blockDim.cast<DeviceMappingAttrInterface>()
+                                 .getMappingId()]);
   }
 
   // Step 4. Maybe create conditionals to predicate the region.
   Value predicate;
   for (auto [threadId, blockDim, globalBlockDim] :
-       llvm::zip(threadOps, blockDims, globalBlockDims)) {
+       llvm::zip(threadOpsUpdated, blockDims, globalBlockDims)) {
     if (blockDim > globalBlockDim) {
       return failureHelper(
           "The requested GPU threads are fewer than the number of loop trip "
@@ -519,9 +518,17 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
                               foreachThreadOp.getMapping(), transformOp);
     if (diag.succeeded()) {
       rewriter.setInsertionPoint(foreachThreadOp);
+      IndexType indexType = rewriter.getIndexType();
+      SmallVector<Value> threadOps{
+          rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
+                                      Dimension::x),
+          rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
+                                      Dimension::y),
+          rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
+                                      Dimension::z)};
       diag = rewriteOneForeachThreadToGpuThreads(
-          rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp,
-          threadMappingAttributes);
+          rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
+          transformOp, threadMappingAttributes);
     }
     return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt();
   });


        


More information about the Mlir-commits mailing list