[Mlir-commits] [mlir] [mlir][Vector] Add support for masks in castAwayContractionLeadingOneDim (PR #81906)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sat Mar 16 11:58:52 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/81906

>From b7cebcac397facc7bc6b46ac37d4d68a3f8e2b7b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 4 Mar 2024 09:23:36 +0000
Subject: [PATCH 1/4] [mlir][vector] Adds pattern rewrite for maskable Ops

Adds a generic pattern rewrite for maskable Ops that we would like to
work for both masked and un-masked cases, e.g. for both `vector.mask
{vector.contract}` and `vector.contract` (this is a very contrived
example - just to demonstrate the idea).

This helps to reduce code-duplication and standardise how we implement
such patterns.

Fixes #78787
---
 .../Vector/Transforms/LowerVectorContract.cpp | 251 ++++++++++--------
 1 file changed, 143 insertions(+), 108 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 0eaf9f71a37d21..6480b295c49cb6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -41,7 +41,6 @@ using namespace mlir::vector;
 //===----------------------------------------------------------------------===//
 // Helper functions
 //===----------------------------------------------------------------------===//
-
 // Helper to find an index in an affine map.
 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -212,6 +211,64 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
 
 namespace {
 
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+///   1. It matches `SourceOp` operation, Op.
+///   2. If Op is masked, retrieves the mask and updates the insertion point to
+///   avoid inserting new ops into `vector.mask` Op region (which only allows
+///   one Op). If the Op is not masked, this step is a nop.
+///   3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+///   required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+///   * the source Op when mask _is not_ present,
+///   * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(SourceOp sourceOp,
+                                PatternRewriter &rewriter) const final {
+    auto maskableOp =
+        dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+    if (!maskableOp)
+      return failure();
+
+    // Retrieve the mask if present
+    MaskingOpInterface maskOp;
+    if (maskableOp.isMasked())
+      maskOp = maskableOp.getMaskingOp();
+
+    // If this Op is masked, update the insertion point to avoid inserting into
+    // the vector.mask Op region.
+    OpBuilder::InsertionGuard guard(rewriter);
+    Operation *rootOp = sourceOp;
+    if (maskOp) {
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+    }
+
+    FailureOr<Value> newOp =
+        matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+public:
+  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+  // latter is present, returns an updated masking op (with a replacement for
+  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+  virtual FailureOr<Value>
+  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const = 0;
+};
+
 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to:
 /// ```
@@ -226,9 +283,9 @@ namespace {
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -241,12 +298,13 @@ class ContractionOpToMatmulOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -270,9 +328,9 @@ class ContractionOpToMatmulOpLowering
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToOuterProductOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -285,12 +343,13 @@ class ContractionOpToOuterProductOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -317,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
 /// This only kicks in when VectorTransformsOptions is set to Dot and
 /// the vector.contract op is a row-major matmul or matvec.
 class ContractionOpToDotLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -332,11 +391,12 @@ class ContractionOpToDotLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -344,23 +404,10 @@ class ContractionOpToDotLowering
   FilterConstraintType filter;
 };
 
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-///   %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-///   %a = vector.contract with one less free/batch dimension
-///   %b = vector.contract with one less free/batch dimension
-///   ..
-///   %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+class ContractionOpLowering
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -371,12 +418,13 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
                         MLIRContext *context, PatternBenefit benefit = 1,
                         FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -634,8 +682,10 @@ struct UnrolledOuterProductGenerator
 ///
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
@@ -643,43 +693,25 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-  Operation *rootOp;
-  if (maskableOp.isMasked()) {
-    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-    rootOp = maskableOp.getMaskingOp();
-  } else {
-    rootOp = op;
-  }
-
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(rootOp, *matmatRes);
-    return success();
+    return matmatRes;
   }
   FailureOr<Value> matvecRes = e.matvec();
   if (succeeded(matvecRes)) {
-    rewriter.replaceOp(rootOp, *matvecRes);
-    return success();
-  }
-  FailureOr<Value> tmatvecRes = e.tmatvec();
-  if (succeeded(tmatvecRes)) {
-    rewriter.replaceOp(rootOp, *tmatvecRes);
-    return success();
+    return matvecRes;
   }
 
