[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)

Andrzej Warzyński llvmlistbot at llvm.org
Tue Mar 5 08:52:07 PST 2024


================
@@ -634,52 +682,36 @@ struct UnrolledOuterProductGenerator
 ///
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
 
   if (failed(filter(op)))
     return failure();
 
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-  Operation *rootOp;
-  if (maskableOp.isMasked()) {
-    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-    rootOp = maskableOp.getMaskingOp();
-  } else {
-    rootOp = op;
-  }
-
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(rootOp, *matmatRes);
-    return success();
+    return matmatRes;
----------------
banach-space wrote:

We have two cases though:
* `vector.contract` --> here we should replace `vector.contract`
* `vector.mask {vector.contract}` --> here we should replace `vector.mask` (no need to "replace" `vector contract`)

This **can be** implemented, but then you will require logic like this in every pattern:
```cpp
  OpBuilder::InsertionGuard guard(rewriter);
  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
  Operation *rootOp;
  if (maskableOp.isMasked()) {
    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
    rootOp = maskableOp.getMaskingOp();
  } else {
    rootOp = op;
  }
```

That's something that I am removing and instead moving to the base class, `MaskableOpRewritePattern`. But it also means I can't "replace" anymore as the logic to decide "what" to replace has been removed.

Am I overthinking this? 🤔 

https://github.com/llvm/llvm-project/pull/83827


More information about the Mlir-commits mailing list