[Mlir-commits] [mlir] [mlir][mesh] Add endomorphism simplification for all-reduce (PR #73150)

Boian Petkantchin llvmlistbot at llvm.org
Fri Dec 1 16:20:35 PST 2023


https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/73150

>From 8df8ae3094a61bc07dddb786c2ebb856690a0560 Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 20 Nov 2023 17:42:37 -0800
Subject: [PATCH 1/2] [mlir][mesh] Add endomorphism simplification for
 all-reduce

Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)

max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.
---
 .../mlir/Dialect/Mesh/Transforms/Passes.h     |   2 +
 .../Dialect/Mesh/Transforms/Simplifications.h |  89 +++++++++
 .../Transforms/EndomorphismSimplification.h   |  93 +++++++++
 .../Transforms/HomomorphismSimplification.h   | 189 ++++++++++++++++++
 .../Dialect/Mesh/Transforms/CMakeLists.txt    |   2 +
 .../Mesh/Transforms/Simplifications.cpp       |  37 ++++
 mlir/test/Dialect/Mesh/simplifications.mlir   | 131 ++++++++++++
 mlir/test/lib/Dialect/CMakeLists.txt          |   1 +
 mlir/test/lib/Dialect/Mesh/CMakeLists.txt     |  11 +
 .../lib/Dialect/Mesh/TestSimplifications.cpp  |  43 ++++
 mlir/tools/mlir-opt/CMakeLists.txt            |   1 +
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 12 files changed, 601 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
 create mode 100644 mlir/include/mlir/Transforms/EndomorphismSimplification.h
 create mode 100644 mlir/include/mlir/Transforms/HomomorphismSimplification.h
 create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
 create mode 100644 mlir/test/Dialect/Mesh/simplifications.mlir
 create mode 100644 mlir/test/lib/Dialect/Mesh/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae48..9b788d3f304c2c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -17,6 +17,8 @@ namespace func {
 class FuncOp;
 }
 
