[Mlir-commits] [mlir] 4e98d61 - [mlir] Implement backward dataflow.

Alex Zinenko llvmlistbot at llvm.org
Tue Dec 13 09:35:36 PST 2022


Author: Matthias Kramm
Date: 2022-12-13T18:35:27+01:00
New Revision: 4e98d611ef67e46e60743b8429bad5eb531a1e7c

URL: https://github.com/llvm/llvm-project/commit/4e98d611ef67e46e60743b8429bad5eb531a1e7c
DIFF: https://github.com/llvm/llvm-project/commit/4e98d611ef67e46e60743b8429bad5eb531a1e7c.diff

LOG: [mlir] Implement backward dataflow.

This enables interprocedural lifeness analysis, very busy expression
analysis, etc.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D138935

Added: 
    mlir/test/Analysis/DataFlow/test-written-to.mlir
    mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp

Modified: 
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/test/lib/Analysis/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 62d32b21ac656..a178c6024fdee 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -16,6 +16,7 @@
 #define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
 
 #include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -39,7 +40,15 @@ class AbstractSparseLattice : public AnalysisState {
 
   /// Join the information contained in 'rhs' into this lattice. Returns
   /// if the value of the lattice changed.
-  virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
+  virtual ChangeResult join(const AbstractSparseLattice &rhs) {
+    return ChangeResult::NoChange;
+  }
+
+  /// Meet (intersect) the information in this lattice with 'rhs'. Returns
+  /// if the value of the lattice changed.
+  virtual ChangeResult meet(const AbstractSparseLattice &rhs) {
+    return ChangeResult::NoChange;
+  }
 
   /// When the lattice gets updated, propagate an update to users of the value
   /// using its use-def chain to subscribed analyses.
@@ -86,14 +95,18 @@ class Lattice : public AbstractSparseLattice {
     return const_cast<Lattice<ValueT> *>(this)->getValue();
   }
 
+  using LatticeT = Lattice<ValueT>;
+
   /// Join the information contained in the 'rhs' lattice into this
   /// lattice. Returns if the state of the current lattice changed.
   ChangeResult join(const AbstractSparseLattice &rhs) override {
-    const Lattice<ValueT> &rhsLattice =
-        static_cast<const Lattice<ValueT> &>(rhs);
+    return join(static_cast<const LatticeT &>(rhs).getValue());
+  }
 
-    // Join the rhs value into this lattice.
-    return join(rhsLattice.getValue());
+  /// Meet (intersect) the information contained in the 'rhs' lattice with
+  /// this lattice. Returns if the state of the current lattice changed.
+  ChangeResult meet(const AbstractSparseLattice &rhs) override {
+    return meet(static_cast<const LatticeT &>(rhs).getValue());
   }
 
   /// Join the information contained in the 'rhs' value into this
@@ -114,6 +127,38 @@ class Lattice : public AbstractSparseLattice {
     return ChangeResult::Change;
   }
 
+  /// Trait to check if `T` provides a `meet` method. Needed since for forward
+  /// analysis, lattices will only have a `join`, no `meet`, but we want to use
+  /// the same `Lattice` class for both directions.
+  template <typename T, typename... Args>
+  using has_meet = decltype(std::declval<T>().meet());
+  template <typename T>
+  using lattice_has_meet = llvm::is_detected<has_meet, T>;
+
+  /// Meet (intersect) the information contained in the 'rhs' value with this
+  /// lattice. Returns if the state of the current lattice changed.  If the
+  /// lattice elements don't have a `meet` method, this is a no-op (see below.)
+  template <typename VT, std::enable_if_t<lattice_has_meet<VT>::value>>
+  ChangeResult meet(const VT &rhs) {
+    ValueT newValue = ValueT::meet(value, rhs);
+    assert(ValueT::meet(newValue, value) == newValue &&
+           "expected `meet` to be monotonic");
+    assert(ValueT::meet(newValue, rhs) == newValue &&
+           "expected `meet` to be monotonic");
+
+    // Update the current optimistic value if something changed.
+    if (newValue == value)
+      return ChangeResult::NoChange;
+
+    value = newValue;
+    return ChangeResult::Change;
+  }
+
+  template <typename VT>
+  ChangeResult meet(const VT &rhs) {
+    return ChangeResult::NoChange;
+  }
+
   /// Print the lattice element.
   void print(raw_ostream &os) const override { value.print(os); }
 
