[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Chengji Yao
llvmlistbot at llvm.org
Sat Oct 21 14:11:48 PDT 2023
================
@@ -0,0 +1,571 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+static FailureOr<MeshShardingAttr>
+getMeshShardingAttr(OpResult result, bool useOperandSharding) {
+ Value val = result.cast<Value>();
+ bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ if (!shardOp)
+ return false;
+ return !shardOp.getAnnotateForUsers();
+ });
+
+ if (anyShardedForDef) {
+ assert(val.hasOneUse() &&
+ "expected to has exact one use if it has a use of mesh.shard "
+ "without unit attr annotate_for_users");
----------------
yaochengji wrote:
Change to `return failure()` if `val.hasOneUse()`.
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list