[Mlir-commits] [mlir] [mlir][mesh] Handling changed halo region sizes during spmdization (PR #114238)
Matteo Franciolini
llvmlistbot at llvm.org
Thu Oct 31 14:29:07 PDT 2024
================
@@ -192,33 +192,33 @@ template <typename InShape, typename MeshShape, typename SplitAxes,
typename OutShape>
static void shardShape(const InShape &inShape, const MeshShape &meshShape,
const SplitAxes &splitAxes, OutShape &outShape,
- ArrayRef<int64_t> shardedDimsSizes = {},
+ ArrayRef<int64_t> shardedDimsOffsets = {},
ArrayRef<int64_t> haloSizes = {}) {
std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
llvm::adl_begin(outShape));
- if (!shardedDimsSizes.empty()) {
+ if (!shardedDimsOffsets.empty()) {
+ uint64_t pos = 0;
for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
- if (innerSplitAxes.empty()) {
-#ifndef NDEBUG
- for (auto dimSz : shardedDimsSizes) {
- auto inAxis = dimSz % inShape.size();
- assert(inShape[inAxis] == dimSz || dimSz == ShapedType::kDynamic ||
- inShape[inAxis] == ShapedType::kDynamic);
- }
-#endif // NDEBUG
- } else {
- // find sharded dims in sharded_dims_sizes with same static size on
- // all devices. Use kDynamic for dimensions with dynamic or non-uniform
- // sizes in sharded_dims_sizes.
- auto sz = shardedDimsSizes[tensorAxis];
- bool same = true;
- for (size_t i = tensorAxis + inShape.size();
- i < shardedDimsSizes.size(); i += inShape.size()) {
- if (shardedDimsSizes[i] != sz) {
- same = false;
- break;
+ if (!innerSplitAxes.empty()) {
+ auto sz = shardedDimsOffsets[pos];
+ bool same = !ShapedType::isDynamicShape(meshShape);
+ if (same) {
+ // find sharded dims in shardedDimsOffsets with same static size on
----------------
mfrancio wrote:
```suggestion
// Find sharded dims in shardedDimsOffsets with same static size on
```
https://github.com/llvm/llvm-project/pull/114238
More information about the Mlir-commits
mailing list