[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)

Diego Caballero llvmlistbot at llvm.org
Mon Mar 4 12:12:22 PST 2024


================
@@ -634,52 +682,36 @@ struct UnrolledOuterProductGenerator
 ///
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
 
   if (failed(filter(op)))
     return failure();
 
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-  Operation *rootOp;
-  if (maskableOp.isMasked()) {
-    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-    rootOp = maskableOp.getMaskingOp();
-  } else {
-    rootOp = op;
-  }
-
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(rootOp, *matmatRes);
-    return success();
+    return matmatRes;
----------------
dcaballe wrote:

Why don't we replace the root op here? That should preserve the same structure as for other patterns

https://github.com/llvm/llvm-project/pull/83827


More information about the Mlir-commits mailing list