[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Boian Petkantchin
llvmlistbot at llvm.org
Mon Oct 23 15:05:40 PDT 2023
================
@@ -0,0 +1,99 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ The iterator types determine how the operation traverses its input and
+ output tensors.
+
+ Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
+ types are parallel, parallel, reduction-sum. This indicates that M and
+ N are traversed in parallel, while the K dimension is used for
+ reduction.
+
+ Example 2: A softmax op's loop iterator types are parallel and
+ invalid. The second dimension is considered as invalid because it is
+ neither parallel nor any kind of reduction.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ Indexing maps determine how indices in the iteration space map to
+ tensor indices. They are specified using `affine_map` in MLIR, which
+ provides an affine transformation of indices.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
----------------
sogartar wrote:
Isn't this method more related to the auto-sharding algorithm than it is to how an operation is sharded. It is very particular to how sharding propagation functions.
Shouldn't there be a more general method that enumerates what are the possible ways the operation can be shardings? Then the sharding propagation could chose based on the existing sharding for some of the operands/results.
`ShardingOption` seems to describe only one way an operation is sharded. Shouldn't it be like a collection? Maybe not explicitly a collection as its size may be very large.
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list