[Mlir-commits] [mlir] d35f7f2 - [mlir] Allow data flow analysis of non-control flow branch arguments

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Apr 25 13:19:40 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-04-25T20:19:34Z
New Revision: d35f7f254f6a5c666262ad41f9c36330c4728651

URL: https://github.com/llvm/llvm-project/commit/d35f7f254f6a5c666262ad41f9c36330c4728651
DIFF: https://github.com/llvm/llvm-project/commit/d35f7f254f6a5c666262ad41f9c36330c4728651.diff

LOG: [mlir] Allow data flow analysis of non-control flow branch arguments

This commit adds the visitNonControlFlowArguments method to
DataFlowAnalysis, allowing analyses to provide lattice values for the
arguments to a RegionSuccessor block that aren't directly tied to an
op's inputs. For example, integer range interface can use this method
to infer bounds for the step values in loops.

This method has a default implementation that keeps the old behavior
of assigning a pessimistic fixedpoint state to all such arguments.

Reviewed By: Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D124021

Added: 
    mlir/test/Analysis/test-data-flow.mlir
    mlir/test/lib/Analysis/TestDataFlow.cpp

Modified: 
    mlir/include/mlir/Analysis/DataFlowAnalysis.h
    mlir/lib/Analysis/DataFlowAnalysis.cpp
    mlir/test/lib/Analysis/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
index 29a610e7aa10a..6beb64ba0c2e8 100644
--- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
@@ -250,6 +250,15 @@ class ForwardDataFlowAnalysisBase {
                            ArrayRef<AbstractLatticeElement *> operands,
                            SmallVectorImpl<RegionSuccessor> &successors) = 0;
 
+  /// Given a operation with successor regions, one of those regions,
+  /// and the lattice elements corresponding to the operation's
+  /// arguments, compute the latice values for block arguments
+  /// that are not accounted for by the branching control flow (ex. the
+  /// bounds of loops).
+  virtual ChangeResult
+  visitNonControlFlowArguments(Operation *op, const RegionSuccessor &region,
+                               ArrayRef<AbstractLatticeElement *> operands) = 0;
+
   /// Create a new uninitialized lattice element. An optional value is provided
   /// which, if valid, should be used to initialize the known conservative state
   /// of the lattice.
@@ -347,6 +356,33 @@ class ForwardDataFlowAnalysis : public detail::ForwardDataFlowAnalysisBase {
     branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
   }
 
+  /// Given a operation with successor regions, one of those regions,
+  /// and the lattice elements corresponding to the operation's
+  /// arguments, compute the latice values for block arguments
+  /// that are not accounted for by the branching control flow (ex. the
+  /// bounds of loops). By default, this method marks all such lattice elements
+  /// as having reached a pessimistic fixpoint. The region in the
+  /// RegionSuccessor and the operand latice elements are guaranteed to be
+  /// non-null.
+  virtual ChangeResult
+  visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
+                               ArrayRef<LatticeElement<ValueT> *> operands) {
+    ChangeResult result = ChangeResult::NoChange;
+    Region *region = successor.getSuccessor();
+    ValueRange succArgs = successor.getSuccessorInputs();
+    Block *block = &region->front();
+    Block::BlockArgListType arguments = block->getArguments();
+    if (arguments.size() != succArgs.size()) {
+      unsigned firstArgIdx =
+          succArgs.empty() ? 0
+                           : succArgs[0].cast<BlockArgument>().getArgNumber();
+      result |= markAllPessimisticFixpoint(arguments.take_front(firstArgIdx));
+      result |= markAllPessimisticFixpoint(
+          arguments.drop_front(firstArgIdx + succArgs.size()));
+    }
+    return result;
+  }
+
 private:
   /// Type-erased wrappers that convert the abstract lattice operands to derived
   /// lattices and invoke the virtual hooks operating on the derived lattices.
@@ -379,6 +415,14 @@ class ForwardDataFlowAnalysis : public detail::ForwardDataFlowAnalysisBase {
         branch, sourceIndex,
         llvm::makeArrayRef(derivedOperandBase, operands.size()), successors);
   }
+  ChangeResult visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &region,
+      ArrayRef<detail::AbstractLatticeElement *> operands) final {
+    LatticeElement<ValueT> *const *derivedOperandBase =
+        reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
+    return visitNonControlFlowArguments(
+        op, region, llvm::makeArrayRef(derivedOperandBase, operands.size()));
+  }
 
   /// Create a new uninitialized lattice element. An optional value is provided,
   /// which if valid, should be used to initialize the known conservative state

diff  --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
index 6718dee107fe5..9c10595dbb00a 100644
--- a/mlir/lib/Analysis/DataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
 #include <queue>
