[Mlir-commits] [mlir] [MLIR][Mesh] Add sharding propagation pass (PR #69665)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 19 17:42:41 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chengji Yao (yaochengji)

<details>
<summary>Changes</summary>

- add `ShardingInterface` and the methods' default implementation
- add `ShardingInterface` implementation for element-wise and matmul ops in TOSA dialect
- add sharding propagation pass

---

Patch is 66.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69665.diff


27 Files Affected:

- (modified) mlir/include/mlir/Dialect/Mesh/CMakeLists.txt (+2) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td (+34) 
- (modified) mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h (+22) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt (+4) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h (+58) 
- (added) mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td (+87) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt (+6) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h (+41) 
- (added) mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td (+33) 
- (added) mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h (+23) 
- (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+3) 
- (modified) mlir/include/mlir/IR/AffineMap.h (+12) 
- (modified) mlir/include/mlir/IR/Builders.h (+1) 
- (modified) mlir/include/mlir/InitAllDialects.h (+2) 
- (modified) mlir/include/mlir/InitAllPasses.h (+2) 
- (modified) mlir/lib/Dialect/Mesh/CMakeLists.txt (+2) 
- (modified) mlir/lib/Dialect/Mesh/IR/MeshOps.cpp (+31-1) 
- (added) mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt (+15) 
- (added) mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp (+544) 
- (added) mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt (+17) 
- (added) mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp (+155) 
- (modified) mlir/lib/Dialect/Tosa/CMakeLists.txt (+14) 
- (added) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+111) 
- (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+13) 
- (modified) mlir/lib/IR/AffineMap.cpp (+13) 
- (modified) mlir/lib/IR/Builders.cpp (+9) 
- (added) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+167) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 39d24595ec1c446..b6623ed818f0770 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -49,6 +49,22 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
   let cppNamespace = "::mlir::mesh";
 }
 
+// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
+// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
+// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or
+// is partial.
+def Mesh_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
+  I32EnumAttrCase<"Parallel", 1, "parallel">,
+  I32EnumAttrCase<"ReductionSum", 2, "reduction_sum">,
+  I32EnumAttrCase<"ReductionMax", 3, "reduction_max">,
+  I32EnumAttrCase<"ReductionMin", 4, "reduction_min">,
+  I32EnumAttrCase<"ReductionGeneric", 5, "reduction_generic">,
+  I32EnumAttrCase<"Invalid", 100, "invalid">
+]> {
+    let genSpecializedAttr = 0;
+    let cppNamespace = "::mlir::mesh";
+}
+
 //===----------------------------------------------------------------------===//
 // Mesh Attribute
 //===----------------------------------------------------------------------===//
@@ -122,6 +138,24 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
        $partial_axes^ `]`)? `>`
   }];
 
+  let builders = [
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes,
+                     "ArrayRef<int32_t>": $partial_axes,
+                     "mesh::Partial": $partial_type), [{
+      SmallVector<DenseI32ArrayAttr> splitAxesAttr = llvm::to_vector(
+        llvm::map_range(split_axes, [&](ArrayRef<int32_t> array) {
+          return DenseI32ArrayAttr::get($_ctxt, array);
+      }));
+      return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
+                   partial_type);
+    }]>,
+    AttrBuilder<(ins "SymbolRefAttr":$cluster, 
+                     "ArrayRef<SmallVector<int32_t>>":$split_axes), [{
+      return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
+    }]>
+  ];
+
   let genVerifyDecl = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9dfeca84d012165..cb86887091330c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -24,4 +24,26 @@
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
 
+namespace mlir {
+namespace mesh {
+
+bool isReductionLoop(IteratorType iType);
+
+bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
+
+template <typename T>
+void removeTrailingEmptySubArray(SmallVector<SmallVector<T>> &array) {
+  for (int64_t i = array.size() - 1; i >= 0; i--) {
+    if (array[i].empty())
+      array.pop_back();
+    else
+      break;
+  }
+}
+
+Partial getPartialTypeFromReduction(IteratorType iType);
+
+} // namespace mesh
+} // namespace mlir
+
 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
