[llvm-branch-commits] [mlir] fbd2926 - Revert "Revert "[mlir][tensor] Loosen restrictions on folding dynamic reshape…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 4 11:06:49 PDT 2025
Author: Ian Wood
Date: 2025-06-04T11:06:46-07:00
New Revision: fbd2926fb3c197c6d5dfd9502bff0d2a5e77749a
URL: https://github.com/llvm/llvm-project/commit/fbd2926fb3c197c6d5dfd9502bff0d2a5e77749a
DIFF: https://github.com/llvm/llvm-project/commit/fbd2926fb3c197c6d5dfd9502bff0d2a5e77749a.diff
LOG: Revert "Revert "[mlir][tensor] Loosen restrictions on folding dynamic reshape…"
This reverts commit f5a2f00da9b741f4f2fe925a434f608aa217cee2.
Added:
mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
Modified:
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
mlir/unittests/Dialect/Utils/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 1a04d702e0559..3b1fdb69e8ef1 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,6 +10,10 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
#include <numeric>
#include <optional>
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return std::nullopt;
}
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
- return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+ /// FIXME: Signed type is used for consistency with ReassociationIndices.
+ /// We should consider refactoring all reassociation utilities to use unsigned
+ /// types.
+ int64_t leftIdx = 0, rightIdx = 0;
+
+ /// Util for manual checks of the range's validity
+ LogicalResult verify() const {
+ return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+ }
+
+ /// Checks range's containment within another range. Treats the edges
+ /// non-exclusively.
+ bool isInRange(const ReassociationIndexRange &outerRange) const {
+ return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+ }
+
+ unsigned size() const {
+ assert(succeeded(verify()));
+ return rightIdx - leftIdx + 1;
+ }
+ bool containsSingleIndex() const { return size() == 1; }
+
+ /// Collects indices that do not overlap between this and another range.
+ ReassociationIndices
+ getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
+ if (rightIdx < rhs.leftIdx) {
+ // The intervals do not overlap - concatenate the indices from both.
+ auto jointFullIndices = getFullIndices();
+ jointFullIndices.append(rhs.getFullIndices());
+ return jointFullIndices;
+ }
+ ReassociationIndices result;
+ // Handle the chunk left of the overlapping range.
+ int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
+ int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
+ llvm::append_range(result, llvm::seq(leftStart, leftEnd));
+ // Handle the chunk right of the overlapping range. Symmetrically, we should
+ // skip the edge of the overlap AND include the rightmost index.
+ int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
+ int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
+ if (rightStart < rightEnd)
+ llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
+ return result;
+ }
+
+ /// Converts the range into ReassociationIndices.
+ ReassociationIndices getFullIndices() const {
+ ReassociationIndices result;
+ for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+ result.push_back(idx);
+ }
+ return result;
+ }
+};
+} // namespace
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+static FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+ int64_t sourceStartIdx,
+ bool matchGreedily = false) {
+ const unsigned numSourceDims = sourceShape.size();
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+ std::optional<ReassociationIndexRange> resultRange = std::nullopt;
+
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+ for (; iterationRange.isInRange(sourceShapeAsRange);
+ iterationRange.rightIdx++) {
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+ if (sourceSize == ShapedType::kDynamic) {
+ resultRange = iterationRange;
+ break;
+ }
+ }
+ if (!resultRange)
+ return failure();
+ if (matchGreedily)
+ resultRange->rightIdx = sourceShapeAsRange.rightIdx;
+ return *resultRange;
+}
- ReassociationIndices currIndices;
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+static FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+ int64_t sourceStartIdx, int64_t targetSize,
+ bool matchGreedily = false) {
+ const unsigned numSourceDims = sourceShape.size();
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+ std::optional<ReassociationIndexRange> resultRange = std::nullopt;
+
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
+ while (iterationRange.isInRange(sourceShapeAsRange)) {
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+ if (sourceSize == ShapedType::kDynamic) {
+ // Reassociation for a static dim cannot include a dynamic dim. Reset
+ // induction variables to essentially restart the loop from the next
+ // source dimension.
+ prodOfCollapsedDims = 1;
+ iterationRange = {iterationRange.rightIdx + 1,
+ iterationRange.rightIdx + 1};
+ continue;
+ }
+ prodOfCollapsedDims *= sourceSize;
+ // If the target size has been exceeded without matching, we need to shift
+ // the range start right. From the start of the range, roll back the
+ // multiplication until the target size exceeds the product again.
+ while (prodOfCollapsedDims > targetSize &&
+ !iterationRange.containsSingleIndex()) {
+ int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+ prodOfCollapsedDims /= frontSourceSize;
+ // Shrink the range rightwards
+ iterationRange.leftIdx++;
+ }
+ // We could've reached the target size with the current dimension,
+ // also as a result of the above shift to right.
+ if (prodOfCollapsedDims == targetSize) {
+ resultRange = iterationRange;
break;
+ }
+ // Increment the iteration range
+ iterationRange.rightIdx++;
+ }
+ if (!resultRange)
+ return failure();
+ if (matchGreedily) {
+ // We now want to collect all unit dimensions directly after the target
+ // product match. Advance the iterator to avoid OOB when the product match
+ // happens at the last element.
+ iterationRange.rightIdx++;
+ while (iterationRange.isInRange(sourceShapeAsRange) &&
+ sourceShape[iterationRange.rightIdx] == 1) {
+ resultRange = iterationRange;
+ iterationRange.rightIdx++;
+ }
+ }
+ return *resultRange;
+}
- int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
- prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+static FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape) {
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ assert(numSourceDims > numTargetDims);
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+
+ SmallVector<ReassociationIndexRange> reassocRanges;
+ reassocRanges.reserve(numTargetDims);
+ // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
+ // cases, e.g.:
+ // - ?x2x3x5 into ?x15
+ std::optional<int64_t> prevTargetSize = std::nullopt;
+ for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
+ targetDimIdx < numTargetDims; ++targetDimIdx) {
+ int64_t targetSize = targetShape[targetDimIdx];
+ // Simply check if there are any subsequent target dimensions left - if not,
+ // the match must be made greedily.
+ bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
+ FailureOr<ReassociationIndexRange> sourceRange;
+ if (targetSize == ShapedType::kDynamic) {
+ sourceRange = findReassociationRangeForDynamicDim(
+ sourceShape, sourceDimIdx, shouldMatchGreedily);
+ } else {
+ sourceRange = findReassociationRangeForSize(
+ sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
}
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
- return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
- return std::nullopt;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return std::nullopt;
-
- currIndices.push_back(sourceDim++);
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
+ // Run sanity checks on the returned index range.
+ if (failed(sourceRange) || failed(sourceRange->verify()) ||
+ !sourceRange->isInRange(sourceShapeAsRange))
+ return failure();
+ if (sourceRange->leftIdx > sourceDimIdx) {
+ // If some source dimensions had to be skipped in order to find a match,
+ // they must be collapsed into the directly preceding dynamic dimension.
+ if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
+ return failure();
+ reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
+ }
+
+ // Store the gathered information as required for the next iteration.
+ prevTargetSize = targetSize;
+ sourceDimIdx = sourceRange->rightIdx + 1;
+ reassocRanges.push_back(*sourceRange);
}
- // All the dimensions in the target must have been processed.
- if (reassociationMap.size() != targetShape.size())
+ // Fail if the source shape wasn't a full match for the target shape. We only
+ // need to check the last recorded index - any other gaps should have been
+ // mended by the main loop.
+ if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
+ return failure();
+ return reassocRanges;
+}
+
+/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
+/// the shapes right-to-left.
+static FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape,
+ bool iterateRightToLeft) {
+ if (!iterateRightToLeft)
+ return findReassociationRangesForCollapse(sourceShape, targetShape);
+ // NB: To iterate right-to-left, we currently reverse the shapes and then
+ // reverse the result back. The reversed shapes must not be temporary, as
+ // we're passing through an ArrayRef.
+ // FIXME: It would be preferable to avoid the expensive copies. At the moment,
+ // this approach is chosen for readability of the main implementation.
+ std::vector<int64_t> sourceToReverse = sourceShape.vec(),
+ targetToReverse = targetShape.vec();
+ std::reverse(sourceToReverse.begin(), sourceToReverse.end());
+ std::reverse(targetToReverse.begin(), targetToReverse.end());
+ auto invertedRanges =
+ findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
+ if (failed(invertedRanges))
+ return failure();
+ SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
+ unsigned numSourceDims = sourceShape.size();
+ // We have received the ranges for inverted shapes. Now we have to invert
+ // the ranges back to correspond with the original source shape.
+ for (auto &range : rangesToInvert) {
+ int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
+ range.leftIdx = numSourceDims - 1 - invRightIdx;
+ range.rightIdx = numSourceDims - 1 - invLeftIdx;
+ }
+ // Also invert the ordering of the ranges to correspond with the original
+ // target shape.
+ std::reverse(rangesToInvert.begin(), rangesToInvert.end());
+ return rangesToInvert;
+}
+
+std::optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape) {
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ // We're supposed to search for a collapsing reassociation. If the sizes
+ // match, there's no actual collapsing taking place - it's either a no-op or a
+ // `tensor.reshape`-style reassociation (that would be beyond the scope of
+ // this utility).
+ if (numSourceDims <= numTargetDims)
+ return std::nullopt;
+ // Early handling for scalar target types.
+ if (numTargetDims == 0) {
+ ReassociationIndices allSourceIndices;
+ allSourceIndices.reserve(numSourceDims);
+ for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
+ ++sourceDimIdx) {
+ int64_t sourceSize = sourceShape[sourceDimIdx];
+ // All source dimensions must be unit or dynamic.
+ if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
+ return std::nullopt;
+ allSourceIndices.push_back(sourceDimIdx);
+ }
+ return SmallVector<ReassociationIndices>{allSourceIndices};
+ }
+
+ // Collect source ranges by iterating over the target shape left-to-right.
+ FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
+ findReassociationRangesForCollapse(sourceShape, targetShape);
+ if (failed(maybeForwardRanges))
+ return std::nullopt;
+ auto &ranges = *maybeForwardRanges;
+ // Now do the same in reverse. We need to get another valid reassociation
+ // through some other strategy, and then compare the results in order to
+ // disambiguate mixed subshapes, such as:
+ // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
+ // This leads us to lose some of the reassociation opportunities that can only
+ // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
+ // backtracking, the algorithm will fail right-to-left. However, this is the
+ // best way to preserve correctness.
+ FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
+ findReassociationRangesForCollapse(sourceShape, targetShape,
+ /*iterateRightToLeft=*/true);
+ if (failed(maybeReverseRanges))
+ return std::nullopt;
+ auto &reverseRanges = *maybeReverseRanges;
+
+ if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
return std::nullopt;
- // Process any remaining entries in the source shape. They all need to be
- // 1 or dynamic.
- for (; sourceDim < sourceShape.size(); sourceDim++) {
- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
- sourceShape[sourceDim] != 1)
- return std::nullopt;
- // The map is empty when the target type is a scalar.
- if (!reassociationMap.empty())
- reassociationMap.back().push_back(sourceDim);
+ // Now we can check for ambiguity of each target dimension's reassociation. If
+ // successful, we put the full indices into our result map for the target
+ // shape.
+ SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
+ for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
+ ++targetDimIdx) {
+ ReassociationIndexRange &range = ranges[targetDimIdx];
+ ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
+ // Get non-overlapping indices between the ranges
+ ReassociationIndices nonMatchingIndices =
+ range.getNonOverlappingIndicesWith(reverseRange);
+ // Unit dimensions can be collapsed wherever - this is the only ambiguity
+ // that we allow.
+ for (int64_t sourceDimIdx : nonMatchingIndices) {
+ if (sourceShape[sourceDimIdx] != 1)
+ return std::nullopt;
+ }
+ reassociationMap[targetDimIdx] = range.getFullIndices();
}
return reassociationMap;
}
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT: tensor.collapse
-// CHECK: linalg.unpack
+// CHECK: tensor.collapse
+// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0abec7e01d184..646b2197d9aa6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1117,7 +1117,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
// -----
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1125,12 +1125,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape
// -----
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+ : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+// CHECK-NOT: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+// CHECK-NEXT: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1138,7 +1154,22 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
+// CHECK: tensor.collapse_shape
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
+// CHECK: return %[[EXPAND]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
+ : tensor<?x?x?xf32> into tensor<?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
+ : tensor<?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt
index 61b9cdcb3b8f3..e921c8bcfb4e5 100644
--- a/mlir/unittests/Dialect/Utils/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRDialectUtilsTests
StructuredOpsUtilsTest.cpp
+ ReshapeOpsUtilsTest.cpp
IndexingUtilsTest.cpp
)
mlir_target_link_libraries(MLIRDialectUtilsTests
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
new file mode 100644
index 0000000000000..db1a87a4de2d5
--- /dev/null
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -0,0 +1,203 @@
+//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+
+/// Helper to make constructing
+/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
+static std::optional<SmallVector<ReassociationIndices>>
+makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
+ return std::optional<SmallVector<ReassociationIndices>>(list);
+}
+
+TEST(ReassociationIndicesForCollapse, ScalarTest) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
+ makeOptionalIndices({{0}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
+ makeOptionalIndices({{0}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
+ ShapedType::kDynamic, 1,
+ ShapedType::kDynamic},
+ {}),
+ makeOptionalIndices({{0, 1, 2, 3, 4}}));
+}
+
+TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
+ EXPECT_EQ(
+ getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
+ std::nullopt);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTest) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
+ makeOptionalIndices({{0}, {1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
+ makeOptionalIndices({{0, 1}, {2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
+ // No-op reassociation
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
+ std::nullopt);
+ // Invalid static reassociations
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
+ std::nullopt);
+ // Non-collapsing (expanding) reassociation
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
+ std::nullopt);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}),
+ makeOptionalIndices({{0}, {1}, {2, 3}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTest) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}, {2, 3, 4}}));
+ EXPECT_EQ(
+ getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, ShapedType::kDynamic},
+ {1, ShapedType::kDynamic}),
+ makeOptionalIndices({{0}, {1, 2}}));
+
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, ShapedType::kDynamic},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
+ {ShapedType::kDynamic, 10}),
+ makeOptionalIndices({{0, 1, 2, 3}, {4}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+ {ShapedType::kDynamic, 20}),
+ makeOptionalIndices({{0, 1}, {2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
+ {ShapedType::kDynamic, 20}),
+ makeOptionalIndices({{0, 1}, {2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
+ makeOptionalIndices({{0, 1}, {2, 3, 4}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic, 20, ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic, 1},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0}, {1, 2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {1, ShapedType::kDynamic, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0, 1}, {2}}));
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 1, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ makeOptionalIndices({{0}, {1, 2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
+ EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+ {ShapedType::kDynamic, 10}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
+ std::nullopt);
+ EXPECT_EQ(
+ getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
+ ShapedType::kDynamic},
+ {ShapedType::kDynamic, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 10, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 12, ShapedType::kDynamic}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
+ {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
+ std::nullopt);
+
+ //===----------------------------------------------------------------------===//
+ // TODO: Reassociation for the following examples can be computed, but isn't
+ // supported by `getReassociationIndicesForCollapse`.
+ //===----------------------------------------------------------------------===//
+
+ // TODO: Fails because there's no backtracking when some source dimensions
+ // remain unmatched at either edge.
+ EXPECT_EQ(getReassociationIndicesForCollapse(
+ {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
+ {ShapedType::kDynamic, 10}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
+ {1, ShapedType::kDynamic, 2}),
+ std::nullopt);
+ EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
+ {2, ShapedType::kDynamic}),
+ std::nullopt);
+}
More information about the llvm-branch-commits
mailing list