[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Boian Petkantchin
llvmlistbot at llvm.org
Mon Oct 23 15:05:40 PDT 2023
================
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
+ /*args=*/(ins
+ "OpBuilder &":$b
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingOption(
+ $_op.getOperation(), b);
+ }]
+ >,
+ InterfaceMethod<
----------------
sogartar wrote:
If you have more complicated indexing maps you would probably need a custom implementation.
For example for [linalg.conv_2d_nhwc_hwcf](https://github.com/llvm/llvm-project/blob/b997ff41c11cc69cfcb6c8a3ed39ed47229cf891/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L2295).
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list