[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Mehdi Amini
llvmlistbot at llvm.org
Thu Oct 19 22:25:19 PDT 2023
================
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+ // An array of int array. The sub-array at the i-th position signifies the
+ // mesh axes the i-th loop will be sharded on.
+ ShardingArray shardingArray;
+ SymbolRefAttr cluster;
+ // `empty` is true indicates that no sharding infomation can be inferred at
+ // present. Note that it is different from that an operation is not sharded.
+ bool empty = false;
+ ShardingOption() = default;
+ ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+ : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
----------------
joker-eph wrote:
I have some concerns about exposing string-based APIs, we should shield this behind the interface.
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list