[Mlir-commits] [mlir] [mlir][vector] Update `CombineContractBroadcastMask` (PR #140050)

Andrzej Warzyński llvmlistbot at llvm.org
Mon May 19 01:21:14 PDT 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/140050

>From 568f955773045c5fd2e8618776db3f7d735ec2c7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 14 May 2025 09:37:58 +0100
Subject: [PATCH 1/2] [mlir][vector] Update `CombineContractBroadcastMask`
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch updates `CombineContractBroadcastMask` to inherit from
`MaskableOpRewritePattern`, enabling it to handle masked
`vector.contract` operations. The pattern rewrites:
```mlir
  %a = vector.broadcast %a_bc
  %res vector.contract %a_bc, %b, ...
```

into:
```mlir
  // Move the broadcast into vector.contract (by updating the indexing
  // maps)
  %res vector.contract %a, %b, ...
```

The main challenge is supporting cases where the pattern drops a leading
unit dimension. For example:
```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
    %arg0 : vector<8x4xi32>,
    %arg1 : vector<8x4xi32>,
    %arg2 : vector<8x8xi32>,
    %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {

  %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
  %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
  %result = vector.mask %mask {
    vector.contract {
      indexing_maps = [#map0, #map1, #map2],
      iterator_types = ["reduction", "parallel", "parallel", "reduction"],
      kind = #vector.kind<add>
    } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
  } : vector<1x8x8x4xi1> -> vector<8x8xi32>

  return %result : vector<8x8xi32>
}
```

Here, the leading unit dimension is dropped. To handle this, the mask is
cast to the correct shape using a `vector.shape_cast`:

```mlir
func.func @contract_broadcast_unit_dim_reduction_masked(
    %arg0: vector<8x4xi32>,
    %arg1: vector<8x4xi32>,
    %arg2: vector<8x8xi32>,
    %arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> {

  %mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1>
  %res = vector.mask %mask_sc {
    vector.contract {
      indexing_maps = [#map, #map1, #map2],
      iterator_types = ["parallel", "parallel", "reduction"],
      kind = #vector.kind<add>
    } %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
  } : vector<8x8x4xi1> -> vector<8x8xi32>

  return %res : vector<8x8xi32>
}
```

While this isn't ideal — since it introduces a `vector.shape_cast` that
must be cleaned up later — it reflects the best we can do once the input
reaches `CombineContractBroadcastMask`. A more robust solution may
involve simplifying the input earlier. I am leaving that as  a TODO for
myself to explore this further. Posting this now to unblock downstream
work.

LIMITATIONS

Currently, this pattern assumes:
* Only leading dimensions are dropped in the mask.
* All dropped dimensions must be unit-sized.

TODO: Check whether any other cases are possible.
---
 .../Vector/Transforms/VectorTransforms.cpp    | 251 +++++++++++-------
 .../Vector/vector-reduce-to-contract.mlir     |  77 +++++-
 2 files changed, 232 insertions(+), 96 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..3cd25c3cb2fc2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
 ///    iterator_types = ["parallel", "parallel", "reduction"],
 ///    kind = add} %arg0, %arg1, %cst_f0
 ///    : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
-///  ```
-struct CombineContractBroadcast
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<AffineMap> maps =
-        llvm::to_vector<4>(contractOp.getIndexingMapsArray());
-    Value lhs = contractOp.getLhs();
-    Value rhs = contractOp.getRhs();
-    size_t index = 0;
-    bool changed = false;
-    for (Value *operand : {&lhs, &rhs}) {
-      AffineMap &map = maps[index++];
-      auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
-      if (!broadcast)
-        continue;
-      // contractionOp can only take vector as operands.
-      auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
-      if (!srcType ||
-          srcType.getRank() == broadcast.getResultVectorType().getRank())
-        continue;
-      int64_t rankDiff =
-          broadcast.getResultVectorType().getRank() - srcType.getRank();
-      bool innerDimBroadcast = false;
-      SmallVector<AffineExpr> originalDims;
-      for (const auto &dim : llvm::enumerate(srcType.getShape())) {
-        if (dim.value() != broadcast.getResultVectorType().getDimSize(
-                               rankDiff + dim.index())) {
-          innerDimBroadcast = true;
-          break;
-        }
-        originalDims.push_back(
-            rewriter.getAffineDimExpr(dim.index() + rankDiff));
+/// ```
+///
+/// For masked vector.contract, the mask requires updating when a dimension is
+/// dropped. In such cases, the dropped dimensions must correspond to the mask's
+/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
+/// is not supported.
+FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
+                                             MaskingOpInterface maskingOp,
+                                             PatternRewriter &rewriter) {
+  SmallVector<AffineMap> maps =
+      llvm::to_vector<4>(contractOp.getIndexingMapsArray());
+  Value lhs = contractOp.getLhs();
+  Value rhs = contractOp.getRhs();
+  size_t index = 0;
+  bool changed = false;
+  for (Value *operand : {&lhs, &rhs}) {
+    AffineMap &map = maps[index++];
+    auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
+    if (!broadcast)
+      continue;
+    // contractionOp can only take vector as operands.
+    auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!srcType ||
+        srcType.getRank() == broadcast.getResultVectorType().getRank())
+      continue;
+    int64_t rankDiff =
+        broadcast.getResultVectorType().getRank() - srcType.getRank();
+    bool innerDimBroadcast = false;
+    SmallVector<AffineExpr> originalDims;
+    for (const auto &dim : llvm::enumerate(srcType.getShape())) {
+      if (dim.value() !=
+          broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
+        innerDimBroadcast = true;
+        break;
       }
-      // Contract doesn't support inner dimension broadcast. Once this is
-      // relaxed we can remove this case.
-      if (innerDimBroadcast)
-        continue;
+      originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
+    }
+    // Contract doesn't support inner dimension broadcast. Once this is
+    // relaxed we can remove this case.
+    if (innerDimBroadcast)
+      continue;
 
-      // It would be incorrect to fold a broadcast onto a reduction dimension
-      // of non-unit size.
-      bool nonUnitDimReductionBroadcast = false;
-      for (int64_t i = 0; i < rankDiff; ++i) {
-        if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
-            isReductionIterator(contractOp.getIteratorTypes()
-                                    .getValue()[map.getDimPosition(i)])) {
-          nonUnitDimReductionBroadcast = true;
-          break;
-        }
+    // It would be incorrect to fold a broadcast onto a reduction dimension
+    // of non-unit size.
+    bool nonUnitDimReductionBroadcast = false;
+    for (int64_t i = 0; i < rankDiff; ++i) {
+      if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
+          isReductionIterator(contractOp.getIteratorTypes()
+                                  .getValue()[map.getDimPosition(i)])) {
+        nonUnitDimReductionBroadcast = true;
+        break;
       }
-      if (nonUnitDimReductionBroadcast)
-        continue;
-
-      AffineMap broadcastMap =
-          AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
-                         originalDims, contractOp.getContext());
-      map = broadcastMap.compose(map);
-      *operand = broadcast.getSource();
-      changed = true;
     }
+    if (nonUnitDimReductionBroadcast)
+      continue;
 
-    if (!changed)
-      return failure();
+    AffineMap broadcastMap =
+        AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
+                       originalDims, contractOp.getContext());
+    map = broadcastMap.compose(map);
+    *operand = broadcast.getSource();
+    changed = true;
+  }
 
