[Mlir-commits] [mlir] [mlir] [dataflow] Add a loadAnalysis method to the dataflow analysis (PR #102808)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Aug 11 02:21:05 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

Dataflow analyses often depend on each other, requiring manual selection of analyses to load in order to obtain desired analysis results. As the number of analyses in the downstream repo increases, maintaining relationships between analyses at call sites becomes increasingly redundant. This submission proposes adding a loadAnalysis method to the dataflow analysis to internally manage dependencies on other analyses.

---
Full diff: https://github.com/llvm/llvm-project/pull/102808.diff


7 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h (+4) 
- (modified) mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h (+6) 
- (modified) mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h (+8) 
- (modified) mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h (+1-1) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+1-3) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (+8) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+1-2) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 1bf991dc193874..c7e60998dd8a1c 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -106,6 +106,10 @@ class SparseConstantPropagation
                       ArrayRef<Lattice<ConstantValue> *> results) override;
 
   void setToEntryState(Lattice<ConstantValue> *lattice) override;
+
+  static void loadAnalysis(DataFlowSolver &solver) {
+    solver.load<SparseConstantPropagation>();
+  }
 };
 
 } // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 10ef8b6ba5843a..84a27026b3d6cc 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -15,6 +15,7 @@
 #ifndef MLIR_ANALYSIS_DATAFLOW_DEADCODEANALYSIS_H
 #define MLIR_ANALYSIS_DATAFLOW_DEADCODEANALYSIS_H
 
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/SymbolTable.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -184,6 +185,11 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
   /// successors are live.
   LogicalResult visit(ProgramPoint point) override;
 
+  static void loadAnalysis(DataFlowSolver &solver) {
+    solver.load<SparseConstantPropagation>();
+    solver.load<DeadCodeAnalysis>();
+  }
+
 private:
   /// Find and mark symbol callables with potentially unknown callsites as
   /// having overdefined predecessors. `top` is the top-level operation that the
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 191c023fb642cb..dc7069c398031a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -18,6 +18,8 @@
 #ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
 #define MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
 
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 
@@ -68,6 +70,12 @@ class IntegerRangeAnalysis
   visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
                                ArrayRef<IntegerValueRangeLattice *> argLattices,
                                unsigned firstIndex) override;
+
+  static void loadAnalysis(DataFlowSolver &solver) {
+    solver.load<SparseConstantPropagation>();
+    solver.load<DeadCodeAnalysis>();
+    solver.load<IntegerRangeAnalysis>();
+  }
 };
 
 } // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index caa03e26a3a423..6189159f0fa294 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -23,7 +23,7 @@
 #ifndef MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H
 #define MLIR_ANALYSIS_DATAFLOW_LIVENESSANALYSIS_H
 
-#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include <optional>
 
 namespace mlir::dataflow {
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 65592a5c5d698b..847f8850ecdf00 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -262,10 +262,8 @@ struct TestNextAccessPass
 
     auto config = DataFlowConfig().setInterprocedural(interprocedural);
     DataFlowSolver solver(config);
-    solver.load<DeadCodeAnalysis>();
+    UnderlyingValueAnalysis::loadAnalysis(solver);
     solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads);
-    solver.load<SparseConstantPropagation>();
-    solver.load<UnderlyingValueAnalysis>();
     if (failed(solver.initializeAndRun(op))) {
       emitError(op->getLoc(), "dataflow solver failed");
       return signalPassFailure();
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
index 61ddc13f8a3d4a..a89ad975091439 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/Value.h"
@@ -220,6 +222,12 @@ class UnderlyingValueAnalysis
     } while (true);
     return value;
   }
+
+  static void loadAnalysis(DataFlowSolver &solver) {
+    solver.load<SparseConstantPropagation>();
+    solver.load<DeadCodeAnalysis>();
+    solver.load<UnderlyingValueAnalysis>();
+  }
 };
 
 } // namespace test
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index 30297380466442..49accc07f0cf86 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -181,8 +181,7 @@ struct TestWrittenToPass
     SymbolTableCollection symbolTable;
 
     DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
-    solver.load<DeadCodeAnalysis>();
-    solver.load<SparseConstantPropagation>();
+    DeadCodeAnalysis::loadAnalysis(solver);
     solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();

``````````

</details>


https://github.com/llvm/llvm-project/pull/102808


More information about the Mlir-commits mailing list