[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `scf.if` (PR #157119)

Artem Kroviakov llvmlistbot at llvm.org
Tue Sep 9 03:20:57 PDT 2025


================
@@ -1713,6 +1713,209 @@ struct WarpOpInsert : public WarpDistributionPattern {
   }
 };
 
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+  WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
+    // Only pick up `IfOp` if it is the last op in the region.
+    Operation *lastNode = warpOpYield->getPrevNode();
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+    if (!ifOp)
+      return failure();
+
+    // The current `WarpOp` can yield two types of values:
+    // 1. Not results of `IfOp`:
+    //     Preserve them in the new `WarpOp`.
+    //     Collect their yield index.
+    // 2. Results of `IfOp`:
+    //     They are not part of the new `WarpOp` results.
+    //     Map current warp's yield operand index to `IfOp` result idx.
+    SmallVector<Value> nonIfYieldValues;
+    SmallVector<unsigned> nonIfYieldIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+    llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+      const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+      if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+        nonIfYieldValues.push_back(yieldOperand.get());
+        nonIfYieldIndices.push_back(yieldOperandIdx);
+        continue;
+      }
+      OpResult ifResult = cast<OpResult>(yieldOperand.get());
+      const unsigned ifResultIdx = ifResult.getResultNumber();
+      ifResultMapping[yieldOperandIdx] = ifResultIdx;
+      // If this `ifOp` result is vector type and it is yielded by the
+      // `WarpOp`, we keep track the distributed type for this result.
+      if (!isa<VectorType>(ifResult.getType()))
+        continue;
+      VectorType distType =
+          cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+      ifResultDistTypes[ifResultIdx] = distType;
+    }
+
+    // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+    // them
+    auto getEscapingValues = [&](Region &branch,
----------------
akroviakov wrote:

Done

https://github.com/llvm/llvm-project/pull/157119


More information about the Mlir-commits mailing list