[Mlir-commits] [mlir] [MLIR][Vector] Add warp distribution for `scf.if` (PR #157119)

Artem Kroviakov llvmlistbot at llvm.org
Tue Sep 9 03:20:51 PDT 2025


================
@@ -1713,6 +1713,209 @@ struct WarpOpInsert : public WarpDistributionPattern {
   }
 };
 
+struct WarpOpScfIfOp : public WarpDistributionPattern {
+  WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
+      : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
+  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp warpOpYield = warpOp.getTerminator();
+    // Only pick up `IfOp` if it is the last op in the region.
+    Operation *lastNode = warpOpYield->getPrevNode();
+    auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
+    if (!ifOp)
+      return failure();
+
+    // The current `WarpOp` can yield two types of values:
+    // 1. Not results of `IfOp`:
+    //     Preserve them in the new `WarpOp`.
+    //     Collect their yield index.
+    // 2. Results of `IfOp`:
+    //     They are not part of the new `WarpOp` results.
+    //     Map current warp's yield operand index to `IfOp` result idx.
+    SmallVector<Value> nonIfYieldValues;
+    SmallVector<unsigned> nonIfYieldIndices;
+    llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
+    llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
+    for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
+      const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
+      if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
+        nonIfYieldValues.push_back(yieldOperand.get());
+        nonIfYieldIndices.push_back(yieldOperandIdx);
+        continue;
+      }
+      OpResult ifResult = cast<OpResult>(yieldOperand.get());
+      const unsigned ifResultIdx = ifResult.getResultNumber();
+      ifResultMapping[yieldOperandIdx] = ifResultIdx;
+      // If this `ifOp` result is vector type and it is yielded by the
+      // `WarpOp`, we keep track the distributed type for this result.
+      if (!isa<VectorType>(ifResult.getType()))
+        continue;
+      VectorType distType =
+          cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
+      ifResultDistTypes[ifResultIdx] = distType;
+    }
+
+    // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
+    // them
+    auto getEscapingValues = [&](Region &branch,
+                                 llvm::SmallSetVector<Value, 32> &values,
+                                 SmallVector<Type> &inputTypes,
+                                 SmallVector<Type> &distTypes) {
+      if (branch.empty())
+        return;
+      mlir::visitUsedValuesDefinedAbove(branch, [&](OpOperand *operand) {
+        Operation *parent = operand->get().getParentRegion()->getParentOp();
+        if (warpOp->isAncestor(parent)) {
+          if (!values.insert(operand->get()))
+            return;
+          Type distType = operand->get().getType();
+          if (auto vecType = dyn_cast<VectorType>(distType)) {
+            AffineMap map = distributionMapFn(operand->get());
+            distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+          }
+          inputTypes.push_back(operand->get().getType());
+          distTypes.push_back(distType);
+        }
+      });
+    };
+    llvm::SmallSetVector<Value, 32> escapingValuesThen;
+    SmallVector<Type> escapingValueInputTypesThen; // inner warp op block args
+    SmallVector<Type> escapingValueDistTypesThen;  // new warp returns
+    getEscapingValues(ifOp.getThenRegion(), escapingValuesThen,
+                      escapingValueInputTypesThen, escapingValueDistTypesThen);
+    llvm::SmallSetVector<Value, 32> escapingValuesElse;
+    SmallVector<Type> escapingValueInputTypesElse; // inner warp op block args
+    SmallVector<Type> escapingValueDistTypesElse;  // new warp returns
+    getEscapingValues(ifOp.getElseRegion(), escapingValuesElse,
+                      escapingValueInputTypesElse, escapingValueDistTypesElse);
+
+    if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
+        llvm::is_contained(escapingValueDistTypesElse, Type{}))
+      return failure();
+
+    // The new `WarpOp` groups yields values in following order:
+    // 1. Branch condition
+    // 2. Escaping values then branch
+    // 3. Escaping values else branch
+    // 4. All non-`ifOp` yielded values.
+    SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
+    newWarpOpYieldValues.append(escapingValuesThen.begin(),
+                                escapingValuesThen.end());
+    newWarpOpYieldValues.append(escapingValuesElse.begin(),
+                                escapingValuesElse.end());
+    SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
+    newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
+                              escapingValueDistTypesThen.end());
+    newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
+                              escapingValueDistTypesElse.end());
+
+    llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
+    for (auto [idx, val] :
+         llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
+      origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(val);
+      newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
+    }
+    // Create the new `WarpOp` with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    // `ifOp` returns the result of the inner warp op.
+    SmallVector<Type> newIfOpDistResTypes;
+    for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
+      Type distType = cast<Value>(res).getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(cast<Value>(res));
+        distType = ifResultDistTypes.count(i)
+                       ? ifResultDistTypes[i]
+                       : getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newIfOpDistResTypes.push_back(distType);
+    }
+    // Create a new `IfOp` outside the new `WarpOp` region.
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newIfOp = scf::IfOp::create(
+        rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
+        static_cast<bool>(ifOp.thenBlock()),
+        static_cast<bool>(ifOp.elseBlock()));
+
+    auto processBranch = [&](Block *oldIfBranch, Block *newIfBranch,
----------------
akroviakov wrote:

Renamed.

>  could also try to reuse in scf.for if possible.

The outlined utility function requires over 7 arguments to even work for `scf.if` distribution, we use many captured values inside the lambda. The mapping of escaping values to input types is different between `scf.if` and `scf.for`, and cannot be part of the utility. 
Exposing the mapping, input vals and types as additional arguments makes the utility argument list even more bloated. 
Standardizing the two ops for better outlining is out of scope for this PR.

https://github.com/llvm/llvm-project/pull/157119


More information about the Mlir-commits mailing list