[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)

Diego Caballero llvmlistbot at llvm.org
Tue May 20 14:21:09 PDT 2025


================
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
 ///    iterator_types = ["parallel", "parallel", "reduction"],
 ///    kind = add} %arg0, %arg1, %cst_f0
 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
-///  ```
-struct CombineContractBroadcast
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap> maps =
-        llvm::to_vector<4>(contractOp.getIndexingMapsArray());
-    Value lhs = contractOp.getLhs();
-    Value rhs = contractOp.getRhs();
-    size_t index = 0;
-    bool changed = false;
-    for (Value *operand : {&lhs, &rhs}) {
-      AffineMap &map = maps[index++];
-      auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
-      if (!broadcast)
-        continue;
-      // contractionOp can only take vector as operands.
-      auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
-      if (!srcType ||
-          srcType.getRank() == broadcast.getResultVectorType().getRank())
-        continue;
-      int64_t rankDiff =
-          broadcast.getResultVectorType().getRank() - srcType.getRank();
-      bool innerDimBroadcast = false;
-      SmallVector<AffineExpr> originalDims;
-      for (const auto &dim : llvm::enumerate(srcType.getShape())) {
-        if (dim.value() != broadcast.getResultVectorType().getDimSize(
-                               rankDiff + dim.index())) {
-          innerDimBroadcast = true;
-          break;
-        }
-        originalDims.push_back(
-            rewriter.getAffineDimExpr(dim.index() + rankDiff));
+/// ```
+///
+/// For masked vector.contract, the mask requires updating when a dimension is
+/// dropped. In such cases, the dropped dimensions must correspond to the mask's
+/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
+/// is not supported.
+FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
+                                             MaskingOpInterface maskingOp,
+                                             PatternRewriter &rewriter) {
+  SmallVector<AffineMap> maps =
----------------
dcaballe wrote:

auto

https://github.com/llvm/llvm-project/pull/140050


More information about the Mlir-commits mailing list