new file mode 100644
index 000000000000000..b3a44f3b0089abc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS ShardingInterface.td)
+mlir_tablegen(ShardingInterface.h.inc -gen-op-interface-decls)
+mlir_tablegen(ShardingInterface.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRShardingInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
new file mode 100644
index 000000000000000..1d19e41ac1fc555
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -0,0 +1,58 @@
+//===- ShardingInterface.h --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+class Operation;
+
+namespace mesh {
+
+using ShardingArray = SmallVector<SmallVector<int32_t>>;
+using ShardingArrayRef = ArrayRef<SmallVector<int32_t>>;
+
+struct ShardingOption {
+  // An array of int array. The sub-array at the i-th position signifies the
+  // mesh axes the i-th loop will be sharded on.
+  ShardingArray shardingArray;
+  SymbolRefAttr cluster;
+  // `empty` is true indicates that no sharding infomation can be inferred at
+  // present. Note that it is different from that an operation is not sharded.
+  bool empty = false;
+  ShardingOption() = default;
+  ShardingOption(const ShardingArray &shardingArray, SymbolRefAttr cluster)
+      : shardingArray(shardingArray), cluster(cluster) {}
+};
+
+constexpr StringRef getShardingArrayName() { return "sharding_array"; }
+
+constexpr StringRef getMeshClusterName() { return "mesh_cluster"; }
+
+namespace detail {
+
+FailureOr<ShardingOption> defaultGetShardingOption(Operation *op, OpBuilder &b);
+
+LogicalResult
+defaultAddShardingAnnotations(Operation *op, OpBuilder &b,
+                              const ShardingOption &shardingOption);
+
+} // namespace detail
+
+} // namespace mesh
+
+} // namespace mlir
+
+/// Include the ODS generated interface header files.
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h.inc"
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
new file mode 100644
index 000000000000000..c98b9f081492997
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.td
@@ -0,0 +1,87 @@
+//===- ShardingInterfaces.td -------------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+#define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
+
+include "mlir/IR/OpBase.td"
+
+def ShardingInterface : OpInterface<"ShardingInterface"> {
+    let description = [{
+        Interface for allowing operations to expose information needed to
+        shard them.
+    }];
+    let cppNamespace = "::mlir::mesh";
+
+    let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Returns a list of iterator types that describe the number of loops.
+        }],
+        /*retType=*/"SmallVector<::mlir::mesh::IteratorType>",
+        /*methodName=*/"getLoopIteratorTypes",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return the indexing maps attribute within the current operation.
+        }],
+        /*retTy=*/"SmallVector<AffineMap>",
+        /*methodName=*/"getIndexingMaps",
+        /*args=*/(ins),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return {};"
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Given that certain operands or results of the operation may have
+          sharding annotations, this method leverages this information to deduce
+          how the operation should be sharded.
+        }],
+        /*retTy=*/"FailureOr<ShardingOption>",
+        /*methodName=*/"getShardingOption",
+        /*args=*/(ins
+          "OpBuilder &":$b
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultGetShardingOption(
+            $_op.getOperation(), b);
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Based on a given ShardingOption, this method adds `mesh.shard`
+          operations for the operands and results that previously lacked
+          sharding annotations.
+        }],
+        /*retTy=*/"LogicalResult",
+        /*methodName=*/"addShardingAnnotations",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "const ShardingOption &":$shardingOption
+        ),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return detail::defaultAddShardingAnnotations(
+            $_op.getOperation(), b, shardingOption);
+        }]
+      >
+    ];
+
+    let extraClassDeclaration = [{
+      LogicalResult verifyShardingInterfaceImpl();
+
+      void printLoopTypesAndIndexingMaps(raw_ostream &os);
+    }];
+}
+
+
+#endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACE_TD
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..8d768485103b65f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name Mesh)
+add_public_tablegen_target(MLIRMeshPassIncGen)
+add_dependencies(mlir-headers MLIRMeshPassIncGen)
+
+add_mlir_doc(Passes MeshPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
new file mode 100644
index 000000000000000..aa3555f7f186f24
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -0,0 +1,41 @@
+//===- Passes.h - Mesh Passes -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace func {
+class FuncOp;
+}
+
+namespace mesh {
+
+//===----------------------------------------------------------------------===//
+// Passes
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+std::unique_ptr<OperationPass<func::FuncOp>> createShardingPropagationPass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
new file mode 100644
index 000000000000000..d36adfe476a72ac
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -0,0 +1,33 @@
+//===-- Passes.td - Mesh transformation definition file ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+//===----------------------------------------------------------------------===//
+// ShardingPropagation
+//===----------------------------------------------------------------------===//
+
+def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
+  let summary = "sharding propagation";
+  let description = [{
+    Propagates sharding information throughout the graph. After this pass, each
+    of the operations' operands and results is annotated with a `mesh.shard`
+    operation, and the operations themselves are added with sharding option
+    attributes.
+  }];
+  let constructor = "mlir::mesh::createShardingPropagationPass()";
+  let dependentDialects = [
+    "mesh::MeshDialect"
+  ];
+}
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
new file mode 100644
index 000000000000000..16427919dace5da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h
@@ -0,0 +1,23 @@
+//===- ShardingInterfaceImpl.h - ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+#define MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace tosa {
+
+void registerShardingInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TOSA_TRANSFORMS_SHARDINGINTERFACEIMPL_H_
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index f51a8b28b7548ed..b24164cfb552b4f 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -245,6 +245,9 @@ computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
                                     unsigned dropBack = 0);
 
