[Mlir-commits] [mlir] [mlir][mesh] Add endomorphism simplification for all-reduce (PR #73150)
Chengji Yao
llvmlistbot at llvm.org
Thu Dec 7 22:08:35 PST 2023
================
@@ -0,0 +1,89 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SIMPLIFICATIONS_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/EndomorphismSimplification.h"
+#include "llvm/Support/Casting.h"
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <utility>
+
+namespace mlir {
+namespace mesh {
+
+template <typename AlgebraicOp>
+void populateAllReduceEndomorphismSimplificationPatterns(
+ RewritePatternSet &patterns, Partial reduction) {
+ auto getEndomorphismOpOperand = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return &allReduceOp.getInputMutable();
+ };
+ auto getEndomorphismOpResult = [](Operation *op) {
+ auto allReduceOp = llvm::cast<AllReduceOp>(op);
+ return allReduceOp->getResult(0);
+ };
+ auto getAlgebraicOpOperands = [](Operation *op,
+ SmallVector<OpOperand *> &operands) {
+ auto algebraicOp = llvm::cast<AlgebraicOp>(op);
----------------
yaochengji wrote:
No, what I meant is ops in linalg dialect which have the `DestinationStyleOpInterface`. For a linalg.matmul op
```
%matmul = linalg.matmul ins(%0, %1 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%2 : tensor<1x?xi32>) -> tensor<1x?xi32>
```
Here the `%2` operand is different from `%0` and `%1`
https://github.com/llvm/llvm-project/pull/73150
More information about the Mlir-commits
mailing list