[Mlir-commits] [mlir] [MLIR][mesh] Mesh fixes (PR #124724)
Frank Schlimbach
llvmlistbot at llvm.org
Tue Feb 11 03:48:39 PST 2025
================
@@ -0,0 +1,99 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::mesh;
+
+namespace {
+
+// Sharding of arith.empty/arith.splat
+struct ConstantShardingInterface
+ : public ShardingInterface::ExternalModel<ConstantShardingInterface,
+ ConstantOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto ndims = 0;
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ndims = type.getRank();
+ }
+ return SmallVector<utils::IteratorType>(ndims,
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ return SmallVector<AffineMap>(1, {AffineMap::getMultiDimIdentityMap(
+ type.getRank(), op->getContext())});
+ }
+ return {};
+ }
+
+ // Indicate failure if no result sharding exists.
+ // Otherwise mirror result sharding if it is a tensor constant.
+ // Otherwise return replication option.
+ FailureOr<ShardingOption>
+ getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings) const {
+ if (!resultShardings[0]) {
+ return failure();
+ }
+ if (auto type = dyn_cast<RankedTensorType>(op->getResult(0).getType())) {
+ ShardingArray axesArray(resultShardings[0].getSplitAxes().size());
----------------
fschlimb wrote:
done
https://github.com/llvm/llvm-project/pull/124724
More information about the Mlir-commits
mailing list