[Mlir-commits] [mlir] 4b34467 - [mlir][mesh] Add endomorphism simplification for all-reduce (#73150)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 12 10:21:57 PST 2023


Author: Boian Petkantchin
Date: 2023-12-12T10:21:52-08:00
New Revision: 4b3446771f745bb5169354ad9027c0a1c9fca394

URL: https://github.com/llvm/llvm-project/commit/4b3446771f745bb5169354ad9027c0a1c9fca394
DIFF: https://github.com/llvm/llvm-project/commit/4b3446771f745bb5169354ad9027c0a1c9fca394.diff

LOG: [mlir][mesh] Add endomorphism simplification for all-reduce (#73150)

Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.

Added general rewrite pattern HomomorphismSimplification and
EndomorphismSimplification that encapsulate the general algorithm.
Made specialization for all-reduce with respect to
addf, addi, minsi, maxsi, minimumf and maximumf
in the Arithmetic dialect.

Added: 
    mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
    mlir/include/mlir/Transforms/EndomorphismSimplification.h
    mlir/include/mlir/Transforms/HomomorphismSimplification.h
    mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
    mlir/test/Dialect/Mesh/simplifications.mlir
    mlir/test/lib/Dialect/Mesh/CMakeLists.txt
    mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp

Modified: 
    mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 0000000000000..f70bdaa9de0a0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,110 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+// If we have an algebraic op like "+" and a summing all-reduce,
+// `all_reduce_sum(x) + all_reduce_sum(y)` will be transformed to
+// `all_reduce_sum(x + y)`.
+//
+// Another example with `min`.
+// `min(all_reduce_min(x), all_reduce_min(y))` will be transformed to
+// `all_reduce_min(min(x, y))`.
+//
+// Works only with algebraic ops that have all their operands relevant
+// to the all-reduce endomorphism.
+// Will not work with some op `f(x, y, z)` where only `x` and `y` form
+// the algebraic structure.
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+    RewritePatternSet &patterns, Partial reduction) {
+  auto getEndomorphismOpOperand = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return &allReduceOp.getInputMutable();
+  };
+  auto getEndomorphismOpResult = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return allReduceOp->getResult(0);
+  };
+  auto getAlgebraicOpOperands = [](Operation *op,
+                                   SmallVector<OpOperand *> &operands) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    std::transform(algebraicOp->getOpOperands().begin(),
+                   algebraicOp->getOpOperands().end(),
+                   std::back_inserter(operands),
+                   [](OpOperand &operand) { return &operand; });
+  };
+  auto getAlgebraicOpResult = [](Operation *op) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    return algebraicOp->getResult(0);
+  };
+  auto isEndomorphismOp = [reduction](Operation *op,
+                                      std::optional<Operation *> referenceOp) {
+    auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+    if (!allReduceOp ||
+        allReduceOp.getInput().getType().getElementType() !=
+            allReduceOp.getResult().getType().getElementType() ||
+        allReduceOp.getReduction() != reduction) {
+      return false;
+    }
+
+    // Dont't use simplify if the all-reduce is used other than by the
+    // algebraic op.
+    // TODO: maybe handle this by an additional pass that later reverses the
+    // simplification if there are other uses left other optimizations have
+    // been done.
+    if (!allReduceOp->hasOneUse()) {
+      return false;
+    }
+
+    if (!referenceOp) {
+      return true;
+    }
+
+    auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+           allReduceOp.getInput().getType().getElementType() ==
+               refAllReduceOp.getInput().getType().getElementType();
+  };
+  auto isAlgebraicOp = [](Operation *op) {
+    return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+  };
+
+  using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+      std::decay_t<decltype(getEndomorphismOpOperand)>,
+      std::decay_t<decltype(getEndomorphismOpResult)>,
+      std::decay_t<decltype(getAlgebraicOpOperands)>,
+      std::decay_t<decltype(getAlgebraicOpResult)>,
+      std::decay_t<decltype(isEndomorphismOp)>,
+      std::decay_t<decltype(isAlgebraicOp)>>;
+  patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+      std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+      std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+      std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+      AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H