@@ -113,6 +114,7 @@ class ForwardDataFlowSolver {
   /// the parent operation results.
   void visitRegionSuccessors(
       Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
+      ArrayRef<AbstractLatticeElement *> operandLattices,
       function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
 
   /// Visit the given terminator operation and compute any necessary lattice
@@ -460,7 +462,7 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
   if (successors.empty())
     return markAllPessimisticFixpoint(branch, branch->getResults());
   return visitRegionSuccessors(
-      branch, successors, [&](Optional<unsigned> index) {
+      branch, successors, operandLattices, [&](Optional<unsigned> index) {
         assert(index && "expected valid region index");
         return branch.getSuccessorEntryOperands(*index);
       });
@@ -468,6 +470,7 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
 
 void ForwardDataFlowSolver::visitRegionSuccessors(
     Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
+    ArrayRef<AbstractLatticeElement *> operandLattices,
     function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
   for (const RegionSuccessor &it : regionSuccessors) {
     Region *region = it.getSuccessor();
@@ -514,22 +517,25 @@ void ForwardDataFlowSolver::visitRegionSuccessors(
     if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
       continue;
 
-    // Mark any arguments that do not receive inputs as having reached a
-    // pessimistic fixpoint, we won't be able to discern if they are constant.
-    // TODO: This isn't exactly ideal. There may be situations in which a
-    // region operation can provide information for certain results that
-    // aren't part of the control flow.
     if (succArgs.size() != arguments.size()) {
-      if (succArgs.empty()) {
-        markAllPessimisticFixpoint(arguments);
-        continue;
+      if (analysis.visitNonControlFlowArguments(
+              parentOp, it, operandLattices) == ChangeResult::Change) {
+        unsigned firstArgIdx =
+            succArgs.empty() ? 0
+                             : succArgs[0].cast<BlockArgument>().getArgNumber();
+        for (Value v : arguments.take_front(firstArgIdx)) {
+          assert(!analysis.getLatticeElement(v).isUninitialized() &&
+                 "Non-control flow block arg has no lattice value after "
+                 "analysis callback");
+          visitUsers(v);
+        }
+        for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) {
+          assert(!analysis.getLatticeElement(v).isUninitialized() &&
+                 "Non-control flow block arg has no lattice value after "
+                 "analysis callback");
+          visitUsers(v);
+        }
       }
-
-      unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
-      markAllPessimisticFixpointAndVisitUsers(
-          arguments.take_front(firstArgIdx));
-      markAllPessimisticFixpointAndVisitUsers(
-          arguments.drop_front(firstArgIdx + succArgs.size()));
     }
 
     // Update the lattice of arguments that have inputs from the predecessor.
@@ -573,12 +579,13 @@ void ForwardDataFlowSolver::visitTerminatorOperation(
     // Try to get "region-like" successor operands if possible in order to
     // propagate the operand states to the successors.
     if (isRegionReturnLike(op)) {
-      return visitRegionSuccessors(
-          parentOp, regionSuccessors, [&](Optional<unsigned> regionIndex) {
-            // Determine the individual region successor operands for the given
-            // region index (if any).
-            return *getRegionBranchSuccessorOperands(op, regionIndex);
-          });
+      auto getOperands = [&](Optional<unsigned> regionIndex) {
+        // Determine the individual region  successor operands for the given
+        // region index (if any).
+        return *getRegionBranchSuccessorOperands(op, regionIndex);
+      };
+      return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices,
+                                   getOperands);
     }
 
     // If this terminator is not "region-like", conservatively mark all of the

diff  --git a/mlir/test/Analysis/test-data-flow.mlir b/mlir/test/Analysis/test-data-flow.mlir
new file mode 100644
index 0000000000000..3d13f394dbd20
--- /dev/null
+++ b/mlir/test/Analysis/test-data-flow.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -test-data-flow --allow-unregistered-dialect %s 2>&1 | FileCheck %s
+
+// CHECK-LABEL: Testing : "loop-arg-pessimistic"
+module attributes {test.name = "loop-arg-pessimistic"} {
+  func @f() -> index {
+    // CHECK: Visiting : %{{.*}} = arith.constant 0
+    // CHECK-NEXT: Result 0 moved from uninitialized to 1
+    %c0 = arith.constant 0 : index
+    // CHECK: Visiting : %{{.*}} = arith.constant 1
+    // CHECK-NEXT: Result 0 moved from uninitialized to 1
+    %c1 = arith.constant 1 : index
+    // CHECK: Visiting region branch op : %{{.*}} = scf.for
+    // CHECK: Block argument 0 moved from uninitialized to 1
+    %0 = scf.for %arg1 = %c0 to %c1 step %c1 iter_args(%arg2 = %c0) -> index {
+      // CHECK: Visiting : %{{.*}} = arith.addi %{{.*}}, %{{.*}}
+      // CHECK-NEXT: Arg 0 : 1
+      // CHECK-NEXT: Arg 1 : 1
+      // CHECK-NEXT: Result 0 moved from uninitialized to 1
+      %10 = arith.addi %arg1, %arg2 : index
+      scf.yield %10 : index
+    }
+    return %0 : index
+  }
+}

diff  --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index fc378b8af84f3..f3f5f81f3ef77 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRTestAnalysis
   TestAliasAnalysis.cpp
   TestCallGraph.cpp
+  TestDataFlow.cpp
   TestLiveness.cpp
   TestMatchReduction.cpp
   TestMemRefBoundCheck.cpp

diff  --git a/mlir/test/lib/Analysis/TestDataFlow.cpp b/mlir/test/lib/Analysis/TestDataFlow.cpp
new file mode 100644
index 0000000000000..84664873d2ff2
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestDataFlow.cpp
@@ -0,0 +1,127 @@
+//===- TestDataFlow.cpp - Test data flow analysis system -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for defining and running a dataflow analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace mlir;
+
+namespace {
+struct WasAnalyzed {
+  WasAnalyzed(bool wasAnalyzed) : wasAnalyzed(wasAnalyzed) {}
+
+  static WasAnalyzed join(const WasAnalyzed &a, const WasAnalyzed &b) {
+    return a.wasAnalyzed && b.wasAnalyzed;
+  }
+
+  static WasAnalyzed getPessimisticValueState(MLIRContext *context) {
+    return false;
+  }
+
+  static WasAnalyzed getPessimisticValueState(Value v) {
+    return getPessimisticValueState(v.getContext());
+  }
+
+  bool operator==(const WasAnalyzed &other) const {
+    return wasAnalyzed == other.wasAnalyzed;
+  }
+
+  bool wasAnalyzed;
+};
+
+struct TestAnalysis : public ForwardDataFlowAnalysis<WasAnalyzed> {
+  using ForwardDataFlowAnalysis<WasAnalyzed>::ForwardDataFlowAnalysis;
+
+  ChangeResult
+  visitOperation(Operation *op,
+                 ArrayRef<LatticeElement<WasAnalyzed> *> operands) final {
+    ChangeResult ret = ChangeResult::NoChange;
+    llvm::errs() << "Visiting : ";
+    op->print(llvm::errs());
+    llvm::errs() << "\n";
+
+    WasAnalyzed result(true);
+    for (auto &pair : llvm::enumerate(operands)) {
+      LatticeElement<WasAnalyzed> *elem = pair.value();
+      llvm::errs() << "Arg " << pair.index();
+      if (!elem->isUninitialized()) {
+        llvm::errs() << " : " << elem->getValue().wasAnalyzed << "\n";
+        result = WasAnalyzed::join(result, elem->getValue());
+      } else {
+        llvm::errs() << " uninitialized\n";
+      }
+    }
+    for (const auto &pair : llvm::enumerate(op->getResults())) {
+      LatticeElement<WasAnalyzed> &lattice = getLatticeElement(pair.value());
+      llvm::errs() << "Result " << pair.index() << " moved from ";
+      if (lattice.isUninitialized())
+        llvm::errs() << "uninitialized";
+      else
+        llvm::errs() << lattice.getValue().wasAnalyzed;
+      ret |= lattice.join({result});
+      llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n";
+    }
+    return ret;
+  }
+
+  ChangeResult visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &successor,
+      ArrayRef<LatticeElement<WasAnalyzed> *> operands) final {
+    ChangeResult ret = ChangeResult::NoChange;
+    llvm::errs() << "Visiting region branch op : ";
+    op->print(llvm::errs());
+    llvm::errs() << "\n";
+
+    Region *region = successor.getSuccessor();
+    Block *block = &region->front();
+    Block::BlockArgListType arguments = block->getArguments();
+    // Mark all arguments to blocks as analyzed unless they already have
+    // an unanalyzed state.
+    for (const auto &pair : llvm::enumerate(arguments)) {
+      LatticeElement<WasAnalyzed> &lattice = getLatticeElement(pair.value());
+      llvm::errs() << "Block argument " << pair.index() << " moved from ";
+      if (lattice.isUninitialized())
+        llvm::errs() << "uninitialized";
+      else
+        llvm::errs() << lattice.getValue().wasAnalyzed;
+      ret |= lattice.join({true});
+      llvm::errs() << " to " << lattice.getValue().wasAnalyzed << "\n";
+    }
+    return ret;
+  }
+};
+
+struct TestDataFlowPass
+    : public PassWrapper<TestDataFlowPass, OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataFlowPass)
+
+  StringRef getArgument() const final { return "test-data-flow"; }
+  StringRef getDescription() const final {
+    return "Print the actions taken during a dataflow analysis.";
+  }
+  void runOnOperation() override {
+    llvm::errs() << "Testing : " << getOperation()->getAttr("test.name")
+                 << "\n";
+    TestAnalysis analysis(getOperation().getContext());
+    analysis.run(getOperation());
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestDataFlowPass() { PassRegistration<TestDataFlowPass>(); }
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 83a1c30af56de..7e26b7fc578ac 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -70,6 +70,7 @@ void registerTestConstantFold();
 void registerTestControlFlowSink();
 void registerTestGpuSerializeToCubinPass();
 void registerTestGpuSerializeToHsacoPass();
+void registerTestDataFlowPass();
 void registerTestDataLayoutQuery();
 void registerTestDecomposeCallGraphTypes();
 void registerTestDiagnosticsPass();
@@ -167,6 +168,7 @@ void registerTestPasses() {
   mlir::test::registerTestGpuSerializeToHsacoPass();
 #endif
   mlir::test::registerTestDecomposeCallGraphTypes();
+  mlir::test::registerTestDataFlowPass();
   mlir::test::registerTestDataLayoutQuery();
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();


        


More information about the Mlir-commits mailing list