[Mlir-commits] [mlir] [mlir][vector] Adds pattern rewrite for maskable Ops (PR #83827)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Mar 5 08:52:07 PST 2024
================
@@ -634,52 +682,36 @@ struct UnrolledOuterProductGenerator
///
/// This only kicks in when VectorTransformsOptions is set to OuterProduct but
/// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
- vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+ vector::ContractionOp op, MaskingOpInterface maskOp,
+ PatternRewriter &rewriter) const {
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::OuterProduct)
return failure();
if (failed(filter(op)))
return failure();
- // Vector mask setup.
- OpBuilder::InsertionGuard guard(rewriter);
- auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
- Operation *rootOp;
- if (maskableOp.isMasked()) {
- rewriter.setInsertionPoint(maskableOp.getMaskingOp());
- rootOp = maskableOp.getMaskingOp();
- } else {
- rootOp = op;
- }
-
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
- rewriter.replaceOp(rootOp, *matmatRes);
- return success();
+ return matmatRes;
----------------
banach-space wrote:
We have two cases though:
* `vector.contract` --> here we should replace `vector.contract`
* `vector.mask {vector.contract}` --> here we should replace `vector.mask` (no need to "replace" `vector contract`)
This **can be** implemented, but then you will require logic like this in every pattern:
```cpp
OpBuilder::InsertionGuard guard(rewriter);
auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
Operation *rootOp;
if (maskableOp.isMasked()) {
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
rootOp = maskableOp.getMaskingOp();
} else {
rootOp = op;
}
```
That's something that I am removing and instead moving to the base class, `MaskableOpRewritePattern`. But it also means I can't "replace" anymore as the logic to decide "what" to replace has been removed.
Am I overthinking this? 🤔
https://github.com/llvm/llvm-project/pull/83827
More information about the Mlir-commits
mailing list