@@ -289,6 +334,135 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// AbstractSparseBackwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for sparse (backward) data-flow analyses. Similar to
+/// AbstractSparseDataFlowAnalysis, but walks bottom to top.
+class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
+public:
+  /// Initialize the analysis by visiting the operation and everything nested
+  /// under it.
+  LogicalResult initialize(Operation *top) override;
+
+  /// Visit a program point. If this is a call operation or an operation with
+  /// block or region control-flow, then operand lattices are set accordingly.
+  /// Otherwise, invokes the operation transfer function (`visitOperationImpl`).
+  LogicalResult visit(ProgramPoint point) override;
+
+protected:
+  explicit AbstractSparseBackwardDataFlowAnalysis(
+      DataFlowSolver &solver, SymbolTableCollection &symbolTable);
+
+  /// The operation transfer function. Given the result lattices, this
+  /// function is expected to set the operand lattices.
+  virtual void visitOperationImpl(
+      Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
+      ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+
+  // Visit operands on branch instructions that are not forwarded
+  virtual void visitBranchOperand(OpOperand &operand) = 0;
+
+  /// Set the given lattice element(s) at control flow exit point(s).
+  virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
+
+  /// Set the given lattice element(s) at control flow exit point(s).
+  void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
+
+  /// Get the lattice element for a value.
+  virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
+
+  /// Get the lattice elements for a range of values.
+  SmallVector<AbstractSparseLattice *> getLatticeElements(ValueRange values);
+
+  /// Join the lattice element and propagate and update if it changed.
+  void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+
+private:
+  /// Recursively initialize the analysis on nested operations and blocks.
+  LogicalResult initializeRecursively(Operation *op);
+
+  /// Visit an operation. If this is a call operation or an operation with
+  /// region control-flow, then its operand lattices are set accordingly.
+  /// Otherwise, the operation transfer function is invoked.
+  void visitOperation(Operation *op);
+
+  /// Visit a block.
+  void visitBlock(Block *block);
+
+  /// Visit an op with regions (like e.g. `scf.while`)
+  void visitRegionSuccessors(RegionBranchOpInterface branch,
+                             ArrayRef<AbstractSparseLattice *> operands);
+
+  /// Get the lattice element for a value, and also set up
+  /// dependencies so that the analysis on the given ProgramPoint is re-invoked
+  /// if the value changes.
+  const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
+                                                    Value value);
+
+  /// Get the lattice elements for a range of values, and also set up
+  /// dependencies so that the analysis on the given ProgramPoint is re-invoked
+  /// if any of the values change.
+  SmallVector<const AbstractSparseLattice *>
+  getLatticeElementsFor(ProgramPoint point, ValueRange values);
+
+  SymbolTableCollection &symbolTable;
+};
+
+//===----------------------------------------------------------------------===//
+// SparseBackwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// A sparse (backward) data-flow analysis for propagating SSA value lattices
+/// backwards across the IR by implementing transfer functions for operations.
+///
+/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+template <typename StateT>
+class SparseBackwardDataFlowAnalysis
+    : public AbstractSparseBackwardDataFlowAnalysis {
+public:
+  explicit SparseBackwardDataFlowAnalysis(DataFlowSolver &solver,
+                                          SymbolTableCollection &symbolTable)
+      : AbstractSparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+
+  /// Visit an operation with the lattices of its results. This function is
+  /// expected to set the lattices of the operation's operands.
+  virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
+                              ArrayRef<const StateT *> results) = 0;
+
+protected:
+  /// Get the lattice element for a value.
+  StateT *getLatticeElement(Value value) override {
+    return getOrCreate<StateT>(value);
+  }
+
+  /// Set the given lattice element(s) at control flow exit point(s).
+  virtual void setToExitState(StateT *lattice) = 0;
+  void setToExitState(AbstractSparseLattice *lattice) override {
+    return setToExitState(reinterpret_cast<StateT *>(lattice));
+  }
+  void setAllToExitStates(ArrayRef<StateT *> lattices) {
+    AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
+        {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
+         lattices.size()});
+  }
+
+private:
+  /// Type-erased wrappers that convert the abstract lattice operands to derived
+  /// lattices and invoke the virtual hooks operating on the derived lattices.
+  void visitOperationImpl(
+      Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
+      ArrayRef<const AbstractSparseLattice *> resultLattices) override {
+    visitOperation(
+        op,
+        {reinterpret_cast<StateT *const *>(operandLattices.begin()),
+         operandLattices.size()},
+        {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
+         resultLattices.size()});
+  }
+};
+
 } // end namespace dataflow
 } // end namespace mlir
 

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 1ea60b45700cf..c5d2ac4ca01ff 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -41,7 +41,7 @@ LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
     if (region.empty())
       continue;
     for (Value argument : region.front().getArguments())
