[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Feb 26 11:01:32 PST 2024
================
@@ -2040,6 +2040,77 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
}
};
+// Patterns to rewrite ExtractOp(ConstantMaskOp)
+//
+// When the result of ExtractOp is a subvector of input, we can rewrite it as
+// a ConstantMaskOp with subvector ranks.
+//
+// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
+//
+// When the result of ExtractOp is a scalar, we can get the scalar value
+// directly.
+//
+// ExtractOp(ConstantMaskOp) -> ConstantOp
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto constantMaskOp =
+ extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
+ if (!constantMaskOp)
+ return failure();
+
+ // All indices must be static.
+ ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+ unsigned dynamicPosCount =
+ llvm::count_if(extractOpPos, ShapedType::isDynamic);
+ // If there is any dynamic position in ExtractOp, we cannot determine the
+ // scalar value.
+ if (dynamicPosCount)
+ return failure();
+
+ ArrayRef<Attribute> maskDimSizes =
+ constantMaskOp.getMaskDimSizes().getValue();
+ Type resultTy = extractOp.getResult().getType();
+ if (resultTy.isa<mlir::VectorType>()) {
+ auto resultVectorTy = resultTy.cast<mlir::VectorType>();
----------------
kuhar wrote:
also here
https://github.com/llvm/llvm-project/pull/81057
More information about the Mlir-commits
mailing list