[flang-commits] [flang] [flang] Fix fir.call setCalleeFromCallable (PR #187124)
Razvan Lupusoru via flang-commits
flang-commits at lists.llvm.org
Tue Mar 17 13:48:45 PDT 2026
https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/187124
The CallOpInterface setCalleeFromCallable allows either value or SymbolRef to be passed in. However, the implementation showed an issue because while it was able to set attribute, it would fall-through and also try to set value.
This PR improves the implementation to handle updating the callee even when switching modes (direct vs indirect) and adds testing for these APIs.
>From f8261cfb0160f6254298cd97258a5946fa1cfdbf Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 17 Mar 2026 13:47:28 -0700
Subject: [PATCH] [flang] Fix fir.call setCalleeFromCallable
The CallOpInterface setCalleeFromCallable allows
either value or SymbolRef to be passed in. However, the
implementation showed an issue because while it was able
to set attribute, it would fall-through and also try
to set value.
This PR improves the implementation to handle updating
the callee even when switching modes (direct vs indirect)
and adds testing for these APIs.
---
.../include/flang/Optimizer/Dialect/FIROps.td | 22 ++-
flang/unittests/Optimizer/CMakeLists.txt | 1 +
.../Optimizer/FIRCallInterfaceTest.cpp | 170 ++++++++++++++++++
3 files changed, 188 insertions(+), 5 deletions(-)
create mode 100644 flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2a849a98903e6..1bf27a6e1fe43 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2701,11 +2701,23 @@ def fir_CallOp : fir_Op<"call",
/// Set the callee for this operation.
void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
- if (auto calling =
- (*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
- (*this)->setAttr(getCalleeAttrName(),
- llvm::cast<mlir::SymbolRefAttr>(callee));
- setOperand(0, llvm::cast<mlir::Value>(callee));
+ if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
+ // Switching to direct call: set attribute and remove callee operand
+ // if the op was in indirect form (operand 0 was the callable value).
+ bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+ (*this)->setAttr(getCalleeAttrName(), symbolRef);
+ if (wasIndirect && getNumOperands() > 0)
+ (*this)->eraseOperand(0);
+ return;
+ }
+ // Switching to indirect call: unset attribute, then either insert
+ // operand 0 (was direct, had no operands) or set it (was already indirect).
+ (*this)->removeAttr(getCalleeAttrNameStr());
+ mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
+ if (getNumOperands() == 0)
+ (*this)->insertOperands(0, calleeVal);
+ else
+ setOperand(0, calleeVal);
}
}];
}
diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt
index c390add457632..8c2fb1c4dc850 100644
--- a/flang/unittests/Optimizer/CMakeLists.txt
+++ b/flang/unittests/Optimizer/CMakeLists.txt
@@ -34,6 +34,7 @@ add_flang_unittest(FlangOptimizerTests
Builder/Runtime/ReductionTest.cpp
Builder/Runtime/StopTest.cpp
Builder/Runtime/TransformationalTest.cpp
+ FIRCallInterfaceTest.cpp
FIRContextTest.cpp
FIRTypesTest.cpp
FortranVariableTest.cpp
diff --git a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
new file mode 100644
index 0000000000000..19374df5cc194
--- /dev/null
+++ b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
@@ -0,0 +1,170 @@
+//===- FIRCallInterfaceTest.cpp - fir::CallOp setCalleeFromCallable tests -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Tests for CallOpInterface on fir::CallOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Support/InitFIR.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+
+static bool isSymbolRef(mlir::CallInterfaceCallable callable) {
+ return llvm::isa<SymbolRefAttr>(callable);
+}
+static bool isValue(mlir::CallInterfaceCallable callable) {
+ return llvm::isa<Value>(callable);
+}
+
+struct FIRCallInterfaceTest : public testing::Test {
+ void SetUp() override { fir::support::loadDialects(context); }
+
+ MLIRContext context;
+};
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
+ OpBuilder builder(&context);
+ auto loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(builder, loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto funcType = builder.getFunctionType({}, {});
+ auto func = func::FuncOp::create(builder, loc, "target", funcType);
+ func.setPrivate();
+ func.getBody().push_back(new Block);
+ builder.setInsertionPointToStart(&func.getBody().front());
+ func::ReturnOp::create(builder, loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ // Direct call: fir.call @target()
+ auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+ auto callOp = fir::CallOp::create(
+ builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+ ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+ EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+ // Change to another symbol; should remain direct with no extra operand.
+ auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other");
+ callOp.setCalleeFromCallable(newCallTargetRef);
+
+ EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+ EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+ .getRootReference()
+ .getValue(),
+ "other");
+ EXPECT_EQ(callOp.getNumOperands(), 0u);
+ EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
+ OpBuilder builder(&context);
+ auto loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(builder, loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto funcType = builder.getFunctionType({}, {});
+ // Container has one argument: procedure pointer () -> ()
+ auto containerType = builder.getFunctionType({funcType}, {});
+ auto func = func::FuncOp::create(builder, loc, "container", containerType);
+ func.setPrivate();
+ Block *block = func.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ // Indirect call: fir.call %arg0()
+ Value callTargetValue = block->getArgument(0);
+ auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+ llvm::ArrayRef<mlir::Type>{}, ValueRange{callTargetValue});
+ ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+ EXPECT_EQ(callOp.getNumOperands(), 1u);
+ EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+
+ // Switch to direct call; operand 0 must be removed.
+ auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+ callOp.setCalleeFromCallable(callTargetRef);
+
+ EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+ EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+ .getRootReference()
+ .getValue(),
+ "direct_target");
+ EXPECT_EQ(callOp.getNumOperands(), 0u);
+ EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
+ OpBuilder builder(&context);
+ auto loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(builder, loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto funcType = builder.getFunctionType({}, {});
+ auto containerType = builder.getFunctionType({funcType}, {});
+ auto func = func::FuncOp::create(builder, loc, "container", containerType);
+ func.setPrivate();
+ Block *block = func.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ // Direct call first
+ auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+ auto callOp = fir::CallOp::create(
+ builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+ ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+ EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+ // Switch to indirect; attribute must be unset, operand 0 set.
+ Value callTargetValue = block->getArgument(0);
+ callOp.setCalleeFromCallable(callTargetValue);
+
+ EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+ EXPECT_EQ(callOp.getNumOperands(), 1u);
+ EXPECT_EQ(callOp.getOperand(0), callTargetValue);
+ EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
+ OpBuilder builder(&context);
+ auto loc = builder.getUnknownLoc();
+ ModuleOp module = ModuleOp::create(builder, loc);
+ builder.setInsertionPointToStart(module.getBody());
+
+ auto funcType = builder.getFunctionType({}, {});
+ // Container has two arguments: procedure pointers () -> ()
+ auto containerType = builder.getFunctionType({funcType, funcType}, {});
+ auto func = func::FuncOp::create(builder, loc, "container", containerType);
+ func.setPrivate();
+ Block *block = func.addEntryBlock();
+ builder.setInsertionPointToStart(block);
+
+ Value callTarget0 = block->getArgument(0);
+ Value callTarget1 = block->getArgument(1);
+
+ // Indirect call: fir.call %arg0()
+ auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+ llvm::ArrayRef<mlir::Type>{}, ValueRange{callTarget0});
+ ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+ EXPECT_EQ(callOp.getNumOperands(), 1u);
+ EXPECT_EQ(callOp.getOperand(0), callTarget0);
+
+ // Switch to other indirect call target; should remain indirect, operand 0
+ // updated.
+ callOp.setCalleeFromCallable(callTarget1);
+
+ EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+ EXPECT_EQ(callOp.getNumOperands(), 1u);
+ EXPECT_EQ(callOp.getOperand(0), callTarget1);
+ EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
More information about the flang-commits
mailing list