[Mlir-commits] [mlir] bf3861b - [mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 13 12:35:00 PST 2020
Author: MaheshRavishankar
Date: 2020-11-13T12:34:38-08:00
New Revision: bf3861bf71b61024e5988d15655aee826fc12313
URL: https://github.com/llvm/llvm-project/commit/bf3861bf71b61024e5988d15655aee826fc12313
DIFF: https://github.com/llvm/llvm-project/commit/bf3861bf71b61024e5988d15655aee826fc12313.diff
LOG: [mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.
Using LinalgOp will reduce the repeated conversion from Operation <->
LinalgOp.
Differential Revision: https://reviews.llvm.org/D91101
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 372f6c4e01a1..f27b929f2fc0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
#define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
@@ -67,7 +68,7 @@ class LinalgDependenceGraph {
// Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
- LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
+ LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
/// Returns the X such that op -> X is a dependence of type dt.
dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
@@ -168,7 +169,7 @@ class LinalgDependenceGraph {
ArrayRef<DependenceType> types) const;
Aliases &aliases;
- SmallVector<Operation *, 8> linalgOps;
+ SmallVector<LinalgOp, 8> linalgOps;
DenseMap<Operation *, unsigned> linalgOpPositions;
};
} // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 01e167d1f0aa..96da933888f2 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
LinalgDependenceGraph
LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
- SmallVector<Operation *, 8> linalgOps;
+ SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
return LinalgDependenceGraph(aliases, linalgOps);
}
LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
- ArrayRef<Operation *> ops)
+ ArrayRef<LinalgOp> ops)
: aliases(aliases), linalgOps(ops.begin(), ops.end()) {
for (auto en : llvm::enumerate(linalgOps)) {
- assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
- linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
+ linalgOpPositions.insert(
+ std::make_pair(en.value().getOperation(), en.index()));
}
for (unsigned i = 0, e = ops.size(); i < e; ++i) {
for (unsigned j = i + 1; j < e; ++j) {
- addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
+ addDependencesBetween(ops[i], ops[j]);
}
}
}
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index e6e150b7bf47..eb9e3a533138 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
DenseSet<Operation *> eraseSet;
// Save original Linalg ops, we only want to make a pass over those.
- SmallVector<Operation *, 8> linalgOps;
+ SmallVector<LinalgOp, 8> linalgOps;
f.walk([&](LinalgOp op) {
// TODO: support multi-results.
if (op.getOperation()->getNumResults() <= 1)
@@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
bool changed = false;
- for (auto *op : llvm::reverse(linalgOps)) {
- LinalgOp linalgOp = cast<LinalgOp>(op);
+ for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
if (en.value().getType().isa<MemRefType>()) {
// TODO: LinalgDependenceGraph should be able to update itself.
@@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// removed.
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
- if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+ if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
auto *originalOp = info->originalProducer.getOperation();
eraseSet.insert(originalOp);
auto *originalOpInLinalgOpsVector =
@@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
// Tile and Fuse tensor input (TODO: init_tensors too).
if (en.index() >= linalgOp.getNumInputs())
continue;
- if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+ if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
More information about the Mlir-commits
mailing list