[Mlir-commits] [mlir] [mlir][dataflow] Use SparseForwardDataFlowAnalysis to implement constant analysis (PR #156486)

lonely eagle llvmlistbot at llvm.org
Tue Sep 2 09:29:11 PDT 2025


https://github.com/linuxlonelyeagle created https://github.com/llvm/llvm-project/pull/156486

Previous constant analysis was implemented using the DataFlowAnalysis class.Now it is implemented using SparseForwardDataFlowAnalysis.

>From 00fa95ca15508f74ce199b4c43abb1cf26094e2b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Tue, 2 Sep 2025 16:25:06 +0000
Subject: [PATCH] use SparseForwardDataFlowAnalysis to implement constant
 analysis

---
 .../DataFlow/TestDeadCodeAnalysis.cpp         | 36 ++++++-------------
 1 file changed, 11 insertions(+), 25 deletions(-)

diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 2dc77c9705d35..0f94d95408f29 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -66,40 +66,26 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
 namespace {
 /// This is a simple analysis that implements a transfer function for constant
 /// operations.
-struct ConstantAnalysis : public DataFlowAnalysis {
-  using DataFlowAnalysis::DataFlowAnalysis;
+struct SparseConstantAnalysis
+    : public SparseForwardDataFlowAnalysis<Lattice<ConstantValue>> {
+  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
 
-  LogicalResult initialize(Operation *top) override {
-    WalkResult result = top->walk([&](Operation *op) {
-      if (failed(visit(getProgramPointAfter(op))))
-        return WalkResult::interrupt();
-      return WalkResult::advance();
-    });
-    return success(!result.wasInterrupted());
-  }
-
-  LogicalResult visit(ProgramPoint *point) override {
-    Operation *op = point->getPrevOp();
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const Lattice<ConstantValue> *> operands,
+                 ArrayRef<Lattice<ConstantValue> *> results) override {
     Attribute value;
     if (matchPattern(op, m_Constant(&value))) {
       auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
       propagateIfChanged(
           constant, constant->join(ConstantValue(value, op->getDialect())));
-      return success();
     }
-    setAllToUnknownConstants(op->getResults());
-    for (Region &region : op->getRegions())
-      setAllToUnknownConstants(region.getArguments());
     return success();
   }
 
-  /// Set all given values as not constants.
-  void setAllToUnknownConstants(ValueRange values) {
-    for (Value value : values) {
-      auto *constant = getOrCreate<Lattice<ConstantValue>>(value);
-      propagateIfChanged(constant,
-                         constant->join(ConstantValue::getUnknownConstant()));
-    }
+  void setToEntryState(Lattice<ConstantValue> *lattice) override {
+    propagateIfChanged(lattice,
+                       lattice->join(ConstantValue::getUnknownConstant()));
   }
 };
 
@@ -116,7 +102,7 @@ struct TestDeadCodeAnalysisPass
 
     DataFlowSolver solver;
     solver.load<DeadCodeAnalysis>();
-    solver.load<ConstantAnalysis>();
+    solver.load<SparseConstantAnalysis>();
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
     printAnalysisResults(solver, op, llvm::errs());



More information about the Mlir-commits mailing list