-      setAllToEntryStates(getLatticeElement(argument));
+      setToEntryState(getLatticeElement(argument));
   }
 
   return initializeRecursively(top);
@@ -281,3 +281,271 @@ void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
                                           const AbstractSparseLattice &rhs) {
   propagateIfChanged(lhs, lhs->join(rhs));
 }
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseBackwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
+    DataFlowSolver &solver, SymbolTableCollection &symbolTable)
+    : DataFlowAnalysis(solver), symbolTable(symbolTable) {
+  registerPointKind<CFGEdge>();
+}
+
+LogicalResult
+AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
+  return initializeRecursively(top);
+}
+
+LogicalResult
+AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+  visitOperation(op);
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region) {
+      getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+      // Initialize ops in reverse order, so we can do as much initial
+      // propagation as possible without having to go through the
+      // solver queue.
+      for (auto it = block.rbegin(); it != block.rend(); it++)
+        if (failed(initializeRecursively(&*it)))
+          return failure();
+    }
+  }
+  return success();
+}
+
+LogicalResult
+AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
+  if (Operation *op = point.dyn_cast<Operation *>())
+    visitOperation(op);
+  else if (Block *block = point.dyn_cast<Block *>())
+    // For backward dataflow, we don't have to do any work for the blocks
+    // themselves. CFG edges between blocks are processed by the BranchOp
+    // logic in `visitOperation`, and entry blocks for functions are tied
+    // to the CallOp arguments by visitOperation.
+    return success();
+  else
+    return failure();
+  return success();
+}
+
+SmallVector<AbstractSparseLattice *>
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
+  SmallVector<AbstractSparseLattice *> resultLattices;
+  resultLattices.reserve(values.size());
+  for (Value result : values) {
+    AbstractSparseLattice *resultLattice = getLatticeElement(result);
+    resultLattices.push_back(resultLattice);
+  }
+  return resultLattices;
+}
+
+SmallVector<const AbstractSparseLattice *>
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
+    ProgramPoint point, ValueRange values) {
+  SmallVector<const AbstractSparseLattice *> resultLattices;
+  resultLattices.reserve(values.size());
+  for (Value result : values) {
+    const AbstractSparseLattice *resultLattice =
+        getLatticeElementFor(point, result);
+    resultLattices.push_back(resultLattice);
+  }
+  return resultLattices;
+}
+
+static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
+  return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+  // If we're in a dead block, bail out.
+  if (!getOrCreate<Executable>(op->getBlock())->isLive())
+    return;
+
+  SmallVector<AbstractSparseLattice *> operandLattices =
+      getLatticeElements(op->getOperands());
+  SmallVector<const AbstractSparseLattice *> resultLattices =
+      getLatticeElementsFor(op, op->getResults());
+
+  // Block arguments of region branch operations flow back into the operands
+  // of the parent op
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+    visitRegionSuccessors(branch, operandLattices);
+    return;
+  }
+
+  if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+    // Block arguments of successor blocks flow back into our operands.
+
+    // We remember all operands not forwarded to any block in a BitVector.
+    // We can't just cut out a range here, since the non-forwarded ops might
+    // be non-contiguous (if there's more than one successor).
+    BitVector unaccounted(op->getNumOperands(), true);
+
+    for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
+      SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
+      OperandRange forwarded = successorOperands.getForwardedOperands();
+      if (forwarded.size()) {
+        MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
+            forwarded.getBeginOperandIndex(), forwarded.size());
+        for (OpOperand &operand : operands) {
+          unaccounted.reset(operand.getOperandNumber());
+          if (Optional<BlockArgument> blockArg =
+                  detail::getBranchSuccessorArgument(
+                      successorOperands, operand.getOperandNumber(), block)) {
+            meet(getLatticeElement(operand.get()),
+                 *getLatticeElementFor(op, *blockArg));
+          }
+        }
+      }
+    }
+    // Operands not forwarded to successor blocks are typically parameters
+    // of the branch operation itself (for example the boolean for if/else).
+    for (int index : unaccounted.set_bits()) {
+      OpOperand &operand = op->getOpOperand(index);
+      visitBranchOperand(operand);
+    }
+    return;
+  }
+
+  // For function calls, connect the arguments of the entry blocks
+  // to the operands of the call op.
+  if (auto call = dyn_cast<CallOpInterface>(op)) {
+    Operation *callableOp = call.resolveCallable(&symbolTable);
+    if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+      Region *region = callable.getCallableRegion();
+      if (!region->empty()) {
+        Block &block = region->front();
+        for (auto [blockArg, operand] :
+             llvm::zip(block.getArguments(), operandLattices)) {
+          meet(operand, *getLatticeElementFor(op, blockArg));
+        }
+      }
+      return;
+    }
+  }
+
+  // The block arguments of the branched to region flow back into the
+  // operands of the yield operation.
+  if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+    if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+      SmallVector<RegionSuccessor> successors;
+      SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
+      branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+                                 operands, successors);
+      // All operands not forwarded to any successor. This set can be
+      // non-contiguous in the presence of multiple successors.
+      BitVector unaccounted(op->getNumOperands(), true);
+
+      for (const RegionSuccessor &successor : successors) {
+        ValueRange inputs = successor.getSuccessorInputs();
+        Region *region = successor.getSuccessor();
+        OperandRange operands =
+            region ? terminator.getSuccessorOperands(region->getRegionNumber())
+                   : terminator.getSuccessorOperands({});
+        MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
+        for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) {
+          meet(getLatticeElement(opoperand.get()),
+               *getLatticeElementFor(op, input));
+          unaccounted.reset(
+              const_cast<OpOperand &>(opoperand).getOperandNumber());
+        }
+      }
+      // Visit operands of the branch op not forwarded to the next region.
+      // (Like e.g. the boolean of `scf.conditional`)
+      for (int index : unaccounted.set_bits()) {
+        visitBranchOperand(op->getOpOperand(index));
+      }
+      return;
+    }
+  }
+
+  // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`,
+  // since they behave like a return in the sense that they forward to the
+  // results of some other (here: the parent) op.
+  if (op->hasTrait<OpTrait::ReturnLike>()) {
+    if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+      OperandRange operands = op->getOperands();
+      ResultRange results = op->getParentOp()->getResults();
+      assert(results.size() == operands.size() &&
+             "Can't derive arg mapping for yield-like op.");
+      for (auto [operand, result] : llvm::zip(operands, results))
+        meet(getLatticeElement(operand), *getLatticeElementFor(op, result));
+      return;
+    }
+
+    // Going backwards, the operands of the return are derived from the
+    // results of all CallOps calling this CallableOp.
+    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+      const PredecessorState *callsites =
+          getOrCreateFor<PredecessorState>(op, callable);
+      if (callsites->allPredecessorsKnown()) {
+        for (Operation *call : callsites->getKnownPredecessors()) {
+          SmallVector<const AbstractSparseLattice *> callResultLattices =
+              getLatticeElementsFor(op, call->getResults());
+          for (auto [op, result] :
+               llvm::zip(operandLattices, callResultLattices))
+            meet(op, *result);
+        }
+      } else {
+        // If we don't know all the callers, we can't know where the
+        // returned values go. Note that, in particular, this will trigger
+        // for the return ops of any public functions.
+        setAllToExitStates(operandLattices);
+      }
+      return;
+    }
+  }
+
+  visitOperationImpl(op, operandLattices, resultLattices);
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
+    RegionBranchOpInterface branch,
+    ArrayRef<AbstractSparseLattice *> operandLattices) {
+  Operation *op = branch.getOperation();
+  SmallVector<RegionSuccessor> successors;
+  SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
+  branch.getSuccessorRegions(/*index=*/{}, operands, successors);
+
+  // All operands not forwarded to any successor. This set can be non-contiguous
+  // in the presence of multiple successors.
+  BitVector unaccounted(op->getNumOperands(), true);
+
+  for (RegionSuccessor &successor : successors) {
+    Region *region = successor.getSuccessor();
+    OperandRange operands =
+        region ? branch.getSuccessorEntryOperands(region->getRegionNumber())
+               : branch.getSuccessorEntryOperands({});
+    MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
+    ValueRange inputs = successor.getSuccessorInputs();
+    for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
+      meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
+      unaccounted.reset(operand.getOperandNumber());
+    }
+  }
+  // All operands not forwarded to regions are typically parameters of the
+  // branch operation itself (for example the boolean for if/else).
+  for (int index : unaccounted.set_bits()) {
+    visitBranchOperand(op->getOpOperand(index));
+  }
+}
+
+const AbstractSparseLattice *
+AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
+                                                             Value value) {
+  AbstractSparseLattice *state = getLatticeElement(value);
+  addDependency(state, point);
+  return state;
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
+    ArrayRef<AbstractSparseLattice *> lattices) {
+  for (AbstractSparseLattice *lattice : lattices)
+    setToExitState(lattice);
+}
+
+void AbstractSparseBackwardDataFlowAnalysis::meet(
+    AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
+  propagateIfChanged(lhs, lhs->meet(rhs));
+}

