[Mlir-commits] [mlir] [mlir][mesh] Add endomorphism simplification for all-reduce (PR #73150)
Boian Petkantchin
llvmlistbot at llvm.org
Wed Nov 22 14:58:01 PST 2023
https://github.com/sogartar updated https://github.com/llvm/llvm-project/pull/73150
>From 363f6e55f719e5ccc5495ed96b321fe356821ebc Mon Sep 17 00:00:00 2001
From: Boian Petkantchin <boian at nod-labs.com>
Date: Mon, 20 Nov 2023 17:42:37 -0800
Subject: [PATCH] [mlir][mesh] Add endomorphism simplification for all-reduce
Does transformations like
all_reduce(x) + all_reduce(y) -> all_reduce(x + y)
max(all_reduce(x), all_reduce(y)) -> all_reduce(max(x, y))
when the all_reduce element-wise op is max.
---
.../mlir/Dialect/Mesh/Transforms/Passes.h | 2 +
.../Dialect/Mesh/Transforms/Simplifications.h | 89 ++++++++++
.../Transforms/EndomorphismSimplification.h | 161 ++++++++++++++++++
.../Dialect/Mesh/Transforms/CMakeLists.txt | 2 +
.../Mesh/Transforms/Simplifications.cpp | 37 ++++
mlir/test/Dialect/Mesh/simplifications.mlir | 131 ++++++++++++++
mlir/test/lib/Dialect/CMakeLists.txt | 1 +
mlir/test/lib/Dialect/Mesh/CMakeLists.txt | 11 ++
.../lib/Dialect/Mesh/TestSimplifications.cpp | 43 +++++
mlir/tools/mlir-opt/CMakeLists.txt | 1 +
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
11 files changed, 480 insertions(+)
create mode 100644 mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
create mode 100644 mlir/include/mlir/Transforms/EndomorphismSimplification.h
create mode 100644 mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
create mode 100644 mlir/test/Dialect/Mesh/simplifications.mlir
create mode 100644 mlir/test/lib/Dialect/Mesh/CMakeLists.txt
create mode 100644 mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae48..9b788d3f304c2c8 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -17,6 +17,8 @@ namespace func {
class FuncOp;
}
+class RewritePatternSet;
+
namespace mesh {
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
new file mode 100644
index 000000000000000..1af0f52114f10e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+ RewritePatternSet &patterns, Partial reduction) {
+ auto getEndomorphismOpOperand = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return &allReduceOp.getInputMutable();
+ };
+ auto getEndomorphismOpResult = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return allReduceOp->getResult(0);
+ };
+ auto getAlgebraicOpOperands = [](Operation *op,
+ SmallVector<OpOperand *> &operands) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+ std::transform(algebraicOp->getOpOperands().begin(),
+ algebraicOp->getOpOperands().end(),
+ std::back_inserter(operands),
+ [](OpOperand &operand) { return &operand; });
+ };
+ auto getAlgebraicOpResult = [](Operation *op) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
+ return algebraicOp->getResult(0);
+ };
+ auto isEndomorphismOp = [reduction](Operation *op,
+ std::optional<Operation *> referenceOp) {
+ auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+ if (!allReduceOp ||
+ allReduceOp.getInput().getType().getElementType() !=
+ allReduceOp.getResult().getType().getElementType() ||
+ allReduceOp.getReduction() != reduction) {
+ return false;
+ }
+
+ if (!referenceOp) {
+ return true;
+ }
+
+ auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
+ return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
+ allReduceOp.getInput().getType().getElementType() ==
+ refAllReduceOp.getInput().getType().getElementType();
+ };
+ auto isAlgebraicOp = [](Operation *op) {
+ return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
+ };
+
+ using ConcreteEndomorphismSimplification = EndomorphismSimplification<
+ std::decay_t<decltype(getEndomorphismOpOperand)>,
+ std::decay_t<decltype(getEndomorphismOpResult)>,
+ std::decay_t<decltype(getAlgebraicOpOperands)>,
+ std::decay_t<decltype(getAlgebraicOpResult)>,
+ std::decay_t<decltype(isEndomorphismOp)>,
+ std::decay_t<decltype(isAlgebraicOp)>>;
+ patterns.add(std::make_unique<ConcreteEndomorphismSimplification>(
+ std::move(getEndomorphismOpOperand), std::move(getEndomorphismOpResult),
+ std::move(getAlgebraicOpOperands), std::move(getAlgebraicOpResult),
+ std::move(isEndomorphismOp), std::move(isAlgebraicOp),
+ AlgebraicOp::getOperationName(), 1, patterns.getContext()));
+}
+
+void populateSimplificationPatterns(RewritePatternSet &patterns);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
diff --git a/mlir/include/mlir/Transforms/EndomorphismSimplification.h b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
new file mode 100644
index 000000000000000..36484667a20d875
--- /dev/null
+++ b/mlir/include/mlir/Transforms/EndomorphismSimplification.h
@@ -0,0 +1,161 @@
+//===- EndomorphismSimplification.h -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+#define MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
+
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include <iterator>
+#include <optional>
+#include <type_traits>
+#include <utility>
+
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+
+// If `f` is an endomorphism with respect to the algebraic structure induced by
+// function `g`, transforms `g(f(x1), f(x2) ..., f(xn))` into
+// `f(g(x1, x2, ..., xn))`.
+// `g` is the algebraic operation and `f` is the endomorphism.
+//
+// Functors:
+// ---------
+// `GetEndomorphismOpOperandFn`: `(Operation*) -> OpOperand*`
+// Returns the operand relevant to the endomorphism.
+// There may be other operands that are not relevant.
+//
+// `GetEndomorphismOpResultFn`: `(Operation*) -> OpResult`
+// Returns the result relevant to the endomorphism.
+//
+// `GetAlgebraicOpOperandsFn`: `(Operation*, SmallVector<OpOperand*>&) -> void`
+// Populates into the vector the operands relevant to the endomorphism.
+//
+// `GetAlgebraicOpResultFn`: `(Operation*) -> OpResult`
+// Return the result relevant to the endomorphism.
+//
+// `IsEndomorphismOpFn`: `(Operation*, std::optional<Operation*>) -> bool`
+// Check if the operation is an endomorphism of the required type.
+// Additionally if the optional is present checks if the operations are
+// compatible endomorphisms.
+//
+// `IsAlgebraicOpFn`: `(Operation*) -> bool`
+// Check if the operation is an operation of the algebraic structure.
+template <typename GetEndomorphismOpOperandFn,
+ typename GetEndomorphismOpResultFn, typename GetAlgebraicOpOperandsFn,
+ typename GetAlgebraicOpResultFn, typename IsEndomorphismOpFn,
+ typename IsAlgebraicOpFn>
+struct EndomorphismSimplification : RewritePattern {
+ template <typename GetEndomorphismOpOperandFnArg,
+ typename GetEndomorphismOpResultFnArg,
+ typename GetAlgebraicOpOperandsFnArg,
+ typename GetAlgebraicOpResultFnArg, typename IsEndomorphismOpFnArg,
+ typename IsAlgebraicOpFnArg, typename... RewritePatternArgs>
+ EndomorphismSimplification(
+ GetEndomorphismOpOperandFnArg &&getEndomorphismOpOperand,
+ GetEndomorphismOpResultFnArg &&getEndomorphismOpResult,
+ GetAlgebraicOpOperandsFnArg &&getAlgebraicOpOperands,
+ GetAlgebraicOpResultFnArg &&getAlgebraicOpResult,
+ IsEndomorphismOpFnArg &&isEndomorphismOp,
+ IsAlgebraicOpFnArg &&isAlgebraicOp, RewritePatternArgs &&...args)
+ : RewritePattern(std::forward<RewritePatternArgs>(args)...),
+ getEndomorphismOpOperand(std::forward<GetEndomorphismOpOperandFnArg>(
+ getEndomorphismOpOperand)),
+ getEndomorphismOpResult(std::forward<GetEndomorphismOpResultFnArg>(
+ getEndomorphismOpResult)),
+ getAlgebraicOpOperands(
+ std::forward<GetAlgebraicOpOperandsFnArg>(getAlgebraicOpOperands)),
+ getAlgebraicOpResult(
+ std::forward<GetAlgebraicOpResultFnArg>(getAlgebraicOpResult)),
+ isEndomorphismOp(std::forward<IsEndomorphismOpFnArg>(isEndomorphismOp)),
+ isAlgebraicOp(std::forward<IsAlgebraicOpFnArg>(isAlgebraicOp)) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (failed(matchOp(op, algebraicOpOperands))) {
+ return failure();
+ }
+ return rewriteOp(op, algebraicOpOperands, rewriter);
+ }
+
+private:
+ LogicalResult matchOp(Operation *algebraicOp,
+ SmallVector<OpOperand *> &algebraicOpOperands) const {
+ if (!isAlgebraicOp(algebraicOp)) {
+ return failure();
+ }
+ algebraicOpOperands.clear();
+ getAlgebraicOpOperands(algebraicOp, algebraicOpOperands);
+ if (algebraicOpOperands.empty()) {
+ return failure();
+ }
+
+ Operation *firstEndomorphismOp =
+ algebraicOpOperands.front()->get().getDefiningOp();
+ if (!firstEndomorphismOp ||
+ !isEndomorphismOp(firstEndomorphismOp, std::nullopt)) {
+ return failure();
+ }
+ OpResult firstEndomorphismOpResult =
+ getEndomorphismOpResult(firstEndomorphismOp);
+ if (firstEndomorphismOpResult != algebraicOpOperands.front()->get()) {
+ return failure();
+ }
+
+ for (auto operand : algebraicOpOperands) {
+ Operation *endomorphismOp = operand->get().getDefiningOp();
+ if (!endomorphismOp ||
+ !isEndomorphismOp(endomorphismOp, firstEndomorphismOp)) {
+ return failure();
+ }
+ }
+ return success();
+ }
+
+ LogicalResult rewriteOp(Operation *algebraicOp,
+ const SmallVector<OpOperand *> &algebraicOpOperands,
+ PatternRewriter &rewriter) const {
+ irMapping.clear();
+ for (auto operand : algebraicOpOperands) {
+ Operation *endomorphismOp = operand->get().getDefiningOp();
+ irMapping.map(operand->get(),
+ getEndomorphismOpOperand(endomorphismOp)->get());
+ }
+ Operation *newAlgebraicOp = rewriter.clone(*algebraicOp, irMapping);
+
+ irMapping.clear();
+ assert(!algebraicOpOperands.empty());
+ Operation *firstEndomorphismOp =
+ algebraicOpOperands[0]->get().getDefiningOp();
+ irMapping.map(getEndomorphismOpOperand(firstEndomorphismOp)->get(),
+ getAlgebraicOpResult(newAlgebraicOp));
+ Operation *newEndomorphismOp =
+ rewriter.clone(*firstEndomorphismOp, irMapping);
+ rewriter.replaceAllUsesWith(getAlgebraicOpResult(algebraicOp),
+ getEndomorphismOpResult(newEndomorphismOp));
+ return success();
+ }
+
+ GetEndomorphismOpOperandFn getEndomorphismOpOperand;
+ GetEndomorphismOpResultFn getEndomorphismOpResult;
+ GetAlgebraicOpOperandsFn getAlgebraicOpOperands;
+ GetAlgebraicOpResultFn getAlgebraicOpResult;
+ IsEndomorphismOpFn isEndomorphismOp;
+ IsAlgebraicOpFn isAlgebraicOp;
+ mutable SmallVector<OpOperand *> algebraicOpOperands;
+ mutable IRMapping irMapping;
+};
+
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SIMPLIFY_ENDOMORPHISM_H_
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index bcf45c4ea276080..044b8672c8c60cf 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMeshTransforms
+ Simplifications.cpp
ShardingPropagation.cpp
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_dialect_library(MLIRMeshTransforms
MLIRShardingInterface
LINK_LIBS PUBLIC
+ MLIRArithDialect
MLIRFuncDialect
MLIRIR
MLIRMeshDialect
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
new file mode 100644
index 000000000000000..1d241fe03a127b8
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp
@@ -0,0 +1,37 @@
+//===- Patterns.cpp - Mesh Patterns -----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+
+namespace mlir {
+namespace mesh {
+
+void populateSimplificationPatterns(RewritePatternSet &patterns) {
+ populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
+ patterns, Partial::Sum);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
+ patterns, Partial::Sum);
+
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
+ patterns, Partial::Min);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
+ patterns, Partial::Min);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
+ patterns, Partial::Min);
+
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
+ patterns, Partial::Max);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
+ patterns, Partial::Max);
+ populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
+ patterns, Partial::Max);
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
new file mode 100644
index 000000000000000..2b305df6e0a97f1
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-opt -test-mesh-simplifications %s | FileCheck %s
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = [4, 2])
+mesh.cluster @mesh1(rank = 1, dim_sizes = [4])
+
+// Checks that `all_reduce(x) + all_reduce(y)` gets transformed to
+// `all_reduce(x + y)`.
+// CHECK-LABEL: func.func @all_reduce_arith_addf_endomorphism
+func.func @all_reduce_arith_addf_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ARG0]], %[[ARG1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]]
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh1
+ %1 = mesh.all_reduce %arg1 on @mesh1 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes
+func.func @all_reduce_arith_addf_no_endomorphism_different_mesh_axes(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [1]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [1]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind
+func.func @all_reduce_arith_addf_no_endomorphism_wrong_reduction_kind(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0] reduction = <max>
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <max>
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf32>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types
+func.func @all_reduce_arith_addf_no_endomorphism_different_operand_result_element_types(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf64> {
+ // CHECK: %[[ALL_REDUCE0:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG0]] on @mesh0 mesh_axes = [0]
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf64>
+ // CHECK: %[[ALL_REDUCE1:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ARG1]] on @mesh0 mesh_axes = [0]
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0]
+ : tensor<5xf32> -> tensor<5xf64>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.addf %[[ALL_REDUCE0]], %[[ALL_REDUCE1]]
+ %2 = arith.addf %0, %1 : tensor<5xf64>
+ // CHECK: return %[[ADD_RES]]
+ return %2 : tensor<5xf64>
+}
+
+// Checks that `min(all_reduce(x), all_reduce(y))` gets transformed to
+// `all_reduce(min(x, y))`.
+// CHECK-LABEL: func.func @all_reduce_arith_minimumf_endomorphism
+func.func @all_reduce_arith_minimumf_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg0: tensor<5xf32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xf32>
+ %arg1: tensor<5xf32>) -> tensor<5xf32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xf32> -> tensor<5xf32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xf32> -> tensor<5xf32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minimumf %[[ARG0]], %[[ARG1]]
+ %2 = arith.minimumf %0, %1 : tensor<5xf32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @all_reduce_arith_minsi_endomorphism
+func.func @all_reduce_arith_minsi_endomorphism(
+ // CHECK-SAME: %[[ARG0:[A-Za-z0-9_]*]]: tensor<5xi32>
+ %arg0: tensor<5xi32>,
+ // CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: tensor<5xi32>
+ %arg1: tensor<5xi32>) -> tensor<5xi32> {
+ %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xi32> -> tensor<5xi32>
+ %1 = mesh.all_reduce %arg1 on @mesh0 mesh_axes = [0] reduction = <min>
+ : tensor<5xi32> -> tensor<5xi32>
+ // CHECK: %[[ADD_RES:[A-Za-z0-9_]*]] = arith.minsi %[[ARG0]], %[[ARG1]]
+ %2 = arith.minsi %0, %1 : tensor<5xi32>
+ // CHECK: %[[ALL_REDUCE_RES:[A-Za-z0-9_]*]] = mesh.all_reduce %[[ADD_RES]] on @mesh0 mesh_axes = [0] reduction = <min>
+ // CHECK: return %[[ALL_REDUCE_RES]]
+ return %2 : tensor<5xi32>
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 48bde69e0170041..30a17c201ff7635 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -9,6 +9,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVM)
add_subdirectory(Math)
add_subdirectory(MemRef)
+add_subdirectory(Mesh)
add_subdirectory(NVGPU)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
new file mode 100644
index 000000000000000..102bd59da61b992
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -0,0 +1,11 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRMeshTestSimplifications
+ TestSimplifications.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRMeshDialect
+ MLIRMeshTransforms
+ MLIRPass
+ )
diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
new file mode 100644
index 000000000000000..93b1da52d46b4ef
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp
@@ -0,0 +1,43 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMeshSimplificationsPass
+ : public PassWrapper<TestMeshSimplificationsPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshSimplificationsPass)
+
+ void runOnOperation() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect, mesh::MeshDialect>();
+ }
+ StringRef getArgument() const final { return "test-mesh-simplifications"; }
+ StringRef getDescription() const final { return "Test mesh simplifications"; }
+};
+} // namespace
+
+void TestMeshSimplificationsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ mesh::populateSimplificationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+}
+
+namespace mlir {
+namespace test {
+void registerTestMeshSimplificationsPass() {
+ PassRegistration<TestMeshSimplificationsPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 88a0562cb6e7207..bc8eed18215525e 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -26,6 +26,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRLoopLikeInterfaceTestPasses
MLIRMathTestPasses
MLIRMemRefTestPasses
+ MLIRMeshTestSimplifications
MLIRNVGPUTestPasses
MLIRSCFTestPasses
MLIRShapeTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3e3223b48505601..c7cf1e55a556f4a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -118,6 +118,7 @@ void registerTestMathAlgebraicSimplificationPass();
void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
+void registerTestMeshSimplificationsPass();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -238,6 +239,7 @@ void registerTestPasses() {
mlir::test::registerTestMathPolynomialApproximationPass();
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
+ mlir::test::registerTestMeshSimplificationsPass();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
More information about the Mlir-commits
mailing list