[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `scf.if` (PR #157119)

Charitha Saumya llvmlistbot at llvm.org
Tue Sep 9 23:01:39 PDT 2025


================
@@ -1713,6 +1713,209 @@ struct WarpOpInsert : public WarpDistributionPattern {
   }
 };
 
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+  WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
+    // Only pick up `IfOp` if it is the last op in the region.
+    Operation *lastNode = warpOpYield->getPrevNode();
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+    if (!ifOp)
+      return failure();
+
+    // The current `WarpOp` can yield two types of values:
+    // 1. Not results of `IfOp`:
+    //     Preserve them in the new `WarpOp`.
+    //     Collect their yield index.
+    // 2. Results of `IfOp`:
+    //     They are not part of the new `WarpOp` results.
+    //     Map current warp's yield operand index to `IfOp` result idx.
+    SmallVector<Value> nonIfYieldValues;
+    SmallVector<unsigned> nonIfYieldIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+    llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+      const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+      if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+        nonIfYieldValues.push_back(yieldOperand.get());
+        nonIfYieldIndices.push_back(yieldOperandIdx);
+        continue;
+      }
+      OpResult ifResult = cast<OpResult>(yieldOperand.get());
+      const unsigned ifResultIdx = ifResult.getResultNumber();
+      ifResultMapping[yieldOperandIdx] = ifResultIdx;
+      // If this `ifOp` result is vector type and it is yielded by the
+      // `WarpOp`, we keep track the distributed type for this result.
+      if (!isa<VectorType>(ifResult.getType()))
+        continue;
+      VectorType distType =
+          cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+      ifResultDistTypes[ifResultIdx] = distType;
+    }
+
+    // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+    // them
+    auto getEscapingValues = [&](Region &branch,
+                                 llvm::SmallSetVector<Value, 32> &values,
+                                 SmallVector<Type> &inputTypes,
+                                 SmallVector<Type> &distTypes) {
+      if (branch.empty())
+        return;
+      mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
+        Operation *parent = operand->get().getParentRegion()->getParentOp();
+        if (warpOp->isAncestor(parent)) {
+          if (!values.insert(operand->get()))
+            return;
+          Type distType = operand->get().getType();
+          if (auto vecType = dyn_cast<VectorType>(distType)) {
+            AffineMap map = distributionMapFn(operand->get());
+            distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+          }
+          inputTypes.push_back(operand->get().getType());
+          distTypes.push_back(distType);
+        }
+      });
+    };
+    llvm::SmallSetVector<Value, 32> escapingValuesThen;
+    SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
+    SmallVector<Type> escapingValueDistTypesThen;  // new warp returns
+    getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
+                      escapingValueInputTypesThen, escapingValueDistTypesThen);
+    llvm::SmallSetVector<Value, 32> escapingValuesElse;
+    SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
+    SmallVector<Type> escapingValueDistTypesElse;  // new warp returns
+    getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
+                      escapingValueInputTypesElse, escapingValueDistTypesElse);
+
+    if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+        llvm::is_contained(escapingValueDistTypesElse, Type{}))
+      return failure();
+
+    // The new `WarpOp` groups yields values in following order:
+    // 1. Branch condition
+    // 2. Escaping values then branch
+    // 3. Escaping values else branch
+    // 4. All non-`ifOp` yielded values.
+    SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
+    newWarpOpYieldValues.append(escapingValuesThen.begin(),
+                                escapingValuesThen.end());
+    newWarpOpYieldValues.append(escapingValuesElse.begin(),
+                                escapingValuesElse.end());
+    SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
+    newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
+                              escapingValueDistTypesThen.end());
+    newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+                              escapingValueDistTypesElse.end());
+
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+    for (auto [idx, val] :
+         llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+      origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(val);
+      newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+    }
+    // Create the new `WarpOp` with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    // `ifOp` returns the result of the inner warp op.
+    SmallVector<Type> newIfOpDistResTypes;
+    for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+      Type distType = cast<Value>(res).getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(cast<Value>(res));
+        distType = ifResultDistTypes.count(i)
+                       ? ifResultDistTypes[i]
+                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newIfOpDistResTypes.push_back(distType);
+    }
+    // Create a new `IfOp` outside the new `WarpOp` region.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newIfOp = scf::IfOp::create(
+        rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
+        static_cast<bool>(ifOp.thenBlock()),
+        static_cast<bool>(ifOp.elseBlock()));
+
+    auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
----------------
charithaintc wrote:

let's go back to original lambda version then. That looks cleaner. sorry about more work. :-)

https://github.com/llvm/llvm-project/pull/157119


More information about the Mlir-commits mailing list