[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)
Han-Chung Wang
llvmlistbot at llvm.org
Fri May 16 15:28:43 PDT 2025
================
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
-/// ```
-struct CombineContractBroadcast
- : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
- PatternRewriter &rewriter) const override {
- SmallVector<AffineMap> maps =
- llvm::to_vector<4>(contractOp.getIndexingMapsArray());
- Value lhs = contractOp.getLhs();
- Value rhs = contractOp.getRhs();
- size_t index = 0;
- bool changed = false;
- for (Value *operand : {&lhs, &rhs}) {
- AffineMap &map = maps[index++];
- auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
- if (!broadcast)
- continue;
- // contractionOp can only take vector as operands.
- auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
- if (!srcType ||
- srcType.getRank() == broadcast.getResultVectorType().getRank())
- continue;
- int64_t rankDiff =
- broadcast.getResultVectorType().getRank() - srcType.getRank();
- bool innerDimBroadcast = false;
- SmallVector<AffineExpr> originalDims;
- for (const auto &dim : llvm::enumerate(srcType.getShape())) {
- if (dim.value() != broadcast.getResultVectorType().getDimSize(
- rankDiff + dim.index())) {
- innerDimBroadcast = true;
- break;
- }
- originalDims.push_back(
- rewriter.getAffineDimExpr(dim.index() + rankDiff));
+/// ```
+///
+/// For masked vector.contract, the mask requires updating when a dimension is
+/// dropped. In such cases, the dropped dimensions must correspond to the mask's
+/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
+/// is not supported.
+FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
+ MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) {
+ SmallVector<AffineMap> maps =
+ llvm::to_vector<4>(contractOp.getIndexingMapsArray());
+ Value lhs = contractOp.getLhs();
+ Value rhs = contractOp.getRhs();
+ size_t index = 0;
+ bool changed = false;
+ for (Value *operand : {&lhs, &rhs}) {
+ AffineMap &map = maps[index++];
+ auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast)
+ continue;
+ // contractionOp can only take vector as operands.
+ auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!srcType ||
+ srcType.getRank() == broadcast.getResultVectorType().getRank())
+ continue;
+ int64_t rankDiff =
+ broadcast.getResultVectorType().getRank() - srcType.getRank();
+ bool innerDimBroadcast = false;
+ SmallVector<AffineExpr> originalDims;
+ for (const auto &dim : llvm::enumerate(srcType.getShape())) {
+ if (dim.value() !=
+ broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
+ innerDimBroadcast = true;
+ break;
}
- // Contract doesn't support inner dimension broadcast. Once this is
- // relaxed we can remove this case.
- if (innerDimBroadcast)
- continue;
+ originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
+ }
+ // Contract doesn't support inner dimension broadcast. Once this is
+ // relaxed we can remove this case.
+ if (innerDimBroadcast)
+ continue;
- // It would be incorrect to fold a broadcast onto a reduction dimension
- // of non-unit size.
- bool nonUnitDimReductionBroadcast = false;
- for (int64_t i = 0; i < rankDiff; ++i) {
- if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
- isReductionIterator(contractOp.getIteratorTypes()
- .getValue()[map.getDimPosition(i)])) {
- nonUnitDimReductionBroadcast = true;
- break;
- }
+ // It would be incorrect to fold a broadcast onto a reduction dimension
+ // of non-unit size.
+ bool nonUnitDimReductionBroadcast = false;
+ for (int64_t i = 0; i < rankDiff; ++i) {
+ if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
+ isReductionIterator(contractOp.getIteratorTypes()
+ .getValue()[map.getDimPosition(i)])) {
+ nonUnitDimReductionBroadcast = true;
+ break;
}
- if (nonUnitDimReductionBroadcast)
- continue;
-
- AffineMap broadcastMap =
- AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
- originalDims, contractOp.getContext());
- map = broadcastMap.compose(map);
- *operand = broadcast.getSource();
- changed = true;
}
+ if (nonUnitDimReductionBroadcast)
+ continue;
- if (!changed)
- return failure();
+ AffineMap broadcastMap =
+ AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
+ originalDims, contractOp.getContext());
+ map = broadcastMap.compose(map);
+ *operand = broadcast.getSource();
+ changed = true;
+ }
- // Determine which dims are usused, now that the maps have been composed
- // with the broadcast maps.
- llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
- // Compress unused dims.
- for (auto &m : maps)
- m = compressDims(m, unusedDimsBitVector);
- // Compute the combined iterators.
- SmallVector<Attribute> iterators;
- for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
- if (!unusedDimsBitVector.test(i))
- iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
- }
- // Check that compressing unused dims isn't removing all reduction dimension
- // pairs. For example, if the vector.contract had only one reduction
- // iterator and that was a unit-dimension created by a broadcast,
- // then we should bail here, otherwise we would create a contract without
- // a reduction dimension pair.
- bool hasReductionIteratorApplyingOnBothSides = false;
- for (unsigned i = 0; i < iterators.size(); ++i) {
- if (!isReductionIterator(iterators[i]))
- continue;
- if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
- hasReductionIteratorApplyingOnBothSides = true;
+ if (!changed)
+ return failure();
+
+ // Determine which dims are usused, now that the maps have been composed
+ // with the broadcast maps.
+ llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
+ // Compress unused dims.
+ for (auto &m : maps)
+ m = compressDims(m, unusedDimsBitVector);
+ // Compute the combined iterators.
+ SmallVector<Attribute> iterators;
+ for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
+ if (!unusedDimsBitVector.test(i))
+ iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
+ }
+
+ // Check whether any of the unused dims is non-unit, e.g.:
+ // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
+ // This is only required when collapsing a mask. If there is no mask, skip.
+ VectorType oldMaskType;
+ bool isAnyUnusedDimNonUnit = false;
+ if (maskingOp) {
+ oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
+ for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
+ if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
+ isAnyUnusedDimNonUnit = true;
break;
}
}
- if (!hasReductionIteratorApplyingOnBothSides)
- return failure();
+ }
- // If the compressed maps have a dimension that is not used by either LHS or
- // RHS then the ContractionOp verifier would fail.
- if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
- return failure();
- rewriter.replaceOpWithNewOp<vector::ContractionOp>(
- contractOp, lhs, rhs, contractOp.getAcc(),
- rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
- return success();
+ // Check that compressing unused dims isn't removing all reduction dimension
+ // pairs. For example, if the vector.contract had only one reduction
+ // iterator and that was a unit-dimension created by a broadcast,
+ // then we should bail here, otherwise we would create a contract without
+ // a reduction dimension pair.
+ bool hasReductionIteratorApplyingOnBothSides = false;
+ for (unsigned i = 0; i < iterators.size(); ++i) {
+ if (!isReductionIterator(iterators[i]))
+ continue;
+ if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
+ hasReductionIteratorApplyingOnBothSides = true;
+ break;
+ }
+ }
+ if (!hasReductionIteratorApplyingOnBothSides)
+ return failure();
+
+ // If the compressed maps have a dimension that is not used by either LHS or
+ // RHS then the ContractionOp verifier would fail.
+ if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
+ return failure();
+
+ Operation *newOp = rewriter.create<vector::ContractionOp>(
+ contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
+ rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
+
+ // Handle the mask.
+ if (maskingOp) {
+ if (isAnyUnusedDimNonUnit)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Cannont drop non-unit mask dim.");
+ assert(unusedDimsBitVector.size() ==
+ static_cast<size_t>(oldMaskType.getRank()) &&
+ "The mask rank is incorrect!");
+
+ // If a dimension has been dropped, update the mask accordingly. Otherwise,
+ // keep it as is.
+ Value mask = maskingOp.getMask();
+ if (unusedDimsBitVector.count() != 0) {
+ // At this point, two assumptions are made:
+ // * The unused dimensions are the leading mask dimensions
+ // (vector.contract does not support inner dim broadcasting).
+ // * The unused dimensions are all unit.
+ // These conditions are effectively verified in the blocks preceeding this
+ // one.
+ auto newShape =
+ oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
+ auto newShapeScalableDims =
+ oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
+ VectorType maskOpType =
+ VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
----------------
hanhanW wrote:
Do we have scalable vector support in this method/pattern? I don't find such test case in [vector-reduce-to-contract.mlir](https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir). Should we add a test for it?
https://github.com/llvm/llvm-project/pull/140050
More information about the Mlir-commits
mailing list