[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)

Han-Chung Wang llvmlistbot at llvm.org
Fri May 16 15:28:43 PDT 2025


================
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
 ///    iterator_types = ["parallel", "parallel", "reduction"],
 ///    kind = add} %arg0, %arg1, %cst_f0
 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
-///  ```
-struct CombineContractBroadcast
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap> maps =
-        llvm::to_vector<4>(contractOp.getIndexingMapsArray());
-    Value lhs = contractOp.getLhs();
-    Value rhs = contractOp.getRhs();
-    size_t index = 0;
-    bool changed = false;
-    for (Value *operand : {&lhs, &rhs}) {
-      AffineMap &map = maps[index++];
-      auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
-      if (!broadcast)
-        continue;
-      // contractionOp can only take vector as operands.
-      auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
-      if (!srcType ||
-          srcType.getRank() == broadcast.getResultVectorType().getRank())
-        continue;
-      int64_t rankDiff =
-          broadcast.getResultVectorType().getRank() - srcType.getRank();
-      bool innerDimBroadcast = false;
-      SmallVector<AffineExpr> originalDims;
-      for (const auto &dim : llvm::enumerate(srcType.getShape())) {
-        if (dim.value() != broadcast.getResultVectorType().getDimSize(
-                               rankDiff + dim.index())) {
-          innerDimBroadcast = true;
-          break;
-        }
-        originalDims.push_back(
-            rewriter.getAffineDimExpr(dim.index() + rankDiff));
+/// ```
+///
+/// For masked vector.contract, the mask requires updating when a dimension is
+/// dropped. In such cases, the dropped dimensions must correspond to the mask's
+/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
+/// is not supported.
+FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
+                                             MaskingOpInterface maskingOp,
+                                             PatternRewriter &rewriter) {
+  SmallVector<AffineMap> maps =
+      llvm::to_vector<4>(contractOp.getIndexingMapsArray());
+  Value lhs = contractOp.getLhs();
+  Value rhs = contractOp.getRhs();
+  size_t index = 0;
+  bool changed = false;
+  for (Value *operand : {&lhs, &rhs}) {
+    AffineMap &map = maps[index++];
+    auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast)
+      continue;
+    // contractionOp can only take vector as operands.
+    auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!srcType ||
+        srcType.getRank() == broadcast.getResultVectorType().getRank())
+      continue;
+    int64_t rankDiff =
+        broadcast.getResultVectorType().getRank() - srcType.getRank();
+    bool innerDimBroadcast = false;
+    SmallVector<AffineExpr> originalDims;
+    for (const auto &dim : llvm::enumerate(srcType.getShape())) {
+      if (dim.value() !=
+          broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
+        innerDimBroadcast = true;
+        break;
       }
-      // Contract doesn't support inner dimension broadcast. Once this is
-      // relaxed we can remove this case.
-      if (innerDimBroadcast)
-        continue;
+      originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
+    }
+    // Contract doesn't support inner dimension broadcast. Once this is
+    // relaxed we can remove this case.
+    if (innerDimBroadcast)
+      continue;
 
-      // It would be incorrect to fold a broadcast onto a reduction dimension
-      // of non-unit size.
-      bool nonUnitDimReductionBroadcast = false;
-      for (int64_t i = 0; i < rankDiff; ++i) {
-        if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
-            isReductionIterator(contractOp.getIteratorTypes()
-                                    .getValue()[map.getDimPosition(i)])) {
-          nonUnitDimReductionBroadcast = true;
-          break;
-        }
+    // It would be incorrect to fold a broadcast onto a reduction dimension
+    // of non-unit size.
+    bool nonUnitDimReductionBroadcast = false;
+    for (int64_t i = 0; i < rankDiff; ++i) {
+      if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
+          isReductionIterator(contractOp.getIteratorTypes()
+                                  .getValue()[map.getDimPosition(i)])) {
+        nonUnitDimReductionBroadcast = true;
+        break;
       }
-      if (nonUnitDimReductionBroadcast)
-        continue;
-
-      AffineMap broadcastMap =
-          AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
-                         originalDims, contractOp.getContext());
-      map = broadcastMap.compose(map);
-      *operand = broadcast.getSource();
-      changed = true;
     }
+    if (nonUnitDimReductionBroadcast)
+      continue;
 
