[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)
Artem Gindinson
llvmlistbot at llvm.org
Tue May 20 15:46:11 PDT 2025
https://github.com/AGindinson updated https://github.com/llvm/llvm-project/pull/137963
>From 2ae44808cd573c31331d143d687afcdab6986a72 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Mon, 28 Apr 2025 17:24:11 +0000
Subject: [PATCH 1/8] [mlir][tensor] Loosen restrictions on folding dynamic
reshapes
The main idea behind the change is to allow expand-of-collapse folds
for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the
expand op must have a coherent index/affine expression specified in its
`output_shape` argument (see example below), and if it doesn't, the IR
has already been invalidated at an earlier stage:
```
%c32 = arith.constant 32 : index
%div = arith.divsi %<some_index>, %c32 : index
%collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]]
: tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32>
%affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div]
%expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine]
: tensor<9x?x?xf32> into tensor<9x?x32x?xf32>
```
On the above assumption, adjust the routine in
`getReassociationIndicesForCollapse()` to allow dynamic reshapes
beyond just `?x..?x1x1x..x1` -> `?`.
Moreover, the reassociation util was refactored to clearly distinguish
between dynamic and static subshapes. A few known caveats were noted as
a comment; it doesn't seem possible to fold all qualifying dynamic shape
patterns in a deterministic way without looking into affine expressions
simultaneously. That would be difficult to maintain in a single general
utility. Other implementation ideas/larger refactoring could include:
- abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern,
employing similar logic to `ComposeCollapseOfExpandOp`;
- providing dialect-specific implementations for Linalg/Tensor.
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 103 ++++++++++--------
.../Dialect/Linalg/simplify-pack-unpack.mlir | 4 +-
mlir/test/Dialect/Tensor/canonicalize.mlir | 24 +++-
3 files changed, 79 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..694783849198a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ if (numSourceDims <= numTargetDims)
return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
-
- ReassociationIndices currIndices;
- int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
- break;
+ SmallVector<ReassociationIndices, 4> reassociationMap;
+ reassociationMap.reserve(numTargetDims);
+ unsigned sourceDim = 0, targetDim = 0;
+ for (; targetDim < numTargetDims; ++targetDim) {
int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+ ReassociationIndices currIndices;
+ // 1. Target dimension is dynamic. Source shape should contain at least
+ // one dynamic dimension.
+ if (currTargetShape == ShapedType::kDynamic) {
+ // FIXME: We stop the search with the first dynamic dimension, while in
+ // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
+ // indeterministic altogether when we have neighboring dynamic dimensions
+ // in the target shape. Most of these patterns will be safely rejected,
+ // however we might achieve more correct folds by taking affine
+ // expressions into account, if these can be passed on by the call sites.
+ bool foundDynamic = false;
+ while (sourceDim < numSourceDims) {
+ currIndices.push_back(sourceDim);
+ if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
+ foundDynamic = true;
+ break;
+ }
+ }
+ if (!foundDynamic)
+ return std::nullopt;
+
+ reassociationMap.push_back(currIndices);
+ continue;
+ }
+ // 2. Target dimension is static. The product of dimensions of the expanded
+ // shape should match the collapsed dimension shape.
+ int64_t prodOfCollapsedDims = 1;
+ bool reachedTargetDimSize = false;
+ while (sourceDim < numSourceDims) {
+ // Source shape cannot be dynamic if the target dim is static.
+ if (sourceShape[sourceDim] == ShapedType::kDynamic)
+ return std::nullopt;
prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+ if (prodOfCollapsedDims > currTargetShape)
+ break;
+ else if (prodOfCollapsedDims == currTargetShape) {
+ currIndices.push_back(sourceDim++);
+ reachedTargetDimSize = true;
+ break;
+ } else // prodOfCollapsedDims < currTargetShape
+ currIndices.push_back(sourceDim++);
}
-
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+ if (!reachedTargetDimSize)
return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
- return std::nullopt;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return std::nullopt;
-
- currIndices.push_back(sourceDim++);
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
+ reassociationMap.push_back(currIndices);
}
- // All the dimensions in the target must have been processed.
- if (reassociationMap.size() != targetShape.size())
- return std::nullopt;
- // Process any remaining entries in the source shape. They all need to be
- // 1 or dynamic.
- for (; sourceDim < sourceShape.size(); sourceDim++) {
- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+ // Now that we've mapped all the target dimensions, process any remaining
+ // entries in the source shape explicitly. Either the last target dimension
+ // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
+ // applies when target shape is empty (can be the case for subshape
+ // reassociations).
+ for (; sourceDim < numSourceDims; sourceDim++) {
+ if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
+ sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
// The map is empty when the target type is a scalar.
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT: tensor.collapse
-// CHECK: linalg.unpack
+// CHECK: tensor.collapse
+// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..443f931745557 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
// -----
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape
// -----
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+ : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+// CHECK-NOT: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+// CHECK-NEXT: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
>From 52ff4e0a1e81e13282975076d730e741a1da1cae Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:12:21 +0000
Subject: [PATCH 2/8] [fixup] Algorithm rewrite
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 140 ++++++++++++++-------
1 file changed, 93 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 694783849198a..1cd06a2757363 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -39,67 +39,113 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
reassociationMap.reserve(numTargetDims);
unsigned sourceDim = 0, targetDim = 0;
- for (; targetDim < numTargetDims; ++targetDim) {
- int64_t currTargetShape = targetShape[targetDim];
- ReassociationIndices currIndices;
- // 1. Target dimension is dynamic. Source shape should contain at least
- // one dynamic dimension.
- if (currTargetShape == ShapedType::kDynamic) {
- // FIXME: We stop the search with the first dynamic dimension, while in
- // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
- // indeterministic altogether when we have neighboring dynamic dimensions
- // in the target shape. Most of these patterns will be safely rejected,
- // however we might achieve more correct folds by taking affine
- // expressions into account, if these can be passed on by the call sites.
- bool foundDynamic = false;
- while (sourceDim < numSourceDims) {
- currIndices.push_back(sourceDim);
- if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
- foundDynamic = true;
- break;
- }
- }
- if (!foundDynamic)
- return std::nullopt;
-
- reassociationMap.push_back(currIndices);
- continue;
- }
- // 2. Target dimension is static. The product of dimensions of the expanded
- // shape should match the collapsed dimension shape.
+ // Source dimensions iteration logic for static target dimensions.
+ // FIXME: Instead of lambda-capturing this function's source shape index "in
+ // place", consider refactoring this into a separate function.
+ auto collectSourceIndicesForStaticTargetDim =
+ [&](int64_t targetShape,
+ bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
+ ReassociationIndices resultIndices;
int64_t prodOfCollapsedDims = 1;
bool reachedTargetDimSize = false;
- while (sourceDim < numSourceDims) {
+ for (; sourceDim < numSourceDims; ++sourceDim) {
// Source shape cannot be dynamic if the target dim is static.
if (sourceShape[sourceDim] == ShapedType::kDynamic)
- return std::nullopt;
+ return failure();
prodOfCollapsedDims *= sourceShape[sourceDim];
- if (prodOfCollapsedDims > currTargetShape)
- break;
- else if (prodOfCollapsedDims == currTargetShape) {
- currIndices.push_back(sourceDim++);
+ resultIndices.push_back(sourceDim);
+ if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
+ return failure();
+ while (prodOfCollapsedDims > targetShape) {
+ assert(!resultIndices.empty());
+ auto frontOffsetIdx = resultIndices.begin();
+ prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
+ resultIndices.erase(frontOffsetIdx);
+ }
+ if (prodOfCollapsedDims == targetShape) {
reachedTargetDimSize = true;
+ ++sourceDim;
break;
- } else // prodOfCollapsedDims < currTargetShape
- currIndices.push_back(sourceDim++);
+ }
}
if (!reachedTargetDimSize)
+ return failure();
+ return resultIndices;
+ };
+ // Source dimensions iteration logic for dynamic target dimensions.
+ // FIXME: Instead of lambda-capturing this function's source shape index "in
+ // place", consider refactoring this into a separate function.
+ auto collectSourceIndicesForDynamicTargetDim =
+ [&](bool allowStaticNonOnes,
+ bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
+ ReassociationIndices resultIndices;
+ bool foundFirstDynamic = false;
+ while (sourceDim < numSourceDims) {
+ if (sourceShape[sourceDim] == ShapedType::kDynamic) {
+ if (foundFirstDynamic && !mapConsecutiveDynDims)
+ break;
+ foundFirstDynamic |= true;
+ } else {
+ if (foundFirstDynamic)
+ break;
+ else if (sourceShape[sourceDim] > 1 && !allowStaticNonOnes)
+ return failure();
+ }
+ resultIndices.push_back(sourceDim++);
+ }
+ if (!foundFirstDynamic)
+ return failure();
+ return resultIndices;
+ };
+ // Iterate over target shape.
+ bool wasLastDimDynamic = false;
+ for (; targetDim < numTargetDims; ++targetDim) {
+ int64_t currTargetShape = targetShape[targetDim];
+ if (currTargetShape != ShapedType::kDynamic) {
+ unsigned sourceDimAtStart = sourceDim;
+ auto indices = collectSourceIndicesForStaticTargetDim(
+ currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
+ if (failed(indices))
+ return std::nullopt;
+ if (wasLastDimDynamic) {
+ assert(!reassociationMap.empty());
+ auto &previousIndices = reassociationMap.back();
+ for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart)
+ previousIndices.push_back(sourceDimAtStart);
+ }
+ reassociationMap.push_back(*indices);
+ wasLastDimDynamic = false;
+ continue;
+ }
+
+ bool isNextDimDynamic = targetDim + 1 < numTargetDims &&
+ targetShape[targetDim + 1] == ShapedType::kDynamic;
+ auto indices = collectSourceIndicesForDynamicTargetDim(
+ /*allowStaticNonOnes=*/!wasLastDimDynamic,
+ /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
+ if (failed(indices))
return std::nullopt;
- reassociationMap.push_back(currIndices);
+ reassociationMap.push_back(*indices);
+ wasLastDimDynamic = true;
}
// Now that we've mapped all the target dimensions, process any remaining
- // entries in the source shape explicitly. Either the last target dimension
- // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
- // applies when target shape is empty (can be the case for subshape
- // reassociations).
+ // entries in the source shape explicitly.
for (; sourceDim < numSourceDims; sourceDim++) {
- if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- sourceShape[sourceDim] != 1)
+ const bool isOne = sourceShape[sourceDim] == 1,
+ isDynamic = sourceShape[sourceDim] == ShapedType::kDynamic;
+ if (targetShape.empty()) {
+ if (!isOne && !isDynamic)
+ return std::nullopt;
+ continue;
+ }
+ if (wasLastDimDynamic && isDynamic)
+ return std::nullopt;
+ // If the last target dimension is static, only source dimensions of 1 are
+ // acceptable.
+ if (!wasLastDimDynamic && !isOne)
return std::nullopt;
- // The map is empty when the target type is a scalar.
- if (!reassociationMap.empty())
- reassociationMap.back().push_back(sourceDim);
+ assert(!reassociationMap.empty());
+ reassociationMap.back().push_back(sourceDim);
}
return reassociationMap;
}
>From 1c85a6875fecf52e6714b81bbe9bd2da81178c9e Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:12:34 +0000
Subject: [PATCH 3/8] [fixup] Add/expand unit tests
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Co-authored-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
mlir/test/Dialect/Tensor/canonicalize.mlir | 15 +++
mlir/unittests/Dialect/Utils/CMakeLists.txt | 1 +
.../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 125 ++++++++++++++++++
3 files changed, 141 insertions(+)
create mode 100644 mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 443f931745557..035ea850c9102 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1112,6 +1112,21 @@ func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %
// -----
+func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
+ : tensor<?x?x?xf32> into tensor<?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
+ : tensor<?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
+// CHECK: tensor.collapse_shape
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
+// CHECK: return %[[EXPAND]]
+
+// -----
+
func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
%c0 = arith.constant 0 : index
diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt
index 61b9cdcb3b8f3..e921c8bcfb4e5 100644
--- a/mlir/unittests/Dialect/Utils/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRDialectUtilsTests
StructuredOpsUtilsTest.cpp
+ ReshapeOpsUtilsTest.cpp
IndexingUtilsTest.cpp
)
mlir_target_link_libraries(MLIRDialectUtilsTests
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
new file mode 100644
index 0000000000000..bfcc70150e2ed
--- /dev/null
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -0,0 +1,125 @@
+//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+
+/// Helper to make constructing
+/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
+static std::optional<SmallVector<ReassociationIndices>>
+makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
+ return std::optional<SmallVector<ReassociationIndices>>(list);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTest) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
+ makeOptionalIndices({{0}, {1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
+ makeOptionalIndices({{0, 1}, {2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
+ std::nullopt);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1, 1}),
+ makeOptionalIndices({{0}, {1, 2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTest) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(
+ getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, ShapedType::kDynamic},
+ {1, ShapedType::kDynamic}),
+ makeOptionalIndices({{0}, {1, 2}}));
+
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, ShapedType::kDynamic},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+ {ShapedType::kDynamic, 20}),
+ makeOptionalIndices({{0, 1}, {2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
+ {ShapedType::kDynamic, 20}),
+ makeOptionalIndices({{0, 1}, {2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
+ makeOptionalIndices({{0, 1}, {2, 3, 4}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic, 20, ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}, {2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 1, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0}, {1, 2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+ {ShapedType::kDynamic, 10}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
+ std::nullopt);
+ EXPECT_EQ(
+ getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+}
>From 0fe986e217e52cc519823a999e83155f1cd2d3ef Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:15:08 +0000
Subject: [PATCH 4/8] [fixup] variable renaming
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 39 +++++++++++-----------
1 file changed, 20 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 1cd06a2757363..a6ee21d941e17 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -38,7 +38,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
SmallVector<ReassociationIndices, 4> reassociationMap;
reassociationMap.reserve(numTargetDims);
- unsigned sourceDim = 0, targetDim = 0;
+ unsigned sourceDimIdx = 0, targetDimIdx = 0;
// Source dimensions iteration logic for static target dimensions.
// FIXME: Instead of lambda-capturing this function's source shape index "in
// place", consider refactoring this into a separate function.
@@ -48,12 +48,12 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ReassociationIndices resultIndices;
int64_t prodOfCollapsedDims = 1;
bool reachedTargetDimSize = false;
- for (; sourceDim < numSourceDims; ++sourceDim) {
+ for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
// Source shape cannot be dynamic if the target dim is static.
- if (sourceShape[sourceDim] == ShapedType::kDynamic)
+ if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
return failure();
- prodOfCollapsedDims *= sourceShape[sourceDim];
- resultIndices.push_back(sourceDim);
+ prodOfCollapsedDims *= sourceShape[sourceDimIdx];
+ resultIndices.push_back(sourceDimIdx);
if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
return failure();
while (prodOfCollapsedDims > targetShape) {
@@ -64,7 +64,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
}
if (prodOfCollapsedDims == targetShape) {
reachedTargetDimSize = true;
- ++sourceDim;
+ ++sourceDimIdx;
break;
}
}
@@ -80,18 +80,18 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
ReassociationIndices resultIndices;
bool foundFirstDynamic = false;
- while (sourceDim < numSourceDims) {
- if (sourceShape[sourceDim] == ShapedType::kDynamic) {
+ while (sourceDimIdx < numSourceDims) {
+ if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
if (foundFirstDynamic && !mapConsecutiveDynDims)
break;
foundFirstDynamic |= true;
} else {
if (foundFirstDynamic)
break;
- else if (sourceShape[sourceDim] > 1 && !allowStaticNonOnes)
+ else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
return failure();
}
- resultIndices.push_back(sourceDim++);
+ resultIndices.push_back(sourceDimIdx++);
}
if (!foundFirstDynamic)
return failure();
@@ -99,10 +99,10 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
};
// Iterate over target shape.
bool wasLastDimDynamic = false;
- for (; targetDim < numTargetDims; ++targetDim) {
- int64_t currTargetShape = targetShape[targetDim];
+ for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
+ int64_t currTargetShape = targetShape[targetDimIdx];
if (currTargetShape != ShapedType::kDynamic) {
- unsigned sourceDimAtStart = sourceDim;
+ unsigned sourceDimAtStart = sourceDimIdx;
auto indices = collectSourceIndicesForStaticTargetDim(
currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
if (failed(indices))
@@ -118,8 +118,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
continue;
}
- bool isNextDimDynamic = targetDim + 1 < numTargetDims &&
- targetShape[targetDim + 1] == ShapedType::kDynamic;
+ bool isNextDimDynamic =
+ targetDimIdx + 1 < numTargetDims &&
+ targetShape[targetDimIdx + 1] == ShapedType::kDynamic;
auto indices = collectSourceIndicesForDynamicTargetDim(
/*allowStaticNonOnes=*/!wasLastDimDynamic,
/*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
@@ -130,9 +131,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
}
// Now that we've mapped all the target dimensions, process any remaining
// entries in the source shape explicitly.
- for (; sourceDim < numSourceDims; sourceDim++) {
- const bool isOne = sourceShape[sourceDim] == 1,
- isDynamic = sourceShape[sourceDim] == ShapedType::kDynamic;
+ for (; sourceDimIdx < numSourceDims; sourceDimIdx++) {
+ const bool isOne = sourceShape[sourceDimIdx] == 1,
+ isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic;
if (targetShape.empty()) {
if (!isOne && !isDynamic)
return std::nullopt;
@@ -145,7 +146,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
if (!wasLastDimDynamic && !isOne)
return std::nullopt;
assert(!reassociationMap.empty());
- reassociationMap.back().push_back(sourceDim);
+ reassociationMap.back().push_back(sourceDimIdx);
}
return reassociationMap;
}
>From e3aa2394225bb23cde429cc06bea28df7257070a Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:51:59 +0000
Subject: [PATCH 5/8] [fixup] Additional edge-case
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 10 ++++++++--
mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 9 +++++++++
2 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index a6ee21d941e17..8c19f20f446da 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -139,8 +139,14 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
return std::nullopt;
continue;
}
- if (wasLastDimDynamic && isDynamic)
- return std::nullopt;
+ // If the last 2 dimensions in the target were dynamic, the tail in the
+ // source shape cannot contain a dynamic value. E.g. ?x?->? is valid,
+ // however ?x?x10x?->?x? would be indeterminate.
+ if (wasLastDimDynamic && numTargetDims > 1 &&
+ targetShape[numTargetDims - 2] == ShapedType::kDynamic) {
+ if (isDynamic)
+ return std::nullopt;
+ }
// If the last target dimension is static, only source dimensions of 1 are
// acceptable.
if (!wasLastDimDynamic && !isOne)
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index bfcc70150e2ed..2564866fac493 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -92,6 +92,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0}, {1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{1, ShapedType::kDynamic, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
@@ -122,4 +126,9 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
{ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
+ ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
}
>From 16a932c8fa45f00e6474dd18bd8b7781a4b2fac8 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 20 May 2025 20:00:24 +0000
Subject: [PATCH 6/8] [WIP] Current tests pass
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 430 +++++++++++++-----
.../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 24 +
2 files changed, 337 insertions(+), 117 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 8c19f20f446da..25dd434fc2122 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,6 +10,10 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
#include <numeric>
#include <optional>
@@ -28,6 +32,257 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return std::nullopt;
}
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+ /// FIXME: Signed type is used for consistency with ReassociationIndices.
+ /// We should consider refactoring all reassociation utilities to use unsigned
+ /// types.
+ int64_t leftIdx = 0, rightIdx = 0;
+
+ /// Util for manual checks of the range's validity
+ LogicalResult verify() const {
+ return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+ }
+
+ /// Checks range's containment within another range. Treats the edges
+ /// non-exclusively.
+ bool isInRange(const ReassociationIndexRange &outerRange) const {
+ return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+ }
+
+ unsigned size() const {
+ assert(succeeded(verify()));
+ return rightIdx - leftIdx + 1;
+ }
+ bool containsSingleIndex() const { return size() == 1; }
+
+ void expandRight() { ++rightIdx; }
+ void shrinkLeft() { ++leftIdx; }
+
+ /// Implements arithmetic XOR semantics to get non-overlapping indices between
+ /// ranges.
+ ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
+ ReassociationIndices result;
+ result.reserve(size() + rhs.size() / 2); // Attempt to amortize
+ for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
+ if (idx < rhs.leftIdx || idx > rhs.rightIdx)
+ result.push_back(idx);
+ }
+ for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
+ if (rhsIndex < leftIdx || rhsIndex > rightIdx)
+ result.push_back(rhsIndex);
+ }
+ return result;
+ }
+
+ /// Converts the range into ReassociationIndices.
+ ReassociationIndices getFullIndices() const {
+ ReassociationIndices result;
+ for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+ result.push_back(idx);
+ }
+ return result;
+ }
+};
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+ int64_t sourceStartIdx,
+ bool matchGreedily = false) {
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+ const unsigned numSourceDims = sourceShape.size();
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+ if (!iterationRange.isInRange(sourceShapeAsRange))
+ return failure();
+ auto resultRange = iterationRange;
+
+ bool foundDynamic = false;
+ for (; iterationRange.isInRange(sourceShapeAsRange);
+ iterationRange.expandRight()) {
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+ if (foundDynamic && !matchGreedily)
+ break;
+ if (sourceSize == ShapedType::kDynamic)
+ foundDynamic = true;
+ resultRange = iterationRange;
+ }
+ if (!foundDynamic)
+ return failure();
+ return resultRange;
+}
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+ int64_t sourceStartIdx, int64_t targetSize,
+ bool matchGreedily = false) {
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+ const unsigned numSourceDims = sourceShape.size();
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+ if (!iterationRange.isInRange(sourceShapeAsRange))
+ return failure();
+ auto resultRange = iterationRange;
+
+ int64_t prodOfCollapsedDims = 1;
+ bool reachedTargetDimSize = false;
+ while (iterationRange.isInRange(sourceShapeAsRange)) {
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+ if (reachedTargetDimSize && !matchGreedily)
+ break;
+ if (sourceSize == ShapedType::kDynamic) {
+ if (reachedTargetDimSize)
+ break;
+ // Reassociation for a static dim cannot include a dynamic dim. Reset
+ // induction variables to essentially restart the loop from the next
+ // source dimension.
+ prodOfCollapsedDims = 1;
+ resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1};
+ iterationRange = resultRange;
+ continue;
+ }
+ prodOfCollapsedDims *= sourceSize;
+ if (prodOfCollapsedDims > targetSize && reachedTargetDimSize)
+ break;
+ // If the target size has been exceeded without matching, we need to shift
+ // the range start right. From the start of the range, roll back the
+ // multiplication until the target size exceeds the product again.
+ while (prodOfCollapsedDims > targetSize &&
+ !iterationRange.containsSingleIndex()) {
+ int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+ prodOfCollapsedDims /= frontSourceSize;
+ iterationRange.shrinkLeft();
+ }
+ resultRange = iterationRange;
+ // We could've reached the target size with the current dimension,
+ // also as a result of the above shift to right.
+ if (prodOfCollapsedDims == targetSize)
+ reachedTargetDimSize = true;
+ // Increment the iteration range
+ iterationRange.expandRight();
+ }
+ if (!reachedTargetDimSize)
+ return failure();
+ return resultRange;
+}
+
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape) {
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ assert(numSourceDims > numTargetDims);
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+
+ SmallVector<ReassociationIndexRange> reassocRanges;
+ reassocRanges.reserve(numTargetDims);
+ // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
+ // cases, e.g.:
+ // - ?x2x3x5 into ?x15
+ std::optional<int64_t> prevTargetSize = std::nullopt;
+ for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
+ targetDimIdx < numTargetDims; ++targetDimIdx) {
+ int64_t targetSize = targetShape[targetDimIdx];
+ std::optional<int64_t> nextTargetSize = std::nullopt;
+
+ // Simply check if there are any subsequent target dimensions left - if not,
+ // the match must be made greedily.
+ bool isLastTargetDim = targetDimIdx == numTargetDims - 1;
+ bool shouldMatchGreedily = isLastTargetDim;
+ FailureOr<ReassociationIndexRange> sourceRange;
+ if (targetSize == ShapedType::kDynamic) {
+ sourceRange = findReassociationRangeForDynamicDim(
+ sourceShape, sourceDimIdx, shouldMatchGreedily);
+ } else {
+ sourceRange = findReassociationRangeForSize(
+ sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
+ }
+
+ // Run sanity checks on the returned index range.
+ if (failed(sourceRange) || failed(sourceRange->verify()) ||
+ !sourceRange->isInRange(sourceShapeAsRange))
+ return failure();
+ if (sourceRange->leftIdx > sourceDimIdx) {
+ // If some source dimensions had to be skipped in order to find a match,
+ // they must be collapsed into the directly preceding dynamic dimension.
+ if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
+ return failure();
+ reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
+ }
+
+ // Store the gathered information as required for the next iteration.
+ prevTargetSize = targetSize;
+ sourceDimIdx = sourceRange->rightIdx + 1;
+ reassocRanges.emplace_back(std::move(*sourceRange));
+ }
+ // Fail if the source shape wasn't a full match for the target shape. We only
+ // need to check the last recorded index - any other gaps should have been
+ // mended by the main loop.
+ if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
+ return failure();
+ return reassocRanges;
+}
+
+/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
+/// the shapes right-to-left.
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape,
+ bool iterateRightToLeft) {
+ if (!iterateRightToLeft)
+ return findReassociationRangesForCollapse(sourceShape, targetShape);
+ // FIXME: It would be preferable to avoid the expensive copies. At the moment,
+ // this approach is chosen for readability of the main implementation.
+ auto sourceToReverse = sourceShape.vec(), targetToReverse = targetShape.vec();
+ std::reverse(sourceToReverse.begin(), sourceToReverse.end());
+ std::reverse(targetToReverse.begin(), targetToReverse.end());
+ auto invertedRanges =
+ findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
+ if (failed(invertedRanges))
+ return failure();
+ auto rangesToInvert = *invertedRanges;
+ unsigned numSourceDims = sourceShape.size();
+ // We have received the ranges for inverted shapes. Now we have to invert
+ // the ranges back to correspond with the original source shape.
+ for (auto &range : rangesToInvert) {
+ if (failed(range.verify()))
+ return failure();
+ int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
+ range.leftIdx = numSourceDims - 1 - invRightIdx;
+ range.rightIdx = numSourceDims - 1 - invLeftIdx;
+ }
+ // Also invert the ordering of the ranges to correspond with the original
+ // target shape.
+ std::reverse(rangesToInvert.begin(), rangesToInvert.end());
+ return rangesToInvert;
+}
+} // namespace
+
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
@@ -35,124 +290,65 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
numTargetDims = targetShape.size();
if (numSourceDims <= numTargetDims)
return std::nullopt;
- SmallVector<ReassociationIndices, 4> reassociationMap;
- reassociationMap.reserve(numTargetDims);
-
- unsigned sourceDimIdx = 0, targetDimIdx = 0;
- // Source dimensions iteration logic for static target dimensions.
- // FIXME: Instead of lambda-capturing this function's source shape index "in
- // place", consider refactoring this into a separate function.
- auto collectSourceIndicesForStaticTargetDim =
- [&](int64_t targetShape,
- bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
- ReassociationIndices resultIndices;
- int64_t prodOfCollapsedDims = 1;
- bool reachedTargetDimSize = false;
- for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
- // Source shape cannot be dynamic if the target dim is static.
- if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
- return failure();
- prodOfCollapsedDims *= sourceShape[sourceDimIdx];
- resultIndices.push_back(sourceDimIdx);
- if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
- return failure();
- while (prodOfCollapsedDims > targetShape) {
- assert(!resultIndices.empty());
- auto frontOffsetIdx = resultIndices.begin();
- prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
- resultIndices.erase(frontOffsetIdx);
- }
- if (prodOfCollapsedDims == targetShape) {
- reachedTargetDimSize = true;
- ++sourceDimIdx;
- break;
- }
- }
- if (!reachedTargetDimSize)
- return failure();
- return resultIndices;
- };
- // Source dimensions iteration logic for dynamic target dimensions.
- // FIXME: Instead of lambda-capturing this function's source shape index "in
- // place", consider refactoring this into a separate function.
- auto collectSourceIndicesForDynamicTargetDim =
- [&](bool allowStaticNonOnes,
- bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
- ReassociationIndices resultIndices;
- bool foundFirstDynamic = false;
- while (sourceDimIdx < numSourceDims) {
- if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
- if (foundFirstDynamic && !mapConsecutiveDynDims)
- break;
- foundFirstDynamic |= true;
- } else {
- if (foundFirstDynamic)
- break;
- else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
- return failure();
- }
- resultIndices.push_back(sourceDimIdx++);
- }
- if (!foundFirstDynamic)
- return failure();
- return resultIndices;
- };
- // Iterate over target shape.
- bool wasLastDimDynamic = false;
- for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
- int64_t currTargetShape = targetShape[targetDimIdx];
- if (currTargetShape != ShapedType::kDynamic) {
- unsigned sourceDimAtStart = sourceDimIdx;
- auto indices = collectSourceIndicesForStaticTargetDim(
- currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
- if (failed(indices))
+ // Early handling for scalar target types.
+ if (numTargetDims == 0) {
+ ReassociationIndices allSourceIndices(numSourceDims);
+ for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
+ ++sourceDimIdx) {
+ int64_t sourceSize = sourceShape[sourceDimIdx];
+ // All source dimensions must be unit or dynamic.
+ if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
return std::nullopt;
- if (wasLastDimDynamic) {
- assert(!reassociationMap.empty());
- auto &previousIndices = reassociationMap.back();
- for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart)
- previousIndices.push_back(sourceDimAtStart);
- }
- reassociationMap.push_back(*indices);
- wasLastDimDynamic = false;
- continue;
+ allSourceIndices.emplace_back(sourceDimIdx);
}
+ return SmallVector<ReassociationIndices>{allSourceIndices};
+ }
- bool isNextDimDynamic =
- targetDimIdx + 1 < numTargetDims &&
- targetShape[targetDimIdx + 1] == ShapedType::kDynamic;
- auto indices = collectSourceIndicesForDynamicTargetDim(
- /*allowStaticNonOnes=*/!wasLastDimDynamic,
- /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
- if (failed(indices))
+ // Collect source ranges by iterating over the target shape left-to-right.
+ auto maybeForwardRanges =
+ findReassociationRangesForCollapse(sourceShape, targetShape);
+ if (failed(maybeForwardRanges))
+ return std::nullopt;
+ auto &ranges = *maybeForwardRanges;
+ // Now do the same in reverse. We need to get another valid reassociation
+ // through some other strategy, and then compare the results in order to
+ // disambiguate mixed subshapes, such as:
+ // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
+ // This leads us to lose some of the reassociation opportunities that can only
+ // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
+ // backtracking, the algorithm will fail right-to-left. However, this is the
+ // best way to preserve correctness.
+ //
+ // NB: The reversed shapes must not be temporary as we're passing through an
+ // ArrayRef.
+ auto maybeReverseRanges = findReassociationRangesForCollapse(
+ sourceShape, targetShape, /*iterateRightToLeft=*/true);
+ if (failed(maybeReverseRanges))
+ return std::nullopt;
+ auto &reverseRanges = *maybeReverseRanges;
+
+ if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
+ return std::nullopt;
+ // Now we can check for ambiguity of each target dimension's reassociation. If
+ // successful, we put the full indices into our result map for the target
+ // shape.
+ SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
+ for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
+ ++targetDimIdx) {
+ auto &range = ranges[targetDimIdx];
+ auto &reverseRange = reverseRanges[targetDimIdx];
+ // Get non-overlapping indices between the ranges
+ ReassociationIndices nonMatchingIndices = range ^ reverseRange;
+ // The ranges should overlap, at the very least
+ if (nonMatchingIndices.size() == range.size() + reverseRange.size())
return std::nullopt;
- reassociationMap.push_back(*indices);
- wasLastDimDynamic = true;
- }
- // Now that we've mapped all the target dimensions, process any remaining
- // entries in the source shape explicitly.
- for (; sourceDimIdx < numSourceDims; sourceDimIdx++) {
- const bool isOne = sourceShape[sourceDimIdx] == 1,
- isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic;
- if (targetShape.empty()) {
- if (!isOne && !isDynamic)
- return std::nullopt;
- continue;
- }
- // If the last 2 dimensions in the target were dynamic, the tail in the
- // source shape cannot contain a dynamic value. E.g. ?x?->? is valid,
- // however ?x?x10x?->?x? would be indeterminate.
- if (wasLastDimDynamic && numTargetDims > 1 &&
- targetShape[numTargetDims - 2] == ShapedType::kDynamic) {
- if (isDynamic)
+ // Unit dimensions can be collapsed wherever - this is the only ambiguity
+ // that we allow.
+ for (int64_t sourceDimIdx : nonMatchingIndices) {
+ if (sourceShape[sourceDimIdx] != 1)
return std::nullopt;
}
- // If the last target dimension is static, only source dimensions of 1 are
- // acceptable.
- if (!wasLastDimDynamic && !isOne)
- return std::nullopt;
- assert(!reassociationMap.empty());
- reassociationMap.back().push_back(sourceDimIdx);
+ reassociationMap[targetDimIdx] = range.getFullIndices();
}
return reassociationMap;
}
@@ -379,11 +575,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index 2564866fac493..a179d91129edb 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -57,6 +57,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}, {2, 3, 4}}));
EXPECT_EQ(
getReassociationIndicesForCollapse(
{ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
@@ -76,6 +80,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
+ {ShapedType::kDynamic, 10}),
+ makeOptionalIndices({{0, 1, 2, 3}, {4}}));
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
{ShapedType::kDynamic, 20}),
makeOptionalIndices({{0, 1}, {2}}));
@@ -131,4 +139,20 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 10, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 12, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
+ std::nullopt);
}
>From dd36c47d7a6bb402497a3c4c1757f47928132e06 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 20 May 2025 22:26:12 +0000
Subject: [PATCH 7/8] [WIP] New tests
---
.../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index a179d91129edb..124c8ce86fc9c 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -155,4 +155,22 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
{ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
{ShapedType::kDynamic, 32, ShapedType::kDynamic}),
std::nullopt);
+
+ //===----------------------------------------------------------------------===//
+ // TODO: Reassociation for the following examples can be computed, but isn't
+ // supported by `getReassociationIndicesForCollapse`.
+ //===----------------------------------------------------------------------===//
+
+ // TODO: Fails because there's no backtracking when some source dimensions
+ // remain unmatched at either edge.
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
+ {ShapedType::kDynamic, 10}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
+ {1, ShapedType::kDynamic, 2}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
+ {2, ShapedType::kDynamic}),
+ std::nullopt);
}
>From 07ed33d4363d64ed32f85fe0b296ca39cc916124 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 20 May 2025 22:45:24 +0000
Subject: [PATCH 8/8] [fixup] Add scalar target tests & fix em
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 3 ++-
.../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 24 +++++++++++++++++++
2 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 25dd434fc2122..209577db3272f 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -292,7 +292,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
return std::nullopt;
// Early handling for scalar target types.
if (numTargetDims == 0) {
- ReassociationIndices allSourceIndices(numSourceDims);
+ ReassociationIndices allSourceIndices;
+ allSourceIndices.reserve(numSourceDims);
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
++sourceDimIdx) {
int64_t sourceSize = sourceShape[sourceDimIdx];
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index 124c8ce86fc9c..7abdf75c34cda 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "gtest/gtest.h"
#include <optional>
@@ -20,6 +21,29 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
return std::optional<SmallVector<ReassociationIndices>>(list);
}
+TEST(ReassociationIndicesForCollapse, ScalarTest) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
+ makeOptionalIndices({{0}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
+ makeOptionalIndices({{0}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
+ ShapedType::kDynamic, 1,
+ ShapedType::kDynamic},
+ {}),
+ makeOptionalIndices({{0, 1, 2, 3, 4}}));
+}
+
+TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
+ EXPECT_EQ(
+ getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
+ std::nullopt);
+}
+
TEST(ReassociationIndicesForCollapse, StaticTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
makeOptionalIndices({{0, 1}}));
More information about the Mlir-commits
mailing list