-  return failure();
+  FailureOr<Value> tmatvecRes = e.tmatvec();
+  return tmatvecRes;
 }
 
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
-                                            PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
+  if (maskOp)
     return failure();
 
   if (failed(filter(op)))
@@ -788,15 +820,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   }
   if (auto acc = op.getAcc())
     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
-  rewriter.replaceOp(op, res);
-  return success();
+  return res;
 }
 
 /// Lower vector.contract with all size one reduction dimensions to
 /// elementwise ops when possible.
 struct ContractOpToElementwise
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
   static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -806,14 +837,15 @@ struct ContractOpToElementwise
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: Support vector.mask.
-    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
-    if (maskableOp.isMasked())
+    if (maskOp)
       return failure();
 
     if (failed(filter(contractOp)))
@@ -903,8 +935,10 @@ struct ContractOpToElementwise
     std::optional<Value> result =
         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
                               contractOp.getKind(), rewriter, isInt);
-    rewriter.replaceOp(contractOp, {*result});
-    return success();
+    if (result)
+      return *result;
+
+    return failure();
   }
 
 private:
@@ -930,9 +964,9 @@ struct ContractOpToElementwise
 // TODO: break down into transpose/reshape/cast ops
 //               when they become available to avoid code dup
 // TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                       PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (failed(filter(op)))
     return failure();
 
@@ -951,29 +985,36 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
+
   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
-  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal1 =
+      pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal1))
+    return newVal1;
+
   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
-  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal2 =
+      pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal2))
+    return newVal2;
+
   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
-  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal3 =
+      pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal3))
+    return newVal3;
+
   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
-  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal4 =
+      pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal4))
+    return newVal4;
 
   // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  Operation *rootOp = op;
-  Value mask;
-  if (op.isMasked()) {
-    rewriter.setInsertionPoint(op.getMaskingOp());
-    rootOp = op.getMaskingOp();
-    mask = op.getMaskingOp().getMask();
-  }
 
+  Value mask;
+  if (maskOp)
+    mask = maskOp.getMask();
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
   if (!batchDimMap.empty()) {
@@ -982,8 +1023,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
+    return newOp;
   }
 
   // Collect contracting dimensions.
@@ -1003,8 +1043,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
+      return newOp;
     }
   }
 
@@ -1015,8 +1054,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
+      return newOp;
     }
   }
 
@@ -1025,8 +1063,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
     auto newOp = lowerReduction(rewriter, op, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
+    return newOp;
   }
 
   return failure();