-    // Determine which dims are usused, now that the maps have been composed
-    // with the broadcast maps.
-    llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
-    // Compress unused dims.
-    for (auto &m : maps)
-      m = compressDims(m, unusedDimsBitVector);
-    // Compute the combined iterators.
-    SmallVector<Attribute> iterators;
-    for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
-      if (!unusedDimsBitVector.test(i))
-        iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
-    }
-    // Check that compressing unused dims isn't removing all reduction dimension
-    // pairs. For example, if the vector.contract had only one reduction
-    // iterator and that was a unit-dimension created by a broadcast,
-    // then we should bail here, otherwise we would create a contract without
-    // a reduction dimension pair.
-    bool hasReductionIteratorApplyingOnBothSides = false;
-    for (unsigned i = 0; i < iterators.size(); ++i) {
-      if (!isReductionIterator(iterators[i]))
-        continue;
-      if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
-        hasReductionIteratorApplyingOnBothSides = true;
+  if (!changed)
+    return failure();
+
+  // Determine which dims are usused, now that the maps have been composed
+  // with the broadcast maps.
+  llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
+  // Compress unused dims.
+  for (auto &m : maps)
+    m = compressDims(m, unusedDimsBitVector);
+  // Compute the combined iterators.
+  SmallVector<Attribute> iterators;
+  for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
+    if (!unusedDimsBitVector.test(i))
+      iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
+  }
+
+  // Check whether any of the unused dims is non-unit, e.g.:
+  //  * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
+  // This is only required when collapsing a mask. If there is no mask, skip.
+  VectorType oldMaskType;
+  bool isAnyUnusedDimNonUnit = false;
+  if (maskingOp) {
+    oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
+    for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
+      if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
+        isAnyUnusedDimNonUnit = true;
         break;
       }
     }
