[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Oct 19 03:40:10 PDT 2023
================
@@ -1983,6 +1983,66 @@ class ExtractOpNonSplatConstantFolder final
}
};
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ auto maskOperands = createMaskOp.getOperands();
+ VectorType maskType = createMaskOp.getVectorType();
+ VectorType::Builder newMaskType(maskType);
+
+ bool allFalse = false;
+ bool containsUnknownDims = false;
+ for (auto [i, pos] : llvm::enumerate(position)) {
+ newMaskType.dropDim(0);
+ Value operand = maskOperands[i];
+ auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
+ if (!constantOp) {
+ // Bounds of this dim unknown.
+ containsUnknownDims = true;
+ continue;
+ }
+
+ int64_t createMaskBound =
+ llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+ if (pos == ShapedType::kDynamic) {
+ // Extractions must be in-bounds. So if the corresponding `create_mask`
+ // size is 0 or the size of the dim, we know this dim is false or true.
+ if (createMaskBound == 0)
+ allFalse = true;
+ else if (createMaskBound < maskType.getDimSize(i))
+ // Unknown if this dim is within the true or false region.
+ containsUnknownDims = true;
+ } else {
----------------
MacDue wrote:
I've moved the `all-false` mask detection to `getMaskFormat()`, which applies to a number of rewrites. The rest relies on knowledge of the extraction position, so I'm unsure how they'd apply more generally.
https://github.com/llvm/llvm-project/pull/69456
More information about the Mlir-commits
mailing list