@@ -1291,12 +1328,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
 /// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                                 PatternRewriter &rew) const {
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rew) const {
   // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
+  if (maskOp)
     return failure();
 
   if (vectorTransformOptions.vectorContractLowering !=
@@ -1379,8 +1415,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
           : static_cast<Value>(
                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
 
-  rew.replaceOp(op, res);
-  return success();
+  return res;
 }
 } // namespace
 

>From f534bddf57f115fb35cdc5619d564c535ef372c6 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 5 Mar 2024 16:42:26 +0000
Subject: [PATCH 2/4] fixup! [mlir][vector] Adds pattern rewrite for maskable
 Ops

Rename the pattern, simplify code, restore docs
---
 .../Vector/Transforms/LowerVectorContract.cpp | 66 +++++++++++--------
 1 file changed, 40 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6480b295c49cb6..af502270c68db2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -226,27 +226,27 @@ namespace {
 ///   * the source Op when mask _is not_ present,
 ///   * the source Op *and* the mask Op when mask _is_ present.
 template <class SourceOp>
-struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
   using OpRewritePattern<SourceOp>::OpRewritePattern;
 
 private:
   LogicalResult matchAndRewrite(SourceOp sourceOp,
                                 PatternRewriter &rewriter) const final {
-    auto maskableOp =
-        dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+    auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
     if (!maskableOp)
       return failure();
 
-    // Retrieve the mask if present
-    MaskingOpInterface maskOp;
-    if (maskableOp.isMasked())
-      maskOp = maskableOp.getMaskingOp();
+    // Op to update
+    Operation *rootOp = sourceOp;
 
-    // If this Op is masked, update the insertion point to avoid inserting into
-    // the vector.mask Op region.
+    // If this Op is masked:
+    //    * update the insertion point to avoid inserting into the vector.mask
+    //      Op region,
+    //    * update the Op to rewrite so that it's the parent vector.mask Op
     OpBuilder::InsertionGuard guard(rewriter);
-    Operation *rootOp = sourceOp;
-    if (maskOp) {
+    MaskingOpInterface maskOp;
+    if (maskableOp.isMasked()) {
+      maskOp = maskableOp.getMaskingOp();
       rewriter.setInsertionPoint(maskOp);
       rootOp = maskOp;
     }
@@ -283,9 +283,9 @@ struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
-    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -298,7 +298,7 @@ class ContractionOpToMatmulOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
@@ -328,9 +328,9 @@ class ContractionOpToMatmulOpLowering
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToOuterProductOpLowering
-    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -343,7 +343,7 @@ class ContractionOpToOuterProductOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
@@ -376,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
 /// This only kicks in when VectorTransformsOptions is set to Dot and
 /// the vector.contract op is a row-major matmul or matvec.
 class ContractionOpToDotLowering
-    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -391,7 +391,7 @@ class ContractionOpToDotLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
   FailureOr<Value>
@@ -404,10 +404,24 @@ class ContractionOpToDotLowering
   FilterConstraintType filter;
 };
 
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+///   %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+///   %a = vector.contract with one less free/batch dimension
+///   %b = vector.contract with one less free/batch dimension
+///   ..
+///   %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
 class ContractionOpLowering
-    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -418,7 +432,7 @@ class ContractionOpLowering
   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
                         MLIRContext *context, PatternBenefit benefit = 1,
                         FilterConstraintType constraint = defaultFilter)
-      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
@@ -826,8 +840,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
 /// Lower vector.contract with all size one reduction dimensions to
 /// elementwise ops when possible.
 struct ContractOpToElementwise
-    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
-  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
   static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -837,7 +851,7 @@ struct ContractOpToElementwise
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
   FailureOr<Value>

>From 158e6dd66a7ed87e65caa81e27130f324c888f0b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Mar 2024 15:12:08 +0000
Subject: [PATCH 3/4] fixup! [mlir][vector] Adds pattern rewrite for maskable
 Ops

Move pattern
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 58 ++++++++++++++++++
 .../Vector/Transforms/LowerVectorContract.cpp | 60 +------------------
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 31 ++++++++++
 3 files changed, 90 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f6b03a0f2c8007..3ac5ba7bf1134c 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -98,6 +98,64 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 std::optional<StaticTileOffsetRange>
 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
 
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+///   1. It matches `SourceOp` operation, Op.
+///   2. If Op is masked, retrieves the mask and updates the insertion point to
+///   avoid inserting new ops into `vector.mask` Op region (which only allows
+///   one Op). If the Op is not masked, this step is a nop.
+///   3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+///   required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+///   * the source Op when mask _is not_ present,
+///   * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(SourceOp sourceOp,
+                                PatternRewriter &rewriter) const final {
+    auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
+    if (!maskableOp)
+      return failure();
+
+    // Op to update
+    Operation *rootOp = sourceOp;
+
+    // If this Op is masked:
+    //    * update the insertion point to avoid inserting into the vector.mask
+    //      Op region,
+    //    * update the Op to rewrite so that it's the parent vector.mask Op
+    OpBuilder::InsertionGuard guard(rewriter);
+    MaskingOpInterface maskOp;
+    if (maskableOp.isMasked()) {
+      maskOp = maskableOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+    }
+
+    FailureOr<Value> newOp =
+        matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+public:
+  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+  // latter is present, returns an updated masking op (with a replacement for
+  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+  virtual FailureOr<Value>
+  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const = 0;
+};
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index af502270c68db2..ba1c96805ff831 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -211,64 +211,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
 
 namespace {
 
-/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
-/// masked (i.e. inside `vector.mask` Op region). In particular:
-///   1. It matches `SourceOp` operation, Op.
-///   2. If Op is masked, retrieves the mask and updates the insertion point to
-///   avoid inserting new ops into `vector.mask` Op region (which only allows
-///   one Op). If the Op is not masked, this step is a nop.
-///   3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
-///   required) in the matched `vector.mask` operation from step 2.
-///
-/// It frees the patterns implementing this class from worrying about the
-/// logic to update the insertion point. However, those patterns are still
-/// responsible for providing an updated version of:
-///   * the source Op when mask _is not_ present,
-///   * the source Op *and* the mask Op when mask _is_ present.
-template <class SourceOp>
-struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
-  using OpRewritePattern<SourceOp>::OpRewritePattern;
-
-private:
-  LogicalResult matchAndRewrite(SourceOp sourceOp,
-                                PatternRewriter &rewriter) const final {
-    auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
-    if (!maskableOp)
-      return failure();
-
-    // Op to update
-    Operation *rootOp = sourceOp;
-
-    // If this Op is masked:
-    //    * update the insertion point to avoid inserting into the vector.mask
-    //      Op region,
-    //    * update the Op to rewrite so that it's the parent vector.mask Op
-    OpBuilder::InsertionGuard guard(rewriter);
-    MaskingOpInterface maskOp;
-    if (maskableOp.isMasked()) {
-      maskOp = maskableOp.getMaskingOp();
-      rewriter.setInsertionPoint(maskOp);
-      rootOp = maskOp;
-    }
-
-    FailureOr<Value> newOp =
-        matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
-    if (failed(newOp))
-      return failure();
-
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
-  }
-
-public:
-  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
-  // latter is present, returns an updated masking op (with a replacement for
-  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
-  virtual FailureOr<Value>
-  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
-                            PatternRewriter &rewriter) const = 0;
-};
-
 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to:
 /// ```
