[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 17:42:41 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Chengji Yao (yaochengji)
<details>
<summary>Changes</summary>
- add `ShardingInterface` and the methods' default implementation
- add `ShardingInterface` implementation for element-wise and matmul ops in TOSA dialect
- add sharding propagation pass
---
Patch is 66.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69665.diff
27 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/CMakeLists.txt (+2)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+34)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+22)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt (+4)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+58)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+87)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt (+6)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+41)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+33)
- (added) mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h (+23)
- (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+3)
- (modified) mlir/include/mlir/IR/AffineMap.h (+12)
- (modified) mlir/include/mlir/IR/Builders.h (+1)
- (modified) mlir/include/mlir/InitAllDialects.h (+2)
- (modified) mlir/include/mlir/InitAllPasses.h (+2)
- (modified) mlir/lib/Dialect/Mesh/CMakeLists.txt (+2)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+31-1)
- (added) mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt (+15)
- (added) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+544)
- (added) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+17)
- (added) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+155)
- (modified) mlir/lib/Dialect/Tosa/CMakeLists.txt (+14)
- (added) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+111)
- (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+13)
- (modified) mlir/lib/IR/AffineMap.cpp (+13)
- (modified) mlir/lib/IR/Builders.cpp (+9)
- (added) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+167)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..b6623ed818f0770 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+ I32EnumAttrCase<"Parallel", 1, "parallel">,
+ I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+ I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+ I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+ I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+ I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::mesh";
+}
+
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
$partial_axes^ `]`)? `>`
}];
+ let builders = [
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes,
+ "ArrayRef<int32_t>": $partial_axes,
+ "mesh::Partial": $partial_type), [{
+ SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
+ llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+ return DenseI32ArrayAttr::get($_ctxt, array);
+ }));
+ return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+ partial_type);
+ }]>,
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+ return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+ }]>
+ ];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..cb86887091330c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,26 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+ for (int64_t i = array.size() - 1; i >= 0; i--) {
+ if (array[i].empty())
+ array.pop_back();
+ else
+ break;
+ }
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..1d19e41ac1fc555
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+ // An array of int array. The sub-array at the i-th position signifies the
+ // mesh axes the i-th loop will be sharded on.
+ ShardingArray shardingArray;
+ SymbolRefAttr cluster;
+ // `empty` is true indicates that no sharding infomation can be inferred at
+ // present. Note that it is different from that an operation is not sharded.
+ bool empty = false;
+ ShardingOption() = default;
+ ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+ : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+namespace detail {
+
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+ const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..c98b9f081492997
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
+ /*args=*/(ins
+ "OpBuilder &":$b
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingOption(
+ $_op.getOperation(), b);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Based on a given ShardingOption, this method adds `mesh.shard`
+ operations for the operands and results that previously lacked
+ sharding annotations.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"addShardingAnnotations",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "const ShardingOption &":$shardingOption
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultAddShardingAnnotations(
+ $_op.getOperation(), b, shardingOption);
+ }]
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ LogicalResult verifyShardingInterfaceImpl();
+
+ void printLoopTypesAndIndexingMaps(raw_ostream &os);
+ }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..aa3555f7f186f24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..d36adfe476a72ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,33 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+ let summary = "sharding propagation";
+ let description = [{
+ Propagates sharding information throughout the graph. After this pass, each
+ of the operations' operands and results is annotated with a `mesh.shard`
+ operation, and the operations themselves are added with sharding option
+ attributes.
+ }];
+ let constructor = "mlir::mesh::createShardingPropagationPass()";
+ let dependentDialects = [
+ "mesh::MeshDialect"
+ ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548ed..b24164cfb552b4f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
+/// Helper to return a vector of sub-vector of int64_t
+SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
+
/// Compute linear index from provided strides and indices, assuming strided
/// layout.
/// Returns AffineExpr and list of values to apply to it, e.g.:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 3430db2b99c3f2e..18e2313ef2b446b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -101,6 +101,18 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ /// Returns an affine map with `numDims` input dimensions and results
+ /// specified by `targets`.
+ ///
+ /// Examples:
+ /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+ /// * getMultiDimMapWithTargets(3, [2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d2, d1)>
+ static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+ ArrayRef<int64_t> targets,
+ MLIRContext *context);
+
/// Returns a vector of AffineMaps; each with as many results as
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
/// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5e54d4ea49e8251..3988835622b7629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -168,6 +168,7 @@ class Builder {
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
ArrayAttr getTypeArrayAttr(TypeRange values);
+ ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..3556f82023828b2 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..27711417ed91a8c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -29,6 +29,7 @@
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -73,6 +74,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
+ mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+ return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+ return (partial == Partial::Generic &&
+ iType == IteratorType::ReductionGeneric) ||
+ (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+ (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+ (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+ switch (iType) {
+ case IteratorType::ReductionGeneric:
+ return Partial::Generic;
+ case IteratorType::ReductionSum:
+ return Partial::Sum;
+ case IteratorType::ReductionMax:
+ return Partial::Max;
+ case Iterator...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/69665
More information about the Mlir-commits
mailing list