[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Chengji Yao
llvmlistbot at llvm.org
Sat Oct 21 14:16:33 PDT 2023
================
@@ -0,0 +1,104 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ The iterator types determine how the operation tranverses its input
+ and output tensors.
+
+ Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
+ types are paralle, parallel, reduction-sum. This indicates that M and
+ N are traversed in parallel, while the K dimension is used for
+ reduction.
+
+ Example 2: A softmax op's loop iterator types are parallel and
+ invalid. The second dimension is considered as invalid because it is
+ neigher parallel nor any kind of reduction.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ Indexing maps determine how indices in the iteration space map to
+ tensor indices. They are specified using `affine_map` in MLIR, which
+ provides an affine transformation of indices.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingOption(
+ $_op.getOperation());
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Based on a given ShardingOption, this method adds `mesh.shard`
+ operations for the operands and results that previously lacked
+ sharding annotations.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"addShardingAnnotations",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "const ShardingOption &":$shardingOption
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultAddShardingAnnotations(
+ $_op.getOperation(), b, shardingOption);
+ }]
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ LogicalResult verifyShardingInterfaceImpl();
+
+ void printLoopTypesAndIndexingMaps(raw_ostream &os);
+
+ FailureOr<ShardingOption> getShardingOptionFromAttr();
+
+ void setShardingOptionAttr(Builder &b, const ShardingOption& option);
----------------
yaochengji wrote:
I removed the logic of adding `ShardingOption` as an attribute first. Because in my second thought, this information might be more suitable in an analysis.
BTW, the reason they were in `extraClassDeclaration` not `InterfaceMethod` is because they are not supposed to be customized for different operations.
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list