-    if (!hasReductionIteratorApplyingOnBothSides)
-      return failure();
+  }
 
-    // If the compressed maps have a dimension that is not used by either LHS or
-    // RHS then the ContractionOp verifier would fail.
-    if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
-      return failure();
-    rewriter.replaceOpWithNewOp<vector::ContractionOp>(
-        contractOp, lhs, rhs, contractOp.getAcc(),
-        rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
-    return success();
+  // Check that compressing unused dims isn't removing all reduction dimension
+  // pairs. For example, if the vector.contract had only one reduction
+  // iterator and that was a unit-dimension created by a broadcast,
+  // then we should bail here, otherwise we would create a contract without
+  // a reduction dimension pair.
+  bool hasReductionIteratorApplyingOnBothSides = false;
+  for (unsigned i = 0; i < iterators.size(); ++i) {
+    if (!isReductionIterator(iterators[i]))
+      continue;
+    if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
+      hasReductionIteratorApplyingOnBothSides = true;
+      break;
+    }
+  }
+  if (!hasReductionIteratorApplyingOnBothSides)
+    return failure();
+
+  // If the compressed maps have a dimension that is not used by either LHS or
+  // RHS then the ContractionOp verifier would fail.
+  if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
+    return failure();
+
+  Operation *newOp = rewriter.create<vector::ContractionOp>(
+      contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
+      rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
+
+  // Handle the mask.
+  if (maskingOp) {
+    if (isAnyUnusedDimNonUnit)
+      return rewriter.notifyMatchFailure(contractOp,
+                                         "Cannont drop non-unit mask dim.");
+    assert(unusedDimsBitVector.size() ==
+               static_cast<size_t>(oldMaskType.getRank()) &&
+           "The mask rank is incorrect!");
+
+    // If a dimension has been dropped, update the mask accordingly. Otherwise,
+    // keep it as is.
+    Value mask = maskingOp.getMask();
+    if (unusedDimsBitVector.count() != 0) {
+      // At this point, two assumptions are made:
+      //  * The unused dimensions are the leading mask dimensions
+      //  (vector.contract does not support inner dim broadcasting).
+      //  * The unused dimensions are all unit.
+      // These conditions are effectively verified in the blocks preceeding this
+      // one.
+      auto newShape =
+          oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
+      auto newShapeScalableDims =
+          oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
+      VectorType maskOpType =
+          VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
+      mask = rewriter
+                 .create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
+                                              maskingOp.getMask())
+                 .getResult();
+    }
+
+    newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
+  }
+  return newOp->getResult(0);
+}
+
+struct CombineContractBroadcastMask
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+  FailureOr<Value>
+
+  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+                            MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
   }
 };
 
@@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
 
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<MultiReduceToContract, CombineContractBroadcast,
+  patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
                CombineContractABTranspose, CombineContractResultTranspose>(
       patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 24070dbf017a5..a3d135665f6b2 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -13,8 +13,7 @@ func.func @multidimreduction_contract(
   %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>, %acc: vector<8x16xf32>) -> vector<8x16xf32> {
   %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
   %1 = vector.multi_reduction <add>, %0, %acc [1] : vector<8x32x16xf32> to vector<8x16xf32>
-  return %1 : vector<8x16xf32>
-}
+  return %1 : vector<8x16xf32> }
 
 // -----
 
