[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #71261)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 3 18:02:31 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Chengji Yao (yaochengji)
<details>
<summary>Changes</summary>
Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a mesh.shard operation.
The pass is driven by a newly added ShardingInterface, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.
---
Patch is 69.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71261.diff
23 Files Affected:
- (modified) mlir/include/mlir/Dialect/Mesh/CMakeLists.txt (+2)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+34)
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+18)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt (+4)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+68)
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+102)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt (+6)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+39)
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+32)
- (added) mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h (+23)
- (modified) mlir/include/mlir/IR/AffineMap.h (+12)
- (modified) mlir/include/mlir/InitAllDialects.h (+2)
- (modified) mlir/include/mlir/InitAllPasses.h (+2)
- (modified) mlir/lib/Dialect/Mesh/CMakeLists.txt (+2)
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+31-1)
- (added) mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt (+15)
- (added) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+524)
- (added) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+19)
- (added) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+210)
- (modified) mlir/lib/Dialect/Tosa/CMakeLists.txt (+14)
- (added) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+111)
- (modified) mlir/lib/IR/AffineMap.cpp (+12-4)
- (added) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+188)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..a91ef569347bff1 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
let cppNamespace = "::mlir::mesh";
}
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+ I32EnumAttrCase<"Parallel", 1, "parallel">,
+ I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+ I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+ I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+ I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+ I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::mesh";
+}
+
//===----------------------------------------------------------------------===//
// Mesh Attribute
//===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
$partial_axes^ `]`)? `>`
}];
+ let builders = [
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes,
+ "ArrayRef<int32_t>": $partial_axes,
+ "mesh::Partial": $partial_type), [{
+ SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
+ split_axes, [&](ArrayRef<int32_t> array) {
+ return DenseI32ArrayAttr::get($_ctxt, array);
+ });
+ return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+ partial_type);
+ }]>,
+ AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+ return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+ }]>
+ ];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..05eba66a89949b6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,22 @@
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+ while (!array.empty() && array.back().empty())
+ array.pop_back();
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..d860628cf371aa9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,68 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+ // An array of int array. The sub-array at the i-th position signifies the
+ // mesh axes the i-th loop will be sharded on.
+ ShardingArray shardingArray;
+ SymbolRefAttr cluster;
+ // `empty` being true indicates that no sharding information can be inferred
+ // at present. Note that it is different from the case where an operation is
+ // not sharded.
+ bool empty = false;
+ ShardingOption() = default;
+ ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+ : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+};
+
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// result and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpResult result);
+
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// operand and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand);
+
+namespace detail {
+
+FailureOr<ShardingOption>
+defaultGetShardingOption(Operation *op,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+ const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..21b6c8d4f599a8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,102 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ shard them.
+ }];
+ let cppNamespace = "::mlir::mesh";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of iterator types that describe the number of loops.
+ The iterator types determine how the operation traverses its input and
+ output tensors.
+
+ Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
+ types are parallel, parallel, reduction-sum. This indicates that M and
+ N are traversed in parallel, while the K dimension is used for
+ reduction.
+
+ Example 2: A softmax op's loop iterator types are parallel and
+ invalid. The second dimension is considered as invalid because it is
+ neither parallel nor any kind of reduction.
+ }],
+ /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+ /*methodName=*/"getLoopIteratorTypes",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ Indexing maps determine how indices in the iteration space map to
+ tensor indices. They are specified using `affine_map` in MLIR, which
+ provides an affine transformation of indices.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMaps",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return {};"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Given that certain operands or results of the operation may have
+ sharding annotations, this method leverages this information to deduce
+ how the operation should be sharded.
+ }],
+ /*retTy=*/"FailureOr<ShardingOption>",
+ /*methodName=*/"getShardingOption",
+ /*args=*/(ins
+ "ArrayRef<MeshShardingAttr>": $operandShardings,
+ "ArrayRef<MeshShardingAttr>": $resultShardings
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultGetShardingOption(
+ $_op.getOperation(), operandShardings, resultShardings);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Based on a given ShardingOption, this method adds `mesh.shard`
+ operations for the operands and results that previously lacked
+ sharding annotations.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"addShardingAnnotations",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "const ShardingOption &":$shardingOption
+ ),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return detail::defaultAddShardingAnnotations(
+ $_op.getOperation(), b, shardingOption);
+ }]
+ >
+ ];
+
+ let extraClassDeclaration = [{
+ LogicalResult verifyShardingInterfaceImpl();
+
+ void printLoopTypesAndIndexingMaps(raw_ostream &os);
+ }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..83399d10beaae48
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,39 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..c09cf3e710d4278
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,32 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+ let summary = "sharding propagation";
+ let description = [{
+ Propagates sharding information throughout the graph. After this pass, each
+ of the operations' operands and results is annotated with a `mesh.shard`
+ operation, and the operations themselves are added with sharding option
+ attributes.
+ }];
+ let dependentDialects = [
+ "mesh::MeshDialect"
+ ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..f691a3daf8889c5 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -104,6 +104,18 @@ class AffineMap {
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
MLIRContext *context);
+ /// Returns an affine map with `numDims` input dimensions and results
+ /// specified by `targets`.
+ ///
+ /// Examples:
+ /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+ /// * getMultiDimMapWithTargets(3, [2, 1])
+ /// -> affine_map<(d0, d1, d2) -> (d2, d1)>
+ static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+ ArrayRef<unsigned> targets,
+ MLIRContext *context);
+
/// Returns a vector of AffineMaps; each with as many results as
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
/// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 621110d130818d3..395d899f9ad84b0 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
tensor::registerSubsetOpInterfaceExternalModels(registry);
tensor::registerTilingInterfaceExternalModels(registry);
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerSubsetOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 80894094484b999..f22980036ffcfa1 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -30,6 +30,7 @@
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -74,6 +75,7 @@ inline void registerAllPasses() {
LLVM::registerLLVMPasses();
math::registerMathPasses();
memref::registerMemRefPasses();
+ mesh::registerMeshPasses();
ml_program::registerMLProgramPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
return arith::ConstantOp::materialize(builder, value, type, loc);
}
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+ return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+ return (partial == Partial::Generic &&
+ iType == IteratorType::ReductionGeneric) ||
+ (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+ (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+ (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+ switch (iType) {
+ case IteratorType::ReductionGeneric:
+ return Partial::Generic;
+ case IteratorType::ReductionSum:
+ return Partial::Sum;
+ case IteratorType::ReductionMax:
+ return Partial::Max;
+ case IteratorType::ReductionMin:
+ return Partial::Min;
+ default:
+ assert(0 && "No corresponding partial type can be found");
+ }
+}
+
//===--------------------------------------------...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/71261
More information about the Mlir-commits
mailing list