[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)
Diego Caballero
llvmlistbot at llvm.org
Wed Feb 7 22:23:58 PST 2024
================
@@ -2039,6 +2039,60 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
}
};
+// Pattern to rewrite
+// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
+// or
+// ExtractOp(ConstantMaskOp) -> ConstantOp
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto constantMaskOp =
+ extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
+ if (!constantMaskOp)
+ return failure();
+
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+ auto resultTy = extractOp.getResult().getType();
+ if (resultTy.isa<mlir::VectorType>()) {
+ auto resultVectorTy = resultTy.cast<mlir::VectorType>();
+ auto resultRank = resultVectorTy.getRank();
+ auto n = maskDimSizes.size();
+ std::vector<int64_t> indices;
+ for (size_t i = n - resultRank; i < n; i++)
----------------
dcaballe wrote:
nit: use pre-increment per coding standards
https://github.com/llvm/llvm-project/pull/81057
More information about the Mlir-commits
mailing list