[Mlir-commits] [mlir] [mlir][Affine] Split off delinearize parts that depend on last component (PR #117015)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Nov 25 11:39:52 PST 2024


================
@@ -4729,12 +4729,97 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
     return success();
   }
 };
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+///    %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+///    %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+///    %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+///    %1:2 = affine.delinearize_index %0 by (2, 3) : index
+///    %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+    : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp,
+                                         "linearize isn't disjoint");
+
+    int64_t target = linearizeOp.getStaticBasis().back();
+    if (ShapedType::isDynamic(target))
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "linearize ends with dynamic basis value");
+
+    int64_t sizeToSplit = 1;
+    size_t elemsToSplit = 0;
+    ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+    for (int64_t basisElem : llvm::reverse(basis)) {
+      if (ShapedType::isDynamic(basisElem))
+        return rewriter.notifyMatchFailure(
+            delinearizeOp, "dynamic basis element while scanning for split");
+      sizeToSplit *= basisElem;
+      elemsToSplit += 1;
+
+      if (sizeToSplit > target)
+        return rewriter.notifyMatchFailure(delinearizeOp,
+                                           "overshot last argument size");
+      if (sizeToSplit == target)
+        break;
+    }
+
+    if (sizeToSplit < target)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "product of known basis elements doesn't exceed last "
+                         "linearize argument");
+
+    if (elemsToSplit < 2)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "don't have a non-trivial basis product");
----------------
krzysz00 wrote:

I think it's implicit in an existing test that permutes but I'll go add another one if it isn't.

https://github.com/llvm/llvm-project/pull/117015


More information about the Mlir-commits mailing list