[Mlir-commits] [mlir] SparseAnalysis: support ReturnLike terminators (PR #140797)

Jeremy Kun llvmlistbot at llvm.org
Tue May 20 13:43:09 PDT 2025


https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/140797

This PR adds support in sparse analysis for non-control flow region-bearing ops that have return-like terminators. By default it propagates the terminator's operand lattices to the containing op's result lattices, and also allows the analysis subclass to override this behavior.

Cf. https://discourse.llvm.org/t/how-should-non-control-flow-region-bearing-ops-be-used-with-by-dataflow-analyses/ for context

>From a1de01197f357cdffeab313e2ea086381e494415 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Tue, 20 May 2025 13:11:56 -0700
Subject: [PATCH] SparseAnalysis: support ReturnLike terminators

This PR adds support in sparse analysis for non-control flow
region-bearing ops that have return-like terminators. By default it
propagates the terminator's operand lattices to the containing op's
result lattices, and also allows the analysis subclass to override this
behavior.
---
 .../mlir/Analysis/DataFlow/SparseAnalysis.h   |  65 ++++++++
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp |  45 +++---
 .../Analysis/DataFlow/test-return-like.mlir   |  19 +++
 mlir/test/lib/Analysis/CMakeLists.txt         |   1 +
 .../TestSparseForwardDataFlowAnalysis.cpp     | 141 ++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         |  10 ++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 7 files changed, 263 insertions(+), 20 deletions(-)
 create mode 100644 mlir/test/Analysis/DataFlow/test-return-like.mlir
 create mode 100644 mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp

diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1b2c679176107..4a6d0eb132828 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -220,6 +220,14 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
       Operation *op, const RegionSuccessor &successor,
       ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
 
+  /// Visit a region terminator. This is intended for non-control-flow
+  /// region-bearing ops whose terminators determine the lattice values of the
+  /// parent op's results.
+  virtual LogicalResult visitNonControlFlowTerminatorImpl(
+      Operation *terminatorOp,
+      ArrayRef<const AbstractSparseLattice *> terminatorOperandLattices,
+      ArrayRef<AbstractSparseLattice *> parentResultLattices) = 0;
+
   /// Get the lattice element of a value.
   virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
 
@@ -235,6 +243,29 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Join the lattice element and propagate and update if it changed.
   void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
 
+  // Get the lattice elements of the operands.
+  SmallVector<const AbstractSparseLattice *> getOperandLattices(Operation *op) {
+    SmallVector<const AbstractSparseLattice *> operandLattices;
+    operandLattices.reserve(op->getNumOperands());
+    for (Value operand : op->getOperands()) {
+      AbstractSparseLattice *operandLattice = getLatticeElement(operand);
+      operandLattice->useDefSubscribe(this);
+      operandLattices.push_back(operandLattice);
+    }
+    return operandLattices;
+  }
+
+  // Get the lattice elements of the results.
+  SmallVector<AbstractSparseLattice *> getResultLattices(Operation *op) {
+    SmallVector<AbstractSparseLattice *> resultLattices;
+    resultLattices.reserve(op->getNumResults());
+    for (Value result : op->getResults()) {
+      AbstractSparseLattice *resultLattice = getLatticeElement(result);
+      resultLattices.push_back(resultLattice);
+    }
+    return resultLattices;
+  }
+
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);
@@ -299,6 +330,28 @@ class SparseForwardDataFlowAnalysis
     setAllToEntryStates(resultLattices);
   }
 
+  /// Visit a region terminator. This is intended for non-control-flow
+  /// region-bearing ops whose terminators determine the lattice values of the
+  /// parent op's results. By default the terminator's operand lattices are
+  /// forwarded to the parent result lattices, if there is a 1-1
+  /// correspondence.
+  virtual LogicalResult visitNonControlFlowTerminator(
+      Operation *terminatorOp,
+      ArrayRef<const StateT *> terminatorOperandLattices,
+      ArrayRef<StateT *> parentResultLattices) {
+    // ReturnLike terminators forward their lattice values to the results of the
+    // parent op.
+    if (terminatorOp->hasTrait<OpTrait::ReturnLike>() &&
+        terminatorOperandLattices.size() == parentResultLattices.size()) {
+      for (const auto &[operandLattice, resultLattice] :
+           llvm::zip(terminatorOperandLattices, parentResultLattices)) {
+        propagateIfChanged(resultLattice, resultLattice->join(*operandLattice));
+      }
+    }
+
+    return success();
+  }
+
   /// Given an operation with possible region control-flow, the lattices of the
   /// operands, and a region successor, compute the lattice values for block
   /// arguments that are not accounted for by the branching control flow (ex.
@@ -370,6 +423,18 @@ class SparseForwardDataFlowAnalysis
          argLattices.size()},
         firstIndex);
   }