+class RewritePatternSet;
+
 namespace mesh {
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 000000000000000..1af0f52114f10e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+    RewritePatternSet &patterns, Partial reduction) {
+  auto getEndomorphismOpOperand = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return &allReduceOp.getInputMutable();
+  };
+  auto getEndomorphismOpResult = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return allReduceOp->getResult(0);
+  };
+  auto getAlgebraicOpOperands = [](Operation *op,
+                                   SmallVector<OpOperand *> &operands) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    std::transform(algebraicOp->getOpOperands().begin(),
+                   algebraicOp->getOpOperands().end(),
+                   std::back_inserter(operands),
+                   [](OpOperand &operand) { return &operand; });
+  };
+  auto getAlgebraicOpResult = [](Operation *op) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+    return algebraicOp->getResult(0);
+  };
+  auto isEndomorphismOp = [reduction](Operation *op,
+                                      std::optional<Operation *> referenceOp) {
+    auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+    if (!allReduceOp ||
+        allReduceOp.getInput().getType().getElementType() !=
+            allReduceOp.getResult().getType().getElementType() ||
+        allReduceOp.getReduction() != reduction) {
+      return false;
+    }
+
+    if (!referenceOp) {
+      return true;
+    }
+
+    auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+    return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+           allReduceOp.getInput().getType().getElementType() ==
+               refAllReduceOp.getInput().getType().getElementType();
+  };
+  auto isAlgebraicOp = [](Operation *op) {
+    return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+  };
+
+  using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+      std::decay_t<decltype(getEndomorphismOpOperand)>,
+      std::decay_t<decltype(getEndomorphismOpResult)>,
+      std::decay_t<decltype(getAlgebraicOpOperands)>,
+      std::decay_t<decltype(getAlgebraicOpResult)>,
+      std::decay_t<decltype(isEndomorphismOp)>,
+      std::decay_t<decltype(isAlgebraicOp)>>;
+  patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+      std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+      std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+      std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+      AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 000000000000000..1aa5da2346f776b
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,93 @@
+//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/Transforms/HomomorphismSimplification.h"
+
+namespace mlir {
+
+namespace detail {
+struct CreateAlgebraicOpForEndomorphismSimplification {
+  Operation *operator()(Operation *op, IRMapping &operandsRemapping,
+                        PatternRewriter &rewriter) const {
+    return rewriter.clone(*op, operandsRemapping);
+  }
+};
+} // namespace detail
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+          typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+          typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+          typename IsAlgebraicOpFn>
+struct EndomorphismSimplification
+    : HomomorphismSimplification<
+          GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
+          GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
+          GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
+          detail::CreateAlgebraicOpForEndomorphismSimplification> {
+  template <typename GetEndomorphismOpOperandFnArg,
+            typename GetEndomorphismOpResultFnArg,
+            typename GetAlgebraicOpOperandsFnArg,
+            typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+            typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+  EndomorphismSimplification(
+      GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+      GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+      GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+      GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+      IsEndomorphismOpFnArg &&isEndomorphismOp,
+      IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+      : HomomorphismSimplification<
+            GetEndomorphismOpOperandFn, GetEndomorphismOpResultFn,
+            GetAlgebraicOpOperandsFn, GetAlgebraicOpResultFn,
+            GetAlgebraicOpResultFn, IsEndomorphismOpFn, IsAlgebraicOpFn,
+            detail::CreateAlgebraicOpForEndomorphismSimplification>(
+            std::forward<GetEndomorphismOpOperandFnArg>(
+                getEndomorphismOpOperand),
+            std::forward<GetEndomorphismOpResultFnArg>(getEndomorphismOpResult),
+            std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands),
+            std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
+            std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult),
+            std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp),
+            std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp),
+            detail::CreateAlgebraicOpForEndomorphismSimplification(),
+            std::forward<RewritePatternArgs>(args)...) {}
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
diff --git a/mlir/include/mlir/Transforms/HomomorphismSimplification.h b/mlir/include/mlir/Transforms/HomomorphismSimplification.h
new file mode 100644
index 000000000000000..a340f0424ad8331
--- /dev/null
+++ b/mlir/include/mlir/Transforms/HomomorphismSimplification.h
@@ -0,0 +1,189 @@
+//===- HomomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+namespace mlir {
+
+// If `h` is an homomorphism with respect to the source algebraic structure
+// induced by function `s` and the target algebraic structure induced by
+// function `t`, transforms `s(h(x1), h(x2) ..., h(xn))` into
+// `h(t(x1, x2, ..., xn))`.
+//
+// Functors:
+// ---------
+// `GetHomomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the homomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetHomomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the homomorphism.
+//
+// `GetSourceAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) ->
+// void` Populates into the vector the operands relevant to the homomorphism.
+//
+// `GetSourceAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result of the source algebraic operation relevant to the
+//  homomorphism.
+//
+// `GetTargetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+//  Return the result of the target algebraic operation relevant to the
+//  homomorphism.
+//
+// `IsHomomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an homomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible homomorphisms.
+//
+// `IsSourceAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+//
+// `CreateTargetAlgebraicOpFn`: `(Operation*, IRMapping& operandsRemapping,
+// PatternRewriter &rewriter) -> Operation*`
+template <typename GetHomomorphismOpOperandFn,
+          typename GetHomomorphismOpResultFn,
+          typename GetSourceAlgebraicOpOperandsFn,
+          typename GetSourceAlgebraicOpResultFn,
+          typename GetTargetAlgebraicOpResultFn, typename IsHomomorphismOpFn,
+          typename IsSourceAlgebraicOpFn, typename CreateTargetAlgebraicOpFn>
+struct HomomorphismSimplification : public RewritePattern {
+  template <typename GetHomomorphismOpOperandFnArg,
+            typename GetHomomorphismOpResultFnArg,
+            typename GetSourceAlgebraicOpOperandsFnArg,
+            typename GetSourceAlgebraicOpResultFnArg,
+            typename GetTargetAlgebraicOpResultFnArg,
+            typename IsHomomorphismOpFnArg, typename IsSourceAlgebraicOpFnArg,
+            typename CreateTargetAlgebraicOpFnArg,
+            typename... RewritePatternArgs>
+  HomomorphismSimplification(
+      GetHomomorphismOpOperandFnArg &&getHomomorphismOpOperand,
+      GetHomomorphismOpResultFnArg &&getHomomorphismOpResult,
+      GetSourceAlgebraicOpOperandsFnArg &&getSourceAlgebraicOpOperands,
+      GetSourceAlgebraicOpResultFnArg &&getSourceAlgebraicOpResult,
+      GetTargetAlgebraicOpResultFnArg &&getTargetAlgebraicOpResult,
+      IsHomomorphismOpFnArg &&isHomomorphismOp,
+      IsSourceAlgebraicOpFnArg &&isSourceAlgebraicOp,
+      CreateTargetAlgebraicOpFnArg &&createTargetAlgebraicOpFn,
+      RewritePatternArgs &&...args)
+      : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+        getHomomorphismOpOperand(std::forward<GetHomomorphismOpOperandFnArg>(
+            getHomomorphismOpOperand)),
+        getHomomorphismOpResult(std::forward<GetHomomorphismOpResultFnArg>(
+            getHomomorphismOpResult)),
+        getSourceAlgebraicOpOperands(
+            std::forward<GetSourceAlgebraicOpOperandsFnArg>(
+                getSourceAlgebraicOpOperands)),
+        getSourceAlgebraicOpResult(
+            std::forward<GetSourceAlgebraicOpResultFnArg>(
+                getSourceAlgebraicOpResult)),
+        getTargetAlgebraicOpResult(
+            std::forward<GetTargetAlgebraicOpResultFnArg>(
+                getTargetAlgebraicOpResult)),
+        isHomomorphismOp(std::forward<IsHomomorphismOpFnArg>(isHomomorphismOp)),
+        isSourceAlgebraicOp(
+            std::forward<IsSourceAlgebraicOpFnArg>(isSourceAlgebraicOp)),
+        createTargetAlgebraicOpFn(std::forward<CreateTargetAlgebraicOpFnArg>(
+            createTargetAlgebraicOpFn)) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(matchOp(op, algebraicOpOperands))) {
+      return failure();
+    }
+    return rewriteOp(op, algebraicOpOperands, rewriter);
+  }
+
+private:
+  LogicalResult
+  matchOp(Operation *sourceAlgebraicOp,
+          SmallVector<OpOperand *> &sourceAlgebraicOpOperands) const {
+    if (!isSourceAlgebraicOp(sourceAlgebraicOp)) {
+      return failure();
+    }
+    sourceAlgebraicOpOperands.clear();
+    getSourceAlgebraicOpOperands(sourceAlgebraicOp, sourceAlgebraicOpOperands);
+    if (sourceAlgebraicOpOperands.empty()) {
+      return failure();
+    }
+
+    Operation *firstHomomorphismOp =
+        sourceAlgebraicOpOperands.front()->get().getDefiningOp();
+    if (!firstHomomorphismOp ||
+        !isHomomorphismOp(firstHomomorphismOp, std::nullopt)) {
+      return failure();
+    }
+    OpResult firstHomomorphismOpResult =
+        getHomomorphismOpResult(firstHomomorphismOp);
+    if (firstHomomorphismOpResult != sourceAlgebraicOpOperands.front()->get()) {
+      return failure();
+    }
+
+    for (auto operand : sourceAlgebraicOpOperands) {
+      Operation *homomorphismOp = operand->get().getDefiningOp();
+      if (!homomorphismOp ||
+          !isHomomorphismOp(homomorphismOp, firstHomomorphismOp)) {
+        return failure();
+      }
+    }
+    return success();
+  }
+
+  LogicalResult
+  rewriteOp(Operation *sourceAlgebraicOp,
+            const SmallVector<OpOperand *> &sourceAlgebraicOpOperands,
+            PatternRewriter &rewriter) const {
+    irMapping.clear();
+    for (auto operand : sourceAlgebraicOpOperands) {
+      Operation *homomorphismOp = operand->get().getDefiningOp();
+      irMapping.map(operand->get(),
+                    getHomomorphismOpOperand(homomorphismOp)->get());
+    }
+    Operation *targetAlgebraicOp =
+        createTargetAlgebraicOpFn(sourceAlgebraicOp, irMapping, rewriter);
+
+    irMapping.clear();
+    assert(!sourceAlgebraicOpOperands.empty());
+    Operation *firstHomomorphismOp =
+        sourceAlgebraicOpOperands[0]->get().getDefiningOp();
+    irMapping.map(getHomomorphismOpOperand(firstHomomorphismOp)->get(),
+                  getTargetAlgebraicOpResult(targetAlgebraicOp));
+    Operation *newHomomorphismOp =
+        rewriter.clone(*firstHomomorphismOp, irMapping);
+    rewriter.replaceAllUsesWith(getSourceAlgebraicOpResult(sourceAlgebraicOp),
+                                getHomomorphismOpResult(newHomomorphismOp));
+    return success();
+  }
+
+  GetHomomorphismOpOperandFn getHomomorphismOpOperand;
+  GetHomomorphismOpResultFn getHomomorphismOpResult;
+  GetSourceAlgebraicOpOperandsFn getSourceAlgebraicOpOperands;
+  GetSourceAlgebraicOpResultFn getSourceAlgebraicOpResult;
+  GetTargetAlgebraicOpResultFn getTargetAlgebraicOpResult;
+  IsHomomorphismOpFn isHomomorphismOp;
+  IsSourceAlgebraicOpFn isSourceAlgebraicOp;
+  CreateTargetAlgebraicOpFn createTargetAlgebraicOpFn;
+  mutable SmallVector<OpOperand *> algebraicOpOperands;
+  mutable IRMapping irMapping;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_HOMOMORPHISM_H_
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4ea276080..044b8672c8c60cf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMeshTransforms
+  Simplifications.cpp
   ShardingPropagation.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
   MLIRShardingInterface
 
   LINK_LIBS PUBLIC
