[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)

Diego Caballero llvmlistbot at llvm.org
Wed Feb 7 22:23:58 PST 2024


================
@@ -2039,6 +2039,60 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Pattern to rewrite
+// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
+// or
+// ExtractOp(ConstantMaskOp) -> ConstantOp
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantMaskOp =
+        extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
+    if (!constantMaskOp)
+      return failure();
+
+    auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+    auto resultTy = extractOp.getResult().getType();
+    if (resultTy.isa<mlir::VectorType>()) {
+      auto resultVectorTy = resultTy.cast<mlir::VectorType>();
+      auto resultRank = resultVectorTy.getRank();
+      auto n = maskDimSizes.size();
+      std::vector<int64_t> indices;
+      for (size_t i = n - resultRank; i < n; i++)
+        indices.push_back(cast<IntegerAttr>(maskDimSizes[i]).getInt());
+
+      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+          extractOp, resultVectorTy,
+          vector::getVectorSubscriptAttr(rewriter, indices));
+
+      return success();
+    } else if (resultTy.isa<mlir::IntegerType>()) {
+      // Extract a scalar. All indices must be static.
+      ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+      if (maskDimSizes.size() != extractOpPos.size())
+        return failure();
+
+      auto boolType = rewriter.getI1Type();
+      auto setAttr = IntegerAttr::get(boolType, 1);
+      for (size_t i = 0; i < extractOpPos.size(); i++) {
+        if (cast<IntegerAttr>(maskDimSizes[i]).getInt() <= extractOpPos[i]) {
+          setAttr = IntegerAttr::get(boolType, 0);
+          break;
+        }
+      }
+
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, boolType,
+                                                     setAttr);
----------------
dcaballe wrote:

Some comments about what this code is doing would be helpful

https://github.com/llvm/llvm-project/pull/81057


More information about the Mlir-commits mailing list