[Mlir-commits] [mlir] ece4e12 - [mlir][Affine] Split off delinearize parts that depend on last component (#117015)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 25 14:26:22 PST 2024


Author: Krzysztof Drewniak
Date: 2024-11-25T16:26:19-06:00
New Revision: ece4e1276e2140d84b05b8c430a0e547a1f23210

URL: https://github.com/llvm/llvm-project/commit/ece4e1276e2140d84b05b8c430a0e547a1f23210
DIFF: https://github.com/llvm/llvm-project/commit/ece4e1276e2140d84b05b8c430a0e547a1f23210.diff

LOG: [mlir][Affine] Split off delinearize parts that depend on last component (#117015)

If we have

    %0 = affine.linearize_index disjoint [%a, %b] by (A, B)
    %1:3 = affine.delinearize_index %0 into (A, B1, B2)

where B = B1 * B2 (or some mor complex product), we can simplify this to

    %0 = affine.linearize_index disjoint [%a] by (A)
    %1a:1 = affine.delinearize_index %0 into (A)
    %1b:2 = affine.delinearize_index %b into (B1, B2)

This, and more complex cases, prevent us from adding terms together only
to divide them away from each other.

---------

Co-authored-by: Abhishek Varma <abhvarma at amd.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 67d7da622a3550..1c5466730a5589 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4729,12 +4729,98 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
     return success();
   }
 };
+
+/// If the input to a delinearization is a disjoint linearization, and the
+/// last k > 1 components of the delinearization basis multiply to the
+/// last component of the linearization basis, break the linearization and
+/// delinearization into two parts, peeling off the last input to linearization.
+///
+/// For example:
+///    %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
+///    %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
+/// becomes
+///    %0 = affine.linearize_index [%z, %y] by (3, 2) : index
+///    %1:2 = affine.delinearize_index %0 by (2, 3) : index
+///    %2:2 = affine.delinearize_index %x by (8, 4) : index
+/// where the original %1:4 is replaced by %1:2 ++ %2:2
+struct SplitDelinearizeSpanningLastLinearizeArg final
+    : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp = delinearizeOp.getLinearIndex()
+                           .getDefiningOp<affine::AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp,
+                                         "linearize isn't disjoint");
+
+    int64_t target = linearizeOp.getStaticBasis().back();
+    if (ShapedType::isDynamic(target))
+      return rewriter.notifyMatchFailure(
+          linearizeOp, "linearize ends with dynamic basis value");
+
+    int64_t sizeToSplit = 1;
+    size_t elemsToSplit = 0;
+    ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
+    for (int64_t basisElem : llvm::reverse(basis)) {
+      if (ShapedType::isDynamic(basisElem))
+        return rewriter.notifyMatchFailure(
+            delinearizeOp, "dynamic basis element while scanning for split");
+      sizeToSplit *= basisElem;
+      elemsToSplit += 1;
+
+      if (sizeToSplit > target)
+        return rewriter.notifyMatchFailure(delinearizeOp,
+                                           "overshot last argument size");
+      if (sizeToSplit == target)
+        break;
+    }
+
+    if (sizeToSplit < target)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp, "product of known basis elements doesn't exceed last "
+                         "linearize argument");
+
+    if (elemsToSplit < 2)
+      return rewriter.notifyMatchFailure(
+          delinearizeOp,
+          "need at least two elements to form the basis product");
+
+    Value linearizeWithoutBack =
+        rewriter.create<affine::AffineLinearizeIndexOp>(
+            linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
+            linearizeOp.getDynamicBasis(),
+            linearizeOp.getStaticBasis().drop_back(),
+            linearizeOp.getDisjoint());
+    auto delinearizeWithoutSplitPart =
+        rewriter.create<affine::AffineDelinearizeIndexOp>(
+            delinearizeOp.getLoc(), linearizeWithoutBack,
+            delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
+            delinearizeOp.hasOuterBound());
+    auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
+        basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
+    SmallVector<Value> results = llvm::to_vector(
+        llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
+                            delinearizeBack.getResults()));
+    rewriter.replaceOp(delinearizeOp, results);
+
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
-                  DropUnitExtentBasis>(context);
+  patterns
+      .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
+              DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5384977151b47f..d3f61f7e503f9b 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1795,6 +1795,72 @@ func.func @no_cancel_delinearize_linearize_
diff erent_basis(%arg0: index, %arg1:
 
 // -----
 
+// CHECK-LABEL: func @split_delinearize_spanning_final_part
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 4)
+//       CHECK:     %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2)
+//       CHECK:     %[[DELIN2:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+//       CHECK:     return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 8, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_and_cancel
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG2]] into (8, 8)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @split_delinearize_spanning_final_part_and_cancel(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 4, 8, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't match the last basis element before
+// overshooting it, don't simplify.
+// CHECK-LABEL: func @dont_split_delinearize_overshooting_target
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]], %[[ARG2]]] by (2, 4, 64)
+//       CHECK:     %[[DELIN:.+]]:4 = affine.delinearize_index %[[LIN]] into (2, 16, 8)
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2, %[[DELIN]]#3
+func.func @dont_split_delinearize_overshooting_target(%arg0: index, %arg1: index, %arg2: index) -> (index, index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 4, 64) : index
+  %1:4 = affine.delinearize_index %0 into (2, 16, 8)
+      : index, index, index, index
+  return %1#0, %1#1, %1#2, %1#3 : index, index, index, index
+}
+
+// -----
+
+// The delinearize basis doesn't fully multiply to the final basis element.
+// CHECK-LABEL: func @dont_split_delinearize_undershooting_target
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (2, 64)
+//       CHECK:     %[[DELIN:.+]]:3 = affine.delinearize_index %[[LIN]] into (4, 8)
+//       CHECK:     return %[[DELIN]]#0, %[[DELIN]]#1
+func.func @dont_split_delinearize_undershooting_target(%arg0: index, %arg1: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1] by (2, 64) : index
+  %1:3 = affine.delinearize_index %0 into (4, 8)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_disjoint
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index


        


More information about the Mlir-commits mailing list