+  MLIRArithDialect
   MLIRFuncDialect
   MLIRIR
   MLIRMeshDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 000000000000000..1d241fe03a127b8
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,37 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+      patterns, Partial::Sum);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+      patterns, Partial::Sum);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+      patterns, Partial::Min);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+      patterns, Partial::Min);
+
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+      patterns, Partial::Max);
+  populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+      patterns, Partial::Max);
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 000000000000000..2b305df6e0a97f1
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [4, 2])
+mesh.cluster @mesh1(rank = 1, dim_sizes = [4])
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+  %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf32>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf64> {
+  // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+    : tensor<5xf32> -> tensor<5xf64>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+  %2 = arith.addf %0, %1 : tensor<5xf64>
+  // CHECK: return %[[ADD_RES]]
+  return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg0: tensor<5xf32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+    %arg1: tensor<5xf32>) -> tensor<5xf32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xf32> -> tensor<5xf32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+  %2 = arith.minimumf %0, %1 : tensor<5xf32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+    // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg0: tensor<5xi32>,
+    // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+    %arg1: tensor<5xi32>) -> tensor<5xi32> {
+  %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+    : tensor<5xi32> -> tensor<5xi32>
+  // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+  %2 = arith.minsi %0, %1 : tensor<5xi32>
+  // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+  // CHECK: return %[[ALL_REDUCE_RES]]
+  return %2 : tensor<5xi32>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 48bde69e0170041..30a17c201ff7635 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVM)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(Mesh)
 add_subdirectory(NVGPU)
 add_subdirectory(SCF)
 add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..102bd59da61b992
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1,11 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRMeshTestSimplifications
+  TestSimplifications.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRMeshDialect
+  MLIRMeshTransforms
+  MLIRPass
+  )
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
new file mode 100644
index 000000000000000..93b1da52d46b4ef
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -0,0 +1,43 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMeshSimplificationsPass
+    : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+
+  void runOnOperation() override;
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+  }
+  StringRef getArgument() const final { return "test-mesh-simplifications"; }
+  StringRef getDescription() const final { return "Test mesh simplifications"; }
+};
+} // namespace
+
+void TestMeshSimplificationsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  mesh::populateSimplificationPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshSimplificationsPass() {
+  PassRegistration<TestMeshSimplificationsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 88a0562cb6e7207..bc8eed18215525e 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,6 +26,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRLoopLikeInterfaceTestPasses
     MLIRMathTestPasses
     MLIRMemRefTestPasses
+    MLIRMeshTestSimplifications
     MLIRNVGPUTestPasses
     MLIRSCFTestPasses
     MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b48505601..c7cf1e55a556f4a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
 void registerTestMathPolynomialApproximationPass();
 void registerTestMemRefDependenceCheck();
 void registerTestMemRefStrideCalculation();
+void registerTestMeshSimplificationsPass();
 void registerTestNextAccessPass();
 void registerTestOneToNTypeConversionPass();
 void registerTestOpaqueLoc();
@@ -238,6 +239,7 @@ void registerTestPasses() {
   mlir::test::registerTestMathPolynomialApproximationPass();
   mlir::test::registerTestMemRefDependenceCheck();
   mlir::test::registerTestMemRefStrideCalculation();
+  mlir::test::registerTestMeshSimplificationsPass();
   mlir::test::registerTestNextAccessPass();
   mlir::test::registerTestOneToNTypeConversionPass();
   mlir::test::registerTestOpaqueLoc();

>From 2fcb5914e3ed1852ad5c928b3b1e43a0269289da Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian.petkantchin at amd.com>
Date: Fri, 1 Dec 2023 16:19:58 -0800
Subject: [PATCH 2/2] Address some PR comments

---
 mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h   | 2 --
 mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp | 2 ++
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 9b788d3f304c2c8..83399d10beaae48 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -17,8 +17,6 @@ namespace func {
 class FuncOp;
 }
 
-class RewritePatternSet;
-
 namespace mesh {
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
index 1d241fe03a127b8..643bd7b8e77c938 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -31,6 +31,8 @@ void populateSimplificationPatterns(RewritePatternSet &patterns) {
       patterns, Partial::Max);
   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
       patterns, Partial::Max);
+
+  // TODO: add simplifications for all-gather and other collectives.
 }
 
 } // namespace mesh



More information about the Mlir-commits mailing list