[Mlir-commits] [mlir] 232f8ea - [MLIR][analysis] Fix call op handling in sparse backward dataflow
Srishti Srivastava
llvmlistbot at llvm.org
Fri Aug 11 10:27:04 PDT 2023
Author: Srishti Srivastava
Date: 2023-08-11T17:26:58Z
New Revision: 232f8eadae18889627bdca75c45e98b0c1460086
URL: https://github.com/llvm/llvm-project/commit/232f8eadae18889627bdca75c45e98b0c1460086
DIFF: https://github.com/llvm/llvm-project/commit/232f8eadae18889627bdca75c45e98b0c1460086.diff
LOG: [MLIR][analysis] Fix call op handling in sparse backward dataflow
Currently, data in `AbstractSparseBackwardDataFlowAnalysis` is
considered to flow one-to-one, in order, from the operands of an op
implementing `CallOpInterface` to the arguments of the function it is
calling.
This understanding of the data flow is inaccurate. The operands of such
an op that forward to the function arguments are obtained using a
method provided by `CallOpInterface` called `getArgOperands()`.
This commit fixes this bug by using `getArgOperands()` instead of
`getOperands()` to get the mapping from operands to function arguments
because not all operands necessarily forward to the function arguments
and even if they do, they don't necessarily have to be in the order in
which they appear in the op. The operands that don't get forwarded are
handled by the newly introduced `visitCallOperand()` function, which
works analogous to the `visitBranchOperand()` function.
This fix is also propagated to liveness analysis that earlier relied on
this incorrect implementation of the sparse backward dataflow analysis
framework and corrects some incorrect assumptions made in it.
Extra cleanup: Improved a comment and removed an unnecessary code line.
Signed-off-by: Srishti Srivastava <srishtisrivastava.ai at gmail.com>
Reviewed By: matthiaskramm, jcai19
Differential Revision: https://reviews.llvm.org/D157261
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
mlir/test/Analysis/DataFlow/test-written-to.mlir
mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index c27b9beb68dbe5..caa03e26a3a423 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -43,8 +43,9 @@ namespace mlir::dataflow {
///
/// A value "has memory effects" iff it:
/// (1.a) is an operand of an op with memory effects OR
-/// (1.b) is a non-forwarded branch operand and a block where its op could
-/// take the control has an op with memory effects.
+/// (1.b) is a non-forwarded branch operand and its branch op could take the
+/// control to a block that has an op with memory effects OR
+/// (1.c) is a non-forwarded call operand.
///
/// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
/// computed in the absence of `A`. Thus, in this implementation, we say that
@@ -83,6 +84,8 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
void visitBranchOperand(OpOperand &operand) override;
+ void visitCallOperand(OpOperand &operand) override;
+
void setToExitState(Liveness *lattice) override;
};
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index b8514481a044c0..13dacff3aa0422 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -363,9 +363,12 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
- // Visit operands on branch instructions that are not forwarded
+ // Visit operands on branch instructions that are not forwarded.
virtual void visitBranchOperand(OpOperand &operand) = 0;
+ // Visit operands on call instructions that are not forwarded.
+ virtual void visitCallOperand(OpOperand &operand) = 0;
+
/// Set the given lattice element(s) at control flow exit point(s).
virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 0bcfb332207742..fd65a564f3c4e1 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -14,6 +14,7 @@
#include <mlir/Analysis/DataFlowFramework.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Value.h>
+#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LLVM.h>
@@ -54,8 +55,9 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
///
/// A value "has memory effects" iff it:
/// (1.a) is an operand of an op with memory effects OR
-/// (1.b) is a non-forwarded branch operand and a block where its op could
-/// take the control has an op with memory effects.
+/// (1.b) is a non-forwarded branch operand and its branch op could take the
+/// control to a block that has an op with memory effects OR
+/// (1.c) is a non-forwarded call operand.
///
/// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
/// computed in the absence of `A`. Thus, in this implementation, we say that
@@ -149,8 +151,6 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
// Now that we have checked for memory-effecting ops in the blocks of concern,
// we will simply visit the op with this non-forwarded operand to potentially
// mark it "live" due to type (1.a/3) liveness.
- if (operand.getOperandNumber() > 0)
- return;
SmallVector<Liveness *, 4> operandLiveness;
operandLiveness.push_back(getLatticeElement(operand.get()));
SmallVector<const Liveness *, 4> resultsLiveness;
@@ -171,6 +171,22 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
visitOperation(parentOp, operandLiveness, parentResultsLiveness);
}
+void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
+ // We know (at the moment) and assume (for the future) that `operand` is a
+ // non-forwarded call operand of an op implementing `CallOpInterface`.
+ assert(isa<CallOpInterface>(operand.getOwner()) &&
+ "expected the op to implement `CallOpInterface`");
+
+ // The lattices of the non-forwarded call operands don't get updated like the
+ // forwarded call operands or the non-call operands. Thus they need to be
+ // handled separately. This is where we handle them.
+
+ // This marks values of type (1.c) liveness as "live". A non-forwarded
+ // call operand is live.
+ Liveness *operandLiveness = getLatticeElement(operand.get());
+ propagateIfChanged(operandLiveness, operandLiveness->markLive());
+}
+
void LivenessAnalysis::setToExitState(Liveness *lattice) {
// This marks values of type (2) liveness as "live".
lattice->markLive();
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index f8bd754092023d..4708cdb042f126 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -412,19 +412,34 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
return;
}
- // For function calls, connect the arguments of the entry blocks
- // to the operands of the call op.
+ // For function calls, connect the arguments of the entry blocks to the
+ // operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
Operation *callableOp = call.resolveCallable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+ // Not all operands of a call op forward to arguments. Such operands are
+ // stored in `unaccounted`.
+ BitVector unaccounted(op->getNumOperands(), true);
+
+ OperandRange argOperands = call.getArgOperands();
+ MutableArrayRef<OpOperand> argOpOperands =
+ operandsToOpOperands(argOperands);
Region *region = callable.getCallableRegion();
if (region && !region->empty()) {
Block &block = region->front();
- for (auto [blockArg, operand] :
- llvm::zip(block.getArguments(), operandLattices)) {
- meet(operand, *getLatticeElementFor(op, blockArg));
+ for (auto [blockArg, argOpOperand] :
+ llvm::zip(block.getArguments(), argOpOperands)) {
+ meet(getLatticeElement(argOpOperand.get()),
+ *getLatticeElementFor(op, blockArg));
+ unaccounted.reset(argOpOperand.getOperandNumber());
}
}
+ // Handle the operands of the call op that aren't forwarded to any
+ // arguments.
+ for (int index : unaccounted.set_bits()) {
+ OpOperand &opOperand = op->getOpOperand(index);
+ visitCallOperand(opOperand);
+ }
return;
}
}
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index a040fb3961a9d3..b6aed1c0b054eb 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -59,11 +59,27 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %ar
// -----
+func.func private @private(%arg0 : i32, %arg1 : i32) {
+ func.return
+}
+
+// Positive test: Type (1.c) "is a non-forwarded call operand"
+// CHECK-LABEL: test_tag: call
+// CHECK-LABEL: operand #0: not live
+// CHECK-LABEL: operand #1: not live
+// CHECK-LABEL: operand #2: live
+func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
+ test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> ()
+ return
+}
+
+// -----
+
// Positive test: Type (2) "is returned by a public function"
// zero is live because it is returned by a public function.
// CHECK-LABEL: test_tag: zero:
// CHECK-NEXT: result #0: live
-func.func @test_4_type_2() -> (f32){
+func.func @test_5_type_2() -> (f32){
%0 = arith.constant {tag = "zero"} 0.0 : f32
return %0 : f32
}
@@ -90,7 +106,7 @@ func.func @test_4_type_2() -> (f32){
// CHECK-NEXT: operand #3: live
// CHECK-LABEL: test_tag: add:
// CHECK-NEXT: operand #0: live
-func.func @test_5_RegionBranchTerminatorOpInterface_type_3(%arg0: memref<i32>, %arg1: i1) -> (i32) {
+func.func @test_6_RegionBranchTerminatorOpInterface_type_3(%arg0: memref<i32>, %arg1: i1) -> (i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
@@ -135,7 +151,7 @@ func.func private @private0(%0 : i32) -> i32 {
// CHECK-NEXT: result #0: live
// CHECK-LABEL: test_tag: y:
// CHECK-NEXT: result #0: not live
-func.func @test_6_type_3(%arg0: memref<i32>) {
+func.func @test_7_type_3(%arg0: memref<i32>) {
%c0 = arith.constant {tag = "zero"} 0 : index
%c10 = arith.constant {tag = "ten"} 10 : index
%c1 = arith.constant {tag = "one"} 1 : index
@@ -190,7 +206,7 @@ func.func private @private2(%0 : i32) -> i32 {
// CHECK-NEXT: operand #0: live
// CHECK-NEXT: operand #1: live
// CHECK-NEXT: result #0: live
-func.func @test_7_type_3(%arg: i32) -> (i32) {
+func.func @test_8_type_3(%arg: i32) -> (i32) {
%0 = func.call @private1(%arg) : (i32) -> i32
%final = arith.muli %0, %arg {tag = "final"} : i32
return %final : i32
@@ -205,7 +221,7 @@ func.func @test_7_type_3(%arg: i32) -> (i32) {
// CHECK-NEXT: result #0: not live
// CHECK-LABEL: test_tag: one:
// CHECK-NEXT: result #0: live
-func.func @test_8_negative() -> (f32){
+func.func @test_9_negative() -> (f32){
%0 = arith.constant {tag = "zero"} 0.0 : f32
%1 = arith.constant {tag = "one"} 1.0 : f32
return %1 : f32
@@ -230,7 +246,7 @@ func.func private @private_1() -> (i32, i32) {
%1 = arith.addi %0, %0 {tag = "one"} : i32
return %0, %1 : i32, i32
}
-func.func @test_9_negative() -> (i32) {
+func.func @test_10_negative() -> (i32) {
%0:2 = func.call @private_1() : () -> (i32, i32)
return %0#0 : i32
}
diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir
index 1ff92f56a1a80c..82fe755aaf5d46 100644
--- a/mlir/test/Analysis/DataFlow/test-written-to.mlir
+++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir
@@ -286,4 +286,21 @@ llvm.func @decl(i64)
llvm.func @func(%lb : i64) -> () {
llvm.call @decl(%lb) : (i64) -> ()
llvm.return
-}
+}
+
+// -----
+
+func.func private @callee(%arg0 : i32, %arg1 : i32) -> i32 {
+ func.return %arg0 : i32
+}
+
+// CHECK-LABEL: test_tag: a
+// CHECK-LABEL: operand #0: [b]
+// CHECK-LABEL: operand #1: []
+// CHECK-LABEL: operand #2: [callarg2]
+// CHECK-LABEL: result #0: [b]
+func.func @test_call_on_device(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
+ %0 = test.call_on_device @callee(%arg0, %arg1), %device {tag = "a"} : (i32, i32, i32) -> (i32)
+ memref.store %0, %m0[] {tag_name = "b"} : memref<i32>
+ return
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index 8af04a1c38559c..f97a4c8bc5eb3e 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -57,6 +57,8 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
void visitBranchOperand(OpOperand &operand) override;
+ void visitCallOperand(OpOperand &operand) override;
+
void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
};
@@ -87,6 +89,16 @@ void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
propagateIfChanged(lattice, lattice->addWrites(newWrites));
}
+void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
+ // Mark call operands as "callarg%d", with %d the operand number.
+ WrittenTo *lattice = getLatticeElement(operand.get());
+ SetVector<StringAttr> newWrites;
+ newWrites.insert(
+ StringAttr::get(operand.getOwner()->getContext(),
+ "callarg" + Twine(operand.getOperandNumber())));
+ propagateIfChanged(lattice, lattice->addWrites(newWrites));
+}
+
} // end anonymous namespace
namespace {
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 6f3e33052372e8..57a6ab387281dc 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1296,6 +1296,22 @@ MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
return getCalleeOperandsMutable();
}
+CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
+ return getCallee();
+}
+
+void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ setCalleeAttr(callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
+ return getForwardedOperands();
+}
+
+MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
+ return getForwardedOperandsMutable();
+}
+
void TestStoreWithARegion::getSuccessorRegions(
std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
if (!index) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0b121d7a185c7d..6993631fc818d8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2834,6 +2834,21 @@ def TestCallAndStoreOp : TEST_Op<"call_and_store",
"`:` functional-type(operands, results)";
}
+def TestCallOnDeviceOp : TEST_Op<"call_on_device",
+ [DeclareOpInterfaceMethods<CallOpInterface>]> {
+ let arguments = (ins
+ SymbolRefAttr:$callee,
+ Variadic<AnyType>:$forwarded_operands,
+ AnyType:$non_forwarded_device_operand
+ );
+ let results = (outs
+ Variadic<AnyType>:$results
+ );
+ let assemblyFormat =
+ "$callee `(` $forwarded_operands `)` `,` $non_forwarded_device_operand "
+ "attr-dict `:` functional-type(operands, results)";
+}
+
def TestStoreWithARegion : TEST_Op<"store_with_a_region",
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlock]> {
More information about the Mlir-commits
mailing list