@@ -62,6 +61,10 @@ func.func @contract_transpose(
 
 // -----
 
+//-----------------------------------------------------------------------------
+// [Pattern: CombineContractBroadcast]
+//-----------------------------------------------------------------------------
+
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
 
@@ -87,6 +90,43 @@ func.func @contract_broadcast(
 }
 
 // -----
+
+// Same as above, but with a mask.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_masked
+// CHECK-SAME:      %[[ARG0:.*]]: vector<32x16xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: vector<8x32x16xf32>,
+// CHECK-SAME:      %[[MASK:.*]]: vector<8x32x16xi1>) -> vector<8x32xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+// CHECK:           %[[R:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"],
+// CHECK-SAME:        kind = #vector.kind<add>}
+// CHECK-SAME:        %[[ARG0]], %[[ARG1]], %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+// CHECK-SAME       } : vector<8x32x16xi1> -> vector<8x32xf32>
+// CHECK:           return %[[R]] : vector<8x32xf32>
+func.func @contract_broadcast_masked(
+  %arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>, %mask: vector<8x32x16xi1>) -> vector<8x32xf32> {
+  %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+  %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
+  %1 = vector.mask %mask {
+    vector.contract {indexing_maps = [#map0, #map0, #map1],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>
+    } %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
+  } : vector<8x32x16xi1> -> vector<8x32xf32>
+  return %1 : vector<8x32xf32>
+}
+
+// -----
+
 // Test that CombineContractBroadcast is able to combine a broadcast that
 // creates a unit dim that is consumed by a reduction iterator, dropping that
 // reduction iterator, as long as there is another reduction iterator left.
@@ -116,6 +156,39 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
     return %result : vector<8x8xi32>
 }
 
+// -----
+
+// Same as above, but with a mask
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
+//  CHECK:      %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
+//  CHECK:      %[[R:.*]] = vector.mask %[[MASK_SC]] {
+//  CHECK-SAME: vector.contract
+//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+//  CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
+func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
+    %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
+    %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32>
+    %result = vector.mask %mask {
+      vector.contract {
+        indexing_maps = [#map0, #map1, #map2],
+        iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+        kind = #vector.kind<add>
+      } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32>
+    } : vector<1x8x8x4xi1> -> vector<8x8xi32>
+    return %result : vector<8x8xi32>
+}
+
 // -----
 // Test that CombineContractBroadcast will not combine a broadcast that creates
 // a non-unit dim that is consumed by a reduction iterator.

>From f3e3affd4f9ce4f365517f4deaae70cd957a3e70 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 19 May 2025 09:20:04 +0100
Subject: [PATCH 2/2] fixup! [mlir][vector] Update
 `CombineContractBroadcastMask`

* Add tests for scalable vectors
* Capitalize all LIT variables used for maps
* Fix punctuation
---
 .../Vector/vector-reduce-to-contract.mlir     | 153 +++++++++++++-----
 1 file changed, 117 insertions(+), 36 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index a3d135665f6b2..432e399e8f262 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -1,11 +1,15 @@
 // RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// TODO: Seperate tests for vector.multi_reduction -> vector.contract and
+//  * pre-op + vector.contract -> vector.contract,
+//  * vector.contract + post-op -> vector.contract.
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 // CHECK-LABEL: multidimreduction_contract
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xf32>, %[[ARG1:.*]]: vector<8x32x16xf32>, %[[ARG2:.*]]: vector<8x16xf32>)
-//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
 //  CHECK-NEXT:   return %[[R]] : vector<8x16xf32>
@@ -17,12 +21,12 @@ func.func @multidimreduction_contract(
 
 // -----
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
 
 // CHECK-LABEL: multidimreduction_contract_int
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<8x32x16xi32>, %[[ARG1:.*]]: vector<8x32x16xi32>, %[[ARG2:.*]]: vector<8x16xi32>)
-//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP1]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
 //  CHECK-NEXT:   return %[[R]] : vector<8x16xi32>