-    if (!changed)
-      return failure();
+    AffineMap broadcastMap =
+        AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
+                       originalDims, contractOp.getContext());
+    map = broadcastMap.compose(map);
+    *operand = broadcast.getSource();
+    changed = true;
+  }
 
-    // Determine which dims are usused, now that the maps have been composed
-    // with the broadcast maps.
-    llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
-    // Compress unused dims.
-    for (auto &m : maps)
-      m = compressDims(m, unusedDimsBitVector);
-    // Compute the combined iterators.
-    SmallVector<Attribute> iterators;
-    for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
-      if (!unusedDimsBitVector.test(i))
-        iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
-    }
-    // Check that compressing unused dims isn't removing all reduction dimension
-    // pairs. For example, if the vector.contract had only one reduction
-    // iterator and that was a unit-dimension created by a broadcast,
-    // then we should bail here, otherwise we would create a contract without
-    // a reduction dimension pair.
-    bool hasReductionIteratorApplyingOnBothSides = false;
-    for (unsigned i = 0; i < iterators.size(); ++i) {
-      if (!isReductionIterator(iterators[i]))
-        continue;
-      if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
-        hasReductionIteratorApplyingOnBothSides = true;
+  if (!changed)
+    return failure();
+
+  // Determine which dims are usused, now that the maps have been composed
+  // with the broadcast maps.
+  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
+  // Compress unused dims.
+  for (auto &m : maps)
+    m = compressDims(m, unusedDimsBitVector);
+  // Compute the combined iterators.
+  SmallVector<Attribute> iterators;
+  for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
+    if (!unusedDimsBitVector.test(i))
+      iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
+  }
+
+  // Check whether any of the unused dims is non-unit, e.g.:
+  //  * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
+  // This is only required when collapsing a mask. If there is no mask, skip.
+  VectorType oldMaskType;
+  bool isAnyUnusedDimNonUnit = false;
+  if (maskingOp) {
+    oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
+    for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
+      if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
+        isAnyUnusedDimNonUnit = true;
         break;
       }
     }
-    if (!hasReductionIteratorApplyingOnBothSides)
-      return failure();
+  }
 
-    // If the compressed maps have a dimension that is not used by either LHS or
-    // RHS then the ContractionOp verifier would fail.
-    if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
-      return failure();
-    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
-        contractOp, lhs, rhs, contractOp.getAcc(),
-        rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
-    return success();
+  // Check that compressing unused dims isn't removing all reduction dimension
+  // pairs. For example, if the vector.contract had only one reduction
+  // iterator and that was a unit-dimension created by a broadcast,
+  // then we should bail here, otherwise we would create a contract without
+  // a reduction dimension pair.
+  bool hasReductionIteratorApplyingOnBothSides = false;
+  for (unsigned i = 0; i < iterators.size(); ++i) {
+    if (!isReductionIterator(iterators[i]))
+      continue;
+    if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
+      hasReductionIteratorApplyingOnBothSides = true;
+      break;
+    }
+  }
+  if (!hasReductionIteratorApplyingOnBothSides)
+    return failure();
+
+  // If the compressed maps have a dimension that is not used by either LHS or
+  // RHS then the ContractionOp verifier would fail.
+  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
+    return failure();
+
+  Operation *newOp = rewriter.create<vector::ContractionOp>(
+      contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
+      rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
+
+  // Handle the mask.
+  if (maskingOp) {
+    if (isAnyUnusedDimNonUnit)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Cannont drop non-unit mask dim.");
+    assert(unusedDimsBitVector.size() ==
+               static_cast<size_t>(oldMaskType.getRank()) &&
+           "The mask rank is incorrect!");
+
+    // If a dimension has been dropped, update the mask accordingly. Otherwise,
+    // keep it as is.
+    Value mask = maskingOp.getMask();
+    if (unusedDimsBitVector.count() != 0) {
+      // At this point, two assumptions are made:
+      //  * The unused dimensions are the leading mask dimensions
+      //  (vector.contract does not support inner dim broadcasting).
+      //  * The unused dimensions are all unit.
+      // These conditions are effectively verified in the blocks preceeding this
+      // one.
+      auto newShape =
+          oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
+      auto newShapeScalableDims =
+          oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
+      VectorType maskOpType =
+          VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
----------------
hanhanW wrote:

Do we have scalable vector support in this method/pattern? I don't find such test case in [vector-reduce-to-contract.mlir](https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir). Should we add a test for it?

https://github.com/llvm/llvm-project/pull/140050


More information about the Mlir-commits mailing list