[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Boian Petkantchin
llvmlistbot at llvm.org
Mon Oct 23 15:05:40 PDT 2023
================
@@ -0,0 +1,529 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+static FailureOr<MeshShardingAttr>
+getMeshShardingAttr(OpResult result, bool useOperandSharding) {
+ Value val = result.cast<Value>();
+ bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ if (!shardOp)
+ return false;
+ return !shardOp.getAnnotateForUsers();
+ });
+
+ if (anyShardedForDef) {
+ // expected to have exact one use if it has a use of `mesh.shard` without
+ // unit attr annotate_for_users
+ if (!val.hasOneUse())
+ return failure();
+ auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+ return shardOp.getShard();
+ } else if (useOperandSharding) {
+ bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ if (!shardOp)
+ return false;
+ return shardOp.getAnnotateForUsers();
+ });
+ if (anyShardedForUsers) {
+ SmallVector<ShardOp> shardOps;
+ for (Operation *user : val.getUsers()) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+ if (shardOp)
+ shardOps.push_back(shardOp);
+ }
+ MeshShardingAttr shardForDef = shardOps[0].getShard();
+ for (size_t i = 1; i < shardOps.size(); ++i) {
+ // TODO: Deduce a reasonable mesh sharding attr for def when they are
+ // different
+ assert(shardOps[i].getShard() == shardForDef &&
+ "only support all shard ops have the same mesh sharding attr");
+ }
+ return shardForDef;
+ }
+ }
+
+ return failure();
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation operand.
+static FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand) {
+ Value val = opOperand.get();
+ if (ShardOp shardOp = val.getDefiningOp<ShardOp>())
+ return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+
+ return failure();
+}
+
+static LogicalResult
+checkOperandAffineExprRecursively(AffineExpr expr,
+ SmallVectorImpl<bool> &seenIds) {
+ switch (expr.getKind()) {
+ case AffineExprKind::Add: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineExpr lhs = binOpExpr.getLHS();
+ AffineExpr rhs = binOpExpr.getRHS();
+ if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
+ return failure();
+ if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
+ return failure();
+ return success();
+ }
+ case AffineExprKind::Mul: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineExpr lhs = binOpExpr.getLHS();
+ AffineExpr rhs = binOpExpr.getRHS();
+ AffineExpr dimExpr;
+ if (lhs.getKind() == AffineExprKind::DimId) {
+ dimExpr = lhs;
+ if (rhs.getKind() != AffineExprKind::Constant)
+ return failure();
+ } else if (rhs.getKind() == AffineExprKind::DimId &&
+ lhs.getKind() == AffineExprKind::Constant) {
+ dimExpr = rhs;
+ } else
+ return failure();
+ unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+ if ((size_t)position >= seenIds.size() || seenIds[position])
+ return failure();
+ seenIds[position] = true;
+ return success();
+ }
+ case AffineExprKind::DimId: {
+ unsigned position = expr.cast<AffineDimExpr>().getPosition();
+ if ((size_t)position >= seenIds.size() || seenIds[position])
+ return failure();
+ seenIds[position] = true;
+ return success();
+ }
+ default:
+ return failure();
+ }
+}
+
+static FailureOr<llvm::SmallSet<unsigned, 2>>
+checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
+ SmallVector<bool> seenIds(numDims, false);
+ if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
+ return failure();
+
+ llvm::SmallSet<unsigned, 2> positions;
+ for (auto it : llvm::enumerate(seenIds)) {
+ if (it.value())
+ positions.insert((unsigned)it.index());
+ }
+ return positions;
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
+LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
----------------
sogartar wrote:
Is the intent to use this outside function outside of `mesh::detail::defaultGetShardingOption`?
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list