[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Oct 19 03:40:10 PDT 2023


================
@@ -1983,6 +1983,66 @@ class ExtractOpNonSplatConstantFolder final
   }
 };
 
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return failure();
+
+    ArrayRef<int64_t> position = extractOp.getStaticPosition();
+    auto maskOperands = createMaskOp.getOperands();
+    VectorType maskType = createMaskOp.getVectorType();
+    VectorType::Builder newMaskType(maskType);
+
+    bool allFalse = false;
+    bool containsUnknownDims = false;
+    for (auto [i, pos] : llvm::enumerate(position)) {
+      newMaskType.dropDim(0);
+      Value operand = maskOperands[i];
+      auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
+      if (!constantOp) {
+        // Bounds of this dim unknown.
+        containsUnknownDims = true;
+        continue;
+      }
+
+      int64_t createMaskBound =
+          llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+      if (pos == ShapedType::kDynamic) {
+        // Extractions must be in-bounds. So if the corresponding `create_mask`
+        // size is 0 or the size of the dim, we know this dim is false or true.
+        if (createMaskBound == 0)
+          allFalse = true;
+        else if (createMaskBound < maskType.getDimSize(i))
+          // Unknown if this dim is within the true or false region.
+          containsUnknownDims = true;
+      } else {
----------------
MacDue wrote:

I've moved the `all-false` mask detection to `getMaskFormat()`, which applies to a number of rewrites. The rest relies on knowledge of the extraction position, so I'm unsure how they'd apply more generally.  

https://github.com/llvm/llvm-project/pull/69456


More information about the Mlir-commits mailing list