[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Oct 18 07:29:42 PDT 2023


================
@@ -1983,6 +1983,66 @@ class ExtractOpNonSplatConstantFolder final
   }
 };
 
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return failure();
+
+    ArrayRef<int64_t> position = extractOp.getStaticPosition();
+    auto maskOperands = createMaskOp.getOperands();
+    VectorType maskType = createMaskOp.getVectorType();
+    VectorType::Builder newMaskType(maskType);
+
+    bool allFalse = false;
+    bool containsUnknownDims = false;
+    for (auto [i, pos] : llvm::enumerate(position)) {
----------------
banach-space wrote:

`i` is effectively the dimension index, right? i.e. `dimIdx`?

https://github.com/llvm/llvm-project/pull/69456


More information about the Mlir-commits mailing list