[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)

Diego Caballero llvmlistbot at llvm.org
Mon Mar 4 12:12:21 PST 2024


================
@@ -212,6 +211,64 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
 
 namespace {
 
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+///   1. It matches `SourceOp` operation, Op.
+///   2. If Op is masked, retrieves the mask and updates the insertion point to
+///   avoid inserting new ops into `vector.mask` Op region (which only allows
+///   one Op). If the Op is not masked, this step is a nop.
+///   3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+///   required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+///   * the source Op when mask _is not_ present,
+///   * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(SourceOp sourceOp,
+                                PatternRewriter &rewriter) const final {
+    auto maskableOp =
+        dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+    if (!maskableOp)
+      return failure();
+
+    // Retrieve the mask if present
+    MaskingOpInterface maskOp;
+    if (maskableOp.isMasked())
+      maskOp = maskableOp.getMaskingOp();
+
+    // If this Op is masked, update the insertion point to avoid inserting into
+    // the vector.mask Op region.
+    OpBuilder::InsertionGuard guard(rewriter);
+    Operation *rootOp = sourceOp;
+    if (maskOp) {
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+    }
+
+    FailureOr<Value> newOp =
+        matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+public:
+  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+  // latter is present, returns an updated masking op (with a replacement for
+  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+  virtual FailureOr<Value>
+  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const = 0;
+};
----------------
dcaballe wrote:

Probably better to place this base class somewhere more public. Perhaps somewhere in vector utils or some other header?

https://github.com/llvm/llvm-project/pull/83827


More information about the Mlir-commits mailing list