[Mlir-commits] [mlir] d80c271 - [mlir] An implementation of dense data-flow analysis

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 7 15:12:52 PDT 2022


Author: Mogball
Date: 2022-07-07T15:12:46-07:00
New Revision: d80c271c8ac0703b5fe9ba40e55121d6dd25b389

URL: https://github.com/llvm/llvm-project/commit/d80c271c8ac0703b5fe9ba40e55121d6dd25b389
DIFF: https://github.com/llvm/llvm-project/commit/d80c271c8ac0703b5fe9ba40e55121d6dd25b389.diff

LOG: [mlir] An implementation of dense data-flow analysis

This patch introduces an implementation of dense data-flow analysis. Dense
data-flow analysis attaches a lattice before and after the execution of every
operation. The lattice state is propagated across operations by a user-defined
transfer function. The state is joined across control-flow and callgraph edges.

Thge patch provides an example pass that uses both a dense and a sparse analysis
together.

Depends on D127139

Reviewed By: rriddle, phisiart

Differential Revision: https://reviews.llvm.org/D127173

Added: 
    mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
    mlir/test/Analysis/DataFlow/test-last-modified.mlir
    mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp

Modified: 
    mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/lib/Analysis/CMakeLists.txt
    mlir/test/lib/Analysis/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index b07994755c03..145170f74e54 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -43,6 +43,12 @@ class ConstantValue {
     return constant == rhs.constant;
   }
 
+  /// Print the constant value.
+  void print(raw_ostream &os) const;
+
+  /// The pessimistic value state of the constant value is unknown.
+  static ConstantValue getPessimisticValueState(Value value) { return {}; }
+
   /// The union with another constant value is null if they are 
