[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Oct 18 07:29:44 PDT 2023
================
@@ -1983,6 +1983,66 @@ class ExtractOpNonSplatConstantFolder final
}
};
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ auto maskOperands = createMaskOp.getOperands();
+ VectorType maskType = createMaskOp.getVectorType();
+ VectorType::Builder newMaskType(maskType);
+
+ bool allFalse = false;
+ bool containsUnknownDims = false;
+ for (auto [i, pos] : llvm::enumerate(position)) {
+ newMaskType.dropDim(0);
+ Value operand = maskOperands[i];
+ auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
+ if (!constantOp) {
+ // Bounds of this dim unknown.
+ containsUnknownDims = true;
+ continue;
+ }
+
+ int64_t createMaskBound =
+ llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+ if (pos == ShapedType::kDynamic) {
+ // Extractions must be in-bounds. So if the corresponding `create_mask`
+ // size is 0 or the size of the dim, we know this dim is false or true.
+ if (createMaskBound == 0)
+ allFalse = true;
+ else if (createMaskBound < maskType.getDimSize(i))
+ // Unknown if this dim is within the true or false region.
+ containsUnknownDims = true;
+ } else {
+ // If any position is outside the range from the `create_mask`, then the
+ // extracted mask will be all false.
+ allFalse |= pos >= createMaskBound;
+ }
----------------
banach-space wrote:
```suggestion
if (pos != ShapedType::kDynamic) {
// If any position is outside the range from the `create_mask`, then the
// extracted mask will be all false.
allFalse |= pos >= createMaskBound;
continue;
}
// Extractions must be in-bounds. So if the corresponding `create_mask`
// size is 0 or the size of the dim, we know this dim is false or true.
if (createMaskBound == 0)
allFalse = true;
else if (createMaskBound < maskType.getDimSize(i)) {
// Unknown if this dim is within the true or false region.
containsUnknownDims = true;
}
```
https://github.com/llvm/llvm-project/pull/69456
More information about the Mlir-commits
mailing list