[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #71261)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 3 18:02:31 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Chengji Yao (yaochengji)

<details>
<summary>Changes</summary>

Add a pass that propagates sharding information throughout the graph.
After this pass, each of the operations' operands and results is
annotated with a mesh.shard operation.

The pass is driven by a newly added ShardingInterface, and an implementation
for element-wise and matmul ops in the TOSA dialect is provided.

---

Patch is 69.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71261.diff


23 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/CMakeLists.txt (+2) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+34) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+18) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt (+4) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+68) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+102) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt (+6) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+39) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+32) 
- (added) mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h (+23) 
- (modified) mlir/include/mlir/IR/AffineMap.h (+12) 
- (modified) mlir/include/mlir/InitAllDialects.h (+2) 
- (modified) mlir/include/mlir/InitAllPasses.h (+2) 
- (modified) mlir/lib/Dialect/Mesh/CMakeLists.txt (+2) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+31-1) 
- (added) mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt (+15) 
- (added) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+524) 
- (added) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+19) 
- (added) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+210) 
- (modified) mlir/lib/Dialect/Tosa/CMakeLists.txt (+14) 
- (added) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+111) 
- (modified) mlir/lib/IR/AffineMap.cpp (+12-4) 
- (added) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+188) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..a91ef569347bff1 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+  I32EnumAttrCase<"Parallel", 1, "parallel">,
+  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+  I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+    let genSpecializedAttr = 0;
+    let cppNamespace = "::mlir::mesh";
+}
+
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
        $partial_axes^ `]`)? `>`
   }];
 
+  let builders = [
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
+                     "ArrayRef<int32_t>": $partial_axes,
+                     "mesh::Partial": $partial_type), [{
+      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::map_to_vector(
+                  split_axes, [&](ArrayRef<int32_t> array) {
+          return DenseI32ArrayAttr::get($_ctxt, array);
+      });
+      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+                   partial_type);
+    }]>,
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+    }]>
+  ];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..05eba66a89949b6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,22 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+  while (!array.empty() && array.back().empty())
+    array.pop_back();
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..d860628cf371aa9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,68 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+  // An array of int array. The sub-array at the i-th position signifies the
+  // mesh axes the i-th loop will be sharded on.
+  ShardingArray shardingArray;
+  SymbolRefAttr cluster;
+  // `empty` being true indicates that no sharding information can be inferred
+  // at present. Note that it is different from the case where an operation is
+  // not sharded.
+  bool empty = false;
+  ShardingOption() = default;
+  ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+      : shardingArray(std::move(shardingArray)), cluster(cluster) {}
+};
+
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// result and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpResult result);
+
+// This method retrieves the 'MeshShardingAttr' attribute from a given operation
+// operand and includes the 'annotate_for_users' information.
+FailureOr<std::pair<bool, MeshShardingAttr>>
+getMeshShardingAttr(OpOperand &opOperand);
+
+namespace detail {
+
+FailureOr<ShardingOption>
+defaultGetShardingOption(Operation *op,
+                         ArrayRef<MeshShardingAttr> operandShardings,
+                         ArrayRef<MeshShardingAttr> resultShardings);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+                              const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..21b6c8d4f599a8d
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,102 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+    let description = [{
+        Interface for allowing operations to expose information needed to
+        shard them.
+    }];
+    let cppNamespace = "::mlir::mesh";
+
+    let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of iterator types that describe the number of loops.
+          The iterator types determine how the operation traverses its input and
+          output tensors.
+
+          Example 1: A gemm op has 3 loops, M, N and K. Their loop iterator
+          types are parallel, parallel, reduction-sum. This indicates that M and
+          N are traversed in parallel, while the K dimension is used for
+          reduction.
+
+          Example 2: A softmax op's loop iterator types are parallel and
+          invalid. The second dimension is considered as invalid because it is
+          neither parallel nor any kind of reduction. 
+        }],
+        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*methodName=*/"getLoopIteratorTypes",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the indexing maps attribute within the current operation.
+          Indexing maps determine how indices in the iteration space map to
+          tensor indices. They are specified using `affine_map` in MLIR, which
+          provides an affine transformation of indices.
+        }],
+        /*retTy=*/"SmallVector<AffineMap>",
+        /*methodName=*/"getIndexingMaps",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Given that certain operands or results of the operation may have
+          sharding annotations, this method leverages this information to deduce
+          how the operation should be sharded.
+        }],
+        /*retTy=*/"FailureOr<ShardingOption>",
+        /*methodName=*/"getShardingOption",
+        /*args=*/(ins
+          "ArrayRef<MeshShardingAttr>": $operandShardings,
+          "ArrayRef<MeshShardingAttr>": $resultShardings
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultGetShardingOption(
+            $_op.getOperation(), operandShardings, resultShardings);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Based on a given ShardingOption, this method adds `mesh.shard`
+          operations for the operands and results that previously lacked
+          sharding annotations.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"addShardingAnnotations",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "const ShardingOption &":$shardingOption
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultAddShardingAnnotations(
+            $_op.getOperation(), b, shardingOption);
+        }]
+      >
+    ];
+
+    let extraClassDeclaration = [{
+      LogicalResult verifyShardingInterfaceImpl();
+
+      void printLoopTypesAndIndexingMaps(raw_ostream &os);
+    }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..83399d10beaae48
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,39 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..c09cf3e710d4278
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,32 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+  let summary = "sharding propagation";
+  let description = [{
+    Propagates sharding information throughout the graph. After this pass, each
+    of the operations' operands and results is annotated with a `mesh.shard`
+    operation, and the operations themselves are added with sharding option
+    attributes.
+  }];
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 5af7835258f6bd2..f691a3daf8889c5 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -104,6 +104,18 @@ class AffineMap {
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
 
+  /// Returns an affine map with `numDims` input dimensions and results
+  /// specified by `targets`.
+  ///
+  /// Examples:
+  /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+  /// * getMultiDimMapWithTargets(3, [2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d2, d1)>
+  static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+                                             ArrayRef<unsigned> targets,
+                                             MLIRContext *context);
+
   /// Returns a vector of AffineMaps; each with as many results as
   /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
   /// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 621110d130818d3..395d899f9ad84b0 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerSubsetOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+  tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   vector::registerSubsetOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 80894094484b999..f22980036ffcfa1 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -30,6 +30,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -74,6 +75,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
+  mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
   registerSCFPasses();
   registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+  return (partial == Partial::Generic &&
+          iType == IteratorType::ReductionGeneric) ||
+         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+         (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+  switch (iType) {
+  case IteratorType::ReductionGeneric:
+    return Partial::Generic;
+  case IteratorType::ReductionSum:
+    return Partial::Sum;
+  case IteratorType::ReductionMax:
+    return Partial::Max;
+  case IteratorType::ReductionMin:
+    return Partial::Min;
+  default:
+    assert(0 && "No corresponding partial type can be found");
+  }
+}
+
 //===--------------------------------------------...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71261


More information about the Mlir-commits mailing list