diff erent, and
   /// the same if they are the same.
   static ConstantValue join(const ConstantValue &lhs,
@@ -50,9 +56,6 @@ class ConstantValue {
     return lhs == rhs ? lhs : ConstantValue();
   }
 
-  /// Print the constant value.
-  void print(raw_ostream &os) const;
-
 private:
   /// The constant value.
   Attribute constant;

diff  --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
new file mode 100644
index 000000000000..43432c758824
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -0,0 +1,167 @@
+//===- DenseAnalysis.h - Dense data-flow analysis -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements dense data-flow analysis using the data-flow analysis
+// framework. The analysis is forward and conditional and uses the results of
+// dead code analysis to prune dead code during the analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
+#define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
+
+#include "mlir/Analysis/DataFlowFramework.h"
+
+namespace mlir {
+
+class RegionBranchOpInterface;
+
+namespace dataflow {
+
+//===----------------------------------------------------------------------===//
+// AbstractDenseLattice
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dense lattice. A dense lattice is attached to
+/// operations to represent the program state after their execution or to blocks
+/// to represent the program state at the beginning of the block. A dense
+/// lattice is propagated through the IR by dense data-flow analysis.
+class AbstractDenseLattice : public AnalysisState {
+public:
+  /// A dense lattice can only be created for operations and blocks.
+  using AnalysisState::AnalysisState;
+
+  /// Join the lattice across control-flow or callgraph edges.
+  virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0;
+
+  /// Reset the dense lattice to a pessimistic value. This occurs when the
+  /// analysis cannot reason about the data-flow.
+  virtual ChangeResult reset() = 0;
+
+  /// Returns true if the lattice state has reached a pessimistic fixpoint. That
+  /// is, no further modifications to the lattice can occur.
+  virtual bool isAtFixpoint() const = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// AbstractDenseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for dense data-flow analyses. Dense data-flow analysis attaches a
+/// lattice between the execution of operations and implements a transfer
+/// function from the lattice before each operation to the lattice after. The
+/// lattice contains information about the state of the program at that point.
+///
+/// In this implementation, a lattice attached to an operation represents the
+/// state of the program after its execution, and a lattice attached to block
+/// represents the state of the program right before it starts executing its
+/// body.
+class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
+public:
+  using DataFlowAnalysis::DataFlowAnalysis;
+
+  /// Initialize the analysis by visiting every program point whose execution
+  /// may modify the program state; that is, every operation and block.
+  LogicalResult initialize(Operation *top) override;
+
+  /// Visit a program point that modifies the state of the program. If this is a
+  /// block, then the state is propagated from control-flow predecessors or
+  /// callsites. If this is a call operation or region control-flow operation,
+  /// then the state after the execution of the operation is set by control-flow
+  /// or the callgraph. Otherwise, this function invokes the operation transfer
+  /// function.
+  LogicalResult visit(ProgramPoint point) override;
+
+protected:
+  /// Propagate the dense lattice before the execution of an operation to the
+  /// lattice after its execution.
+  virtual void visitOperationImpl(Operation *op,
+                                  const AbstractDenseLattice &before,
+                                  AbstractDenseLattice *after) = 0;
+
+  /// Get the dense lattice after the execution of the given program point.
+  virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
+
+  /// Get the dense lattice after the execution of the given program point and
+  /// add it as a dependency to a program point.
+  const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
+                                            ProgramPoint point);
+
+  /// Mark the dense lattice as having reached its pessimistic fixpoint and
+  /// propagate an update if it changed.
+  void reset(AbstractDenseLattice *lattice) {
+    propagateIfChanged(lattice, lattice->reset());
+  }
+
+  /// Join a lattice with another and propagate an update if it changed.
+  void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) {
+    propagateIfChanged(lhs, lhs->join(rhs));
+  }
+
+private:
+  /// Visit an operation. If this is a call operation or region control-flow
+  /// operation, then the state after the execution of the operation is set by
+  /// control-flow or the callgraph. Otherwise, this function invokes the
+  /// operation transfer function.
+  void visitOperation(Operation *op);
+
+  /// Visit a block. The state at the start of the block is propagated from
+  /// control-flow predecessors or callsites
+  void visitBlock(Block *block);
+
+  /// Visit a program point within a region branch operation with predecessors
+  /// in it. This can either be an entry block of one of the regions of the
+  /// parent operation itself.
+  void visitRegionBranchOperation(ProgramPoint point,
+                                  RegionBranchOpInterface branch,
+                                  AbstractDenseLattice *after);
+};
+
+//===----------------------------------------------------------------------===//
+// DenseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// A dense (forward) data-flow analysis for propagating lattices before and
+/// after the execution of every operation across the IR by implementing
+/// transfer functions for operations.
+///
+/// `StateT` is expected to be a subclass of `AbstractDenseLattice`.
+template <typename LatticeT>
+class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
+  static_assert(
+      std::is_base_of<AbstractDenseLattice, LatticeT>::value,
+      "analysis state class expected to subclass AbstractDenseLattice");
+
+public:
+  using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis;
+
+  /// Visit an operation with the dense lattice before its execution. This
+  /// function is expected to set the dense lattice after its execution.
+  virtual void visitOperation(Operation *op, const LatticeT &before,
+                              LatticeT *after) = 0;
+
+protected:
+  /// Get the dense lattice after this program point.
+  LatticeT *getLattice(ProgramPoint point) override {
+    return getOrCreate<LatticeT>(point);
+  }
+
+private:
+  /// Type-erased wrappers that convert the abstract dense lattice to a derived
+  /// lattice and invoke the virtual hooks operating on the derived lattice.
+  void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
+                          AbstractDenseLattice *after) override {
+    visitOperation(op, static_cast<const LatticeT &>(before),
+                   static_cast<LatticeT *>(after));
+  }
+};
+
+} // end namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H

diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 5907da0ef8da..6456f7d6cec2 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -16,10 +16,12 @@
 #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
 
 #include "mlir/Analysis/DataFlowFramework.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
 namespace mlir {
+
+class RegionBranchOpInterface;
+
 namespace dataflow {
 
 //===----------------------------------------------------------------------===//
@@ -80,11 +82,10 @@ class AbstractSparseLattice : public AnalysisState {
 template <typename ValueT>
 class Lattice : public AbstractSparseLattice {
 public:
-  using AbstractSparseLattice::AbstractSparseLattice;
-
-  /// Get a lattice element with a known value.
-  Lattice(const ValueT &knownValue = ValueT())
-      : AbstractSparseLattice(Value()), knownValue(knownValue) {}
+  /// Construct a lattice with a known value.
+  explicit Lattice(Value value)
+      : AbstractSparseLattice(value),
+        knownValue(ValueT::getPessimisticValueState(value)) {}
 
   /// Return the value held by this lattice. This requires that the value is
   /// initialized.

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 2e37135adac4..efac97d665e7 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -12,6 +12,7 @@ set(LLVM_OPTIONAL_SOURCES
 
   DataFlow/ConstantPropagationAnalysis.cpp
   DataFlow/DeadCodeAnalysis.cpp
+  DataFlow/DenseAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
   )
 
@@ -30,6 +31,7 @@ add_mlir_library(MLIRAnalysis
 
   DataFlow/ConstantPropagationAnalysis.cpp
   DataFlow/DeadCodeAnalysis.cpp
+  DataFlow/DenseAnalysis.cpp
   DataFlow/SparseAnalysis.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
new file mode 100644
index 000000000000..4c0a16684ced
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -0,0 +1,172 @@
+//===- DenseAnalysis.cpp - Dense data-flow analysis -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+//===----------------------------------------------------------------------===//
+// AbstractDenseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) {
+  // Visit every operation and block.
+  visitOperation(top);
+  for (Region &region : top->getRegions()) {
+    for (Block &block : region) {
+      visitBlock(&block);
+      for (Operation &op : block)
+        if (failed(initialize(&op)))
+          return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) {
+  if (auto *op = point.dyn_cast<Operation *>())
+    visitOperation(op);
+  else if (auto *block = point.dyn_cast<Block *>())
+    visitBlock(block);
+  else
+    return failure();
+  return success();
+}
+
+void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
+  // If the containing block is not executable, bail out.
+  if (!getOrCreateFor<Executable>(op, op->getBlock())->isLive())
+    return;
+
+  // Get the dense lattice to update.
+  AbstractDenseLattice *after = getLattice(op);
+  if (after->isAtFixpoint())
+    return;
+
+  // If this op implements region control-flow, then control-flow dictates its
+  // transfer function.
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
+    return visitRegionBranchOperation(op, branch, after);
+
+  // If this is a call operation, then join its lattices across known return
+  // sites.
+  if (auto call = dyn_cast<CallOpInterface>(op)) {
+    const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
+    // If not all return sites are known, then conservatively assume we can't
+    // reason about the data-flow.
+    if (!predecessors->allPredecessorsKnown())
+      return reset(after);
+    for (Operation *predecessor : predecessors->getKnownPredecessors())
+      join(after, *getLatticeFor(op, predecessor));
+    return;
+  }
+
+  // Get the dense state before the execution of the op.
+  const AbstractDenseLattice *before;
+  if (Operation *prev = op->getPrevNode())
+    before = getLatticeFor(op, prev);
+  else
+    before = getLatticeFor(op, op->getBlock());
+  // If the incoming lattice is uninitialized, bail out.
+  if (before->isUninitialized())
+    return;
+
+  // Invoke the operation transfer function.
+  visitOperationImpl(op, *before, after);
+}
+
+void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
+  // If the block is not executable, bail out.
+  if (!getOrCreateFor<Executable>(block, block)->isLive())
+    return;
+
+  // Get the dense lattice to update.
+  AbstractDenseLattice *after = getLattice(block);
+  if (after->isAtFixpoint())
+    return;
+
+  // The dense lattices of entry blocks are set by region control-flow or the
+  // callgraph.
+  if (block->isEntryBlock()) {
+    // Check if this block is the entry block of a callable region.
+    auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
+    if (callable && callable.getCallableRegion() == block->getParent()) {
+      const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+      // If not all callsites are known, conservatively mark all lattices as
+      // having reached their pessimistic fixpoints.
+      if (!callsites->allPredecessorsKnown())
+        return reset(after);
+      for (Operation *callsite : callsites->getKnownPredecessors()) {
+        // Get the dense lattice before the callsite.
+        if (Operation *prev = callsite->getPrevNode())
+          join(after, *getLatticeFor(block, prev));
+        else
+          join(after, *getLatticeFor(block, callsite->getBlock()));
+      }
+      return;
+    }
+
+    // Check if we can reason about the control-flow.
+    if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp()))
+      return visitRegionBranchOperation(block, branch, after);
+
+    // Otherwise, we can't reason about the data-flow.
+    return reset(after);
+  }
+
+  // Join the state with the state after the block's predecessors.
+  for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
+       it != e; ++it) {
+    // Skip control edges that aren't executable.
+    Block *predecessor = *it;
+    if (!getOrCreateFor<Executable>(
+             block, getProgramPoint<CFGEdge>(predecessor, block))
+             ->isLive())
+      continue;
+
+    // Merge in the state from the predecessor's terminator.
+    join(after, *getLatticeFor(block, predecessor->getTerminator()));
+  }
+}
+
+void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation(
+    ProgramPoint point, RegionBranchOpInterface branch,
+    AbstractDenseLattice *after) {
+  // Get the terminator predecessors.
+  const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
+  assert(predecessors->allPredecessorsKnown() &&
+         "unexpected unresolved region successors");
+
+  for (Operation *op : predecessors->getKnownPredecessors()) {
+    const AbstractDenseLattice *before;
+    // If the predecessor is the parent, get the state before the parent.
+    if (op == branch) {
+      if (Operation *prev = op->getPrevNode())
+        before = getLatticeFor(point, prev);
+      else
+        before = getLatticeFor(point, op->getBlock());
+
+      // Otherwise, get the state after the terminator.
+    } else {
+      before = getLatticeFor(point, op);
+    }
+    join(after, *before);
+  }
+}
+
+const AbstractDenseLattice *
+AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
+                                             ProgramPoint point) {
+  AbstractDenseLattice *state = getLattice(point);
+  addDependency(state, dependent);
+  return state;
+}

