[Mlir-commits] [mlir] [mlir][dataflow] Update dataflow doc and add dataflow example code (PR #149296)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 17 05:24:23 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

The dataflow tutorial is broken. This PR updates the documentation and adds code examples.

---

Patch is 29.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149296.diff


13 Files Affected:

- (modified) mlir/docs/Tutorials/DataFlowAnalysis.md (+135-113) 
- (modified) mlir/examples/CMakeLists.txt (+1) 
- (added) mlir/examples/dataflow/CMakeLists.txt (+7) 
- (added) mlir/examples/dataflow/dataflow-opt/CMakeLists.txt (+11) 
- (added) mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp (+36) 
- (added) mlir/examples/dataflow/include/MetadataAnalysis.h (+65) 
- (added) mlir/examples/dataflow/lib/Analysis/CMakeLists.txt (+6) 
- (added) mlir/examples/dataflow/lib/Analysis/MetadataAnalysis.cpp (+114) 
- (added) mlir/examples/dataflow/lib/CMakeLists.txt (+2) 
- (added) mlir/examples/dataflow/lib/Pass/CMakeLists.txt (+6) 
- (added) mlir/examples/dataflow/lib/Pass/TestMetadataAnalsys.cpp (+57) 
- (modified) mlir/test/CMakeLists.txt (+3) 
- (added) mlir/test/Examples/dataflow/join.mlir (+33) 


``````````diff
diff --git a/mlir/docs/Tutorials/DataFlowAnalysis.md b/mlir/docs/Tutorials/DataFlowAnalysis.md
index ea7158fb7391d..fec2a25bbad7c 100644
--- a/mlir/docs/Tutorials/DataFlowAnalysis.md
+++ b/mlir/docs/Tutorials/DataFlowAnalysis.md
@@ -7,7 +7,7 @@ constructs, of which MLIR has many (Block-based branches, Region-based branches,
 CallGraph, etc), and it isn't always clear how best to go about performing the
 propagation. To help writing these types of analyses in MLIR, this document
 details several utilities that simplify the process and make it a bit more
-approachable.
+approachable. The code from that tutorial can be found in `mlir/examples/dataflow`.
 
 ## Forward Dataflow Analysis
 
@@ -72,31 +72,11 @@ held by an element of the lattice used by our dataflow analysis:
 struct MetadataLatticeValue {
   MetadataLatticeValue() = default;
   /// Compute a lattice value from the provided dictionary.
-  MetadataLatticeValue(DictionaryAttr attr)
-      : metadata(attr.begin(), attr.end()) {}
-
-  /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown`
-  /// state, for our value type. The resultant state should not assume any
-  /// information about the state of the IR.
-  static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) {
-    // The `top`/`overdefined`/`unknown` state is when we know nothing about any
-    // metadata, i.e. an empty dictionary.
-    return MetadataLatticeValue();
-  }
-  /// Return a pessimistic value state for our value type using only information
-  /// about the state of the provided IR. This is similar to the above method,
-  /// but may produce a slightly more refined result. This is okay, as the
-  /// information is already encoded as fact in the IR.
-  static MetadataLatticeValue getPessimisticValueState(Value value) {
-    // Check to see if the parent operation has metadata.
-    if (Operation *parentOp = value.getDefiningOp()) {
-      if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata"))
-        return MetadataLatticeValue(metadata);
-
-      // If no metadata is present, fallback to the
-      // `top`/`overdefined`/`unknown` state.
+  MetadataLatticeValue(DictionaryAttr attr) {
+    for (NamedAttribute pair : attr) {
+      metadata.insert(
+          std::pair<StringAttr, Attribute>(pair.getName(), pair.getValue()));
     }
-    return MetadataLatticeValue();
   }
 
   /// This method conservatively joins the information held by `lhs` and `rhs`
@@ -111,15 +91,17 @@ struct MetadataLatticeValue {
   static MetadataLatticeValue join(const MetadataLatticeValue &lhs,
                                    const MetadataLatticeValue &rhs) {
     // To join `lhs` and `rhs` we will define a simple policy, which is that we
-    // only keep information that is the same. This means that we only keep
-    // facts that are true in both.
+    // directly insert the metadata of rhs into the metadata of lhs.If lhs and rhs
+    // have overlapping attributes, keep the attribute value in lhs unchanged.
     MetadataLatticeValue result;
-    for (const auto &lhsIt : lhs.metadata) {
-      // As noted above, we only merge if the values are the same.
-      auto it = rhs.metadata.find(lhsIt.first);
-      if (it == rhs.metadata.end() || it.second != lhsIt.second)
-        continue;
-      result.insert(lhsIt);
+    for (auto &&lhsIt : lhs.metadata) {
+      result.metadata.insert(
+          std::pair<StringRef, Attribute>(lhsIt.getKey(), lhsIt.getValue()));
+    }
+
+    for (auto &&rhsIt : rhs.metadata) {
+      result.metadata.insert(
+          std::pair<StringRef, Attribute>(rhsIt.getKey(), rhsIt.getValue()));
     }
     return result;
   }
@@ -129,18 +111,28 @@ struct MetadataLatticeValue {
   bool operator==(const MetadataLatticeValue &rhs) const {
     if (metadata.size() != rhs.metadata.size())
       return false;
+
     // Check that `rhs` contains the same metadata.
-    for (const auto &it : metadata) {
-      auto rhsIt = rhs.metadata.find(it.first);
-      if (rhsIt == rhs.metadata.end() || it.second != rhsIt.second)
+    for (auto &&it : metadata) {
+      auto rhsIt = rhs.metadata.find(it.getKey());
+      if (rhsIt == rhs.metadata.end() || it.second != rhsIt->second)
         return false;
     }
     return true;
   }
 
+  /// Print data in metadata.
+  void print(llvm::raw_ostream &os) const {
+    os << "{";
+    for (auto&& iter : metadata) {
+      os << iter.getKey() << ": " << iter.getValue() << ", ";
+    }
+    os << "\b\b}\n";
+  }
+
   /// Our value represents the combined metadata, which is originally a
   /// DictionaryAttr, so we use a map.
-  DenseMap<StringAttr, Attribute> metadata;
+  llvm::StringMap<Attribute> metadata;
 };
 ```
 
@@ -154,7 +146,7 @@ shown below:
 /// This class represents a lattice element holding a specific value of type
 /// `ValueT`.
 template <typename ValueT>
-class LatticeElement ... {
+class Lattice ... {
 public:
   /// Return the value held by this element. This requires that a value is
   /// known, i.e. not `uninitialized`.
@@ -168,20 +160,25 @@ public:
   /// Join the information contained in the 'rhs' value into this
   /// lattice. Returns if the state of the current lattice changed.
   ChangeResult join(const ValueT &rhs);
-
-  /// Mark the lattice element as having reached a pessimistic fixpoint. This
-  /// means that the lattice may potentially have conflicting value states, and
-  /// only the conservatively known value state should be relied on.
-  ChangeResult markPessimisticFixPoint();
+  
+  ...
 };
 ```
 
 With our lattice defined, we can now define the driver that will compute and
-propagate our lattice across the IR.
+propagate our lattice across the IR. The following is our definition of metadata
+lattice.
 
-### ForwardDataflowAnalysis Driver
+```c++
+class MetadataLatticeValueLattice : public Lattice<MetadataLatticeValue> {
+public:
+  using Lattice::Lattice;
+};
+```
+
+### SparseForwardDataFlowAnalysis Driver
 
-The `ForwardDataFlowAnalysis` class represents the driver of the dataflow
+The `SparseForwardDataFlowAnalysis` class represents the driver of the dataflow
 analysis, and performs all of the related analysis computation. When defining
 our analysis, we will inherit from this class and implement some of its hooks.
 Before that, let's look at a quick overview of this class and some of the
@@ -190,42 +187,36 @@ important API for our analysis:
 ```c++
 /// This class represents the main driver of the forward dataflow analysis. It
 /// takes as a template parameter the value type of lattice being computed.
-template <typename ValueT>
-class ForwardDataFlowAnalysis : ... {
+template <typename StateT>
+class SparseForwardDataFlowAnalysis : ... {
 public:
-  ForwardDataFlowAnalysis(MLIRContext *context);
-
-  /// Compute the analysis on operations rooted under the given top-level
-  /// operation. Note that the top-level operation is not visited.
-  void run(Operation *topLevelOp);
+  explicit SparseForwardDataFlowAnalysis(DataFlowSolver &solver)
+      : AbstractSparseForwardDataFlowAnalysis(solver) {}
+
+  /// Visit an operation with the lattices of its operands. This function is
+  /// expected to set the lattices of the operation's results.
+  virtual LogicalResult visitOperation(Operation *op,
+                                       ArrayRef<const StateT *> operands,
+                                       ArrayRef<StateT *> results) = 0;
+  ...
 
+protected:
   /// Return the lattice element attached to the given value. If a lattice has
   /// not been added for the given value, a new 'uninitialized' value is
   /// inserted and returned.
-  LatticeElement<ValueT> &getLatticeElement(Value value);
+  StateT *getLatticeElement(Value value);
 
-  /// Return the lattice element attached to the given value, or nullptr if no
-  /// lattice element for the value has yet been created.
-  LatticeElement<ValueT> *lookupLatticeElement(Value value);
+  /// Get the lattice element for a value and create a dependency on the
+  /// provided program point.
+  const StateT *getLatticeElementFor(ProgramPoint *point, Value value);
 
-  /// Mark all of the lattice elements for the given range of Values as having
-  /// reached a pessimistic fixpoint.
-  ChangeResult markAllPessimisticFixPoint(ValueRange values);
-
-protected:
-  /// Visit the given operation, and join any necessary analysis state
-  /// into the lattice elements for the results and block arguments owned by
-  /// this operation using the provided set of operand lattice elements
-  /// (all pointer values are guaranteed to be non-null). Returns if any result
-  /// or block argument value lattice elements changed during the visit. The
-  /// lattice element for a result or block argument value can be obtained, and
-  /// join'ed into, by using `getLatticeElement`.
-  virtual ChangeResult visitOperation(
-      Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0;
+  /// Set the given lattice element(s) at control flow entry point(s).
+  virtual void setToEntryState(StateT *lattice) = 0;
+  ...
 };
 ```
 
-NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis`
+NOTE: Some API has been redacted for our example. The `SparseForwardDataFlowAnalysis`
 contains various other hooks that allow for injecting custom behavior when
 applicable.
 
@@ -237,60 +228,91 @@ function for the operation, that is specific to our analysis. A simple
 implementation for our example is shown below:
 
 ```c++
-class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> {
+class MetadataAnalysis
+    : public SparseForwardDataFlowAnalysis<MetadataLatticeValueLattice> {
 public:
-  using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis;
-
-  ChangeResult visitOperation(
-      Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override {
-    DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");
-
-    // If we have no metadata for this operation, we will conservatively mark
-    // all of the results as having reached a pessimistic fixpoint.
-    if (!metadata)
-      return markAllPessimisticFixPoint(op->getResults());
-
-    // Otherwise, we will compute a lattice value for the metadata and join it
-    // into the current lattice element for all of our results.
-    MetadataLatticeValue latticeValue(metadata);
-    ChangeResult result = ChangeResult::NoChange;
-    for (Value value : op->getResults()) {
-      // We grab the lattice element for `value` via `getLatticeElement` and
-      // then join it with the lattice value for this operation's metadata. Note
-      // that during the analysis phase, it is fine to freely create a new
-      // lattice element for a value. This is why we don't use the
-      // `lookupLatticeElement` method here.
-      result |= getLatticeElement(value).join(latticeValue);
+  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const MetadataLatticeValueLattice *> operands,
+                 ArrayRef<MetadataLatticeValueLattice *> results) override {
+  DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");
+  // If we have no metadata for this operation and the operands is empty, we
+  // will conservatively mark all of the results as having reached a pessimistic
+  // fixpoint.
+  if (!metadata && operands.empty()) {
+    setAllToEntryStates(results);
+    return success();
+  }
+
+  MetadataLatticeValue latticeValue;
+  if (metadata)
+    latticeValue = MetadataLatticeValue(metadata);
+
+  // Otherwise, we will compute a lattice value for the metadata and join it
+  // into the current lattice element for all of our results.`results` stores
+  // the lattices corresponding to the results of op, We use a loop to traverse
+  // them.
+  for (int i = 0, e = results.size(); i < e; ++i) {
+
+    // `isChanged` records whether the result has been changed.
+    ChangeResult isChanged = ChangeResult::NoChange;
+
+    // Op's metadata is joined result's lattice.
+    isChanged |= results[i]->join(latticeValue);
+
+    // All lattice of operands of op are joined to the lattice of result.
+    for (int j = 0, m = operands.size(); j < m; ++j) {
+      isChanged |= results[i]->join(*operands[j]);
     }
-    return result;
+    propagateIfChanged(results[i], isChanged);
+  }
+  return success();
   }
 };
 ```
 
 With that, we have all of the necessary components to compute our analysis.
-After the analysis has been computed, we can grab any computed information for
-values by using `lookupLatticeElement`. We use this function over
-`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g.
-if the value is in a unreachable block, and we don't want to create a new
-uninitialized lattice element in this case. See below for a quick example:
+After the analysis has been computed, we need to run our analysis using 
+`DataFlowSolver`, and we can grab any computed information for values by 
+using `lookupState`. See below for a quick example, after the pass runs the
+analysis, we print the metadata of each op's results.
 
 ```c++
 void MyPass::runOnOperation() {
-  MetadataAnalysis analysis(&getContext());
-  analysis.run(getOperation());
+  Operation *op = getOperation();
+  DataFlowSolver solver;
+  solver.load<DeadCodeAnalysis>();
+  solver.load<MetadataAnalysis>();
+  if (failed(solver.initializeAndRun(op)))
+    return signalPassFailure();
+
+  // If an op has more than one result, then the lattice is the same for each
+  // result, and we just print one of the results.
+  op->walk([&](Operation *op) {
+    if (op->getNumResults()) {
+      Value result = op->getResult(0);
+      auto lattice = solver.lookupState<MetadataLatticeValueLattice>(result);
+      lattice->print(llvm::outs());
+    }
+  });
   ...
 }
+```
 
-void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) {
-  LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value);
-
-  // If we don't have an element, the `value` wasn't visited during our analysis
-  // meaning that it could be dead. We need to treat this conservatively.
-  if (!lattice)
-    return;
+The following is a simple example. More tests can be found in the `mlir/Example/dataflow`.
 
-  // Our lattice element has a value, use it:
-  MetadataLatticeValue &value = lattice->getValue();
-  ...
+```mlir
+func.func @single_join(%arg0 : index, %arg1 : index) -> index {
+  %1 = arith.addi %arg0, %arg1 {metadata = { likes_pizza = true }} : index
+  %2 = arith.addi %1, %arg1 : index
+  return %2 : index
 }
 ```
+
+The above IR will print the following after running pass.
+
+```
+{likes_pizza: true}
+{likes_pizza: true}
+```
\ No newline at end of file
diff --git a/mlir/examples/CMakeLists.txt b/mlir/examples/CMakeLists.txt
index 2a1cac34d8c29..6ea7c20188eb6 100644
--- a/mlir/examples/CMakeLists.txt
+++ b/mlir/examples/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(dataflow)
 add_subdirectory(toy)
 add_subdirectory(transform)
 add_subdirectory(transform-opt)
diff --git a/mlir/examples/dataflow/CMakeLists.txt b/mlir/examples/dataflow/CMakeLists.txt
new file mode 100644
index 0000000000000..2393d0abbe9c8
--- /dev/null
+++ b/mlir/examples/dataflow/CMakeLists.txt
@@ -0,0 +1,7 @@
+add_custom_target(DataFlowExample)
+set_target_properties(DataFlowExample PROPERTIES FOLDER "MLIR/Examples")
+
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
+
+add_subdirectory(lib)
+add_subdirectory(dataflow-opt)
diff --git a/mlir/examples/dataflow/dataflow-opt/CMakeLists.txt b/mlir/examples/dataflow/dataflow-opt/CMakeLists.txt
new file mode 100644
index 0000000000000..54499e8a42041
--- /dev/null
+++ b/mlir/examples/dataflow/dataflow-opt/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_dependencies(DataFlowExample dataflow-opt)
+add_llvm_example(dataflow-opt
+  dataflow-opt.cpp
+)
+
+target_link_libraries(dataflow-opt
+  PRIVATE
+  MLIRIR
+  MLIRMlirOptMain
+  MLIRTestMetadataAnalysisPass
+)
diff --git a/mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp b/mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp
new file mode 100644
index 0000000000000..5fd9464563646
--- /dev/null
+++ b/mlir/examples/dataflow/dataflow-opt/dataflow-opt.cpp
@@ -0,0 +1,36 @@
+//===-- dataflow-opt.cpp - dataflow tutorial entry point ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the top-level file for the dataflow tutorial.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/InitAllExtensions.h"
+#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+namespace test {
+void registerTestMetadataAnalysisPass();
+};
+} // namespace mlir
+
+int main(int argc, char *argv[]) {
+  // Register all MLIR core dialects.
+  mlir::DialectRegistry registry;
+  registerAllDialects(registry);
+  registerAllExtensions(registry);
+
+  // Register test-metadata-analysis pass.
+  mlir::test::registerTestMetadataAnalysisPass();
+  return mlir::failed(
+      mlir::MlirOptMain(argc, argv, "dataflow-opt optimizer driver", registry));
+}
diff --git a/mlir/examples/dataflow/include/MetadataAnalysis.h b/mlir/examples/dataflow/include/MetadataAnalysis.h
new file mode 100644
index 0000000000000..f5e8f566edea3
--- /dev/null
+++ b/mlir/examples/dataflow/include/MetadataAnalysis.h
@@ -0,0 +1,65 @@
+//===-- MetadataAnalysis.h - dataflow tutorial ------------------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file is contains the dataflow tutorial's classes related to metadata.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace llvm;
+
+namespace mlir {
+/// The value of our lattice represents the inner structure of a DictionaryAttr,
+/// for the `metadata`.
+struct MetadataLatticeValue {
+  MetadataLatticeValue() = default;
+  /// Compute a lattice value from the provided dictionary.
+  MetadataLatticeValue(DictionaryAttr attr) {
+    for (NamedAttribute pair : attr) {
+      metadata.insert(
+          std::pair<StringAttr, Attribute>(pair.getName(), pair.getValue()));
+    }
+  }
+
+  static MetadataLatticeValue join(const MetadataLatticeValue &lhs,
+                                   const MetadataLatticeValue &rhs);
+
+  /// A simple comparator that checks to see if this value is equal to the one
+  /// provided.
+  bool operator==(const MetadataLatticeValue &rhs) const;
+
+  /// Print data in metadata.
+  void print(llvm::raw_ostream &os) const;
+
+  /// Our value represents the combined metadata, which is originally a
+  /// DictionaryAttr, so we use a map.
+  llvm::StringMap<Attribute> metadata;
+};
+
+namespace dataflow {
+class MetadataLatticeValueLattice : public Lattice<MetadataLatticeValue> {
+public:
+  using Lattice::Lattice;
+};
+
+class MetadataAnalysis
+    : public SparseForwardDataFlowAnalysis<MetadataLatticeValueLattice> {
+public:
+  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+  LogicalResult
+  visitOperation(Operation *op,
+                 ArrayRef<const MetadataLatticeValueLattice *> operands,
+                 ArrayRef<MetadataLatticeValueLattice *> results) override;
+  void setToEntryState(MetadataLatticeValueLattice *lattice) override;
+};
+
+} // namespace dataflow
+} // namespace mlir
diff --git a/mlir/examples/dataflow/lib/Analysis/CMakeLists.txt b/mlir/examples/dataflow/lib/Analysis/CMakeLists.txt
new file mode 100644
index 0000000000000..c9e07f520e0be
--- /dev/null
+++ b/mlir/examples/dataflow/lib/Analysis/CMakeLists.txt
@@ -0,0 +1,6 @@
+add_mlir_example_library(MLIRMetadataAnalysis
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/149296


More information about the Mlir-commits mailing list