diff  --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir
new file mode 100644
index 0000000000000..09c9212fa216c
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: test_tag: constant0
+// CHECK: result #0: [a]
+// CHECK-LABEL: test_tag: constant1
+// CHECK: result #0: [b]
+func.func @test_two_writes(%m0: memref<i32>, %m1: memref<i32>) -> (memref<i32>, memref<i32>) {
+  %c0 = arith.constant {tag = "constant0"} 0 : i32
+  %c1 = arith.constant {tag = "constant1"} 1 : i32
+  memref.store %c0, %m0[] {tag_name = "a"} : memref<i32>
+  memref.store %c1, %m1[] {tag_name = "b"} : memref<i32>
+  return %m0, %m1 : memref<i32>, memref<i32>
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: c0
+// CHECK: result #0: [b]
+// CHECK-LABEL: test_tag: c1
+// CHECK: result #0: [b]
+// CHECK-LABEL: test_tag: condition
+// CHECK: result #0: [brancharg0]
+// CHECK-LABEL: test_tag: c2
+// CHECK: result #0: [a]
+// CHECK-LABEL: test_tag: c3
+// CHECK: result #0: [a]
+func.func @test_if(%m0: memref<i32>, %m1: memref<i32>, %condition: i1) {
+  %c0 = arith.constant {tag = "c0"} 2 : i32
+  %c1 = arith.constant {tag = "c1"} 3 : i32
+  %condition2 = arith.addi %condition, %condition {tag = "condition"} : i1
+  %0, %1 = scf.if %condition2 -> (i32, i32) {
+    %c2 = arith.constant {tag = "c2"} 0 : i32
+    scf.yield %c2, %c0: i32, i32
+  } else {
+    %c3 = arith.constant {tag = "c3"} 1 : i32
+    scf.yield %c3, %c1: i32, i32
+  }
+  memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+  memref.store %1, %m1[] {tag_name = "b"} : memref<i32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: c0
+// CHECK: result #0: [a c]
+// CHECK-LABEL: test_tag: c1
+// CHECK: result #0: [b c]
+// CHECK-LABEL: test_tag: br
+// CHECK: operand #0: [brancharg0]
+func.func @test_blocks(%m0: memref<i32>,
+                       %m1: memref<i32>,
+                       %m2: memref<i32>, %cond : i1) {
+  %0 = arith.constant {tag = "c0"} 0 : i32
+  %1 = arith.constant {tag = "c1"} 1 : i32
+  cf.cond_br %cond, ^a(%0: i32), ^b(%1: i32) {tag = "br"}
+^a(%a0: i32):
+  memref.store %a0, %m0[] {tag_name = "a"} : memref<i32>
+  cf.br ^c(%a0 : i32)
+^b(%b0: i32):
+  memref.store %b0, %m1[] {tag_name = "b"} : memref<i32>
+  cf.br ^c(%b0 : i32)
+^c(%c0 : i32):
+  memref.store %c0, %m2[] {tag_name = "c"} : memref<i32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: two
+// CHECK: result #0: [a]
+func.func @test_infinite_loop(%m0: memref<i32>) {
+  %0 = arith.constant 0 : i32
+  %1 = arith.constant 1 : i32
+  %2 = arith.constant {tag = "two"} 2 : i32
+  %3 = arith.constant -1 : i32
+  cf.br ^loop(%0, %1, %2: i32, i32, i32)
+^loop(%a: i32, %b: i32, %c: i32):
+  memref.store %a, %m0[] {tag_name = "a"} : memref<i32>
+  cf.br ^loop(%b, %c, %3 : i32, i32, i32)
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: c0
+// CHECK: result #0: [a b c]
+func.func @test_switch(%flag: i32, %m0: memref<i32>) {
+  %0 = arith.constant {tag = "c0"} 0 : i32
+  cf.switch %flag : i32, [
+      default: ^a(%0 : i32),
+      42: ^b(%0 : i32),
+      43: ^c(%0 : i32)
+  ]
+^a(%a0: i32):
+  memref.store %a0, %m0[] {tag_name = "a"} : memref<i32>
+  cf.br ^c(%a0 : i32)
+^b(%b0: i32):
+  memref.store %b0, %m0[] {tag_name = "b"} : memref<i32>
+  cf.br ^c(%b0 : i32)
+^c(%c0 : i32):
+  memref.store %c0, %m0[] {tag_name = "c"} : memref<i32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: add
+// CHECK: result #0: [a]
+func.func @test_caller(%m0: memref<f32>, %arg: f32) {
+  %0 = arith.addf %arg, %arg {tag = "add"} : f32
+  %1 = func.call @callee(%0) : (f32) -> f32
+  %2 = arith.mulf %1, %1 : f32
+  %3 = arith.mulf %2, %2 : f32
+  %4 = arith.mulf %3, %3 : f32
+  memref.store %4, %m0[] {tag_name = "a"} : memref<f32>
+  return
+}
+
+func.func private @callee(%0 : f32) -> f32 {
+  %1 = arith.mulf %0, %0 : f32
+  %2 = arith.mulf %1, %1 : f32
+  func.return %2 : f32
+}
+
+// -----
+
+func.func private @callee(%0 : f32) -> f32 {
+  %1 = arith.mulf %0, %0 : f32
+  func.return %1 : f32
+}
+
+// CHECK-LABEL: test_tag: sub
+// CHECK: result #0: [a]
+func.func @test_caller_below_callee(%m0: memref<f32>, %arg: f32) {
+  %0 = arith.subf %arg, %arg {tag = "sub"} : f32
+  %1 = func.call @callee(%0) : (f32) -> f32
+  memref.store %1, %m0[] {tag_name = "a"} : memref<f32>
+  return
+}
+
+// -----
+
+func.func private @callee1(%0 : f32) -> f32 {
+  %1 = func.call @callee2(%0) : (f32) -> f32
+  func.return %1 : f32
+}
+
+func.func private @callee2(%0 : f32) -> f32 {
+  %1 = func.call @callee3(%0) : (f32) -> f32
+  func.return %1 : f32
+}
+
+func.func private @callee3(%0 : f32) -> f32 {
+  func.return %0 : f32
+}
+
+// CHECK-LABEL: test_tag: mul
+// CHECK: result #0: [a]
+func.func @test_callchain(%m0: memref<f32>, %arg: f32) {
+  %0 = arith.mulf %arg, %arg {tag = "mul"} : f32
+  %1 = func.call @callee1(%0) : (f32) -> f32
+  memref.store %1, %m0[] {tag_name = "a"} : memref<f32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: zero
+// CHECK: result #0: [c]
+// CHECK-LABEL: test_tag: init
+// CHECK: result #0: [a b]
+// CHECK-LABEL: test_tag: condition
+// CHECK: operand #0: [brancharg0]
+func.func @test_while(%m0: memref<i32>, %init : i32, %cond: i1) {
+  %zero = arith.constant {tag = "zero"} 0 : i32
+  %init2 = arith.addi %init, %init {tag = "init"} : i32
+  %0, %1 = scf.while (%arg1 = %zero, %arg2 = %init2) : (i32, i32) -> (i32, i32) {
+    memref.store %arg2, %m0[] {tag_name = "a"} : memref<i32>
+    scf.condition(%cond) {tag = "condition"} %arg1, %arg2 : i32, i32
+  } do {
+   ^bb0(%arg1: i32, %arg2: i32):
+    memref.store %arg1, %m0[] {tag_name = "c"} : memref<i32>
+    %res = arith.addi %arg2, %arg2 : i32
+    scf.yield %arg1, %res: i32, i32
+  }
+  memref.store %1, %m0[] {tag_name = "b"} : memref<i32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: zero
+// CHECK: result #0: [brancharg0]
+// CHECK-LABEL: test_tag: ten
+// CHECK: result #0: [brancharg1]
+// CHECK-LABEL: test_tag: one
+// CHECK: result #0: [brancharg2]
+// CHECK-LABEL: test_tag: x
+// CHECK: result #0: [a]
+func.func @test_for(%m0: memref<i32>) {
+  %zero = arith.constant {tag = "zero"} 0 : index
+  %ten = arith.constant {tag = "ten"} 10 : index
+  %one = arith.constant {tag = "one"} 1 : index
+  %x = arith.constant {tag = "x"} 0 : i32
+  %0 = scf.for %i = %zero to %ten step %one iter_args(%ix = %x) -> (i32) {
+    scf.yield %ix : i32
+  }
+  memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_tag: default_a
+// CHECK-LABEL: result #0: [a]
+// CHECK-LABEL: test_tag: default_b
+// CHECK-LABEL: result #0: [b]
+// CHECK-LABEL: test_tag: 1a
+// CHECK-LABEL: result #0: [a]
+// CHECK-LABEL: test_tag: 1b
+// CHECK-LABEL: result #0: [b]
+// CHECK-LABEL: test_tag: 2a
+// CHECK-LABEL: result #0: [a]
+// CHECK-LABEL: test_tag: 2b
+// CHECK-LABEL: result #0: [b]
+// CHECK-LABEL: test_tag: switch
+// CHECK-LABEL: operand #0: [brancharg0]
+func.func @test_switch(%arg0 : index, %m0: memref<i32>) {
+  %0, %1 = scf.index_switch %arg0 {tag="switch"} -> i32, i32
+  case 1 {
+    %2 = arith.constant {tag="1a"} 10 : i32
+    %3 = arith.constant {tag="1b"} 100 : i32
+    scf.yield %2, %3 : i32, i32
+  }
+  case 2 {
+    %4 = arith.constant {tag="2a"} 20 : i32
+    %5 = arith.constant {tag="2b"} 200 : i32
+    scf.yield %4, %5 : i32, i32
+  }
+  default {
+    %6 = arith.constant {tag="default_a"} 30 : i32
+    %7 = arith.constant {tag="default_b"} 300 : i32
+    scf.yield %6, %7 : i32, i32
+  }
+  memref.store %0, %m0[] {tag_name = "a"} : memref<i32>
+  memref.store %1, %m0[] {tag_name = "b"} : memref<i32>
+  return
+}

diff  --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 9d3761dec08f2..d83a8d5c070bc 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTestAnalysis
 
   DataFlow/TestDeadCodeAnalysis.cpp
   DataFlow/TestDenseDataFlowAnalysis.cpp
+  DataFlow/TestBackwardDataFlowAnalysis.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..9579052c9940a
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestBackwardDataFlowAnalysis.cpp
@@ -0,0 +1,142 @@
+//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct WrittenTo : public AbstractSparseLattice {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
+  using AbstractSparseLattice::AbstractSparseLattice;
+
+  void print(raw_ostream &os) const override {
+    os << "[";
+    llvm::interleave(
+        writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
+    os << "]";
+  }
+  ChangeResult addWrites(const SetVector<StringAttr> &writes) {
+    int size_before = this->writes.size();
+    this->writes.insert(writes.begin(), writes.end());
+    int size_after = this->writes.size();
+    return size_before == size_after ? ChangeResult::NoChange
+                                     : ChangeResult::Change;
+  }
+  ChangeResult meet(const AbstractSparseLattice &other) override {
+    auto rhs = reinterpret_cast<const WrittenTo *>(&other);
+    return addWrites(rhs->writes);
+  }
+
+  SetVector<StringAttr> writes;
+};
+
+/// An analysis that, by going backwards along the dataflow graph, annotates
+/// each value with all the memory resources it (or anything derived from it)
+/// is eventually written to.
+class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
+public:
+  using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+
+  void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
+                      ArrayRef<const WrittenTo *> results) override;
+
+  void visitBranchOperand(OpOperand &operand) override;
+
+  void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+};
+
+void WrittenToAnalysis::visitOperation(Operation *op,
+                                       ArrayRef<WrittenTo *> operands,
+                                       ArrayRef<const WrittenTo *> results) {
+  if (auto store = dyn_cast<memref::StoreOp>(op)) {
+    SetVector<StringAttr> new_writes;
+    new_writes.insert(op->getAttrOfType<StringAttr>("tag_name"));
+    propagateIfChanged(operands[0], operands[0]->addWrites(new_writes));
+    return;
+  } else {
+    // By default, every result of an op depends on every operand.
+    for (const WrittenTo *r : results) {
+      for (WrittenTo *operand : operands) {
+        meet(operand, *r);
+      }
+      addDependency(const_cast<WrittenTo *>(r), op);
+    }
+  }
+}
+
+void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
+  // Mark branch operands as "brancharg%d", with %d the operand number.
+  WrittenTo *lattice = getLatticeElement(operand.get());
+  SetVector<StringAttr> new_writes;
+  new_writes.insert(
+      StringAttr::get(operand.getOwner()->getContext(),
+                      "brancharg" + Twine(operand.getOperandNumber())));
+  propagateIfChanged(lattice, lattice->addWrites(new_writes));
+}
+
+} // end anonymous namespace
+
+namespace {
+struct TestWrittenToPass
+    : public PassWrapper<TestWrittenToPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
+
+  StringRef getArgument() const override { return "test-written-to"; }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+
+    SymbolTableCollection symbolTable;
+
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<WrittenToAnalysis>(symbolTable);
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    raw_ostream &os = llvm::outs();
+    op->walk([&](Operation *op) {
+      auto tag = op->getAttrOfType<StringAttr>("tag");
+      if (!tag)
+        return;
+      os << "test_tag: " << tag.getValue() << ":\n";
+      for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
+        const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
+        assert(writtenTo && "expected a sparse lattice");
+        os << " operand #" << index << ": ";
+        writtenTo->print(os);
+        os << "\n";
+      }
+      for (auto [index, operand] : llvm::enumerate(op->getResults())) {
+        const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
+        assert(writtenTo && "expected a sparse lattice");
+        os << " result #" << index << ": ";
+        writtenTo->print(os);
+        os << "\n";
+      }
+    });
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
+} // end namespace test
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e9200b7bf9724..622b9c945b6d5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -120,6 +120,7 @@ void registerTestTilingInterface();
 void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
 void registerTestTransformDialectInterpreterPass();
+void registerTestWrittenToPass();
 void registerTestVectorLowerings();
 void registerTestNvgpuLowerings();
 } // namespace test
@@ -227,6 +228,7 @@ void registerTestPasses() {
   mlir::test::registerTestTransformDialectInterpreterPass();
   mlir::test::registerTestVectorLowerings();
   mlir::test::registerTestNvgpuLowerings();
+  mlir::test::registerTestWrittenToPass();
 }
 #endif
 


        


More information about the Mlir-commits mailing list