@@ -35,17 +39,21 @@ func.func @multidimreduction_contract_int(
 
 // -----
 
+//-----------------------------------------------------------------------------
+// [Pattern: CombineContractABTranspose]
+//-----------------------------------------------------------------------------
+
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: contract_transpose
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
 //  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
-//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 //  CHECK-NEXT:   return %[[R]] : vector<8x32xf32>
@@ -68,14 +76,14 @@ func.func @contract_transpose(
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: contract_broadcast
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
 //  CHECK-NEXT:   %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
-//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
 //  CHECK-SAME:   %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
 //  CHECK-NEXT:   return %[[R]] : vector<8x32xf32>
@@ -127,6 +135,42 @@ func.func @contract_broadcast_masked(
 
 // -----
 
+// Same as above, but with a scalable dim.
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_masked_scalable
+// CHECK-SAME:      %[[ARG0:.*]]: vector<[32]x16xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: vector<8x[32]x16xf32>,
+// CHECK-SAME:      %[[MASK:.*]]: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+// CHECK:           %[[R:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
+// CHECK-SAME:        iterator_types = ["parallel", "parallel", "reduction"],
+// CHECK-SAME:        kind = #vector.kind<add>}
+// CHECK-SAME:        %[[ARG0]], %[[ARG1]], %[[C0]] : vector<[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
+// CHECK-SAME       } : vector<8x[32]x16xi1> -> vector<8x32xf32>
+// CHECK:           return %[[R]] : vector<8x32xf32>
+func.func @contract_broadcast_masked_scalable(
+  %arg0: vector<[32]x16xf32>, %arg1: vector<8x[32]x16xf32>, %mask: vector<8x[32]x16xi1>) -> vector<8x32xf32> {
+  %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
+  %0 = vector.broadcast %arg0 : vector<[32]x16xf32> to vector<8x[32]x16xf32>
+  %1 = vector.mask %mask {
+    vector.contract {indexing_maps = [#map0, #map0, #map1],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>
+    } %0, %arg1, %cst : vector<8x[32]x16xf32>, vector<8x[32]x16xf32> into vector<8x32xf32>
+  } : vector<8x[32]x16xi1> -> vector<8x32xf32>
+  return %1 : vector<8x32xf32>
+}
+
+// -----
+
 // Test that CombineContractBroadcast is able to combine a broadcast that
 // creates a unit dim that is consumed by a reduction iterator, dropping that
 // reduction iterator, as long as there is another reduction iterator left.
@@ -135,14 +179,14 @@ func.func @contract_broadcast_masked(
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: contract_broadcast_unit_dim_reduction
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
 //  CHECK: vector.contract
-//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
 //  CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
 func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
@@ -158,22 +202,55 @@ func.func @contract_broadcast_unit_dim_reduction(%arg0 : vector<8x4xi32>, %arg1
 
 // -----
 
-// Same as above, but with a mask
+// Same as above, but with a mask.
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked_scalable
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<[8]x4xi32>, %[[ARG2:.+]]: vector<8x[8]xi32>, %[[MASK:.+]]: vector<1x8x[8]x4xi1>)
+//  CHECK:      %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x[8]x4xi1> to vector<8x[8]x4xi1>
+//  CHECK:      %[[R:.*]] = vector.mask %[[MASK_SC]] {
+//  CHECK-SAME: vector.contract
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+//  CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+//  CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<[8]x4xi32> into vector<8x[8]xi32>
+func.func @contract_broadcast_unit_dim_reduction_masked_scalable(%arg0 : vector<8x4xi32>, %arg1 : vector<[8]x4xi32>, %arg2 : vector<8x[8]xi32>, %mask: vector<1x8x[8]x4xi1>) -> vector<8x[8]xi32> {
+    %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32>
+    %1 = vector.broadcast %arg1 : vector<[8]x4xi32> to vector<1x[8]x4xi32>
+    %result = vector.mask %mask {
+      vector.contract {
+        indexing_maps = [#map0, #map1, #map2],
+        iterator_types = ["reduction", "parallel", "parallel", "reduction"],
+        kind = #vector.kind<add>
+      } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x[8]x4xi32> into vector<8x[8]xi32>
+    } : vector<1x8x[8]x4xi1> -> vector<8x[8]xi32>
+    return %result : vector<8x[8]xi32>
+}
+
+// -----
+
+// Same as above, but with a scalable dim.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
+#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
 // CHECK-LABEL: contract_broadcast_unit_dim_reduction_masked
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>, %[[MASK:.+]]: vector<1x8x8x4xi1>)
 //  CHECK:      %[[MASK_SC:.*]] = vector.shape_cast %[[MASK]] : vector<1x8x8x4xi1> to vector<8x8x4xi1>
 //  CHECK:      %[[R:.*]] = vector.mask %[[MASK_SC]] {
 //  CHECK-SAME: vector.contract
-//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
 //  CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32>
 func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> {
@@ -200,16 +277,16 @@ func.func @contract_broadcast_unit_dim_reduction_masked(%arg0 : vector<8x4xi32>,
 #map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
 
 // CHECK-LABEL: contract_broadcast_non_unit_dim_reduction_with_permutation
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<8x4xi32>, %[[ARG1:.+]]: vector<8x4xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
 //  CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8x4xi32> to vector<2x8x4xi32>
 //  CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8x4xi32> to vector<2x8x4xi32>
 //  CHECK: vector.contract
-//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel", "reduction"]
 //  CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<2x8x4xi32>, vector<2x8x4xi32> into vector<8x8xi32>
 func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
@@ -232,16 +309,16 @@ func.func @contract_broadcast_non_unit_dim_reduction_with_permutation(%arg0 : ve
 #map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
 
 // CHECK-LABEL: contract_broadcast_unit_dim_reduction_as_only_reduction
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<8xi32>, %[[ARG1:.+]]: vector<8xi32>, %[[ARG2:.+]]: vector<8x8xi32>)
 //  CHECK: %[[BROADCAST0:.+]] = vector.broadcast %[[ARG0]] : vector<8xi32> to vector<1x8xi32>
 //  CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<8xi32> to vector<1x8xi32>
 //  CHECK: vector.contract
-//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel"]
 //  CHECK-SAME: %[[BROADCAST0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x8xi32>, vector<1x8xi32> into vector<8x8xi32>
 func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : vector<8x8xi32>) -> vector<8x8xi32> {
@@ -264,15 +341,15 @@ func.func @contract_broadcast_unit_dim_reduction_as_only_reduction(%arg0 : vecto
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 #map2 = affine_map<(d0, d1, d2) -> (d1)>
 
-// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1)>
 
 // CHECK-LABEL: contract_broadcast_dimension_would_go_unused_in_lhs_rhs
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<1x2xi32>, %[[ARG1:.+]]: vector<2xi32>, %[[ARG2:.+]]: vector<1xi32>)
 //  CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<2xi32> to vector<1x1x2xi32>
 //  CHECK: vector.contract
-//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME: iterator_types = ["reduction", "parallel", "reduction"]
 //  CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1x2xi32>, vector<1x1x2xi32> into vector<1xi32>
 
@@ -303,7 +380,7 @@ func.func @contract_broadcast_dimension_would_go_unused_in_lhs_rhs(%arg0 : vecto
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>, %[[ARG2:.+]]: vector<1xf32>)
 //  CHECK: %[[BROADCAST1:.+]] = vector.broadcast %[[ARG1]] : vector<1xf32> to vector<1x1xf32>
 //  CHECK: vector.contract
-//  CHECK-SAME: indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]]
+//  CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
 //  CHECK-SAME: iterator_types = ["parallel", "reduction"]
 //  CHECK-SAME: %[[ARG0]], %[[BROADCAST1]], %[[ARG2]] : vector<1xf32>, vector<1x1xf32> into vector<1xf32>
 
@@ -320,6 +397,10 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
 
 // -----
 
+//-----------------------------------------------------------------------------
+// [Pattern: CombineContractResultTranspose]
+//-----------------------------------------------------------------------------
+
 // CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
 // CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
 // CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>



More information about the Mlir-commits mailing list