[Mlir-commits] [mlir] [mlir][mesh] Add endomorphism simplification for all-reduce (PR #73150)

Boian Petkantchin llvmlistbot at llvm.org
Wed Nov 22 14:58:01 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/73150

>From 363f6e55f719e5ccc5495ed96b321fe356821ebc Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 20 Nov 2023 17:42:37 -0800
Subject: [PATCH] [mlir][mesh] Add endomorphism simplification for all-reduce

Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.
---
 .../mlir/Dialect/Mesh/Transforms/Passes.h     |   2 +
 .../Dialect/Mesh/Transforms/Simplifications.h |  89 ++++++++++
 .../Transforms/EndomorphismSimplification.h   | 161 ++++++++++++++++++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   2 +
 .../Mesh/Transforms/Simplifications.cpp       |  37 ++++
 mlir/test/Dialect/Mesh/simplifications.mlir   | 131 ++++++++++++++
 mlir/test/lib/Dialect/CMakeLists.txt          |   1 +
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |  11 ++
 .../lib/Dialect/Mesh/TestSimplifications.cpp  |  43 +++++
 mlir/tools/mlir-opt/CMakeLists.txt            |   1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 11 files changed, 480 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
 create mode 100644 mlir/include/mlir/Transforms/EndomorphismSimplification.h
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
 create mode 100644 mlir/test/Dialect/Mesh/simplifications.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae48..9b788d3f304c2c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -17,6 +17,8 @@ namespace func {
 class FuncOp;
 }
 
+class RewritePatternSet;
+
 namespace mesh {
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 000000000000000..1af0f52114f10e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+    RewritePatternSet &patterns, Partial reduction) {
+  auto getEndomorphismOpOperand = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return &allReduceOp.getInputMutable();
+  };
+  auto getEndomorphismOpResult = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return allReduceOp->getResult(0);
+  };
+  auto getAlgebraicOpOperands = [](Operation *op,
+                                   SmallVector<OpOperand *> &operands) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    std::transform(algebraicOp->getOpOperands().begin(),
+                   algebraicOp->getOpOperands().end(),
+                   std::back_inserter(operands),
+                   [](OpOperand &operand) { return &operand; });
+  };
+  auto getAlgebraicOpResult = [](Operation *op) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    return algebraicOp->getResult(0);
+  };
+  auto isEndomorphismOp = [reduction](Operation *op,
+                                      std::optional<Operation *> referenceOp) {
+    auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+    if (!allReduceOp ||
+        allReduceOp.getInput().getType().getElementType() !=
+            allReduceOp.getResult().getType().getElementType() ||
+        allReduceOp.getReduction() != reduction) {
+      return false;
+    }
+
+    if (!referenceOp) {
+      return true;
+    }
+
+    auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+           allReduceOp.getInput().getType().getElementType() ==
+               refAllReduceOp.getInput().getType().getElementType();
+  };
+  auto isAlgebraicOp = [](Operation *op) {
+    return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+  };
+
+  using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+      std::decay_t<decltype(getEndomorphismOpOperand)>,
+      std::decay_t<decltype(getEndomorphismOpResult)>,
+      std::decay_t<decltype(getAlgebraicOpOperands)>,
+      std::decay_t<decltype(getAlgebraicOpResult)>,
+      std::decay_t<decltype(isEndomorphismOp)>,
+      std::decay_t<decltype(isAlgebraicOp)>>;
+  patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+      std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+      std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+      std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+      AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 000000000000000..36484667a20d875
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,161 @@
+//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+          typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+          typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+          typename IsAlgebraicOpFn>
+struct EndomorphismSimplification : RewritePattern {
+  template <typename GetEndomorphismOpOperandFnArg,
+            typename GetEndomorphismOpResultFnArg,
+            typename GetAlgebraicOpOperandsFnArg,
+            typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+            typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+  EndomorphismSimplification(
+      GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+      GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+      GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+      GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+      IsEndomorphismOpFnArg &&isEndomorphismOp,
+      IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+      : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+        getEndomorphismOpOperand(std::forward<GetEndomorphismOpOperandFnArg>(
+            getEndomorphismOpOperand)),
+        getEndomorphismOpResult(std::forward<GetEndomorphismOpResultFnArg>(
+            getEndomorphismOpResult)),
+        getAlgebraicOpOperands(
+            std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands)),
+        getAlgebraicOpResult(
+            std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult)),
+        isEndomorphismOp(std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp)),
+        isAlgebraicOp(std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp)) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(matchOp(op, algebraicOpOperands))) {
+      return failure();
+    }
+    return rewriteOp(op, algebraicOpOperands, rewriter);
+  }
+
+private:
+  LogicalResult matchOp(Operation *algebraicOp,
+                        SmallVector<OpOperand *> &algebraicOpOperands) const {
+    if (!isAlgebraicOp(algebraicOp)) {
+      return failure();
+    }
+    algebraicOpOperands.clear();
+    getAlgebraicOpOperands(algebraicOp, algebraicOpOperands);
+    if (algebraicOpOperands.empty()) {
+      return failure();
+    }
+
+    Operation *firstEndomorphismOp =
+        algebraicOpOperands.front()->get().getDefiningOp();
+    if (!firstEndomorphismOp ||
+        !isEndomorphismOp(firstEndomorphismOp, std::nullopt)) {
+      return failure();
+    }
+    OpResult firstEndomorphismOpResult =
+        getEndomorphismOpResult(firstEndomorphismOp);
+    if (firstEndomorphismOpResult != algebraicOpOperands.front()->get()) {
+      return failure();
+    }
+
+    for (auto operand : algebraicOpOperands) {
+      Operation *endomorphismOp = operand->get().getDefiningOp();
+      if (!endomorphismOp ||
+          !isEndomorphismOp(endomorphismOp, firstEndomorphismOp)) {
+        return failure();
+      }
+    }
+    return success();
+  }
+
+  LogicalResult rewriteOp(Operation *algebraicOp,
+                          const SmallVector<OpOperand *> &algebraicOpOperands,
+                          PatternRewriter &rewriter) const {
+    irMapping.clear();
+    for (auto operand : algebraicOpOperands) {
+      Operation *endomorphismOp = operand->get().getDefiningOp();
+      irMapping.map(operand->get(),
+                    getEndomorphismOpOperand(endomorphismOp)->get());
+    }
+    Operation *newAlgebraicOp = rewriter.clone(*algebraicOp, irMapping);
+
+    irMapping.clear();
+    assert(!algebraicOpOperands.empty());
+    Operation *firstEndomorphismOp =
+        algebraicOpOperands[0]->get().getDefiningOp();
+    irMapping.map(getEndomorphismOpOperand(firstEndomorphismOp)->get(),
+                  getAlgebraicOpResult(newAlgebraicOp));
+    Operation *newEndomorphismOp =
+        rewriter.clone(*firstEndomorphismOp, irMapping);
+    rewriter.replaceAllUsesWith(getAlgebraicOpResult(algebraicOp),
+                                getEndomorphismOpResult(newEndomorphismOp));
+    return success();
+  }
+
+  GetEndomorphismOpOperandFn getEndomorphismOpOperand;
+  GetEndomorphismOpResultFn getEndomorphismOpResult;
+  GetAlgebraicOpOperandsFn getAlgebraicOpOperands;
+  GetAlgebraicOpResultFn getAlgebraicOpResult;
+  IsEndomorphismOpFn isEndomorphismOp;
+  IsAlgebraicOpFn isAlgebraicOp;
+  mutable SmallVector<OpOperand *> algebraicOpOperands;
+  mutable IRMapping irMapping;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4ea276080..044b8672c8c60cf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMeshTransforms
+  Simplifications.cpp
   ShardingPropagation.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   MLIRShardingInterface
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
   MLIRFuncDialect
   MLIRIR
   MLIRMeshDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 000000000000000..1d241fe03a127b8
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,37 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+      patterns, Partial::Sum);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+      patterns, Partial::Sum);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+      patterns, Partial::Min);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+      patterns, Partial::Max);
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 000000000000000..2b305df6e0a97f1
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [4, 2])
+mesh.cluster @mesh1(rank = 1, dim_sizes = [4])
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf64> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf64>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+  %2 = arith.minimumf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg0: tensor<5xi32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg1: tensor<5xi32>) -> tensor<5xi32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+  %2 = arith.minsi %0, %1 : tensor<5xi32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xi32>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 48bde69e0170041..30a17c201ff7635 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVM)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(Mesh)
 add_subdirectory(NVGPU)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..102bd59da61b992
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1,11 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRMeshTestSimplifications
+  TestSimplifications.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRMeshDialect
+  MLIRMeshTransforms
+  MLIRPass
+  )
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
new file mode 100644
index 000000000000000..93b1da52d46b4ef
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -0,0 +1,43 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMeshSimplificationsPass
+    : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+  }
+  StringRef getArgument() const final { return "test-mesh-simplifications"; }
+  StringRef getDescription() const final { return "Test mesh simplifications"; }
+};
+} // namespace
+
+void TestMeshSimplificationsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  mesh::populateSimplificationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshSimplificationsPass() {
+  PassRegistration<TestMeshSimplificationsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 88a0562cb6e7207..bc8eed18215525e 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,6 +26,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLoopLikeInterfaceTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
+    MLIRMeshTestSimplifications
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b48505601..c7cf1e55a556f4a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestMeshSimplificationsPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -238,6 +239,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();



More information about the Mlir-commits mailing list