[Mlir-commits] [mlir] [mlir][Affine] Genarilze the linearize(delinearize()) simplifications (PR #117637)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Dec 2 09:31:03 PST 2024


================
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+                                   ArrayRef<OpFoldResult> terms) {
+  int64_t nDynamic = 0;
+  SmallVector<Value> dynamicPart;
+  AffineExpr result = builder.getAffineConstantExpr(1);
+  for (OpFoldResult term : terms) {
+    if (!term)
+      return term;
+    std::optional<int64_t> maybeConst = getConstantIntValue(term);
+    if (maybeConst) {
+      result = result * builder.getAffineConstantExpr(*maybeConst);
+    } else {
+      dynamicPart.push_back(term.get<Value>());
+      result = result * builder.getAffineSymbolExpr(nDynamic++);
+    }
+  }
+  if (auto constant = dyn_cast<AffineConstantExpr>(result))
+    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+///   by (...e, B1, B2, ..., BK, ...f)
+/// ```
 ///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
 /// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
 /// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
     : OpRewritePattern<affine::AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
 
+  struct Match {
+    AffineDelinearizeIndexOp delinearize;
+    unsigned linStart = 0;
+    unsigned delinStart = 0;
+    unsigned length = 0;
+  };
+
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
                                 PatternRewriter &rewriter) const override {
-    auto delinearizeOp = linearizeOp.getMultiIndex()
-                             .front()
-                             .getDefiningOp<affine::AffineDelinearizeIndexOp>();
-    if (!delinearizeOp)
-      return rewriter.notifyMatchFailure(
-          linearizeOp, "last entry doesn't come from a delinearize");
+    SmallVector<Match> matches;
+
+    const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+    ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+    ValueRange multiIndex = linearizeOp.getMultiIndex();
+    unsigned numLinArgs = multiIndex.size();
+    unsigned linArgIdx = 0;
+    // We only want to replace one run from the same delinearize op per
+    // pattern invocation lest we run into invalidation issues.
+    llvm::SmallPtrSet<Operation *, 2> seen;
+    while (linArgIdx < numLinArgs) {
+      auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+      if (!asResult) {
+        linArgIdx++;
+        continue;
+      }
 
-    if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
-      return rewriter.notifyMatchFailure(
-          linearizeOp, "basis of linearize and delinearize don't match exactly "
-                       "(excluding outer bounds)");
+      auto delinearizeOp =
+          dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+      if (!delinearizeOp) {
+        linArgIdx++;
+        continue;
+      }
+
+      /// Result 0 of the delinearize and argument 0 of the linearize can
+      /// leave their maximum value unspecified. However, even if this happens
+      /// we can still sometimes start the match process. Specifically, if
+      /// - The argument we're matching is result 0 and argument 0 (so the
+      /// bounds don't matter). For example,
+      ///
+      ///     %0:2 = affine.delinearize_index %x into (8) : index, index
+      ///     %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+      /// allows cancellation
+      /// - The delinearization doesn't specify a bound, but the linearization
+      ///  is `disjoint`, which asserts that the bound on the linearization is
+      ///  correct.
+      unsigned firstDelinArg = asResult.getResultNumber();
----------------
krzysz00 wrote:

Yeah, solid point. I went with `delinArgIdx`

https://github.com/llvm/llvm-project/pull/117637


More information about the Mlir-commits mailing list