@@ -283,7 +225,7 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
-    : public MaskableOpRewritePattern<vector::ContractionOp> {
+    : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
 public:
   using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d613672608c3ad..02e2efa308d0b1 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -279,6 +279,37 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
 }
 
+// template <class SourceOp>
+// LogicalResult vector::MaskableOpRewritePattern::matchAndRewrite(
+//     SourceOp sourceOp, PatternRewriter &rewriter) const final {
+//   auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
+//   if (!maskableOp)
+//     return failure();
+
+//   // Op to update
+//   Operation *rootOp = sourceOp;
+
+//   // If this Op is masked:
+//   //    * update the insertion point to avoid inserting into the vector.mask
+//   //      Op region,
+//   //    * update the Op to rewrite so that it's the parent vector.mask Op
+//   OpBuilder::InsertionGuard guard(rewriter);
+//   MaskingOpInterface maskOp;
+//   if (maskableOp.isMasked()) {
+//     maskOp = maskableOp.getMaskingOp();
+//     rewriter.setInsertionPoint(maskOp);
+//     rootOp = maskOp;
+//   }
+
+//   FailureOr<Value> newOp =
+//       matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+//   if (failed(newOp))
+//     return failure();
+
+//   rewriter.replaceOp(rootOp, *newOp);
+//   return success();
+// }
+
 std::optional<StaticTileOffsetRange>
 vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
   if (vType.getRank() <= targetRank)

>From 1e9c7a1c5aa73cd74db29a65b644d0d5d95f6418 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 15 Feb 2024 19:12:40 +0000
Subject: [PATCH 4/4] [mlir][Vector] Add support for masks in
 castAwayContractionLeadingOneDim

Partial fix for #78787
---
 .../Vector/Transforms/VectorTransforms.h      |   6 +-
 .../Transforms/VectorDropLeadUnitDim.cpp      |  50 +++++----
 .../vector-dropleadunitdim-transforms.mlir    | 104 +++++++++++++-----
 3 files changed, 109 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