diff  --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 0000000000000..1aa5da2346f77
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,93 @@
+//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/Transforms/HomomorphismSimplification.h"
+
+namespace mlir {
+
+namespace detail {
+struct CreateAlgebraicOpForEndomorphismSimplification {
+  Operation *operator()(Operation *op, IRMapping &operandsRemapping,
+                        PatternRewriter &rewriter) const {
+    return rewriter.clone(*op, operandsRemapping);
+  }
+};
+} // namespace detail
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+          typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+          typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+          typename IsAlgebraicOpFn>
+struct EndomorphismSimplification
+    : HomomorphismSimplification<
+          GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
+          GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
+          GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
+          detail::CreateAlgebraicOpForEndomorphismSimplification> {
+  template <typename GetEndomorphismOpOperandFnArg,
+            typename GetEndomorphismOpResultFnArg,
+            typename GetAlgebraicOpOperandsFnArg,
+            typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+            typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+  EndomorphismSimplification(
+      GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+      GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+      GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+      GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+      IsEndomorphismOpFnArg &&isEndomorphismOp,
+      IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+      : HomomorphismSimplification<
+            GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
+            GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
+            GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
+            detail::CreateAlgebraicOpForEndomorphismSimplification>(
+            std::forward<GetEndomorphismOpOperandFnArg>(
+                getEndomorphismOpOperand),
+            std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult),
+            std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands),
+            std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
+            std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
+            std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp),
+            std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp),
+            detail::CreateAlgebraicOpForEndomorphismSimplification(),
+            std::forward<RewritePatternArgs>(args)...) {}
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_

diff  --git a/mlir/include/mlir/Transforms/HomomorphismSimplification.h b/mlir/include/mlir/Transforms/HomomorphismSimplification.h
new file mode 100644
index 0000000000000..d2732602a0d8e
--- /dev/null
+++ b/mlir/include/mlir/Transforms/HomomorphismSimplification.h
@@ -0,0 +1,188 @@
+//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+namespace mlir {
+
+// If `h` is an homomorphism with respect to the source algebraic structure
+// induced by function `s` and the target algebraic structure induced by
+// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
+// `h(t(x1, x2, ..., xn))`.
+//
+// Functors:
+// ---------
+// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the homomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the homomorphism.
+//
+// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
+// void` Populates into the vector the operands relevant to the homomorphism.
+//
+// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result of the source algebraic operation relevant to the
+//  homomorphism.
+//
+// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result of the target algebraic operation relevant to the
+//  homomorphism.
+//
+// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an homomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible homomorphisms.
+//
+// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+//
+// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
+// PatternRewriter &rewriter) -> Operation*`
+template <typename GetHomomorphismOpOperandFn,
+          typename GetHomomorphismOpResultFn,
+          typename GetSourceAlgebraicOpOperandsFn,
+          typename GetSourceAlgebraicOpResultFn,
+          typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
+          typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
+struct HomomorphismSimplification : public RewritePattern {
+  template <typename GetHomomorphismOpOperandFnArg,
+            typename GetHomomorphismOpResultFnArg,
+            typename GetSourceAlgebraicOpOperandsFnArg,
+            typename GetSourceAlgebraicOpResultFnArg,
+            typename GetTargetAlgebraicOpResultFnArg,
+            typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
+            typename CreateTargetAlgebraicOpFnArg,
+            typename... RewritePatternArgs>
+  HomomorphismSimplification(
+      GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
+      GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
+      GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
+      GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
+      GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
+      IsHomomorphismOpFnArg &&isHomomorphismOp,
+      IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
+      CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
+      RewritePatternArgs &&...args)
+      : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+        getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
+            getHomomorphismOpOperand)),
+        getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
+            getHomomorphismOpResult)),
+        getSourceAlgebraicOpOperands(
+            std::forward<GetSourceAlgebraicOpOperandsFnArg>(
+                getSourceAlgebraicOpOperands)),
+        getSourceAlgebraicOpResult(
+            std::forward<GetSourceAlgebraicOpResultFnArg>(
+                getSourceAlgebraicOpResult)),
+        getTargetAlgebraicOpResult(
+            std::forward<GetTargetAlgebraicOpResultFnArg>(
+                getTargetAlgebraicOpResult)),
+        isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
+        isSourceAlgebraicOp(
+            std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
+        createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
+            createTargetAlgebraicOpFn)) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<OpOperand *> algebraicOpOperands;
+    if (failed(matchOp(op, algebraicOpOperands))) {
+      return failure();
+    }
+    return rewriteOp(op, algebraicOpOperands, rewriter);
+  }
+
+private:
+  LogicalResult
+  matchOp(Operation *sourceAlgebraicOp,
+          SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
+    if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
+      return failure();
+    }
+    sourceAlgebraicOpOperands.clear();
+    getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
+    if (sourceAlgebraicOpOperands.empty()) {
+      return failure();
+    }
+
+    Operation *firstHomomorphismOp =
+        sourceAlgebraicOpOperands.front()->get().getDefiningOp();
+    if (!firstHomomorphismOp ||
+        !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
+      return failure();
+    }
+    OpResult firstHomomorphismOpResult =
+        getHomomorphismOpResult(firstHomomorphismOp);
+    if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
+      return failure();
+    }
+
+    for (auto operand : sourceAlgebraicOpOperands) {
+      Operation *homomorphismOp = operand->get().getDefiningOp();
+      if (!homomorphismOp ||
+          !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
+        return failure();
+      }
+    }
+    return success();
+  }
+
+  LogicalResult
+  rewriteOp(Operation *sourceAlgebraicOp,
+            const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
+            PatternRewriter &rewriter) const {
+    IRMapping irMapping;
+    for (auto operand : sourceAlgebraicOpOperands) {
+      Operation *homomorphismOp = operand->get().getDefiningOp();
+      irMapping.map(operand->get(),
+                    getHomomorphismOpOperand(homomorphismOp)->get());
+    }
+    Operation *targetAlgebraicOp =
+        createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
+
+    irMapping.clear();
+    assert(!sourceAlgebraicOpOperands.empty());
+    Operation *firstHomomorphismOp =
+        sourceAlgebraicOpOperands[0]->get().getDefiningOp();
+    irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
+                  getTargetAlgebraicOpResult(targetAlgebraicOp));
+    Operation *newHomomorphismOp =
+        rewriter.clone(*firstHomomorphismOp, irMapping);
+    rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
+                                getHomomorphismOpResult(newHomomorphismOp));
+    return success();
+  }
+
+  GetHomomorphismOpOperandFn getHomomorphismOpOperand;
+  GetHomomorphismOpResultFn getHomomorphismOpResult;
+  GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
+  GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
+  GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
+  IsHomomorphismOpFn isHomomorphismOp;
+  IsSourceAlgebraicOpFn isSourceAlgebraicOp;
+  CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4ea2760..044b8672c8c60 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMeshTransforms
+  Simplifications.cpp
   ShardingPropagation.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   MLIRShardingInterface
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
   MLIRFuncDialect
   MLIRIR
   MLIRMeshDialect

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 0000000000000..643bd7b8e77c9
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,39 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+      patterns, Partial::Sum);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+      patterns, Partial::Sum);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+      patterns, Partial::Min);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+      patterns, Partial::Max);
+
+  // TODO: add simplifications for all-gather and other collectives.
+}
+
+} // namespace mesh
+} // namespace mlir

