[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Oct 18 07:29:42 PDT 2023
================
@@ -1983,6 +1983,66 @@ class ExtractOpNonSplatConstantFolder final
}
};
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ auto maskOperands = createMaskOp.getOperands();
+ VectorType maskType = createMaskOp.getVectorType();
+ VectorType::Builder newMaskType(maskType);
+
+ bool allFalse = false;
+ bool containsUnknownDims = false;
+ for (auto [i, pos] : llvm::enumerate(position)) {
----------------
banach-space wrote:
`i` is effectively the dimension index, right? i.e. `dimIdx`?
https://github.com/llvm/llvm-project/pull/69456
More information about the Mlir-commits
mailing list