[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)
Diego Caballero
llvmlistbot at llvm.org
Mon Mar 4 12:12:22 PST 2024
================
@@ -634,52 +682,36 @@ struct UnrolledOuterProductGenerator
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
return failure();
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- Operation *rootOp;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- } else {
- rootOp = op;
- }
-
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
- rewriter.replaceOp(rootOp, *matmatRes);
- return success();
+ return matmatRes;
----------------
dcaballe wrote:
Why don't we replace the root op here? That should preserve the same structure as for other patterns
https://github.com/llvm/llvm-project/pull/83827
More information about the Mlir-commits
mailing list