[Mlir-commits] [mlir] SparseAnalysis: support ReturnLike terminators (PR #140797)
Jeremy Kun
llvmlistbot at llvm.org
Tue May 20 13:43:09 PDT 2025
https://github.com/j2kun created https://github.com/llvm/llvm-project/pull/140797
This PR adds support in sparse analysis for non-control flow region-bearing ops that have return-like terminators. By default it propagates the terminator's operand lattices to the containing op's result lattices, and also allows the analysis subclass to override this behavior.
Cf. https://discourse.llvm.org/t/how-should-non-control-flow-region-bearing-ops-be-used-with-by-dataflow-analyses/ for context
>From a1de01197f357cdffeab313e2ea086381e494415 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <jkun at google.com>
Date: Tue, 20 May 2025 13:11:56 -0700
Subject: [PATCH] SparseAnalysis: support ReturnLike terminators
This PR adds support in sparse analysis for non-control flow
region-bearing ops that have return-like terminators. By default it
propagates the terminator's operand lattices to the containing op's
result lattices, and also allows the analysis subclass to override this
behavior.
---
.../mlir/Analysis/DataFlow/SparseAnalysis.h | 65 ++++++++
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 45 +++---
.../Analysis/DataFlow/test-return-like.mlir | 19 +++
mlir/test/lib/Analysis/CMakeLists.txt | 1 +
.../TestSparseForwardDataFlowAnalysis.cpp | 141 ++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 10 ++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
7 files changed, 263 insertions(+), 20 deletions(-)
create mode 100644 mlir/test/Analysis/DataFlow/test-return-like.mlir
create mode 100644 mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1b2c679176107..4a6d0eb132828 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -220,6 +220,14 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
Operation *op, const RegionSuccessor &successor,
ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
+ /// Visit a region terminator. This is intended for non-control-flow
+ /// region-bearing ops whose terminators determine the lattice values of the
+ /// parent op's results.
+ virtual LogicalResult visitNonControlFlowTerminatorImpl(
+ Operation *terminatorOp,
+ ArrayRef<const AbstractSparseLattice *> terminatorOperandLattices,
+ ArrayRef<AbstractSparseLattice *> parentResultLattices) = 0;
+
/// Get the lattice element of a value.
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
@@ -235,6 +243,29 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+ // Get the lattice elements of the operands.
+ SmallVector<const AbstractSparseLattice *> getOperandLattices(Operation *op) {
+ SmallVector<const AbstractSparseLattice *> operandLattices;
+ operandLattices.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ AbstractSparseLattice *operandLattice = getLatticeElement(operand);
+ operandLattice->useDefSubscribe(this);
+ operandLattices.push_back(operandLattice);
+ }
+ return operandLattices;
+ }
+
+ // Get the lattice elements of the results.
+ SmallVector<AbstractSparseLattice *> getResultLattices(Operation *op) {
+ SmallVector<AbstractSparseLattice *> resultLattices;
+ resultLattices.reserve(op->getNumResults());
+ for (Value result : op->getResults()) {
+ AbstractSparseLattice *resultLattice = getLatticeElement(result);
+ resultLattices.push_back(resultLattice);
+ }
+ return resultLattices;
+ }
+
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
@@ -299,6 +330,28 @@ class SparseForwardDataFlowAnalysis
setAllToEntryStates(resultLattices);
}
+ /// Visit a region terminator. This is intended for non-control-flow
+ /// region-bearing ops whose terminators determine the lattice values of the
+ /// parent op's results. By default the terminator's operand lattices are
+ /// forwarded to the parent result lattices, if there is a 1-1
+ /// correspondence.
+ virtual LogicalResult visitNonControlFlowTerminator(
+ Operation *terminatorOp,
+ ArrayRef<const StateT *> terminatorOperandLattices,
+ ArrayRef<StateT *> parentResultLattices) {
+ // ReturnLike terminators forward their lattice values to the results of the
+ // parent op.
+ if (terminatorOp->hasTrait<OpTrait::ReturnLike>() &&
+ terminatorOperandLattices.size() == parentResultLattices.size()) {
+ for (const auto &[operandLattice, resultLattice] :
+ llvm::zip(terminatorOperandLattices, parentResultLattices)) {
+ propagateIfChanged(resultLattice, resultLattice->join(*operandLattice));
+ }
+ }
+
+ return success();
+ }
+
/// Given an operation with possible region control-flow, the lattices of the
/// operands, and a region successor, compute the lattice values for block
/// arguments that are not accounted for by the branching control flow (ex.
@@ -370,6 +423,18 @@ class SparseForwardDataFlowAnalysis
argLattices.size()},
firstIndex);
}
+ LogicalResult visitNonControlFlowTerminatorImpl(
+ Operation *terminatorOp,
+ ArrayRef<const AbstractSparseLattice *> terminatorOperandLattices,
+ ArrayRef<AbstractSparseLattice *> parentResultLattices) override {
+ return visitNonControlFlowTerminator(
+ terminatorOp,
+ {reinterpret_cast<const StateT *const *>(
+ terminatorOperandLattices.begin()),
+ terminatorOperandLattices.size()},
+ {reinterpret_cast<StateT *const *>(parentResultLattices.begin()),
+ parentResultLattices.size()});
+ }
void setToEntryState(AbstractSparseLattice *lattice) override {
return setToEntryState(reinterpret_cast<StateT *>(lattice));
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d14042493..0e4a75f4fa25c 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -7,6 +7,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+
+#include <cassert>
+#include <optional>
+
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Attributes.h"
@@ -20,8 +24,6 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
-#include <cassert>
-#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
@@ -94,23 +96,32 @@ AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
LogicalResult
AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
- // Exit early on operations with no results.
- if (op->getNumResults() == 0)
- return success();
-
// If the containing block is not executable, bail out.
if (op->getBlock() != nullptr &&
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
return success();
- // Get the result lattices.
- SmallVector<AbstractSparseLattice *> resultLattices;
- resultLattices.reserve(op->getNumResults());
- for (Value result : op->getResults()) {
- AbstractSparseLattice *resultLattice = getLatticeElement(result);
- resultLattices.push_back(resultLattice);
+ // Region terminators which are not part of control flow have a special
+ // transfer function.
+ if (op->hasTrait<OpTrait::IsTerminator>()) {
+ Operation *parentOp = op->getParentOp();
+ if (parentOp && !isa<RegionBranchOpInterface>(parentOp) &&
+ !isa<RegionBranchTerminatorOpInterface>(op) &&
+ parentOp->getNumResults() > 0) {
+ SmallVector<const AbstractSparseLattice *> operandLattices =
+ getOperandLattices(op);
+ SmallVector<AbstractSparseLattice *> parentResultLattices =
+ getResultLattices(parentOp);
+ return visitNonControlFlowTerminatorImpl(op, operandLattices,
+ parentResultLattices);
+ }
}
+ if (op->getNumResults() == 0)
+ return success();
+
+ SmallVector<AbstractSparseLattice *> resultLattices = getResultLattices(op);
+
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionSuccessors(getProgramPointAfter(branch), branch,
@@ -119,14 +130,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
return success();
}
- // Grab the lattice elements of the operands.
- SmallVector<const AbstractSparseLattice *> operandLattices;
- operandLattices.reserve(op->getNumOperands());
- for (Value operand : op->getOperands()) {
- AbstractSparseLattice *operandLattice = getLatticeElement(operand);
- operandLattice->useDefSubscribe(this);
- operandLattices.push_back(operandLattice);
- }
+ SmallVector<const AbstractSparseLattice *> operandLattices =
+ getOperandLattices(op);
if (auto call = dyn_cast<CallOpInterface>(op)) {
// If the call operation is to an external function, attempt to infer the
diff --git a/mlir/test/Analysis/DataFlow/test-return-like.mlir b/mlir/test/Analysis/DataFlow/test-return-like.mlir
new file mode 100644
index 0000000000000..24adb59f079d8
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-return-like.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt --test-integer-lattice %s | FileCheck %s
+
+// CHECK-LABEL: @test_returnlike
+// CHECK: analysis_return_like_region_op
+// CHECK-NEXT: arith.constant {test.operand_lattices = [], test.result_lattices = [0 : index]} 1 : i32
+// CHECK-NEXT: region_yield
+// CHECK-SAME: {test.operand_lattices = [0 : index], test.result_lattices = []}
+
+// The core of the return-like test: the operand lattices of the yield forward
+// to the result lattices of the enclosing region-holding op
+
+// CHECK-NEXT: }) {test.operand_lattices = [], test.result_lattices = [0 : index]} : () -> i32
+func.func @test_returnlike() {
+ %0 = "test.analysis_return_like_region_op"() ({
+ %0 = arith.constant 1 : i32
+ "test.region_yield" (%0) : (i32) -> ()
+ }) : () -> i32
+ return
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index 91879981bffd2..e19025084bb8f 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_library(MLIRTestAnalysis
DataFlow/TestDenseForwardDataFlowAnalysis.cpp
DataFlow/TestLivenessAnalysis.cpp
DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+ DataFlow/TestSparseForwardDataFlowAnalysis.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..8f3794da19ac3
--- /dev/null
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseForwardDataFlowAnalysis.cpp
@@ -0,0 +1,141 @@
+//===- TestForwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+namespace {
+
+class IntegerState {
+public:
+ IntegerState() : value(0) {}
+ explicit IntegerState(int value) : value(value) {}
+ ~IntegerState() = default;
+
+ int get() const { return value; }
+
+ bool operator==(const IntegerState &rhs) const { return value == rhs.value; }
+
+ static IntegerState join(const IntegerState &lhs, const IntegerState &rhs) {
+ return IntegerState{std::max(lhs.get(), rhs.get())};
+ }
+
+ void print(llvm::raw_ostream &os) const {
+ os << "IntegerState(" << value << ")";
+ }
+
+ friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const IntegerState &state) {
+ state.print(os);
+ return os;
+ }
+
+private:
+ int value;
+};
+
+/// This lattice represents, for a given value, the set of memory resources that
+/// this value, or anything derived from this value, is potentially written to.
+struct IntegerLattice : public Lattice<IntegerState> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IntegerLattice)
+ using Lattice::Lattice;
+};
+
+/// An analysis that, by going backwards along the dataflow graph, annotates
+/// each value with all the memory resources it (or anything derived from it)
+/// is eventually written to.
+class IntegerLatticeAnalysis
+ : public SparseForwardDataFlowAnalysis<IntegerLattice> {
+public:
+ using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
+
+ LogicalResult visitOperation(Operation *op,
+ ArrayRef<const IntegerLattice *> operands,
+ ArrayRef<IntegerLattice *> results) override;
+
+ void setToEntryState(IntegerLattice *lattice) override {
+ propagateIfChanged(lattice, lattice->join(IntegerState()));
+ }
+};
+
+LogicalResult IntegerLatticeAnalysis::visitOperation(
+ Operation *op, ArrayRef<const IntegerLattice *> operands,
+ ArrayRef<IntegerLattice *> results) {
+ for (auto *operand : operands) {
+ for (auto *result : results) {
+ propagateIfChanged(result, result->join(*operand));
+ }
+ }
+ return success();
+}
+
+} // end anonymous namespace
+
+namespace {
+struct TestIntegerLatticePass
+ : public PassWrapper<TestIntegerLatticePass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntegerLatticePass)
+
+ TestIntegerLatticePass() = default;
+ TestIntegerLatticePass(const TestIntegerLatticePass &other)
+ : PassWrapper(other) {}
+
+ StringRef getArgument() const override { return "test-integer-lattice"; }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ MLIRContext *ctx = &getContext();
+
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ solver.load<IntegerLatticeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+
+ // Walk the IR and attach operand and result lattices as attributes to each
+ // operation.
+ op->walk([&](Operation *op) {
+ SmallVector<Attribute> operandAttrs;
+ SmallVector<Attribute> resultAttrs;
+ for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
+ const IntegerLattice *lattice =
+ solver.lookupState<IntegerLattice>(operand);
+ assert(lattice && "expected a sparse lattice");
+ operandAttrs.push_back(
+ IntegerAttr::get(IndexType::get(ctx), lattice->getValue().get()));
+ }
+ for (auto [index, result] : llvm::enumerate(op->getResults())) {
+ const IntegerLattice *lattice =
+ solver.lookupState<IntegerLattice>(result);
+ assert(lattice && "expected a sparse lattice");
+ resultAttrs.push_back(
+ IntegerAttr::get(IndexType::get(ctx), lattice->getValue().get()));
+ }
+
+ op->setAttr("test.operand_lattices", ArrayAttr::get(ctx, operandAttrs));
+ op->setAttr("test.result_lattices", ArrayAttr::get(ctx, resultAttrs));
+ });
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIntegerLatticePass() {
+ PassRegistration<TestIntegerLatticePass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 43a0bdaf86cf3..dc594c622c67f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3507,4 +3507,14 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
}];
}
+// ==------------------------------------------------------------------------===//
+// Test Analysis ReturnLike
+//===----------------------------------------------------------------------===//
+def AnalysisReturnLikeRegionOp : TEST_Op<"analysis_return_like_region_op",
+ [SingleBlockImplicitTerminator<"RegionYieldOp">]> {
+ let regions = (region AnyRegion:$region);
+ let results = (outs AnyType:$result);
+}
+
+
#endif // TEST_OPS
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index cdcf59b2add13..0b67c9ce0ffe4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -105,6 +105,7 @@ void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIRVisitorsPass();
void registerTestGenericIRVisitorsPass();
+void registerTestIntegerLatticePass();
void registerTestInterfaces();
void registerTestIRVisitorsPass();
void registerTestLastModifiedPass();
@@ -249,6 +250,7 @@ void registerTestPasses() {
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestGenericIRVisitorsPass();
+ mlir::test::registerTestIntegerLatticePass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestLastModifiedPass();
More information about the Mlir-commits
mailing list