diff  --git a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
new file mode 100644
index 000000000000..c1fdf82e4bc7
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: test_tag: test_callsite
+// CHECK: operand #0
+// CHECK-NEXT: - a
+func.func private @single_callsite_fn(%ptr: memref<i32>) -> memref<i32> {
+  return {tag = "test_callsite"} %ptr : memref<i32>
+}
+
+func.func @test_callsite() {
+  %ptr = memref.alloc() : memref<i32>
+  %c0 = arith.constant 0 : i32
+  memref.store %c0, %ptr[] {tag_name = "a"} : memref<i32>
+  %0 = func.call @single_callsite_fn(%ptr) : (memref<i32>) -> memref<i32>
+  return
+}
+
+// CHECK-LABEL: test_tag: test_return_site
+// CHECK: operand #0
+// CHECK-NEXT: - b
+func.func private @single_return_site_fn(%ptr: memref<i32>) -> memref<i32> {
+  %c0 = arith.constant 0 : i32
+  memref.store %c0, %ptr[] {tag_name = "b"} : memref<i32>
+  return %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_multiple_callsites
+// CHECK: operand #0
+// CHECK-NEXT: write0
+// CHECK-NEXT: write1
+func.func @test_return_site(%ptr: memref<i32>) -> memref<i32> {
+  %0 = func.call @single_return_site_fn(%ptr) : (memref<i32>) -> memref<i32>
+  return {tag = "test_return_site"} %0 : memref<i32>
+}
+
+func.func private @multiple_callsite_fn(%ptr: memref<i32>) -> memref<i32> {
+  return {tag = "test_multiple_callsites"} %ptr : memref<i32>
+}
+
+func.func @test_multiple_callsites(%a: i32, %ptr: memref<i32>) -> memref<i32> {
+  memref.store %a, %ptr[] {tag_name = "write0"} : memref<i32>
+  %0 = func.call @multiple_callsite_fn(%ptr) : (memref<i32>) -> memref<i32>
+  memref.store %a, %ptr[] {tag_name = "write1"} : memref<i32>
+  %1 = func.call @multiple_callsite_fn(%ptr) : (memref<i32>) -> memref<i32>
+  return %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_multiple_return_sites
+// CHECK: operand #0
+// CHECK-NEXT: return0
+// CHECK-NEXT: return1
+func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
+  cf.cond_br %cond, ^a, ^b
+
+^a:
+  memref.store %a, %ptr[] {tag_name = "return0"} : memref<i32>
+  return %ptr : memref<i32>
+
+^b:
+  memref.store %a, %ptr[] {tag_name = "return1"} : memref<i32>
+  return %ptr : memref<i32>
+}
+
+func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
+  %0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref<i32>) -> memref<i32>
+  return {tag = "test_multiple_return_sites"} %0 : memref<i32>
+}
\ No newline at end of file

