[Mlir-commits] [mlir] b7324b6 - [mlir][vector] Adds pattern rewrite for maskable Ops (#83827)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 20 14:04:10 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-03-20T21:04:06Z
New Revision: b7324b6a9c6bd43786ea853bf1a9730486b4bc88

URL: https://github.com/llvm/llvm-project/commit/b7324b6a9c6bd43786ea853bf1a9730486b4bc88
DIFF: https://github.com/llvm/llvm-project/commit/b7324b6a9c6bd43786ea853bf1a9730486b4bc88.diff

LOG: [mlir][vector] Adds pattern rewrite for maskable Ops (#83827)

Adds a generic pattern rewrite for maskable Ops, `MaskableOpRewritePattern`,
that will work for both masked and un-masked cases, e.g. for both:

* `vector.mask {vector.contract}` (masked), and
* `vector.contract` (not masked).

This helps to reduce code-duplication and standardise how we implement such
patterns.

Fixes #78787

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 3ce16ef361f37d..35e76a8b623a32 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -112,6 +112,64 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
                                             Operation *xfer,
                                             RewriterBase &rewriter);
 
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+///   1. Matches `SourceOp` operation, Op.
+///   2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
+///     insertion point to avoid inserting new ops into the `vector.mask` Op
+///     region (which only allows one Op).
+///   2.2 If Op is not masked, this step is skipped.
+///   3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
+///     found in step 2.1.
+///
+/// This wrapper frees patterns from re-implementing the logic to update the
+/// insertion point when a maskable Op is masked. Such patterns are still
+/// responsible for providing an updated ("rewritten") version of:
+///   a. the source Op when mask _is not_ present,
+///   b. the source Op and the masking Op when mask _is_ present.
+/// Note that the return value from `matchAndRewriteMaskableOp` depends on the
+/// case above.
+template <class SourceOp>
+struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(SourceOp sourceOp,
+                                PatternRewriter &rewriter) const final {
+    auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
+    if (!maskableOp)
+      return failure();
+
+    Operation *rootOp = sourceOp;
+
+    // If this Op is masked, update the insertion point to avoid inserting into
+    // the vector.mask Op region.
+    OpBuilder::InsertionGuard guard(rewriter);
+    MaskingOpInterface maskOp;
+    if (maskableOp.isMasked()) {
+      maskOp = maskableOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+    }
+
+    FailureOr<Value> newOp =
+        matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+public:
+  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+  // latter is present, returns an updated masking op (with a replacement for
+  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+  virtual FailureOr<Value>
+  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const = 0;
+};
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 0eaf9f71a37d21..ba1c96805ff831 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -41,7 +41,6 @@ using namespace mlir::vector;
 //===----------------------------------------------------------------------===//
 // Helper functions
 //===----------------------------------------------------------------------===//
-
 // Helper to find an index in an affine map.
 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -226,9 +225,9 @@ namespace {
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -241,12 +240,13 @@ class ContractionOpToMatmulOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -270,9 +270,9 @@ class ContractionOpToMatmulOpLowering
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToOuterProductOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -285,12 +285,13 @@ class ContractionOpToOuterProductOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -317,9 +318,9 @@ class ContractionOpToOuterProductOpLowering
 /// This only kicks in when VectorTransformsOptions is set to Dot and
 /// the vector.contract op is a row-major matmul or matvec.
 class ContractionOpToDotLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -332,11 +333,12 @@ class ContractionOpToDotLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -358,9 +360,10 @@ class ContractionOpToDotLowering
 ///
 /// This only kicks in when either VectorTransformsOptions is set
 /// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+class ContractionOpLowering
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -371,12 +374,13 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
                         MLIRContext *context, PatternBenefit benefit = 1,
                         FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -634,8 +638,10 @@ struct UnrolledOuterProductGenerator
 ///
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
@@ -643,43 +649,25 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-  Operation *rootOp;
-  if (maskableOp.isMasked()) {
-    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-    rootOp = maskableOp.getMaskingOp();
-  } else {
-    rootOp = op;
-  }
-
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(rootOp, *matmatRes);
-    return success();
+    return matmatRes;
   }
   FailureOr<Value> matvecRes = e.matvec();
   if (succeeded(matvecRes)) {
-    rewriter.replaceOp(rootOp, *matvecRes);
-    return success();
-  }
-  FailureOr<Value> tmatvecRes = e.tmatvec();
-  if (succeeded(tmatvecRes)) {
-    rewriter.replaceOp(rootOp, *tmatvecRes);
-    return success();
+    return matvecRes;
   }
 
-  return failure();
+  FailureOr<Value> tmatvecRes = e.tmatvec();
+  return tmatvecRes;
 }
 
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
-                                            PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
+  if (maskOp)
     return failure();
 
   if (failed(filter(op)))
@@ -788,15 +776,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   }
   if (auto acc = op.getAcc())
     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
-  rewriter.replaceOp(op, res);
-  return success();
+  return res;
 }
 
 /// Lower vector.contract with all size one reduction dimensions to
 /// elementwise ops when possible.
 struct ContractOpToElementwise
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaskableOpRewritePattern<vector::ContractionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
   static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -806,14 +793,15 @@ struct ContractOpToElementwise
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: Support vector.mask.
-    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
-    if (maskableOp.isMasked())
+    if (maskOp)
       return failure();
 
     if (failed(filter(contractOp)))
@@ -903,8 +891,10 @@ struct ContractOpToElementwise
     std::optional<Value> result =
         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
                               contractOp.getKind(), rewriter, isInt);
-    rewriter.replaceOp(contractOp, {*result});
-    return success();
+    if (result)
+      return *result;
+
+    return failure();
   }
 
 private:
@@ -930,9 +920,9 @@ struct ContractOpToElementwise
 // TODO: break down into transpose/reshape/cast ops
 //               when they become available to avoid code dup
 // TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                       PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (failed(filter(op)))
     return failure();
 
@@ -951,29 +941,36 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
+
   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
-  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal1 =
+      pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal1))
+    return newVal1;
+
   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
-  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal2 =
+      pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal2))
+    return newVal2;
+
   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
-  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal3 =
+      pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal3))
+    return newVal3;
+
   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
-  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal4 =
+      pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal4))
+    return newVal4;
 
   // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  Operation *rootOp = op;
-  Value mask;
-  if (op.isMasked()) {
-    rewriter.setInsertionPoint(op.getMaskingOp());
-    rootOp = op.getMaskingOp();
-    mask = op.getMaskingOp().getMask();
-  }
 
+  Value mask;
+  if (maskOp)
+    mask = maskOp.getMask();
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
   if (!batchDimMap.empty()) {
@@ -982,8 +979,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
+    return newOp;
   }
 
   // Collect contracting dimensions.
@@ -1003,8 +999,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
+      return newOp;
     }
   }
 
@@ -1015,8 +1010,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
+      return newOp;
     }
   }
 
@@ -1025,8 +1019,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
     auto newOp = lowerReduction(rewriter, op, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
+    return newOp;
   }
 
   return failure();
@@ -1291,12 +1284,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
 /// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                                 PatternRewriter &rew) const {
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rew) const {
   // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
+  if (maskOp)
     return failure();
 
   if (vectorTransformOptions.vectorContractLowering !=
@@ -1379,8 +1371,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
           : static_cast<Value>(
                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
 
-  rew.replaceOp(op, res);
-  return success();
+  return res;
 }
 } // namespace
 


        


More information about the Mlir-commits mailing list