[flang-commits] [flang] [flang] Fix fir.call setCalleeFromCallable (PR #187124)

via flang-commits flang-commits at lists.llvm.org
Tue Mar 17 13:49:23 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>

The CallOpInterface setCalleeFromCallable allows either value or SymbolRef to be passed in. However, the implementation showed an issue because while it was able to set attribute, it would fall-through and also try to set value.

This PR improves the implementation to handle updating the callee even when switching modes (direct vs indirect) and adds testing for these APIs.

---
Full diff: https://github.com/llvm/llvm-project/pull/187124.diff


3 Files Affected:

- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+17-5) 
- (modified) flang/unittests/Optimizer/CMakeLists.txt (+1) 
- (added) flang/unittests/Optimizer/FIRCallInterfaceTest.cpp (+170) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2a849a98903e6..1bf27a6e1fe43 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2701,11 +2701,23 @@ def fir_CallOp : fir_Op<"call",
 
     /// Set the callee for this operation.
     void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
-      if (auto calling =
-          (*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
-        (*this)->setAttr(getCalleeAttrName(),
-                         llvm::cast<mlir::SymbolRefAttr>(callee));
-      setOperand(0, llvm::cast<mlir::Value>(callee));
+      if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
+        // Switching to direct call: set attribute and remove callee operand
+        // if the op was in indirect form (operand 0 was the callable value).
+        bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+        (*this)->setAttr(getCalleeAttrName(), symbolRef);
+        if (wasIndirect && getNumOperands() > 0)
+          (*this)->eraseOperand(0);
+        return;
+      }
+      // Switching to indirect call: unset attribute, then either insert
+      // operand 0 (was direct, had no operands) or set it (was already indirect).
+      (*this)->removeAttr(getCalleeAttrNameStr());
+      mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
+      if (getNumOperands() == 0)
+        (*this)->insertOperands(0, calleeVal);
+      else
+        setOperand(0, calleeVal);
     }
   }];
 }
diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt
index c390add457632..8c2fb1c4dc850 100644
--- a/flang/unittests/Optimizer/CMakeLists.txt
+++ b/flang/unittests/Optimizer/CMakeLists.txt
@@ -34,6 +34,7 @@ add_flang_unittest(FlangOptimizerTests
   Builder/Runtime/ReductionTest.cpp
   Builder/Runtime/StopTest.cpp
   Builder/Runtime/TransformationalTest.cpp
+  FIRCallInterfaceTest.cpp
   FIRContextTest.cpp
   FIRTypesTest.cpp
   FortranVariableTest.cpp
diff --git a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
new file mode 100644
index 0000000000000..19374df5cc194
--- /dev/null
+++ b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
@@ -0,0 +1,170 @@
+//===- FIRCallInterfaceTest.cpp - fir::CallOp setCalleeFromCallable tests -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Tests for CallOpInterface on fir::CallOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Support/InitFIR.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+
+static bool isSymbolRef(mlir::CallInterfaceCallable callable) {
+  return llvm::isa<SymbolRefAttr>(callable);
+}
+static bool isValue(mlir::CallInterfaceCallable callable) {
+  return llvm::isa<Value>(callable);
+}
+
+struct FIRCallInterfaceTest : public testing::Test {
+  void SetUp() override { fir::support::loadDialects(context); }
+
+  MLIRContext context;
+};
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  auto func = func::FuncOp::create(builder, loc, "target", funcType);
+  func.setPrivate();
+  func.getBody().push_back(new Block);
+  builder.setInsertionPointToStart(&func.getBody().front());
+  func::ReturnOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  // Direct call: fir.call @target()
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(
+      builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+  // Change to another symbol; should remain direct with no extra operand.
+  auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other");
+  callOp.setCalleeFromCallable(newCallTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "other");
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  // Container has one argument: procedure pointer () -> ()
+  auto containerType = builder.getFunctionType({funcType}, {});
+  auto func = func::FuncOp::create(builder, loc, "container", containerType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  // Indirect call: fir.call %arg0()
+  Value callTargetValue = block->getArgument(0);
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callTargetValue});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+
+  // Switch to direct call; operand 0 must be removed.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "direct_target");
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  auto containerType = builder.getFunctionType({funcType}, {});
+  auto func = func::FuncOp::create(builder, loc, "container", containerType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  // Direct call first
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(
+      builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+  // Switch to indirect; attribute must be unset, operand 0 set.
+  Value callTargetValue = block->getArgument(0);
+  callOp.setCalleeFromCallable(callTargetValue);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTargetValue);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  // Container has two arguments: procedure pointers () -> ()
+  auto containerType = builder.getFunctionType({funcType, funcType}, {});
+  auto func = func::FuncOp::create(builder, loc, "container", containerType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Value callTarget0 = block->getArgument(0);
+  Value callTarget1 = block->getArgument(1);
+
+  // Indirect call: fir.call %arg0()
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callTarget0});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTarget0);
+
+  // Switch to other indirect call target; should remain indirect, operand 0
+  // updated.
+  callOp.setCalleeFromCallable(callTarget1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTarget1);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/187124


More information about the flang-commits mailing list