diff  --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 0000000000000..e716940f2301e
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 4x2)
+mesh.cluster @mesh1(rank = 1, dim_sizes = 4)
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result
+func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_result(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: return %[[ALL_REDUCE_RES]], %[[ALL_REDUCE_RES]]
+  return %2, %2 : tensor<5xf32>, tensor<5xf32>
+}
+
+// Do not simplify if there is another use of one of the all-reduces.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result
+func.func @all_reduce_arith_addf_endomorphism_multiple_uses_of_all_reduce_result(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+  // CHECK: %[[ALL_REDUCE_0_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_1_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE_0_RES]], %[[ALL_REDUCE_1_RES]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ALL_REDUCE_0_RES]], %[[ADD_RES]]
+  return %0, %2 : tensor<5xf32>, tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_
diff erent_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_
diff erent_mesh(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_
diff erent_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_
diff erent_mesh_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_
diff erent_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_
diff erent_operand_result_element_types(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf64> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf64>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+  %2 = arith.minimumf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg0: tensor<5xi32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg1: tensor<5xi32>) -> tensor<5xi32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+  %2 = arith.minsi %0, %1 : tensor<5xi32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xi32>
+}

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 48bde69e01700..30a17c201ff76 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVM)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(Mesh)
 add_subdirectory(NVGPU)
 add_subdirectory(SCF)
 add_subdirectory(Shape)

diff  --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 0000000000000..16b50bb878a07
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1,13 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRMeshTestSimplifications
+  TestSimplifications.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRMeshDialect
+  MLIRMeshTransforms
+  MLIRPass
+  MLIRRewrite
+  MLIRTransformUtils
+  )

diff  --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
new file mode 100644
index 0000000000000..93b1da52d46b4
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -0,0 +1,43 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMeshSimplificationsPass
+    : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+  }
+  StringRef getArgument() const final { return "test-mesh-simplifications"; }
+  StringRef getDescription() const final { return "Test mesh simplifications"; }
+};
+} // namespace
+
+void TestMeshSimplificationsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  mesh::populateSimplificationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshSimplificationsPass() {
+  PassRegistration<TestMeshSimplificationsPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 88a0562cb6e72..bc8eed1821552 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,6 +26,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLoopLikeInterfaceTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
+    MLIRMeshTestSimplifications
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b485056..c7cf1e55a556f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestMeshSimplificationsPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -238,6 +239,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();


        


More information about the Mlir-commits mailing list