[Mlir-commits] [mlir] [mlir][linalg] Add pattern to bubble-up pack through expand shape op (PR #93529)
Prashant Kumar
llvmlistbot at llvm.org
Thu Jun 13 02:24:00 PDT 2024
================
@@ -694,6 +696,107 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
return success();
}
+/// Project dimsPos to their collapsed positions in the reassocIndices.
+///
+/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
+/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
+/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
+/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
+static SmallVector<int64_t>
+projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
+ ArrayRef<ReassociationIndices> reassocIndices) {
+ SmallVector<int64_t> projectedPos;
+
+ // Map each dimension to the position of corresponding reassociation index.
+ for (auto pos : dimsPos) {
+ for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
+ // If the dimension is present in the current indices group, the group
+ // position within the reassociation map is the desired projected
+ // dimension position.
+ if (llvm::any_of(indices,
+ [&](int64_t expandDim) { return expandDim == pos; })) {
+ projectedPos.push_back(idx);
+ break;
+ }
+ }
+ }
+ assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
+
+ return projectedPos;
+}
+
+/// Bubble up pack op through expand shape op.
+static LogicalResult
+bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
+ tensor::PackOp packOp,
+ PatternRewriter &rewriter) {
+ // Outer dimensions permutation is not supported currently.
+ // TODO: Handle outer_dims_perm variants.
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+ if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+ return rewriter.notifyMatchFailure(packOp,
+ "non-identity outer dims perm NYI");
+ }
+
+ // Validate dimensions' relations between shape expansion and packing.
+ SmallVector<ReassociationIndices, 4> reassoc =
+ expandOp.getReassociationIndices();
+ ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
+ llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
+ packInnerDims.end());
+
+ for (auto [idx, indices] : llvm::enumerate(reassoc)) {
+ llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
+ llvm::SetVector<int64_t> packedDims =
+ llvm::set_intersection(packDimsPos, expandDimPos);
+
+ // The expanded dimension is not packed - simply continue.
+ if (packedDims.empty())
+ continue;
+ // Shape expansion cannot be propagated when multiple expanded dimension are
+ // packed.
+ if (packedDims.size() > 1)
+ return rewriter.notifyMatchFailure(
+ packOp, "only one of the expanded dimensions can be packed");
+ // Only the inner-most dim should be packed. Otherwise, elements order will
+ // be affected after operation reordering.
+ if (packedDims[0] != indices.back())
----------------
pashu123 wrote:
```suggestion
if (packedDims.front() != indices.back())
```
https://github.com/llvm/llvm-project/pull/93529
More information about the Mlir-commits
mailing list