+/// Helper to return a vector of sub-vector of int64_t
+SmallVector<SmallVector<int32_t>> getArrayOfI32Array(ArrayAttr arrayAttr);
+
 /// Compute linear index from provided strides and indices, assuming strided
 /// layout.
 /// Returns AffineExpr and list of values to apply to it, e.g.:
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index 3430db2b99c3f2e..18e2313ef2b446b 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -101,6 +101,18 @@ class AffineMap {
   static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
                                      MLIRContext *context);
 
+  /// Returns an affine map with `numDims` input dimensions and results
+  /// specified by `targets`.
+  ///
+  /// Examples:
+  /// * getMultiDimMapWithTargets(3, [0, 2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+  /// * getMultiDimMapWithTargets(3, [2, 1])
+  ///       -> affine_map<(d0, d1, d2) -> (d2, d1)>
+  static AffineMap getMultiDimMapWithTargets(unsigned numDims,
+                                      ArrayRef<int64_t> targets,
+                                      MLIRContext *context);
+
   /// Returns a vector of AffineMaps; each with as many results as
   /// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
   /// symbols as the largest symbol in `exprs`.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5e54d4ea49e8251..3988835622b7629 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -168,6 +168,7 @@ class Builder {
   ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
   ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
   ArrayAttr getTypeArrayAttr(TypeRange values);
+  ArrayAttr getArrayOfI32ArrayAttr(ArrayRef<SmallVector<int32_t>> values);
 
   // Affine expressions and affine maps.
   AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 00f400aab5d50a0..3556f82023828b2 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -79,6 +79,7 @@
 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -170,6 +171,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
   tensor::registerTilingInterfaceExternalModels(registry);
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+  tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..27711417ed91a8c 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -29,6 +29,7 @@
 #include "mlir/Dialect/MLProgram/Transforms/Passes.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Mesh/Transforms/Passes.h"
 #include "mlir/Dialect/NVGPU/Transforms/Passes.h"
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -73,6 +74,7 @@ inline void registerAllPasses() {
   LLVM::registerLLVMPasses();
   math::registerMathPasses();
   memref::registerMemRefPasses();
+  mesh::registerMeshPasses();
   ml_program::registerMLProgramPasses();
   registerSCFPasses();
   registerShapePasses();
diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Mesh/CMakeLists.txt
index f33061b2d87cffc..fa8842fb04fd721 100644
--- a/mlir/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/CMakeLists.txt
@@ -1 +1,3 @@
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index fc91fd994f12dc2..0521147ba2fdff9 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -41,6 +41,37 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
   return arith::ConstantOp::materialize(builder, value, type, loc);
 }
 
+//===----------------------------------------------------------------------===//
+// Mesh utilities
+//===----------------------------------------------------------------------===//
+
+bool mesh::isReductionLoop(IteratorType iType) {
+  return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
+}
+
+bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
+  return (partial == Partial::Generic &&
+          iType == IteratorType::ReductionGeneric) ||
+         (partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
+         (partial == Partial::Max && iType == IteratorType::ReductionMax) ||
+         (partial == Partial::Min && iType == IteratorType::ReductionMin);
+}
+
+Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
+  switch (iType) {
+  case IteratorType::ReductionGeneric:
+    return Partial::Generic;
+  case IteratorType::ReductionSum:
+    return Partial::Sum;
+  case IteratorType::ReductionMax:
+    return Partial::Max;
+  case Iterator...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69665


More information about the Mlir-commits mailing list