[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)

Frank Schlimbach llvmlistbot at llvm.org
Tue Feb 11 03:00:24 PST 2025


================
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.empty/arith.splat
+struct ConstantShardingInterface
+    : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+                                              ConstantOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    auto ndims = 0;
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ndims = type.getRank();
+    }
+    return SmallVector<utils::IteratorType>(ndims,
+                                            utils::IteratorType::parallel);
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+                                           type.getRank(), op->getContext())});
+    }
+    return {};
+  }
+
+  // Indicate failure if no result sharding exists.
+  // Otherwise mirror result sharding if it is a tensor constant.
+  // Otherwise return replication option.
+  FailureOr<ShardingOption>
+  getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+                    ArrayRef<MeshSharding> resultShardings) const {
+    if (!resultShardings[0]) {
+      return failure();
+    }
+    if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+      ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
----------------
fschlimb wrote:

Yes, a `arith.constant` has a single result. Generally, the `resultShardings` provide a `MeshSharding` for each result of the operation, which is why the interface specifies an `ArrayRef` for those.
Extra shardings are simply ignored. We can of course add a check, because in most cases this would probably indicate an error of some kind.

https://github.com/llvm/llvm-project/pull/124724


More information about the Mlir-commits mailing list