diff  --git a/mlir/test/Analysis/DataFlow/test-last-modified.mlir b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
new file mode 100644
index 000000000000..69fb7125f0c5
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-last-modified.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: test_tag: test_simple_mod
+// CHECK: operand #0
+// CHECK-NEXT: - a
+// CHECK: operand #1
+// CHECK-NEXT: - b
+func.func @test_simple_mod(%arg0: memref<i32>, %arg1: memref<i32>) -> (memref<i32>, memref<i32>) {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  memref.store %c0, %arg0[] {tag_name = "a"} : memref<i32>
+  memref.store %c1, %arg1[] {tag_name = "b"} : memref<i32>
+  return {tag = "test_simple_mod"} %arg0, %arg1 : memref<i32>, memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_simple_mod_overwrite_a
+// CHECK: operand #1
+// CHECK-NEXT: - a
+// CHECK-LABEL: test_tag: test_simple_mod_overwrite_b
+// CHECK: operand #0
+// CHECK-NEXT: - b
+func.func @test_simple_mod_overwrite(%arg0: memref<i32>) -> memref<i32> {
+  %c0 = arith.constant 0 : i32
+  memref.store %c0, %arg0[] {tag = "test_simple_mod_overwrite_a", tag_name = "a"} : memref<i32>
+  %c1 = arith.constant 1 : i32
+  memref.store %c1, %arg0[] {tag_name = "b"} : memref<i32>
+  return {tag = "test_simple_mod_overwrite_b"} %arg0 : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_control_flow
+// CHECK: operand #0
+// CHECK-NEXT: - b
+// CHECK-NEXT: - a
+func.func @test_mod_control_flow(%cond: i1, %ptr: memref<i32>) -> memref<i32> {
+  cf.cond_br %cond, ^a, ^b
+
+^a:
+  %c0 = arith.constant 0 : i32
+  memref.store %c0, %ptr[] {tag_name = "a"} : memref<i32>
+  cf.br ^c
+
+^b:
+  %c1 = arith.constant 1 : i32
+  memref.store %c1, %ptr[] {tag_name = "b"} : memref<i32>
+  cf.br ^c
+
+^c:
+  return {tag = "test_mod_control_flow"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_dead_branch
+// CHECK: operand #0
+// CHECK-NEXT: - a
+func.func @test_mod_dead_branch(%arg: i32, %ptr: memref<i32>) -> memref<i32> {
+  %0 = arith.subi %arg, %arg : i32
+  %1 = arith.constant -1 : i32
+  %2 = arith.cmpi sgt, %0, %1 : i32
+  cf.cond_br %2, ^a, ^b
+
+^a:
+  %c0 = arith.constant 0 : i32
+  memref.store %c0, %ptr[] {tag_name = "a"} : memref<i32>
+  cf.br ^c
+
+^b:
+  %c1 = arith.constant 1 : i32
+  memref.store %c1, %ptr[] {tag_name = "b"} : memref<i32>
+  cf.br ^c
+
+^c:
+  return {tag = "test_mod_dead_branch"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_region_control_flow
+// CHECK: operand #0
+// CHECK-NEXT: then
+// CHECK-NEXT: else
+func.func @test_mod_region_control_flow(%cond: i1, %ptr: memref<i32>) -> memref<i32> {
+  scf.if %cond {
+    %c0 = arith.constant 0 : i32
+    memref.store %c0, %ptr[] {tag_name = "then"}: memref<i32>
+  } else {
+    %c1 = arith.constant 1 : i32
+    memref.store %c1, %ptr[] {tag_name = "else"} : memref<i32>
+  }
+  return {tag = "test_mod_region_control_flow"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: test_mod_dead_region
+// CHECK: operand #0
+// CHECK-NEXT: else
+func.func @test_mod_dead_region(%ptr: memref<i32>) -> memref<i32> {
+  %false = arith.constant false
+  scf.if %false {
+    %c0 = arith.constant 0 : i32
+    memref.store %c0, %ptr[] {tag_name = "then"}: memref<i32>
+  } else {
+    %c1 = arith.constant 1 : i32
+    memref.store %c1, %ptr[] {tag_name = "else"} : memref<i32>
+  }
+  return {tag = "test_mod_dead_region"} %ptr : memref<i32>
+}
+
+// CHECK-LABEL: test_tag: unknown_memory_effects_a
+// CHECK: operand #1
+// CHECK-NEXT: - a
+// CHECK-LABEL: test_tag: unknown_memory_effects_b
+// CHECK: operand #0
+// CHECK-NEXT: - <unknown>
+func.func @unknown_memory_effects(%ptr: memref<i32>) -> memref<i32> {
+  %c0 = arith.constant 0 : i32
+  memref.store %c0, %ptr[] {tag = "unknown_memory_effects_a", tag_name = "a"} : memref<i32>
+  "test.unknown_effects"() : () -> ()
+  return {tag = "unknown_memory_effects_b"} %ptr : memref<i32>
+}

diff  --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 8e9446c78bfc..80572f9501fe 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTestAnalysis
   TestSlice.cpp
 
   DataFlow/TestDeadCodeAnalysis.cpp
+  DataFlow/TestDenseDataFlowAnalysis.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
new file mode 100644
index 000000000000..669e683aca1d
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -0,0 +1,278 @@
+//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+/// This lattice represents a single underlying value for an SSA value.
+class UnderlyingValue {
+public:
+  /// The pessimistic underlying value of a value is itself.
+  static UnderlyingValue getPessimisticValueState(Value value) {
+    return {value};
+  }
+
+  /// Create an underlying value state with a known underlying value.
+  UnderlyingValue(Value underlyingValue = {})
+      : underlyingValue(underlyingValue) {}
+
+  /// Returns the underlying value.
+  Value getUnderlyingValue() const { return underlyingValue; }
+
+  /// Join two underlying values. If there are conflicting underlying values,
+  /// go to the pessimistic value.
+  static UnderlyingValue join(const UnderlyingValue &lhs,
+                              const UnderlyingValue &rhs) {
+    return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
+  }
+
+  /// Compare underlying values.
+  bool operator==(const UnderlyingValue &rhs) const {
+    return underlyingValue == rhs.underlyingValue;
+  }
+
+  void print(raw_ostream &os) const { os << underlyingValue; }
+
+private:
+  Value underlyingValue;
+};
+
+/// This lattice represents, for a given memory resource, the potential last
+/// operations that modified the resource.
+class LastModification : public AbstractDenseLattice {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
+
+  using AbstractDenseLattice::AbstractDenseLattice;
+
+  /// The lattice is always initialized.
+  bool isUninitialized() const override { return false; }
+
+  /// Initialize the lattice. Does nothing.
+  ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+
+  /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
+  /// last modifications of all memory resources are unknown.
+  ChangeResult reset() override {
+    if (lastMods.empty())
+      return ChangeResult::NoChange;
+    lastMods.clear();
+    return ChangeResult::Change;
+  }
+
+  /// The lattice is never at a fixpoint.
+  bool isAtFixpoint() const override { return false; }
+
+  /// Join the last modifications.
+  ChangeResult join(const AbstractDenseLattice &lattice) override {
+    const auto &rhs = static_cast<const LastModification &>(lattice);
+    ChangeResult result = ChangeResult::NoChange;
+    for (const auto &mod : rhs.lastMods) {
+      auto &lhsMod = lastMods[mod.first];
+      if (lhsMod != mod.second) {
+        lhsMod.insert(mod.second.begin(), mod.second.end());
+        result |= ChangeResult::Change;
+      }
+    }
+    return result;
+  }
+
+  /// Set the last modification of a value.
+  ChangeResult set(Value value, Operation *op) {
+    auto &lastMod = lastMods[value];
+    ChangeResult result = ChangeResult::NoChange;
+    if (lastMod.size() != 1 || *lastMod.begin() != op) {
+      result = ChangeResult::Change;
+      lastMod.clear();
+      lastMod.insert(op);
+    }
+    return result;
+  }
+
+  /// Get the last modifications of a value. Returns none if the last
+  /// modifications are not known.
+  Optional<ArrayRef<Operation *>> getLastModifiers(Value value) const {
+    auto it = lastMods.find(value);
+    if (it == lastMods.end())
+      return {};
+    return it->second.getArrayRef();
+  }
+
+  void print(raw_ostream &os) const override {
+    for (const auto &lastMod : lastMods) {
+      os << lastMod.first << ":\n";
+      for (Operation *op : lastMod.second)
+        os << "  " << *op << "\n";
+    }
+  }
+
+private:
+  /// The potential last modifications of a memory resource. Use a set vector to
+  /// keep the results deterministic.
+  DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
+                            SmallPtrSet<Operation *, 2>>>
+      lastMods;
+};
+
+class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
+public:
+  using DenseDataFlowAnalysis::DenseDataFlowAnalysis;
+
+  /// Visit an operation. If the operation has no memory effects, then the state
+  /// is propagated with no change. If the operation allocates a resource, then
+  /// its reaching definitions is set to empty. If the operation writes to a
+  /// resource, then its reaching definition is set to the written value.
+  void visitOperation(Operation *op, const LastModification &before,
+                      LastModification *after) override;
+};
+
+/// Define the lattice class explicitly to provide a type ID.
+struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
+  using Lattice::Lattice;
+};
+
+/// An analysis that uses forwarding of values along control-flow and callgraph
+/// edges to determine single underlying values for block arguments. This
+/// analysis exists so that the test analysis and pass can test the behaviour of
+/// the dense data-flow analysis on the callgraph.
+class UnderlyingValueAnalysis
+    : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
+public:
+  using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+
+  /// The underlying value of the results of an operation are not known.
+  void visitOperation(Operation *op,
+                      ArrayRef<const UnderlyingValueLattice *> operands,
+                      ArrayRef<UnderlyingValueLattice *> results) override {
+    markAllPessimisticFixpoint(results);
+  }
+};
+} // end anonymous namespace
+
+/// Look for the most underlying value of a value.
+static Value getMostUnderlyingValue(
+    Value value,
+    function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
+  const UnderlyingValueLattice *underlying;
+  do {
+    underlying = getUnderlyingValueFn(value);
+    if (!underlying || underlying->isUninitialized())
+      return {};
+    Value underlyingValue = underlying->getValue().getUnderlyingValue();
+    if (underlyingValue == value)
+      break;
+    value = underlyingValue;
+  } while (true);
+  return value;
+}
+
+void LastModifiedAnalysis::visitOperation(Operation *op,
+                                          const LastModification &before,
+                                          LastModification *after) {
+  auto memory = dyn_cast<MemoryEffectOpInterface>(op);
+  // If we can't reason about the memory effects, then conservatively assume we
+  // can't deduce anything about the last modifications.
+  if (!memory)
+    return reset(after);
+
+  SmallVector<MemoryEffects::EffectInstance> effects;
+  memory.getEffects(effects);
+
+  ChangeResult result = after->join(before);
+  for (const auto &effect : effects) {
+    Value value = effect.getValue();
+
+    // If we see an effect on anything other than a value, assume we can't
+    // deduce anything about the last modifications.
+    if (!value)
+      return reset(after);
+
+    value = getMostUnderlyingValue(value, [&](Value value) {
+      return getOrCreateFor<UnderlyingValueLattice>(op, value);
+    });
+    if (!value)
+      return;
+
+    // Nothing to do for reads.
+    if (isa<MemoryEffects::Read>(effect.getEffect()))
+      continue;
+
+    result |= after->set(value, op);
+  }
+  propagateIfChanged(after, result);
+}
+
+namespace {
+struct TestLastModifiedPass
+    : public PassWrapper<TestLastModifiedPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
+
+  StringRef getArgument() const override { return "test-last-modified"; }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<LastModifiedAnalysis>();
+    solver.load<UnderlyingValueAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    raw_ostream &os = llvm::errs();
+
+    op->walk([&](Operation *op) {
+      auto tag = op->getAttrOfType<StringAttr>("tag");
+      if (!tag)
+        return;
+      os << "test_tag: " << tag.getValue() << ":\n";
+      const LastModification *lastMods =
+          solver.lookupState<LastModification>(op);
+      assert(lastMods && "expected a dense lattice");
+      for (auto &it : llvm::enumerate(op->getOperands())) {
+        os << " operand #" << it.index() << "\n";
+        Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
+          return solver.lookupState<UnderlyingValueLattice>(value);
+        });
+        assert(value && "expected an underlying value");
+        if (Optional<ArrayRef<Operation *>> lastMod =
+                lastMods->getLastModifiers(value)) {
+          for (Operation *lastModifier : *lastMod) {
+            if (auto tagName =
+                    lastModifier->getAttrOfType<StringAttr>("tag_name")) {
+              os << "  - " << tagName.getValue() << "\n";
+            } else {
+              os << "  - " << lastModifier->getName() << "\n";
+            }
+          }
+        } else {
+          os << "  - <unknown>\n";
+        }
+      }
+    });
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestLastModifiedPass() {
+  PassRegistration<TestLastModifiedPass>();
+}
+} // end namespace test
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ecd85a7f3aaa..1b81fee9f86e 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -86,6 +86,7 @@ void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
 void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
+void registerTestLastModifiedPass();
 void registerTestLinalgCodegenStrategy();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgFusionTransforms();
@@ -185,6 +186,7 @@ void registerTestPasses() {
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
+  mlir::test::registerTestLastModifiedPass();
   mlir::test::registerTestLinalgCodegenStrategy();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgFusionTransforms();


        


More information about the Mlir-commits mailing list