[Mlir-commits] [mlir] [mlir][mesh] Add endomorphism simplification for all-reduce (PR #73150)

Boian Petkantchin llvmlistbot at llvm.org
Wed Dec 6 10:12:47 PST 2023


================
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+    RewritePatternSet &patterns, Partial reduction) {
+  auto getEndomorphismOpOperand = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return &allReduceOp.getInputMutable();
+  };
+  auto getEndomorphismOpResult = [](Operation *op) {
+    auto allReduceOp = llvm::cast<AllReduceOp>(op);
+    return allReduceOp->getResult(0);
+  };
+  auto getAlgebraicOpOperands = [](Operation *op,
+                                   SmallVector<OpOperand *> &operands) {
+    auto algebraicOp = llvm::cast<AlgebraicOp>(op);
----------------
sogartar wrote:

Do you mean a scenario like this?
`h` is the endomorphism and `a` is the algebraic structure op.
Then
```
a(h(x), h(y), z) = h(a(x, y, z))
```
does not actually hold.

If that is the case you may be able to do homomorphism simplification where the target algebraic structure op `b` is different. And we have
```
a(h(x), h(y), z) = h(b(x, y, z))
```


https://github.com/llvm/llvm-project/pull/73150


More information about the Mlir-commits mailing list