index 08d3bb157a0e39..1f7d6411cd5a46 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h
@@ -110,8 +110,10 @@ void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp);
 
 /// Cast away the leading unit dim, if exists, for the given contract op.
 /// Return success if the transformation applies; return failure otherwise.
-LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
-                                               RewriterBase &rewriter);
+FailureOr<Value>
+castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+                                 MaskingOpInterface maskingOp,
+                                 RewriterBase &rewriter);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..50c7d33068799c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -329,12 +329,12 @@ struct CastAwayTransferWriteLeadingOneDim
 
 } // namespace
 
-LogicalResult
+FailureOr<Value>
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
+                                               MaskingOpInterface maskingOp,
                                                RewriterBase &rewriter) {
-  // TODO(#78787): Not supported masked op yet.
-  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
-    return failure();
+  auto isMasked =
+      cast<MaskableOpInterface>(contractOp.getOperation()).isMasked();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();
@@ -368,6 +368,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
                                  contractOp.getAcc()};
   SmallVector<Value> newOperands;
+  auto loc = contractOp.getLoc();
 
   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
     // Check if the dim to be dropped exists as a leading dim in the operand
@@ -405,7 +406,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
         operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            contractOp.getLoc(), operands[it.index()], perm);
+            loc, operands[it.index()], perm);
       }
     }
     // We have taken care to have the dim to be dropped be
@@ -429,18 +430,27 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
     // Extract if its a valid extraction, otherwise use the operand
     // without extraction.
     newOperands.push_back(
-        validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
-                                                          operands[it.index()],
-                                                          splatZero(dropDim))
+        validExtract ? rewriter.create<vector::ExtractOp>(
+                           loc, operands[it.index()], splatZero(dropDim))
                      : operands[it.index()]);
   }
-  auto newContractOp = rewriter.create<vector::ContractionOp>(
-      contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+  Operation *newContractOp = rewriter.create<vector::ContractionOp>(
+      loc, newOperands[0], newOperands[1], newOperands[2],
       rewriter.getAffineMapArrayAttr(newIndexingMaps),
       rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
-  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-      contractOp, contractOp->getResultTypes()[0], newContractOp);
-  return success();
+
+  if (maskingOp) {
+    auto newMask = rewriter.create<vector::ExtractOp>(loc, maskingOp.getMask(),
+                                                      splatZero(dropDim));
+
+    newContractOp =
+        mlir::vector::maskOperation(rewriter, newContractOp, newMask);
+  }
+
+  return rewriter
+      .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0],
+                                   newContractOp->getResults()[0])
+      .getResult();
 }
 
 namespace {
@@ -450,12 +460,14 @@ namespace {
 /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
 /// prior to extract.
 struct CastAwayContractionLeadingOneDim
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
-    return castAwayContractionLeadingOneDim(contractOp, rewriter);
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+                            MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    return castAwayContractionLeadingOneDim(contractOp, maskingOp, rewriter);
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..4ba51c5953d13c 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
 }
 
 // -----
+// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_const_mask
+// CHECK:           %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK:           return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
+// -----
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_mask
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
+// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
+// CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 
+// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_mask(
+  %arg0: vector<1x16x8xf32>,
+  %arg1: vector<1x8x16xf32>,
+  %arg2: vector<1x16x16xf32>,
+  %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0: vector<1x16x16xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
-// -----
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-
-// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
-// CHECK:      %[[MASK:.+]] = vector.constant_mask
-// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
-// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
-// CHECK-SAME:   vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
-// CHECK:      return %[[RET]] : vector<1x16x16xf32>
-
-#contraction_accesses0 = [
-  affine_map<(l, i, j, k) -> (l, i, k)>,
-  affine_map<(l, i, j, k) -> (l, k, j)>,
-  affine_map<(l, i, j, k) -> (l, i, j)>
-]
-#contraction_trait0 = {
-  indexing_maps = #contraction_accesses0,
-  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-}
-
-func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
-  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
-  %0 = vector.mask %mask {
-    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
-  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
-  return %0 : vector<1x16x16xf32>
-}
 
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims



More information about the Mlir-commits mailing list