+  LogicalResult visitNonControlFlowTerminatorImpl(
+      Operation *terminatorOp,
+      ArrayRef<const AbstractSparseLattice *> terminatorOperandLattices,
+      ArrayRef<AbstractSparseLattice *> parentResultLattices) override {
+    return visitNonControlFlowTerminator(
+        terminatorOp,
+        {reinterpret_cast<const StateT *const *>(
+             terminatorOperandLattices.begin()),
+         terminatorOperandLattices.size()},
+        {reinterpret_cast<StateT *const *>(parentResultLattices.begin()),
+         parentResultLattices.size()});
+  }
   void setToEntryState(AbstractSparseLattice *lattice) override {
     return setToEntryState(reinterpret_cast<StateT *>(lattice));
   }
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d14042493..0e4a75f4fa25c 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -7,6 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+
+#include <cassert>
+#include <optional>
+
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/Attributes.h"
@@ -20,8 +24,6 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
-#include <cassert>
-#include <optional>
 
 using namespace mlir;
 using namespace mlir::dataflow;
@@ -94,23 +96,32 @@ AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
 
 LogicalResult
 AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
-  // Exit early on operations with no results.
-  if (op->getNumResults() == 0)
-    return success();
-
   // If the containing block is not executable, bail out.
   if (op->getBlock() != nullptr &&
       !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
     return success();
 
-  // Get the result lattices.
-  SmallVector<AbstractSparseLattice *> resultLattices;
-  resultLattices.reserve(op->getNumResults());
-  for (Value result : op->getResults()) {
-    AbstractSparseLattice *resultLattice = getLatticeElement(result);
-    resultLattices.push_back(resultLattice);
+  // Region terminators which are not part of control flow have a special
+  // transfer function.
+  if (op->hasTrait<OpTrait::IsTerminator>()) {
+    Operation *parentOp = op->getParentOp();
+    if (parentOp && !isa<RegionBranchOpInterface>(parentOp) &&
+        !isa<RegionBranchTerminatorOpInterface>(op) &&
+        parentOp->getNumResults() > 0) {
+      SmallVector<const AbstractSparseLattice *> operandLattices =
+          getOperandLattices(op);
+      SmallVector<AbstractSparseLattice *> parentResultLattices =
+          getResultLattices(parentOp);
+      return visitNonControlFlowTerminatorImpl(op, operandLattices,
+                                               parentResultLattices);
+    }
   }
 
+  if (op->getNumResults() == 0)
+    return success();
+
+  SmallVector<AbstractSparseLattice *> resultLattices = getResultLattices(op);
+
   // The results of a region branch operation are determined by control-flow.
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
     visitRegionSuccessors(getProgramPointAfter(branch), branch,
@@ -119,14 +130,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
     return success();
   }
 
-  // Grab the lattice elements of the operands.
-  SmallVector<const AbstractSparseLattice *> operandLattices;
-  operandLattices.reserve(op->getNumOperands());
-  for (Value operand : op->getOperands()) {
-    AbstractSparseLattice *operandLattice = getLatticeElement(operand);
-    operandLattice->useDefSubscribe(this);
-    operandLattices.push_back(operandLattice);
-  }
+  SmallVector<const AbstractSparseLattice *> operandLattices =
+      getOperandLattices(op);
 
   if (auto call = dyn_cast<CallOpInterface>(op)) {
     // If the call operation is to an external function, attempt to infer the
diff --git a/mlir/test/Analysis/DataFlow/test-return-like.mlir b/mlir/test/Analysis/DataFlow/test-return-like.mlir
new file mode 100644
index 0000000000000..24adb59f079d8
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-return-like.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt --test-integer-lattice %s | FileCheck %s
+
+// CHECK-LABEL: @test_returnlike
+// CHECK: analysis_return_like_region_op
+// CHECK-NEXT: arith.constant {test.operand_lattices = [], test.result_lattices = [0 : index]} 1 : i32
+// CHECK-NEXT: region_yield
+// CHECK-SAME: {test.operand_lattices = [0 : index], test.result_lattices = []}
+
+// The core of the return-like test: the operand lattices of the yield forward
+// to the result lattices of the enclosing region-holding op
+
+// CHECK-NEXT: }) {test.operand_lattices = [], test.result_lattices = [0 : index]} : () -> i32
+func.func @test_returnlike() {
+  %0 = "test.analysis_return_like_region_op"() ({
+    %0 = arith.constant 1 : i32
+    "test.region_yield" (%0) : (i32) -> ()
+  }) : () -> i32
+  return
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 91879981bffd2..e19025084bb8f 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_library(MLIRTestAnalysis
   DataFlow/TestDenseForwardDataFlowAnalysis.cpp
   DataFlow/TestLivenessAnalysis.cpp
   DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+  DataFlow/TestSparseForwardDataFlowAnalysis.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..8f3794da19ac3
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp
@@ -0,0 +1,141 @@
+//===- TestForwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+
+class IntegerState {
+public:
+  IntegerState() : value(0) {}
+  explicit IntegerState(int value) : value(value) {}
+  ~IntegerState() = default;
+
+  int get() const { return value; }
+
+  bool operator==(const IntegerState &rhs) const { return value == rhs.value; }
+
+  static IntegerState join(const IntegerState &lhs, const IntegerState &rhs) {
+    return IntegerState{std::max(lhs.get(), rhs.get())};
+  }
+
+  void print(llvm::raw_ostream &os) const {
+    os << "IntegerState(" << value << ")";
+  }
+
+  friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                                       const IntegerState &state) {
+    state.print(os);
+    return os;
+  }
+
+private:
+  int value;
+};
+
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct IntegerLattice : public Lattice<IntegerState> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IntegerLattice)
+  using Lattice::Lattice;
+};
+
+/// An analysis that, by going backwards along the dataflow graph, annotates
+/// each value with all the memory resources it (or anything derived from it)
+/// is eventually written to.
+class IntegerLatticeAnalysis
+    : public SparseForwardDataFlowAnalysis<IntegerLattice> {
+public:
+  using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+
+  LogicalResult visitOperation(Operation *op,
+                               ArrayRef<const IntegerLattice *> operands,
+                               ArrayRef<IntegerLattice *> results) override;
+
+  void setToEntryState(IntegerLattice *lattice) override {
+    propagateIfChanged(lattice, lattice->join(IntegerState()));
+  }
+};
+
+LogicalResult IntegerLatticeAnalysis::visitOperation(
+    Operation *op, ArrayRef<const IntegerLattice *> operands,
+    ArrayRef<IntegerLattice *> results) {
+  for (auto *operand : operands) {
+    for (auto *result : results) {
+      propagateIfChanged(result, result->join(*operand));
+    }
+  }
+  return success();
+}
+
+} // end anonymous namespace
+
+namespace {
+struct TestIntegerLatticePass
+    : public PassWrapper<TestIntegerLatticePass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntegerLatticePass)
+
+  TestIntegerLatticePass() = default;
+  TestIntegerLatticePass(const TestIntegerLatticePass &other)
+      : PassWrapper(other) {}
+
+  StringRef getArgument() const override { return "test-integer-lattice"; }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = &getContext();
+
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
+    solver.load<IntegerLatticeAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+
+    // Walk the IR and attach operand and result lattices as attributes to each
+    // operation.
+    op->walk([&](Operation *op) {
+      SmallVector<Attribute> operandAttrs;
+      SmallVector<Attribute> resultAttrs;
+      for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
+        const IntegerLattice *lattice =
+            solver.lookupState<IntegerLattice>(operand);
+        assert(lattice && "expected a sparse lattice");
+        operandAttrs.push_back(
+            IntegerAttr::get(IndexType::get(ctx), lattice->getValue().get()));
+      }
+      for (auto [index, result] : llvm::enumerate(op->getResults())) {
+        const IntegerLattice *lattice =
+            solver.lookupState<IntegerLattice>(result);
+        assert(lattice && "expected a sparse lattice");
+        resultAttrs.push_back(
+            IntegerAttr::get(IndexType::get(ctx), lattice->getValue().get()));
+      }
+
+      op->setAttr("test.operand_lattices", ArrayAttr::get(ctx, operandAttrs));
+      op->setAttr("test.result_lattices", ArrayAttr::get(ctx, resultAttrs));
+    });
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIntegerLatticePass() {
+  PassRegistration<TestIntegerLatticePass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 43a0bdaf86cf3..dc594c622c67f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3507,4 +3507,14 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
   }];
 }
 
+// ==------------------------------------------------------------------------===//
+// Test Analysis ReturnLike
+//===----------------------------------------------------------------------===//
+def AnalysisReturnLikeRegionOp : TEST_Op<"analysis_return_like_region_op",
+      [SingleBlockImplicitTerminator<"RegionYieldOp">]> {
+  let regions = (region AnyRegion:$region);
+  let results = (outs AnyType:$result);
+}
+
+
 #endif // TEST_OPS
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index cdcf59b2add13..0b67c9ce0ffe4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -105,6 +105,7 @@ void registerTestComposeSubView();
 void registerTestMultiBuffering();
 void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
+void registerTestIntegerLatticePass();
 void registerTestInterfaces();
 void registerTestIRVisitorsPass();
 void registerTestLastModifiedPass();
@@ -249,6 +250,7 @@ void registerTestPasses() {
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
+  mlir::test::registerTestIntegerLatticePass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestLastModifiedPass();



More information about the Mlir-commits mailing list