[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Mar 8 07:13:31 PST 2024
================
@@ -212,6 +211,64 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
namespace {
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+/// 1. It matches `SourceOp` operation, Op.
+/// 2. If Op is masked, retrieves the mask and updates the insertion point to
+/// avoid inserting new ops into `vector.mask` Op region (which only allows
+/// one Op). If the Op is not masked, this step is a nop.
+/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+/// required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+/// * the source Op when mask _is not_ present,
+/// * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+ LogicalResult matchAndRewrite(SourceOp sourceOp,
+ PatternRewriter &rewriter) const final {
+ auto maskableOp =
+ dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+ if (!maskableOp)
+ return failure();
+
+ // Retrieve the mask if present
+ MaskingOpInterface maskOp;
+ if (maskableOp.isMasked())
+ maskOp = maskableOp.getMaskingOp();
+
+ // If this Op is masked, update the insertion point to avoid inserting into
+ // the vector.mask Op region.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Operation *rootOp = sourceOp;
+ if (maskOp) {
+ rewriter.setInsertionPoint(maskOp);
+ rootOp = maskOp;
+ }
+
+ FailureOr<Value> newOp =
+ matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+ if (failed(newOp))
+ return failure();
+
+ rewriter.replaceOp(rootOp, *newOp);
+ return success();
+ }
+
+public:
+ // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+ // latter is present, returns an updated masking op (with a replacement for
+ // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+ virtual FailureOr<Value>
+ matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const = 0;
+};
----------------
banach-space wrote:
Done!
https://github.com/llvm/llvm-project/pull/83827
More information about the Mlir-commits
mailing list