[Mlir-commits] [mlir] bf3861b - [mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 13 12:35:00 PST 2020


Author: MaheshRavishankar
Date: 2020-11-13T12:34:38-08:00
New Revision: bf3861bf71b61024e5988d15655aee826fc12313

URL: https://github.com/llvm/llvm-project/commit/bf3861bf71b61024e5988d15655aee826fc12313
DIFF: https://github.com/llvm/llvm-project/commit/bf3861bf71b61024e5988d15655aee826fc12313.diff

LOG: [mlir][Linalg] Change LinalgDependenceGraph to use LinalgOp.

Using LinalgOp will reduce the repeated conversion from Operation <->
LinalgOp.

Differential Revision: https://reviews.llvm.org/D91101

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
    mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
index 372f6c4e01a1..f27b929f2fc0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
+++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
 #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
 
@@ -67,7 +68,7 @@ class LinalgDependenceGraph {
 
   // Builds a linalg dependence graph for the ops of type LinalgOp under `f`.
   static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f);
-  LinalgDependenceGraph(Aliases &aliases, ArrayRef<Operation *> ops);
+  LinalgDependenceGraph(Aliases &aliases, ArrayRef<LinalgOp> ops);
 
   /// Returns the X such that op -> X is a dependence of type dt.
   dependence_range getDependencesFrom(Operation *src, DependenceType dt) const;
@@ -168,7 +169,7 @@ class LinalgDependenceGraph {
                                         ArrayRef<DependenceType> types) const;
 
   Aliases &aliases;
-  SmallVector<Operation *, 8> linalgOps;
+  SmallVector<LinalgOp, 8> linalgOps;
   DenseMap<Operation *, unsigned> linalgOpPositions;
 };
 } // namespace linalg

diff  --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
index 01e167d1f0aa..96da933888f2 100644
--- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
+++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
@@ -86,21 +86,21 @@ StringRef LinalgDependenceGraph::getDependenceTypeStr(DependenceType depType) {
 
 LinalgDependenceGraph
 LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) {
-  SmallVector<Operation *, 8> linalgOps;
+  SmallVector<LinalgOp, 8> linalgOps;
   f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
   return LinalgDependenceGraph(aliases, linalgOps);
 }
 
 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
-                                             ArrayRef<Operation *> ops)
+                                             ArrayRef<LinalgOp> ops)
     : aliases(aliases), linalgOps(ops.begin(), ops.end()) {
   for (auto en : llvm::enumerate(linalgOps)) {
-    assert(isa<LinalgOp>(en.value()) && "Expected value for LinalgOp");
-    linalgOpPositions.insert(std::make_pair(en.value(), en.index()));
+    linalgOpPositions.insert(
+        std::make_pair(en.value().getOperation(), en.index()));
   }
   for (unsigned i = 0, e = ops.size(); i < e; ++i) {
     for (unsigned j = i + 1; j < e; ++j) {
-      addDependencesBetween(cast<LinalgOp>(ops[i]), cast<LinalgOp>(ops[j]));
+      addDependencesBetween(ops[i], ops[j]);
     }
   }
 }

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index e6e150b7bf47..eb9e3a533138 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -124,7 +124,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
   DenseSet<Operation *> eraseSet;
 
   // Save original Linalg ops, we only want to make a pass over those.
-  SmallVector<Operation *, 8> linalgOps;
+  SmallVector<LinalgOp, 8> linalgOps;
   f.walk([&](LinalgOp op) {
     // TODO: support multi-results.
     if (op.getOperation()->getNumResults() <= 1)
@@ -133,8 +133,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
 
   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
   bool changed = false;
-  for (auto *op : llvm::reverse(linalgOps)) {
-    LinalgOp linalgOp = cast<LinalgOp>(op);
+  for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
     for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) {
       if (en.value().getType().isa<MemRefType>()) {
         // TODO: LinalgDependenceGraph should be able to update itself.
@@ -142,7 +141,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
         // removed.
         linalg::Aliases aliases;
         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-        if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) {
+        if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) {
           auto *originalOp = info->originalProducer.getOperation();
           eraseSet.insert(originalOp);
           auto *originalOpInLinalgOpsVector =
@@ -155,7 +154,7 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
         // Tile and Fuse tensor input (TODO: init_tensors too).
         if (en.index() >= linalgOp.getNumInputs())
           continue;
-        if (auto info = fuseProducerOfTensor(b, op, en.index())) {
+        if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) {
           auto *originalOp = info->originalProducer.getOperation();
           auto *originalOpInLinalgOpsVector =
               std::find(linalgOps.begin(), linalgOps.end(), originalOp);


        


More information about the Mlir-commits mailing list