[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
Chengji Yao
llvmlistbot at llvm.org
Thu Oct 19 17:41:37 PDT 2023
https://github.com/yaochengji created https://github.com/llvm/llvm-project/pull/69665
- add `ShardingInterface` and the methods' default implementation
- add `ShardingInterface` implementation for element-wise and matmul ops in TOSA dialect
- add sharding propagation pass
>From 25d68a6caf9c32747ab011b6d955d914573c2f55 Mon Sep 17 00:00:00 2001
From: Chengji Yao <yaochengji at hotmail.com>
Date: Wed, 18 Oct 2023 18:54:49 +0000
Subject: [PATCH] [MLIR][Mesh] Add sharding propagation pass
---
mlir/include/mlir/Dialect/Mesh/CMakeLists.txt | 2 +
mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 34 ++
mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h | 22 +
.../Dialect/Mesh/Interfaces/CMakeLists.txt | 4 +
.../Mesh/Interfaces/ShardingInterface.h | 58 ++
.../Mesh/Interfaces/ShardingInterface.td | 87 +++
.../Dialect/Mesh/Transforms/CMakeLists.txt | 6 +
.../mlir/Dialect/Mesh/Transforms/Passes.h | 41 ++
.../mlir/Dialect/Mesh/Transforms/Passes.td | 33 ++
.../Dialect/Tosa/IR/ShardingInterfaceImpl.h | 23 +
.../mlir/Dialect/Utils/IndexingUtils.h | 3 +
mlir/include/mlir/IR/AffineMap.h | 12 +
mlir/include/mlir/IR/Builders.h | 1 +
mlir/include/mlir/InitAllDialects.h | 2 +
mlir/include/mlir/InitAllPasses.h | 2 +
mlir/lib/Dialect/Mesh/CMakeLists.txt | 2 +
mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 32 +-
.../Dialect/Mesh/Interfaces/CMakeLists.txt | 15 +
.../Mesh/Interfaces/ShardingInterface.cpp | 544 ++++++++++++++++++
.../Dialect/Mesh/Transforms/CMakeLists.txt | 17 +
.../Mesh/Transforms/ShardingPropagation.cpp | 155 +++++
mlir/lib/Dialect/Tosa/CMakeLists.txt | 14 +
.../Dialect/Tosa/IR/ShardingInterfaceImpl.cpp | 111 ++++
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 13 +
mlir/lib/IR/AffineMap.cpp | 13 +
mlir/lib/IR/Builders.cpp | 9 +
.../Dialect/Mesh/sharding-propagation.mlir | 167 ++++++
27 files changed, 1421 insertions(+), 1 deletion(-)
create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
create mode 100644 mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
create mode 100644 mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
create mode 100644 mlir/test/Dialect/Mesh/sharding-propagation.mlir
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..b6623ed818f0770 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+ I32EnumAttrCase<"Parallel", 1, "parallel">,
+ I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+ I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+ I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+ I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+ I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::mesh";
+}
+
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
$partial_axes^ `]`)? `>`
}];
+ let builders = [
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes,
+ "ArrayRef<int32_t>": $partial_axes,
+ "mesh::Partial": $partial_type), [{
+ SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
+ llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+ return DenseI32ArrayAttr::get($_ctxt, array);
+ }));
+ return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+ partial_type);
+ }]>,
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+ return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+ }]>
+ ];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..cb86887091330c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,26 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+ for (int64_t i = array.size() - 1; i >= 0; i--) {
+ if (array[i].empty())
+ array.pop_back();
+ else
+ break;
+ }
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..1d19e41ac1fc555
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+ // An array of int array. The sub-array at the i-th position signifies the
+ // mesh axes the i-th loop will be sharded on.
+ ShardingArray shardingArray;
+ SymbolRefAttr cluster;
+ // `empty` is true indicates that no sharding infomation can be inferred at
+ // present. Note that it is different from that an operation is not sharded.
+ bool empty = false;
+ ShardingOption() = default;
+ ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+ : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+namespace detail {
+
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+ const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..c98b9f081492997
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
+ /*args=*/(ins
+ "OpBuilder &":$b
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingOption(
+ $_op.getOperation(), b);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Based on a given ShardingOption, this method adds `mesh.shard`
+ operations for the operands and results that previously lacked
+ sharding annotations.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"addShardingAnnotations",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "const ShardingOption &":$shardingOption
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultAddShardingAnnotations(
+ $_op.getOperation(), b, shardingOption);
+ }]
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ LogicalResult verifyShardingInterfaceImpl();
+
+ void printLoopTypesAndIndexingMaps(raw_ostream &os);
+ }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..aa3555f7f186f24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..d36adfe476a72ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,33 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+ let summary = "sharding propagation";
+ let description = [{
+ Propagates sharding information throughout the graph. After this pass, each
+ of the operations' operands and results is annotated with a `mesh.shard`
+ operation, and the operations themselves are added with sharding option
+ attributes.
+ }];
+ let constructor = "mlir::mesh::createShardingPropagationPass()";
+ let dependentDialects = [
+ "mesh::MeshDialect"
+ ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548ed..b24164cfb552b4f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
+/// Helper to return a vector of sub-vector of int64_t
+SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
+
/// Compute linear index from provided strides and indices, assuming strided
/// layout.
/// Returns AffineExpr and list of values to apply to it, e.g.:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 3430db2b99c3f2e..18e2313ef2b446b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -101,6 +101,18 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ /// Returns an affine map with `numDims` input dimensions and results
+ /// specified by `targets`.
+ ///
+ /// Examples:
+ /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+ /// * getMultiDimMapWithTargets(3, [2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d2, d1)>
+ static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+ ArrayRef<int64_t> targets,
+ MLIRContext *context);
+
/// Returns a vector of AffineMaps; each with as many results as
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
/// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5e54d4ea49e8251..3988835622b7629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -168,6 +168,7 @@ class Builder {
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
ArrayAttr getTypeArrayAttr(TypeRange values);
+ ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..3556f82023828b2 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..27711417ed91a8c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -29,6 +29,7 @@
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -73,6 +74,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
+ mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+ return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+ return (partial == Partial::Generic &&
+ iType == IteratorType::ReductionGeneric) ||
+ (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+ (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+ (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+ switch (iType) {
+ case IteratorType::ReductionGeneric:
+ return Partial::Generic;
+ case IteratorType::ReductionSum:
+ return Partial::Sum;
+ case IteratorType::ReductionMax:
+ return Partial::Max;
+ case IteratorType::ReductionMin:
+ return Partial::Min;
+ default:
+ assert(0 && "No corresponding partial type can be found");
+ }
+}
+
//===----------------------------------------------------------------------===//
// mesh.cluster op
//===----------------------------------------------------------------------===//
@@ -95,7 +126,6 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}
if (failed(checkMeshAxis(partialAxes)))
return failure();
-
return success();
}
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..1010756f1fe279a
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_library(MLIRShardingInterface
+ ShardingInterface.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+ DEPENDS
+ MLIRShardingInterfaceIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRMeshDialect
+ MLIRSupport
+)
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
new file mode 100644
index 000000000000000..7d5c73851bb1f4e
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -0,0 +1,544 @@
+//===- ShardingInterface.cpp -------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Support/Debug.h"
+
+#include <algorithm>
+#include <utility>
+
+#define DEBUG_TYPE "sharding-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// common util functions
+//===----------------------------------------------------------------------===//
+
+static FailureOr<ShardingOption> getShardingOptionFromAttr(Operation *op) {
+ auto arrayAttr = op->getAttrOfType<ArrayAttr>(getShardingArrayName());
+ if (!arrayAttr)
+ return failure();
+ auto symbolRefAttr = op->getAttrOfType<SymbolRefAttr>(getMeshClusterName());
+ if (!symbolRefAttr)
+ return failure();
+ return ShardingOption(getArrayOfI32Array(arrayAttr), symbolRefAttr);
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation result.
+static FailureOr<MeshShardingAttr>
+getMeshShardingAttr(OpResult result, bool useOperandSharding) {
+ Value val = result.cast<Value>();
+ bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) {
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ if (!shardOp)
+ return false;
+ return !shardOp.getAnnotateForUsers();
+ });
+
+ if (anyShardedForDef) {
+ assert(val.hasOneUse() &&
+ "expected to has exact one use if it has a use of mesh.shard "
+ "without unit attr annotate_for_users");
+ auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin());
+ return shardOp.getShard();
+ } else if (useOperandSharding) {
+ bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) {
+ auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
+ if (!shardOp)
+ return false;
+ return shardOp.getAnnotateForUsers();
+ });
+ if (anyShardedForUsers) {
+ SmallVector<ShardOp> shardOps;
+ for (Operation *user : val.getUsers()) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
+ if (shardOp)
+ shardOps.push_back(shardOp);
+ }
+ MeshShardingAttr shardForDef = shardOps[0].getShard();
+ for (size_t i = 1; i < shardOps.size(); ++i) {
+ // TODO: Deduce a reasonable mesh sharding attr for def when they are
+ // different
+ assert(shardOps[i].getShard() == shardForDef &&
+ "only support all shard ops have the same mesh sharding attr");
+ }
+ return shardForDef;
+ }
+ }
+
+ return failure();
+}
+
+// This method aims to retrieve the mesh sharding attribute (MeshShardingAttr)
+// for a given operation operand.
+static FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand) {
+ Value val = opOperand.get();
+ if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) {
+ return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
+ }
+
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::verifyShardingInterfaceImpl
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+checkOperandAffineExprRecursively(AffineExpr expr,
+ SmallVectorImpl<bool> &seenIds) {
+ switch (expr.getKind()) {
+ case AffineExprKind::Add: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineExpr lhs = binOpExpr.getLHS();
+ AffineExpr rhs = binOpExpr.getRHS();
+ if (failed(checkOperandAffineExprRecursively(lhs, seenIds)))
+ return failure();
+ if (failed(checkOperandAffineExprRecursively(rhs, seenIds)))
+ return failure();
+ return success();
+ }
+ case AffineExprKind::Mul: {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
+ AffineExpr lhs = binOpExpr.getLHS();
+ AffineExpr rhs = binOpExpr.getRHS();
+ AffineExpr dimExpr;
+ if (lhs.getKind() == AffineExprKind::DimId) {
+ dimExpr = lhs;
+ if (rhs.getKind() != AffineExprKind::Constant)
+ return failure();
+ } else if (rhs.getKind() == AffineExprKind::DimId &&
+ lhs.getKind() == AffineExprKind::Constant) {
+ dimExpr = rhs;
+ } else
+ return failure();
+ unsigned position = dimExpr.cast<AffineDimExpr>().getPosition();
+ if ((size_t)position >= seenIds.size() || seenIds[position])
+ return failure();
+ seenIds[position] = true;
+ return success();
+ }
+ case AffineExprKind::DimId: {
+ unsigned position = expr.cast<AffineDimExpr>().getPosition();
+ if ((size_t)position >= seenIds.size() || seenIds[position])
+ return failure();
+ seenIds[position] = true;
+ return success();
+ }
+ default:
+ return failure();
+ }
+}
+
+static FailureOr<llvm::SmallSet<unsigned, 2>>
+checkOperandAffineExpr(AffineExpr expr, unsigned numDims) {
+ SmallVector<bool> seenIds(numDims, false);
+ if (failed(checkOperandAffineExprRecursively(expr, seenIds)))
+ return failure();
+
+ llvm::SmallSet<unsigned, 2> positions;
+ for (auto it : llvm::enumerate(seenIds)) {
+ if (it.value())
+ positions.insert((unsigned)it.index());
+ }
+ return positions;
+}
+
+LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
+ Operation *op = getOperation();
+
+ // check operands and results type
+ for (Type type : op->getOperandTypes())
+ if (!llvm::isa<RankedTensorType>(type))
+ return failure();
+ for (Type type : op->getResultTypes())
+ if (!llvm::isa<RankedTensorType>(type))
+ return failure();
+
+ // check loop types
+ SmallVector<IteratorType> loopTypes = getLoopIteratorTypes();
+ if (loopTypes.size() == 0)
+ return failure();
+
+ // check maps
+ SmallVector<AffineMap> maps = getIndexingMaps();
+ if (maps.size() == 0)
+ return failure();
+ unsigned numOperands = op->getNumOperands();
+ unsigned numResults = op->getNumResults();
+ if (numOperands + numResults != maps.size())
+ return failure();
+
+ for (OpResult result : op->getResults()) {
+ auto resultType = result.getType().dyn_cast<RankedTensorType>();
+ if (!resultType)
+ return failure();
+ AffineMap map = maps[numOperands + result.getResultNumber()];
+ if (!map.isProjectedPermutation()) {
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingInterface::printLoopTypesAndIndexingMaps
+//===----------------------------------------------------------------------===//
+
+void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
+ os << "print loop types and indexing maps for: \n";
+ getOperation()->print(os);
+ os << "\n";
+ os << "loop types: [";
+ for (IteratorType type : getLoopIteratorTypes()) {
+ os << stringifyEnum(type) << " ";
+ }
+ os << "]\n";
+ os << "indexing maps: \n";
+ for (AffineMap map : getIndexingMaps())
+ os << map << "\n";
+ os << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultGetShardingOption
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
+static LogicalResult
+fillShardingOption(Operation *op, ShardingOption &shardingOption,
+ SymbolRefAttr cluster, ArrayRef<int32_t> meshAxes,
+ unsigned loopIdx, bool ignoreIfConflicted = false) {
+ if ((shardingOption.cluster && cluster &&
+ shardingOption.cluster != cluster) ||
+ (!shardingOption.shardingArray[loopIdx].empty() &&
+ shardingOption.shardingArray[loopIdx] != meshAxes)) {
+ if (ignoreIfConflicted)
+ return success();
+ else
+ return op->emitOpError()
+ << "sharding option confilicts on loop iterator " << loopIdx;
+ }
+ for (size_t i = 0; i < shardingOption.shardingArray.size(); ++i) {
+ if (i != loopIdx) {
+ for (int32_t axis : meshAxes)
+ if (std::find(shardingOption.shardingArray[i].begin(),
+ shardingOption.shardingArray[i].end(),
+ axis) != shardingOption.shardingArray[i].end()) {
+ if (ignoreIfConflicted)
+ return success();
+ else
+ return op->emitOpError()
+ << "sharding option confilicts because of mesh axis " << axis
+ << " duplicates";
+ }
+ }
+ }
+ if (cluster)
+ shardingOption.cluster = cluster;
+ if (shardingOption.shardingArray[loopIdx].empty())
+ shardingOption.shardingArray[loopIdx].append(meshAxes.begin(),
+ meshAxes.end());
+ return success();
+}
+
+} // namespace
+
+FailureOr<ShardingOption> mesh::detail::defaultGetShardingOption(Operation *op,
+ OpBuilder &b) {
+
+ // 1. If a valid sharding attribute exists, use it.
+ FailureOr<ShardingOption> shardingOptionFromAttr =
+ getShardingOptionFromAttr(op);
+ if (succeeded(shardingOptionFromAttr))
+ return shardingOptionFromAttr;
+
+ ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+ ShardingOption shardingOption;
+
+ if (failed(shardingOp.verifyShardingInterfaceImpl()))
+ return op->emitOpError() << "invalid sharding interface implementation";
+ SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+ SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+ unsigned numOperands = op->getNumOperands();
+ shardingOption.shardingArray.resize(loopTypes.size());
+ llvm::SmallVector<int32_t> partialMeshAxes;
+ Partial partialType;
+ llvm::SmallSet<unsigned, 4> visitedLoopIndices;
+ bool anyShardingInResultsOrOperands = false;
+
+ // 2. Fill sharding option based on op results
+ for (OpResult result : op->getResults()) {
+ AffineMap map = maps[numOperands + result.getResultNumber()];
+ FailureOr<MeshShardingAttr> shardAttr = getMeshShardingAttr(result, true);
+ if (failed(shardAttr))
+ continue;
+ anyShardingInResultsOrOperands = true;
+ // Handle the split axes: calculate the corresponding loop index for each
+ // split axes sub-array, and then store the sub-array to
+ // shardingOption[index]
+ for (auto it : llvm::zip(map.getResults(), shardAttr->getSplitAxes())) {
+ AffineExpr expr = std::get<0>(it);
+ ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+ auto dim = expr.cast<AffineDimExpr>();
+ unsigned index = dim.getPosition();
+ visitedLoopIndices.insert(index);
+ if (failed(fillShardingOption(op, shardingOption, shardAttr->getCluster(),
+ axes, index)))
+ return failure();
+ }
+
+ // Handle the partial axes: at this stage, the exact loop index/indices
+ // cannot be decided because there could be multiple reduction loops.
+ ArrayRef<int32_t> partialAxes = shardAttr->getPartialAxes();
+ if (!partialAxes.empty()) {
+ if (!partialMeshAxes.empty())
+ return op->emitOpError() << "at most one result with partial axes is "
+ "supported at present";
+ partialType = shardAttr->getPartialType();
+ partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
+ // Add all the reduction loop indices to `visitedLoopIndices` if
+ // `partialAxes` is not empty
+ for (size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
+ if (isReductionLoop(loopTypes[loopIdx]))
+ visitedLoopIndices.insert(loopIdx);
+ }
+ }
+ }
+
+ // 3. Fill sharding option based on operands
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ FailureOr<std::pair<bool, MeshShardingAttr>> maybeShardAttr =
+ getMeshShardingAttr(opOperand);
+ if (failed(maybeShardAttr))
+ continue;
+
+ anyShardingInResultsOrOperands = true;
+ bool annotateForUsers = maybeShardAttr->first;
+ MeshShardingAttr shardAttr = maybeShardAttr->second;
+ AffineMap map = maps[opOperand.getOperandNumber()];
+ unsigned numDims = map.getNumDims();
+
+ // Handle the split axes, and partial axes don't need to be handled because
+ // they only affect the definig op of the operand
+ //
+ // TODO: Change to process the operands with single loop index first and
+ // then the operands with multiple loop indices
+ for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) {
+ AffineExpr expr = std::get<0>(it);
+ ArrayRef<int32_t> axes = std::get<1>(it).asArrayRef();
+ FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+ checkOperandAffineExpr(expr, numDims);
+ if (failed(loopIndices))
+ return op->emitOpError()
+ << "operand's affine expression is restricted to const_i * "
+ "dim_i + const_j + dim_j + ...";
+ if (loopIndices->empty())
+ continue;
+ if (loopIndices->size() == 1) {
+ unsigned loopIdx = *loopIndices->begin();
+ visitedLoopIndices.insert(loopIdx);
+ if (failed(fillShardingOption(op, shardingOption,
+ shardAttr.getCluster(), axes, loopIdx,
+ !annotateForUsers)))
+ return failure();
+ }
+ // If multiple loop indices correspond to a dimension of an operand, it is
+ // difficult to infer which loop indices are responsible for sharding.
+ // Therefore, the exact loop index must be specified by others.
+ if (loopIndices->size() > 1) {
+ bool seenLoopIndices = false;
+ for (unsigned loopIdx : *loopIndices) {
+ if (visitedLoopIndices.contains(loopIdx)) {
+ seenLoopIndices = true;
+ break;
+ }
+ }
+ if (!seenLoopIndices)
+ return op->emitOpError()
+ << "the operand " << opOperand.getOperandNumber()
+ << " has multiple loop indices in a dimension, but none of "
+ "them could be found in the exactly specified annotation "
+ "of op results or operands.";
+ }
+ }
+ }
+
+ // 4. Finalize sharding option
+ if (!partialMeshAxes.empty()) {
+ bool anyNonEmptyReductionLoop = llvm::any_of(
+ llvm::enumerate(shardingOption.shardingArray), [&](auto it) {
+ SmallVector<int32_t> &subArray = it.value();
+ int64_t idx = it.index();
+ return isReductionLoop(loopTypes[idx]) && !subArray.empty();
+ });
+ if (!anyNonEmptyReductionLoop) {
+ bool filled = false;
+ for (size_t idx = 0; idx < loopTypes.size(); ++idx) {
+ if (isReductionLoop(loopTypes[idx]) &&
+ areReductionAndPartialMatch(loopTypes[idx], partialType)) {
+ std::ignore = fillShardingOption(op, shardingOption, nullptr,
+ partialMeshAxes, idx);
+ filled = true;
+ break;
+ }
+ }
+ if (!filled)
+ return op->emitOpError() << "no matched reduction loop found for the "
+ "result's partial type";
+ }
+ }
+ removeTrailingEmptySubArray(shardingOption.shardingArray);
+ if (!anyShardingInResultsOrOperands)
+ shardingOption.empty = true;
+ return shardingOption;
+}
+
+//===----------------------------------------------------------------------===//
+// detail::defaultAddShardingAnnotations
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// To add a `mesh.shard` op for the given result, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpResult result,
+ const ShardingOption &shardingOption,
+ AffineMap map,
+ ArrayRef<IteratorType> loopTypes) {
+ if (succeeded(getMeshShardingAttr(result, false)))
+ return success();
+
+ auto resultType = result.getType().cast<RankedTensorType>();
+ SmallVector<SmallVector<int32_t>> splitAxes(resultType.getRank());
+ SmallVector<int32_t> partialAxes;
+
+ // process the split axes
+ for (auto it : llvm::enumerate(map.getResults())) {
+ AffineExpr expr = it.value();
+ auto dim = expr.cast<AffineDimExpr>();
+ unsigned loopIdx = dim.getPosition();
+ if (loopIdx < shardingOption.shardingArray.size())
+ splitAxes[it.index()].append(shardingOption.shardingArray[loopIdx]);
+ }
+
+ // process the partial axes
+ Partial partialType;
+ for (auto it : llvm::zip(loopTypes, shardingOption.shardingArray)) {
+ IteratorType iType = std::get<0>(it);
+ if (isReductionLoop(iType)) {
+ Partial curPartialType = getPartialTypeFromReduction(iType);
+ if (!partialAxes.empty())
+ assert(partialType == curPartialType &&
+ "Only one reduction type is supported");
+ partialType = curPartialType;
+ const SmallVector<int32_t> &axis = std::get<1>(it);
+ partialAxes.append(axis);
+ }
+ }
+
+ removeTrailingEmptySubArray(splitAxes);
+ MeshShardingAttr shardAttr =
+ MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes,
+ partialAxes, partialType);
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointAfterValue(result);
+ auto shardOp = b.create<ShardOp>(result.getLoc(), resultType, result,
+ shardAttr, /*annotate_for_users*/ false);
+ result.replaceAllUsesExcept(shardOp, shardOp);
+ return success();
+}
+
+// To add a `mesh.shard` op for the given operand, based on the details provided
+// in `shardingOption`, `map`, and `loopTypes`.
+static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand,
+ const ShardingOption &shardingOption,
+ AffineMap map,
+ ArrayRef<IteratorType> loopTypes) {
+ auto maybeShardingAttr = getMeshShardingAttr(opOperand);
+ if (succeeded(maybeShardingAttr) && maybeShardingAttr->first)
+ return success();
+ Value operand = opOperand.get();
+ auto operandType = operand.getType().cast<RankedTensorType>();
+ SmallVector<SmallVector<int32_t>> splitAxes(operandType.getRank());
+ unsigned numDims = map.getNumDims();
+ for (auto it : llvm::enumerate(map.getResults())) {
+ int64_t idx = it.index();
+ AffineExpr expr = it.value();
+ FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
+ checkOperandAffineExpr(expr, numDims);
+ if (failed(loopIndices))
+ return failure();
+ SmallVector<unsigned> shardedLoopIndices;
+ for (unsigned loopIdx : *loopIndices) {
+ if ((size_t)loopIdx < shardingOption.shardingArray.size() &&
+ !shardingOption.shardingArray[loopIdx].empty())
+ shardedLoopIndices.push_back(loopIdx);
+ }
+ // mostly one sharded loop index is accepted
+ if (shardedLoopIndices.size() > 1)
+ return failure();
+ if (shardedLoopIndices.size() == 1) {
+ splitAxes[idx].append(
+ shardingOption.shardingArray[shardedLoopIndices[0]]);
+ }
+ }
+
+ removeTrailingEmptySubArray(splitAxes);
+ MeshShardingAttr shardAttr =
+ MeshShardingAttr::get(b.getContext(), shardingOption.cluster, splitAxes);
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPoint(opOperand.getOwner());
+ auto shardOp = b.create<ShardOp>(operand.getLoc(), operandType, operand,
+ shardAttr, true);
+ opOperand.set(shardOp);
+
+ return success();
+}
+
+} // namespace
+
+LogicalResult mesh::detail::defaultAddShardingAnnotations(
+ Operation *op, OpBuilder &b, const ShardingOption &shardingOption) {
+ ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
+ SmallVector<IteratorType> loopTypes = shardingOp.getLoopIteratorTypes();
+ SmallVector<AffineMap> maps = shardingOp.getIndexingMaps();
+ unsigned numOperands = op->getNumOperands();
+
+ // 1. add mesh.shard ops for all op results
+ for (OpResult result : op->getResults()) {
+ if (failed(addShardOp(b, result, shardingOption,
+ maps[numOperands + result.getResultNumber()],
+ loopTypes)))
+ return failure();
+ }
+
+ // 2. add mesh.shard ops for all operands
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ if (failed(addShardOp(b, opOperand, shardingOption,
+ maps[opOperand.getOperandNumber()], loopTypes)))
+ return failure();
+ }
+
+ return success();
+}
\ No newline at end of file
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..9f85d8e9cb22d5b
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRMeshTransforms
+ ShardingPropagation.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
+
+ DEPENDS
+ MLIRMeshPassIncGen
+ MLIRShardingInterface
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRMeshDialect
+ MLIRShardingInterface
+ MLIRSupport
+ MLIRTosaShardingInterfaceImpl
+)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
new file mode 100644
index 000000000000000..a3f305444afd835
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -0,0 +1,155 @@
+//===- ShardingPropagation.cpp ------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Debug.h"
+#include <vector>
+
+namespace mlir {
+namespace mesh {
+#define GEN_PASS_DEF_SHARDINGPROPAGATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+} // namespace mesh
+} // namespace mlir
+
+#define DEBUG_TYPE "sharding-propagation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Utilities
+//===----------------------------------------------------------------------===//
+
+static std::vector<Operation *> getOperationsVector(Block &block) {
+ std::vector<Operation *> res;
+ for (auto it = block.begin(); it != block.end(); ++it) {
+ Operation *op = &*it;
+ res.push_back(op);
+ }
+ return res;
+}
+
+static std::vector<Operation *> getReversedOperationsVector(Block &block) {
+ std::vector<Operation *> res;
+ for (auto it = block.rbegin(); it != block.rend(); ++it) {
+ Operation *op = &*it;
+ res.push_back(op);
+ }
+ return res;
+}
+
+// For each operation that implements the ShardingInterface, infer the sharding
+// option of the operation from its operands and/or results using the
+// `getShardingOption` method. If the inferred sharding option is not empty, add
+// a `mesh.shard` operation for all remaining operands and results that do not
+// have sharding annotations.
+LogicalResult visitOp(Operation *op, OpBuilder &builder) {
+ if (op->hasTrait<OpTrait::IsTerminator>() || llvm::isa<mesh::ShardOp>(op))
+ return success();
+
+ ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op);
+ if (!shardingOp) {
+ op->emitOpError() << "sharding interface is not implemented.";
+ return failure();
+ }
+
+ FailureOr<ShardingOption> shardingOption =
+ shardingOp.getShardingOption(builder);
+ if (failed(shardingOption)) {
+ op->emitOpError() << "fail to get sharding option from results.";
+ return failure();
+ }
+ // sharding info is empty, return immediately
+ if (shardingOption->empty)
+ return success();
+
+ ArrayAttr shardingArrayAttr =
+ builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray);
+ LLVM_DEBUG(DBGS() << "mesh cluster: " << shardingOption->cluster << "\n");
+ LLVM_DEBUG(DBGS() << "sharding array: " << shardingArrayAttr << "\n");
+ op->setAttr(getMeshClusterName(), shardingOption->cluster);
+ op->setAttr(getShardingArrayName(),
+ builder.getArrayOfI32ArrayAttr(shardingOption->shardingArray));
+
+ if (failed(shardingOp.addShardingAnnotations(builder, *shardingOption))) {
+ op->emitOpError() << "fail to set sharding annotations.";
+ return failure();
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagationPass
+//===----------------------------------------------------------------------===//
+struct ShardingPropagationPass
+ : public mesh::impl::ShardingPropagationBase<ShardingPropagationPass> {
+ void runOnOperation() override {
+ func::FuncOp funcOp = getOperation();
+ MLIRContext *ctx = funcOp.getContext();
+ Region ®ion = funcOp.getBody();
+ OpBuilder builder(ctx);
+ if (!region.hasOneBlock()) {
+ funcOp.emitOpError() << "only one block is supported!";
+ signalPassFailure();
+ }
+ Block &block = region.front();
+
+ // clang-format off
+ LLVM_DEBUG(
+ DBGS() << "print all the ops' iterator types and indexing maps in the "
+ "block.\n";
+ DenseSet<ShardingInterface> ops;
+ for (Operation &op : block.getOperations()) {
+ if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) {
+ ops.insert(shardingOp);
+ }
+ }
+ for (ShardingInterface shardingOp : ops) {
+ shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
+ }
+ );
+ // clang-format on
+
+ // 1. propagate in reversed order
+ {
+ std::vector<Operation *> curOps = getReversedOperationsVector(block);
+ for (Operation *op : curOps) {
+ if (failed(visitOp(op, builder)))
+ return signalPassFailure();
+ }
+ }
+
+ LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+ << funcOp << "\n");
+
+ // 2. propagate in original order
+ {
+ std::vector<Operation *> curOps = getOperationsVector(block);
+ for (Operation *op : curOps) {
+ if (failed(visitOp(op, builder)))
+ return signalPassFailure();
+ }
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::mesh::createShardingPropagationPass() {
+ return std::make_unique<ShardingPropagationPass>();
+}
diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt
index 8e32579e0e4c2e3..ba5343dcd7ac6c1 100644
--- a/mlir/lib/Dialect/Tosa/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt
@@ -26,4 +26,18 @@ add_mlir_dialect_library(MLIRTosaDialect
MLIRViewLikeInterface
)
+add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl
+ IR/ShardingInterfaceImpl.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRMeshDialect
+ MLIRShardingInterface
+ MLIRSupport
+ MLIRTosaDialect
+ )
+
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000000..dace86533c0e231
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -0,0 +1,111 @@
+//===- ShardingInterfaceImpl.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "tosa-sharding-impl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
+
+using namespace mlir;
+using namespace mlir::tosa;
+using namespace mlir::mesh;
+
+namespace {
+
+template <typename ElemwiseOp>
+struct ElemwiseSharding
+ : public ShardingInterface::ExternalModel<ElemwiseSharding<ElemwiseOp>,
+ ElemwiseOp> {
+ SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ Value val = op->getOperand(0);
+ auto type = val.getType().dyn_cast<RankedTensorType>();
+ if (!type)
+ return {};
+ SmallVector<IteratorType> types(type.getRank(), IteratorType::Parallel);
+ return types;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ MLIRContext *ctx = op->getContext();
+ Value val = op->getOperand(0);
+ auto type = val.getType().dyn_cast<RankedTensorType>();
+ if (!type)
+ return {};
+ int64_t rank = type.getRank();
+ int64_t num = op->getNumOperands() + op->getNumResults();
+ SmallVector<AffineMap> maps(num,
+ AffineMap::getMultiDimIdentityMap(rank, ctx));
+ return maps;
+ }
+};
+
+// loop types: [parallel, parallel, parallel, reduction_sum]
+// indexing maps:
+// (d0, d1, d2, d3) -> (d0, d1, d3)
+// (d0, d1, d2, d3) -> (d0, d3, d2)
+// (d0, d1, d2, d3) -> (d0, d1, d2)
+struct MatMulOpSharding
+ : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
+ SmallVector<IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+ if (!tensorType)
+ return {};
+
+ SmallVector<IteratorType> types(tensorType.getRank() + 1,
+ IteratorType::Parallel);
+ types[tensorType.getRank()] = IteratorType::ReductionSum;
+ return types;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ auto tensorType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
+ if (!tensorType)
+ return {};
+ MLIRContext *ctx = op->getContext();
+ SmallVector<AffineMap> maps;
+ maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
+ maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+ maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
+ return maps;
+ }
+};
+
+template <typename OpType>
+static void registerElemwiseOne(MLIRContext *ctx) {
+ OpType::template attachInterface<ElemwiseSharding<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerElemwiseAll(MLIRContext *ctx) {
+ (registerElemwiseOne<OpTypes>(ctx), ...);
+}
+
+} // namespace
+
+void mlir::tosa::registerShardingInterfaceExternalModels(
+ DialectRegistry ®istry) {
+
+ registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
+ registerElemwiseAll<
+ ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
+ BitwiseOrOp, BitwiseXorOp, DivOp, LogicalAndOp, LogicalLeftShiftOp,
+ LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
+ MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
+ LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+ GreaterOp, GreaterEqualOp>(ctx);
+
+ MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index f4e29539214b4b6..b247b8cc694eca9 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -259,6 +259,19 @@ SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
return res;
}
+SmallVector<SmallVector<int32_t>>
+mlir::getArrayOfI32Array(ArrayAttr arrayAttr) {
+ SmallVector<SmallVector<int32_t>> arrayOfI32Array;
+ for (auto attr : arrayAttr) {
+ arrayOfI32Array.push_back(llvm::to_vector(
+ llvm::map_range(llvm::cast<ArrayAttr>(attr), [&](Attribute intAttr) {
+ return static_cast<int32_t>(
+ llvm::cast<IntegerAttr>(intAttr).getInt());
+ })));
+ }
+ return arrayOfI32Array;
+}
+
// TODO: do we have any common utily for this?
static MLIRContext *getContext(OpFoldResult val) {
assert(val && "Invalid value");
diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp
index 9cdac964710ca86..db5087ce42809f8 100644
--- a/mlir/lib/IR/AffineMap.cpp
+++ b/mlir/lib/IR/AffineMap.cpp
@@ -235,6 +235,19 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
return permutationMap;
}
+AffineMap AffineMap::getMultiDimMapWithTargets(unsigned numDims,
+ ArrayRef<int64_t> targets,
+ MLIRContext *context) {
+ AffineMap result =
+ AffineMap::get(/*dimCount=*/numDims, /*symbolCount=*/0, context);
+ int64_t pos = 0;
+ for (int64_t t : targets) {
+ result = result.insertResult(getAffineDimExpr(t, context), pos);
+ pos += 1;
+ }
+ return result;
+}
+
template <typename AffineExprContainer>
static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index ab20f4863e11c23..b86ee432a53af04 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -316,6 +316,15 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
return getArrayAttr(attrs);
}
+ArrayAttr
+Builder::getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values) {
+ auto attrs =
+ llvm::map_to_vector<8>(values, [this](ArrayRef<int32_t> v) -> Attribute {
+ return getI32ArrayAttr(v);
+ });
+ return getArrayAttr(attrs);
+}
+
ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
auto attrs = llvm::map_to_vector<8>(
values, [](Type v) -> Attribute { return TypeAttr::get(v); });
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
new file mode 100644
index 000000000000000..4c0809dc5e58636
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1)
+mesh.cluster @mesh_2d(rank = 2, dim_sizes = [2, 4])
+mesh.cluster @mesh_3d(rank = 3)
+
+// CHECK-LABEL: func.func @element_wise_empty_sharding_info
+func.func @element_wise_empty_sharding_info(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: tosa.sigmoid
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: return
+ return %0 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_def
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_def(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V2]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_use
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_use(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V2]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_output
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_output(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = tosa.sigmoid %[[V0]]
+ %0 = tosa.sigmoid %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = mesh.shard %[[V1]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @element_wise_on_graph_input
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @element_wise_on_graph_input(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[V0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ %0 = mesh.shard %arg0 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.sigmoid %[[V1]]
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @arrow_structure
+// CHECK-SAME: %[[ARG:.*]]: tensor<8x16xf32>
+func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) {
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.tanh %[[V1]]
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ %0 = tosa.tanh %arg0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V4:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V5:.*]] = tosa.abs %[[V4]]
+ // CHECK-NEXT: %[[V6:.*]] = mesh.shard %[[V5]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ %1 = tosa.abs %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ // CHECK-NEXT: %[[V7:.*]] = mesh.shard %[[V3]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK-NEXT: %[[V8:.*]] = tosa.negate %[[V7]]
+ // CHECK-NEXT: %[[V9:.*]] = mesh.shard %[[V8]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<8x16xf32>
+ %2 = tosa.negate %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %3 = mesh.shard %2 to <@mesh_2d, [[0], [1]]> : tensor<8x16xf32>
+ // CHECK-NEXT: return %[[V6]], %[[V9]]
+ return %1, %3 : tensor<8x16xf32>, tensor<8x16xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}0], [1]]> annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+ %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}0], [1]]> : tensor<2x16x32xf32>
+ %1 = mesh.shard %0 to <@mesh_2d, [[0], [1]]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_def_shard_m_and_k
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_def_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+ %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ %1 = mesh.shard %0 to <@mesh_2d, [[], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_k
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+ %1 = tosa.matmul %0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %1 : tensor<2x16x32xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_on_use_shard_m_and_duplicted_k
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
+func.func @matmul_on_use_shard_m_and_duplicted_k(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+ // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to <@mesh_2d, {{\[\[}}], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+ %0 = mesh.shard %arg0 to <@mesh_2d, [[], [1], [0]]> annotate_for_users : tensor<2x16x8xf32>
+ // CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to <@mesh_2d, {{\[\[}}], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ %1 = mesh.shard %arg1 to <@mesh_2d, [[], [0]]> annotate_for_users : tensor<2x8x32xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
+ %2 = tosa.matmul %0, %1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to <@mesh_2d, {{\[\[}}], [1]], partial = sum[0]> : tensor<2x16x32xf32>
+ // CHECK-NEXT: return %[[V3]]
+ return %2 : tensor<2x16x32xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(a)
+// CHECK-LABEL: func.func @mlp_1d_weight_stationary
+func.func @mlp_1d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[], [], [0]]> : tensor<2x4x8xf32>
+ // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = {{\[\[}}], [], [0 : i32]]}
+ %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_1d, sharding_array = {{\[\[}}], [], [0 : i32]]}
+ %2 = tosa.sigmoid %1 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_1d, sharding_array = {{\[\[}}], [], [], [0 : i32]]}
+ %3 = tosa.matmul %2, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+ %4 = mesh.shard %3 to <@mesh_1d, [[], [], []], partial = sum[0]> : tensor<2x4x8xf32>
+ %5 = mesh.shard %4 to <@mesh_1d, [[], [], [0]]> annotate_for_users : tensor<2x4x8xf32>
+ return %5 : tensor<2x4x8xf32>
+}
+
+// https://arxiv.org/abs/2211.05102 Figure 2(b)
+// CHECK-LABEL: func.func @mlp_2d_weight_stationary
+func.func @mlp_2d_weight_stationary(%arg0: tensor<2x4x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<2x32x8xf32>) -> tensor<2x4x8xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_3d, [[], [], [0, 1, 2]]> : tensor<2x4x8xf32>
+ // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = {{\[\[}}], [], [1 : i32, 2 : i32], [0 : i32]]}
+ %1 = tosa.matmul %0, %arg1 : (tensor<2x4x8xf32>, tensor<2x8x32xf32>) -> tensor<2x4x32xf32>
+ %2 = mesh.shard %1 to <@mesh_3d, [[], [], [1, 2]], partial = sum[0]> : tensor<2x4x32xf32>
+ // CHECK: tosa.sigmoid {{.*}} {mesh_cluster = @mesh_3d, sharding_array = {{\[\[}}], [], [1 : i32, 2 : i32]]}
+ %3 = tosa.sigmoid %2 : (tensor<2x4x32xf32>) -> tensor<2x4x32xf32>
+ // CHECK: tosa.matmul {{.*}} {mesh_cluster = @mesh_3d, sharding_array = {{\[\[}}], [], [0 : i32], [1 : i32, 2 : i32]]}
+ %4 = tosa.matmul %3, %arg2 : (tensor<2x4x32xf32>, tensor<2x32x8xf32>) -> tensor<2x4x8xf32>
+ %5 = mesh.shard %4 to <@mesh_3d, [[], [], [0]], partial = sum[1, 2]> : tensor<2x4x8xf32>
+ %6 = mesh.shard %5 to <@mesh_3d, [[], [], [0, 1, 2]]> annotate_for_users : tensor<2x4x8xf32>
+ return %6 : tensor<2x4x8xf32>
+}
More information about the Mlir-commits
mailing list