[Mlir-commits] [mlir] [mlir][linalg] Add pattern to bubble-up pack through expand shape op (PR #93529)

Prashant Kumar llvmlistbot at llvm.org
Wed May 29 09:18:27 PDT 2024


================
@@ -694,6 +696,105 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
   return success();
 }
 
+/// Project dimsPos to their collapsed positions in the reassocIndices.
+///
+/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
+/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
+/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
+/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
+static SmallVector<int64_t>
+projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
+                             ArrayRef<ReassociationIndices> reassocIndices) {
+  SmallVector<int64_t> projectedPos;
+
+  // Map each dimension to the position of corresponding reassociation index.
+  for (auto pos : dimsPos) {
+    for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
+      // If the dimension is present in the current indices group, the group
+      // position within the reassociation map is the desired projected
+      // dimension position.
+      if (llvm::any_of(indices,
+                       [&](int64_t expandDim) { return expandDim == pos; })) {
+        projectedPos.push_back(idx);
+        break;
+      }
+    }
+  }
+  assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
+
+  return projectedPos;
+}
+
+/// Bubble up pack op through expand shape op.
+static LogicalResult
+bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
+                                 tensor::PackOp packOp,
+                                 PatternRewriter &rewriter) {
+  // Cannot propagate shape expansion if there is outer dimensions permutation.
+  ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
----------------
pashu123 wrote:

Let's say we have `tensor.expand_shape <6x8x4> -> <3x2x8x4>` followed by `tensor.pack <3x2x8x4> -> <2x3x4x2x2x2> outer_dims_perm = [1, 0, 2, 3]`. We can still bubble up with `tensor.pack <6x8x4> -> <6x4x2x2x2> outer_dims_perm = [0, 1, 2]`  and `tensor.expand_shape<6x4x2x2x2> -> <2x3x4x2x2x2>` .

https://github.com/llvm/llvm-project/pull/93529


More information about the Mlir-commits mailing list