[Mlir-commits] [mlir] [mlir][Affine] Split off delinearize parts that depend on last component (PR #117015)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Nov 25 11:39:52 PST 2024
================
@@ -4729,12 +4729,97 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}
};
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
+/// %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+ : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp = delinearizeOp.getLinearIndex()
+ .getDefiningOp<affine::AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "index doesn't come from linearize");
+
+ if (!linearizeOp.getDisjoint())
+ return rewriter.notifyMatchFailure(linearizeOp,
+ "linearize isn't disjoint");
+
+ int64_t target = linearizeOp.getStaticBasis().back();
+ if (ShapedType::isDynamic(target))
+ return rewriter.notifyMatchFailure(
+ linearizeOp, "linearize ends with dynamic basis value");
+
+ int64_t sizeToSplit = 1;
+ size_t elemsToSplit = 0;
+ ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+ for (int64_t basisElem : llvm::reverse(basis)) {
+ if (ShapedType::isDynamic(basisElem))
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "dynamic basis element while scanning for split");
+ sizeToSplit *= basisElem;
+ elemsToSplit += 1;
+
+ if (sizeToSplit > target)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "overshot last argument size");
+ if (sizeToSplit == target)
+ break;
+ }
+
+ if (sizeToSplit < target)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "product of known basis elements doesn't exceed last "
+ "linearize argument");
+
+ if (elemsToSplit < 2)
+ return rewriter.notifyMatchFailure(
+ delinearizeOp, "don't have a non-trivial basis product");
----------------
krzysz00 wrote:
I think it's implicit in an existing test that permutes but I'll go add another one if it isn't.
https://github.com/llvm/llvm-project/pull/117015
More information about the Mlir-commits
mailing list