[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Mar 20 04:28:11 PDT 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/83827
>From 2f916106db49c1b6ccda0ffed68a7fb79e5465c7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 4 Mar 2024 09:23:36 +0000
Subject: [PATCH 1/4] [mlir][vector] Adds pattern rewrite for maskable Ops
Adds a generic pattern rewrite for maskable Ops that we would like to
work for both masked and un-masked cases, e.g. for both `vector.mask
{vector.contract}` and `vector.contract` (this is a very contrived
example - just to demonstrate the idea).
This helps to reduce code-duplication and standardise how we implement
such patterns.
Fixes #78787
---
.../Vector/Transforms/LowerVectorContract.cpp | 251 ++++++++++--------
1 file changed, 143 insertions(+), 108 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 0eaf9f71a37d21..6480b295c49cb6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -41,7 +41,6 @@ using namespace mlir::vector;
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
-
// Helper to find an index in an affine map.
static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -212,6 +211,64 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
namespace {
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+/// 1. It matches `SourceOp` operation, Op.
+/// 2. If Op is masked, retrieves the mask and updates the insertion point to
+/// avoid inserting new ops into `vector.mask` Op region (which only allows
+/// one Op). If the Op is not masked, this step is a nop.
+/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+/// required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+/// * the source Op when mask _is not_ present,
+/// * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+ LogicalResult matchAndRewrite(SourceOp sourceOp,
+ PatternRewriter &rewriter) const final {
+ auto maskableOp =
+ dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+ if (!maskableOp)
+ return failure();
+
+ // Retrieve the mask if present
+ MaskingOpInterface maskOp;
+ if (maskableOp.isMasked())
+ maskOp = maskableOp.getMaskingOp();
+
+ // If this Op is masked, update the insertion point to avoid inserting into
+ // the vector.mask Op region.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Operation *rootOp = sourceOp;
+ if (maskOp) {
+ rewriter.setInsertionPoint(maskOp);
+ rootOp = maskOp;
+ }
+
+ FailureOr<Value> newOp =
+ matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+ if (failed(newOp))
+ return failure();
+
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+
+public:
+ // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+ // latter is present, returns an updated masking op (with a replacement for
+ // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+ virtual FailureOr<Value>
+ matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const = 0;
+};
+
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
@@ -226,9 +283,9 @@ namespace {
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
+ : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -241,12 +298,13 @@ class ContractionOpToMatmulOpLowering
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -270,9 +328,9 @@ class ContractionOpToMatmulOpLowering
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
- : public OpRewritePattern<vector::ContractionOp> {
+ : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -285,12 +343,13 @@ class ContractionOpToOuterProductOpLowering
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -317,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
/// This only kicks in when VectorTransformsOptions is set to Dot and
/// the vector.contract op is a row-major matmul or matvec.
class ContractionOpToDotLowering
- : public OpRewritePattern<vector::ContractionOp> {
+ : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -332,11 +391,12 @@ class ContractionOpToDotLowering
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -344,23 +404,10 @@ class ContractionOpToDotLowering
FilterConstraintType filter;
};
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-/// %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-/// %a = vector.contract with one less free/batch dimension
-/// %b = vector.contract with one less free/batch dimension
-/// ..
-/// %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+class ContractionOpLowering
+ : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -371,12 +418,13 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
- LogicalResult matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const override;
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
@@ -634,8 +682,10 @@ struct UnrolledOuterProductGenerator
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
@@ -643,43 +693,25 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
if (failed(filter(op)))
return failure();
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- Operation *rootOp;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- } else {
- rootOp = op;
- }
-
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
- rewriter.replaceOp(rootOp, *matmatRes);
- return success();
+ return matmatRes;
}
FailureOr<Value> matvecRes = e.matvec();
if (succeeded(matvecRes)) {
- rewriter.replaceOp(rootOp, *matvecRes);
- return success();
- }
- FailureOr<Value> tmatvecRes = e.tmatvec();
- if (succeeded(tmatvecRes)) {
- rewriter.replaceOp(rootOp, *tmatvecRes);
- return success();
+ return matvecRes;
}
- return failure();
+ FailureOr<Value> tmatvecRes = e.tmatvec();
+ return tmatvecRes;
}
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
// TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
+ if (maskOp)
return failure();
if (failed(filter(op)))
@@ -788,15 +820,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
}
if (auto acc = op.getAcc())
res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
- rewriter.replaceOp(op, res);
- return success();
+ return res;
}
/// Lower vector.contract with all size one reduction dimensions to
/// elementwise ops when possible.
struct ContractOpToElementwise
- : public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+ using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -806,14 +837,15 @@ struct ContractOpToElementwise
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : OpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
- LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+ MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const override {
// TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
- if (maskableOp.isMasked())
+ if (maskOp)
return failure();
if (failed(filter(contractOp)))
@@ -903,8 +935,10 @@ struct ContractOpToElementwise
std::optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
- rewriter.replaceOp(contractOp, {*result});
- return success();
+ if (result)
+ return *result;
+
+ return failure();
}
private:
@@ -930,9 +964,9 @@ struct ContractOpToElementwise
// TODO: break down into transpose/reshape/cast ops
// when they become available to avoid code dup
// TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
if (failed(filter(op)))
return failure();
@@ -951,29 +985,36 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
// TODO: implement benefits, cost models.
MLIRContext *ctx = op.getContext();
+
ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
- if (succeeded(pat1.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal1 =
+ pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal1))
+ return newVal1;
+
ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
- if (succeeded(pat2.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal2 =
+ pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal2))
+ return newVal2;
+
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
- if (succeeded(pat3.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal3 =
+ pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal3))
+ return newVal3;
+
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
- if (succeeded(pat4.matchAndRewrite(op, rewriter)))
- return success();
+ FailureOr<Value> newVal4 =
+ pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+ if (!failed(newVal4))
+ return newVal4;
// Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- Operation *rootOp = op;
- Value mask;
- if (op.isMasked()) {
- rewriter.setInsertionPoint(op.getMaskingOp());
- rootOp = op.getMaskingOp();
- mask = op.getMaskingOp().getMask();
- }
+ Value mask;
+ if (maskOp)
+ mask = maskOp.getMask();
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
@@ -982,8 +1023,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
// Collect contracting dimensions.
@@ -1003,8 +1043,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
}
@@ -1015,8 +1054,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
}
@@ -1025,8 +1063,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
auto newOp = lowerReduction(rewriter, op, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(rootOp, *newOp);
- return success();
+ return newOp;
}
return failure();
@@ -1291,12 +1328,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// This only kicks in when VectorTransformsOptions is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
/// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
- PatternRewriter &rew) const {
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rew) const {
// TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
+ if (maskOp)
return failure();
if (vectorTransformOptions.vectorContractLowering !=
@@ -1379,8 +1415,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
: static_cast<Value>(
rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
- rew.replaceOp(op, res);
- return success();
+ return res;
}
} // namespace
>From 69814ed22c15d630263059308cd50a26a431955d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 5 Mar 2024 16:42:26 +0000
Subject: [PATCH 2/4] fixup! [mlir][vector] Adds pattern rewrite for maskable
Ops
Rename the pattern, simplify code, restore docs
---
.../Vector/Transforms/LowerVectorContract.cpp | 66 +++++++++++--------
1 file changed, 40 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 6480b295c49cb6..af502270c68db2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -226,27 +226,27 @@ namespace {
/// * the source Op when mask _is not_ present,
/// * the source Op *and* the mask Op when mask _is_ present.
template <class SourceOp>
-struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
using OpRewritePattern<SourceOp>::OpRewritePattern;
private:
LogicalResult matchAndRewrite(SourceOp sourceOp,
PatternRewriter &rewriter) const final {
- auto maskableOp =
- dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+ auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
if (!maskableOp)
return failure();
- // Retrieve the mask if present
- MaskingOpInterface maskOp;
- if (maskableOp.isMasked())
- maskOp = maskableOp.getMaskingOp();
+ // Op to update
+ Operation *rootOp = sourceOp;
- // If this Op is masked, update the insertion point to avoid inserting into
- // the vector.mask Op region.
+ // If this Op is masked:
+ // * update the insertion point to avoid inserting into the vector.mask
+ // Op region,
+ // * update the Op to rewrite so that it's the parent vector.mask Op
OpBuilder::InsertionGuard guard(rewriter);
- Operation *rootOp = sourceOp;
- if (maskOp) {
+ MaskingOpInterface maskOp;
+ if (maskableOp.isMasked()) {
+ maskOp = maskableOp.getMaskingOp();
rewriter.setInsertionPoint(maskOp);
rootOp = maskOp;
}
@@ -283,9 +283,9 @@ struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
- : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -298,7 +298,7 @@ class ContractionOpToMatmulOpLowering
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
@@ -328,9 +328,9 @@ class ContractionOpToMatmulOpLowering
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToOuterProductOpLowering
- : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -343,7 +343,7 @@ class ContractionOpToOuterProductOpLowering
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
@@ -376,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
/// This only kicks in when VectorTransformsOptions is set to Dot and
/// the vector.contract op is a row-major matmul or matvec.
class ContractionOpToDotLowering
- : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -391,7 +391,7 @@ class ContractionOpToDotLowering
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
FailureOr<Value>
@@ -404,10 +404,24 @@ class ContractionOpToDotLowering
FilterConstraintType filter;
};
+/// Progressive lowering of ContractionOp.
+///
+/// One:
+/// %x = vector.contract with at least one free/batch dimension
+/// is replaced by:
+/// %a = vector.contract with one less free/batch dimension
+/// %b = vector.contract with one less free/batch dimension
+/// ..
+/// %x = combine %a %b ..
+/// until a pure contraction is reached (no free/batch dimensions),
+/// which is replaced by a dot-product.
+///
+/// This only kicks in when either VectorTransformsOptions is set
+/// to Dot or when other contraction patterns fail.
class ContractionOpLowering
- : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
public:
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -418,7 +432,7 @@ class ContractionOpLowering
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
FilterConstraintType constraint = defaultFilter)
- : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions),
filter(std::move(constraint)) {}
@@ -826,8 +840,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
/// Lower vector.contract with all size one reduction dimensions to
/// elementwise ops when possible.
struct ContractOpToElementwise
- : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -837,7 +851,7 @@ struct ContractOpToElementwise
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context, PatternBenefit benefit = 1,
const FilterConstraintType &constraint = defaultFilter)
- : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
+ : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
FailureOr<Value>
>From 3ae7afc4862674ce912fe19c554a21716e465160 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 8 Mar 2024 15:12:08 +0000
Subject: [PATCH 3/4] fixup! [mlir][vector] Adds pattern rewrite for maskable
Ops
Move pattern
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 58 ++++++++++++++++++
.../Vector/Transforms/LowerVectorContract.cpp | 60 +------------------
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 31 ++++++++++
3 files changed, 90 insertions(+), 59 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 3ce16ef361f37d..98b6f01f216fa2 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -112,6 +112,64 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
Operation *xfer,
RewriterBase &rewriter);
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+/// 1. It matches `SourceOp` operation, Op.
+/// 2. If Op is masked, retrieves the mask and updates the insertion point to
+/// avoid inserting new ops into `vector.mask` Op region (which only allows
+/// one Op). If the Op is not masked, this step is a nop.
+/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+/// required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+/// * the source Op when mask _is not_ present,
+/// * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+ LogicalResult matchAndRewrite(SourceOp sourceOp,
+ PatternRewriter &rewriter) const final {
+ auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
+ if (!maskableOp)
+ return failure();
+
+ // Op to update
+ Operation *rootOp = sourceOp;
+
+ // If this Op is masked:
+ // * update the insertion point to avoid inserting into the vector.mask
+ // Op region,
+ // * update the Op to rewrite so that it's the parent vector.mask Op
+ OpBuilder::InsertionGuard guard(rewriter);
+ MaskingOpInterface maskOp;
+ if (maskableOp.isMasked()) {
+ maskOp = maskableOp.getMaskingOp();
+ rewriter.setInsertionPoint(maskOp);
+ rootOp = maskOp;
+ }
+
+ FailureOr<Value> newOp =
+ matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+ if (failed(newOp))
+ return failure();
+
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+
+public:
+ // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+ // latter is present, returns an updated masking op (with a replacement for
+ // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+ virtual FailureOr<Value>
+ matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const = 0;
+};
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index af502270c68db2..ba1c96805ff831 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -211,64 +211,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
namespace {
-/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
-/// masked (i.e. inside `vector.mask` Op region). In particular:
-/// 1. It matches `SourceOp` operation, Op.
-/// 2. If Op is masked, retrieves the mask and updates the insertion point to
-/// avoid inserting new ops into `vector.mask` Op region (which only allows
-/// one Op). If the Op is not masked, this step is a nop.
-/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
-/// required) in the matched `vector.mask` operation from step 2.
-///
-/// It frees the patterns implementing this class from worrying about the
-/// logic to update the insertion point. However, those patterns are still
-/// responsible for providing an updated version of:
-/// * the source Op when mask _is not_ present,
-/// * the source Op *and* the mask Op when mask _is_ present.
-template <class SourceOp>
-struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
- using OpRewritePattern<SourceOp>::OpRewritePattern;
-
-private:
- LogicalResult matchAndRewrite(SourceOp sourceOp,
- PatternRewriter &rewriter) const final {
- auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
- if (!maskableOp)
- return failure();
-
- // Op to update
- Operation *rootOp = sourceOp;
-
- // If this Op is masked:
- // * update the insertion point to avoid inserting into the vector.mask
- // Op region,
- // * update the Op to rewrite so that it's the parent vector.mask Op
- OpBuilder::InsertionGuard guard(rewriter);
- MaskingOpInterface maskOp;
- if (maskableOp.isMasked()) {
- maskOp = maskableOp.getMaskingOp();
- rewriter.setInsertionPoint(maskOp);
- rootOp = maskOp;
- }
-
- FailureOr<Value> newOp =
- matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
- if (failed(newOp))
- return failure();
-
- rewriter.replaceOp(rootOp, *newOp);
- return success();
- }
-
-public:
- // Matches SourceOp that can potentially be masked with `maskingOp`. If the
- // latter is present, returns an updated masking op (with a replacement for
- // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
- virtual FailureOr<Value>
- matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
- PatternRewriter &rewriter) const = 0;
-};
-
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to:
/// ```
@@ -283,7 +225,7 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
- : public MaskableOpRewritePattern<vector::ContractionOp> {
+ : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 63ed0947cf6ce2..ca31ff9d16ad65 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -279,6 +279,37 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
}
+// template <class SourceOp>
+// LogicalResult vector::MaskableOpRewritePattern::matchAndRewrite(
+// SourceOp sourceOp, PatternRewriter &rewriter) const final {
+// auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
+// if (!maskableOp)
+// return failure();
+
+// // Op to update
+// Operation *rootOp = sourceOp;
+
+// // If this Op is masked:
+// // * update the insertion point to avoid inserting into the vector.mask
+// // Op region,
+// // * update the Op to rewrite so that it's the parent vector.mask Op
+// OpBuilder::InsertionGuard guard(rewriter);
+// MaskingOpInterface maskOp;
+// if (maskableOp.isMasked()) {
+// maskOp = maskableOp.getMaskingOp();
+// rewriter.setInsertionPoint(maskOp);
+// rootOp = maskOp;
+// }
+
+// FailureOr<Value> newOp =
+// matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+// if (failed(newOp))
+// return failure();
+
+// rewriter.replaceOp(rootOp, *newOp);
+// return success();
+// }
+
std::optional<StaticTileOffsetRange>
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
if (vType.getRank() <= targetRank)
>From e80fb947ed3a519d17c853da9965e39ffa540211 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Mar 2024 11:26:38 +0000
Subject: [PATCH 4/4] fixup! [mlir][vector] Adds pattern rewrite for maskable
Ops
Fix comments
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 32 +++++++++----------
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 31 ------------------
2 files changed, 16 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 98b6f01f216fa2..35e76a8b623a32 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -114,18 +114,21 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
/// masked (i.e. inside `vector.mask` Op region). In particular:
-/// 1. It matches `SourceOp` operation, Op.
-/// 2. If Op is masked, retrieves the mask and updates the insertion point to
-/// avoid inserting new ops into `vector.mask` Op region (which only allows
-/// one Op). If the Op is not masked, this step is a nop.
-/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
-/// required) in the matched `vector.mask` operation from step 2.
+/// 1. Matches `SourceOp` operation, Op.
+/// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
+/// insertion point to avoid inserting new ops into the `vector.mask` Op
+/// region (which only allows one Op).
+/// 2.2 If Op is not masked, this step is skipped.
+/// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
+/// found in step 2.1.
///
-/// It frees the patterns implementing this class from worrying about the
-/// logic to update the insertion point. However, those patterns are still
-/// responsible for providing an updated version of:
-/// * the source Op when mask _is not_ present,
-/// * the source Op *and* the mask Op when mask _is_ present.
+/// This wrapper frees patterns from re-implementing the logic to update the
+/// insertion point when a maskable Op is masked. Such patterns are still
+/// responsible for providing an updated ("rewritten") version of:
+/// a. the source Op when mask _is not_ present,
+/// b. the source Op and the masking Op when mask _is_ present.
+/// Note that the return value from `matchAndRewriteMaskableOp` depends on the
+/// case above.
template <class SourceOp>
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
using OpRewritePattern<SourceOp>::OpRewritePattern;
@@ -137,13 +140,10 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
if (!maskableOp)
return failure();
- // Op to update
Operation *rootOp = sourceOp;
- // If this Op is masked:
- // * update the insertion point to avoid inserting into the vector.mask
- // Op region,
- // * update the Op to rewrite so that it's the parent vector.mask Op
+ // If this Op is masked, update the insertion point to avoid inserting into
+ // the vector.mask Op region.
OpBuilder::InsertionGuard guard(rewriter);
MaskingOpInterface maskOp;
if (maskableOp.isMasked()) {
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index ca31ff9d16ad65..63ed0947cf6ce2 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -279,37 +279,6 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
}
-// template <class SourceOp>
-// LogicalResult vector::MaskableOpRewritePattern::matchAndRewrite(
-// SourceOp sourceOp, PatternRewriter &rewriter) const final {
-// auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
-// if (!maskableOp)
-// return failure();
-
-// // Op to update
-// Operation *rootOp = sourceOp;
-
-// // If this Op is masked:
-// // * update the insertion point to avoid inserting into the vector.mask
-// // Op region,
-// // * update the Op to rewrite so that it's the parent vector.mask Op
-// OpBuilder::InsertionGuard guard(rewriter);
-// MaskingOpInterface maskOp;
-// if (maskableOp.isMasked()) {
-// maskOp = maskableOp.getMaskingOp();
-// rewriter.setInsertionPoint(maskOp);
-// rootOp = maskOp;
-// }
-
-// FailureOr<Value> newOp =
-// matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
-// if (failed(newOp))
-// return failure();
-
-// rewriter.replaceOp(rootOp, *newOp);
-// return success();
-// }
-
std::optional<StaticTileOffsetRange>
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
if (vType.getRank() <= targetRank)
More information about the Mlir-commits
mailing list