[Mlir-commits] [mlir] [mlir][Affine] Extend linearize/delinearize cancelation to partial tails (PR #116872)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 19 12:38:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

xisting patterns would cancel out the linearize_index / delinearize_index pairs that had the exact same basis, like

    %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
    %1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ...

This commit extends the canonicalization to handle instances where the entire basis doesn't match, as in

    %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index
    %1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ...

where we can replace the last two results of the delinearize_index operation with the last two inputs of the linearize_index, creating a more canonical (fewer total computations to perform) result.

---
Full diff: https://github.com/llvm/llvm-project/pull/116872.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+45-11) 
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4cf07bc167eab9..67d7da622a3550 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4666,14 +4666,16 @@ struct DropUnitExtentBasis
 };
 
 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
-/// disjoint` and the two operations have the same basis, replace the
-/// delinearizeation results with the inputs of the `affine.linearize_index`
-/// since they are exact inverses of each other.
+/// disjoint` and the two operations end with the same basis elements,
+/// cancel those parts of the operations out because they are inverses
+/// of each other.
+///
+/// If the operations have the same basis, cancel them entirely.
 ///
 /// The `disjoint` flag is needed on the `affine.linearize_index` because
 /// otherwise, there is no guarantee that the inputs to the linearization are
 /// in-bounds the way the outputs of the delinearization would be.
-struct CancelDelinearizeOfLinearizeDisjointExact
+struct CancelDelinearizeOfLinearizeDisjointExactTail
     : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -4685,12 +4687,45 @@ struct CancelDelinearizeOfLinearizeDisjointExact
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "index doesn't come from linearize");
 
-    if (!linearizeOp.getDisjoint() ||
-        linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
+
+    ValueRange linearizeIns = linearizeOp.getMultiIndex();
+    // Note: we use the full basis so we don't lose outer bounds later.
+    SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
+    SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
+    size_t numMatches = 0;
+    for (auto [linSize, delinSize] : llvm::zip(
+             llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
+      if (linSize != delinSize)
+        break;
+      ++numMatches;
+    }
+
+    if (numMatches == 0)
       return rewriter.notifyMatchFailure(
-          linearizeOp, "not disjoint or basis doesn't match delinearize");
+          delinearizeOp, "final basis element doesn't match linearize");
+
+    // The easy case: everything lines up and the basis match sup completely.
+    if (numMatches == linearizeBasis.size() &&
+        numMatches == delinearizeBasis.size() &&
+        linearizeIns.size() == delinearizeOp.getNumResults()) {
+      rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+      return success();
+    }
 
-    rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
+    Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
+        linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
+        ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
+        linearizeOp.getDisjoint());
+    auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        delinearizeOp.getLoc(), newLinearize,
+        ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
+        delinearizeOp.hasOuterBound());
+    SmallVector<Value> mergedResults(newDelinearize.getResults());
+    mergedResults.append(linearizeIns.take_back(numMatches).begin(),
+                         linearizeIns.take_back(numMatches).end());
+    rewriter.replaceOp(delinearizeOp, mergedResults);
     return success();
   }
 };
@@ -4698,9 +4733,8 @@ struct CancelDelinearizeOfLinearizeDisjointExact
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns
-      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
-          context);
+  patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
+                  DropUnitExtentBasis>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index b54a13cffe7771..5384977151b47f 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1739,6 +1739,24 @@ func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0:
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (%[[ARG3]], 4) : index
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[LIN]] into (8) : index, index
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_partial(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (8, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // Without `disjoint`, the cancelation isn't guaranteed to be the identity.
 // CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,

``````````

</details>


https://github.com/llvm/llvm-project/pull/116872


More information about the Mlir-commits mailing list