[flang-commits] [flang] [flang] Fix fir.call setCalleeFromCallable (PR #187124)

Razvan Lupusoru via flang-commits flang-commits at lists.llvm.org
Wed Mar 18 09:39:16 PDT 2026


https://github.com/razvanlupusoru updated https://github.com/llvm/llvm-project/pull/187124

>From f8261cfb0160f6254298cd97258a5946fa1cfdbf Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 17 Mar 2026 13:47:28 -0700
Subject: [PATCH 1/3] [flang] Fix fir.call setCalleeFromCallable

The CallOpInterface setCalleeFromCallable allows
either value or SymbolRef to be passed in. However, the
implementation showed an issue because while it was able
to set attribute, it would fall-through and also try
to set value.

This PR improves the implementation to handle updating
the callee even when switching modes (direct vs indirect)
and adds testing for these APIs.
---
 .../include/flang/Optimizer/Dialect/FIROps.td |  22 ++-
 flang/unittests/Optimizer/CMakeLists.txt      |   1 +
 .../Optimizer/FIRCallInterfaceTest.cpp        | 170 ++++++++++++++++++
 3 files changed, 188 insertions(+), 5 deletions(-)
 create mode 100644 flang/unittests/Optimizer/FIRCallInterfaceTest.cpp

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2a849a98903e6..1bf27a6e1fe43 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2701,11 +2701,23 @@ def fir_CallOp : fir_Op<"call",
 
     /// Set the callee for this operation.
     void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
-      if (auto calling =
-          (*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
-        (*this)->setAttr(getCalleeAttrName(),
-                         llvm::cast<mlir::SymbolRefAttr>(callee));
-      setOperand(0, llvm::cast<mlir::Value>(callee));
+      if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
+        // Switching to direct call: set attribute and remove callee operand
+        // if the op was in indirect form (operand 0 was the callable value).
+        bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+        (*this)->setAttr(getCalleeAttrName(), symbolRef);
+        if (wasIndirect && getNumOperands() > 0)
+          (*this)->eraseOperand(0);
+        return;
+      }
+      // Switching to indirect call: unset attribute, then either insert
+      // operand 0 (was direct, had no operands) or set it (was already indirect).
+      (*this)->removeAttr(getCalleeAttrNameStr());
+      mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
+      if (getNumOperands() == 0)
+        (*this)->insertOperands(0, calleeVal);
+      else
+        setOperand(0, calleeVal);
     }
   }];
 }
diff --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt
index c390add457632..8c2fb1c4dc850 100644
--- a/flang/unittests/Optimizer/CMakeLists.txt
+++ b/flang/unittests/Optimizer/CMakeLists.txt
@@ -34,6 +34,7 @@ add_flang_unittest(FlangOptimizerTests
   Builder/Runtime/ReductionTest.cpp
   Builder/Runtime/StopTest.cpp
   Builder/Runtime/TransformationalTest.cpp
+  FIRCallInterfaceTest.cpp
   FIRContextTest.cpp
   FIRTypesTest.cpp
   FortranVariableTest.cpp
diff --git a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
new file mode 100644
index 0000000000000..19374df5cc194
--- /dev/null
+++ b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
@@ -0,0 +1,170 @@
+//===- FIRCallInterfaceTest.cpp - fir::CallOp setCalleeFromCallable tests -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Tests for CallOpInterface on fir::CallOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Support/InitFIR.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+
+static bool isSymbolRef(mlir::CallInterfaceCallable callable) {
+  return llvm::isa<SymbolRefAttr>(callable);
+}
+static bool isValue(mlir::CallInterfaceCallable callable) {
+  return llvm::isa<Value>(callable);
+}
+
+struct FIRCallInterfaceTest : public testing::Test {
+  void SetUp() override { fir::support::loadDialects(context); }
+
+  MLIRContext context;
+};
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  auto func = func::FuncOp::create(builder, loc, "target", funcType);
+  func.setPrivate();
+  func.getBody().push_back(new Block);
+  builder.setInsertionPointToStart(&func.getBody().front());
+  func::ReturnOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  // Direct call: fir.call @target()
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(
+      builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+  // Change to another symbol; should remain direct with no extra operand.
+  auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other");
+  callOp.setCalleeFromCallable(newCallTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "other");
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  // Container has one argument: procedure pointer () -> ()
+  auto containerType = builder.getFunctionType({funcType}, {});
+  auto func = func::FuncOp::create(builder, loc, "container", containerType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  // Indirect call: fir.call %arg0()
+  Value callTargetValue = block->getArgument(0);
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callTargetValue});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+
+  // Switch to direct call; operand 0 must be removed.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "direct_target");
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  auto containerType = builder.getFunctionType({funcType}, {});
+  auto func = func::FuncOp::create(builder, loc, "container", containerType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  // Direct call first
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(
+      builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+  // Switch to indirect; attribute must be unset, operand 0 set.
+  Value callTargetValue = block->getArgument(0);
+  callOp.setCalleeFromCallable(callTargetValue);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTargetValue);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto funcType = builder.getFunctionType({}, {});
+  // Container has two arguments: procedure pointers () -> ()
+  auto containerType = builder.getFunctionType({funcType, funcType}, {});
+  auto func = func::FuncOp::create(builder, loc, "container", containerType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+
+  Value callTarget0 = block->getArgument(0);
+  Value callTarget1 = block->getArgument(1);
+
+  // Indirect call: fir.call %arg0()
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callTarget0});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTarget0);
+
+  // Switch to other indirect call target; should remain indirect, operand 0
+  // updated.
+  callOp.setCalleeFromCallable(callTarget1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTarget1);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}

>From 54aa52ff442f3e99ccc940795b33af959b94b0a2 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 17 Mar 2026 15:21:05 -0700
Subject: [PATCH 2/3] Add support and test for cases with additional arguments

---
 .../include/flang/Optimizer/Dialect/FIROps.td |  10 +-
 .../Optimizer/FIRCallInterfaceTest.cpp        | 146 ++++++++++++++----
 2 files changed, 120 insertions(+), 36 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 1bf27a6e1fe43..7292acdca538f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2711,13 +2711,15 @@ def fir_CallOp : fir_Op<"call",
         return;
       }
       // Switching to indirect call: unset attribute, then either insert
-      // operand 0 (was direct, had no operands) or set it (was already indirect).
+      // operand 0 (was direct: rest of operands are arguments) or set it
+      // (was already indirect: operand 0 is the callee).
+      bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
       (*this)->removeAttr(getCalleeAttrNameStr());
       mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
-      if (getNumOperands() == 0)
-        (*this)->insertOperands(0, calleeVal);
-      else
+      if (wasIndirect)
         setOperand(0, calleeVal);
+      else
+        (*this)->insertOperands(0, calleeVal);
     }
   }];
 }
diff --git a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
index 19374df5cc194..ea785e7126338 100644
--- a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
+++ b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
@@ -23,6 +23,8 @@
 
 using namespace mlir;
 
+namespace {
+
 static bool isSymbolRef(mlir::CallInterfaceCallable callable) {
   return llvm::isa<SymbolRefAttr>(callable);
 }
@@ -30,6 +32,22 @@ static bool isValue(mlir::CallInterfaceCallable callable) {
   return llvm::isa<Value>(callable);
 }
 
+/// Creates a module and a function with entry block. Builder insertion point is
+/// set to the block start. Returns (func, block) so tests can create calls in
+/// the block and use block arguments as callee/args.
+std::pair<func::FuncOp, Block *> createModuleWithFunction(
+    OpBuilder &builder, Location loc, StringRef name, FunctionType funcType) {
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+  auto func = func::FuncOp::create(builder, loc, name, funcType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+  return {func, block};
+}
+
+} // namespace
+
 struct FIRCallInterfaceTest : public testing::Test {
   void SetUp() override { fir::support::loadDialects(context); }
 
@@ -39,16 +57,8 @@ struct FIRCallInterfaceTest : public testing::Test {
 TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
   OpBuilder builder(&context);
   auto loc = builder.getUnknownLoc();
-  ModuleOp module = ModuleOp::create(builder, loc);
-  builder.setInsertionPointToStart(module.getBody());
-
   auto funcType = builder.getFunctionType({}, {});
-  auto func = func::FuncOp::create(builder, loc, "target", funcType);
-  func.setPrivate();
-  func.getBody().push_back(new Block);
-  builder.setInsertionPointToStart(&func.getBody().front());
-  func::ReturnOp::create(builder, loc);
-  builder.setInsertionPointToStart(module.getBody());
+  (void)createModuleWithFunction(builder, loc, "target", funcType);
 
   // Direct call: fir.call @target()
   auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
@@ -73,16 +83,10 @@ TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
 TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
   OpBuilder builder(&context);
   auto loc = builder.getUnknownLoc();
-  ModuleOp module = ModuleOp::create(builder, loc);
-  builder.setInsertionPointToStart(module.getBody());
-
   auto funcType = builder.getFunctionType({}, {});
-  // Container has one argument: procedure pointer () -> ()
   auto containerType = builder.getFunctionType({funcType}, {});
-  auto func = func::FuncOp::create(builder, loc, "container", containerType);
-  func.setPrivate();
-  Block *block = func.addEntryBlock();
-  builder.setInsertionPointToStart(block);
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
 
   // Indirect call: fir.call %arg0()
   Value callTargetValue = block->getArgument(0);
@@ -108,15 +112,10 @@ TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
 TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
   OpBuilder builder(&context);
   auto loc = builder.getUnknownLoc();
-  ModuleOp module = ModuleOp::create(builder, loc);
-  builder.setInsertionPointToStart(module.getBody());
-
   auto funcType = builder.getFunctionType({}, {});
   auto containerType = builder.getFunctionType({funcType}, {});
-  auto func = func::FuncOp::create(builder, loc, "container", containerType);
-  func.setPrivate();
-  Block *block = func.addEntryBlock();
-  builder.setInsertionPointToStart(block);
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
 
   // Direct call first
   auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
@@ -138,16 +137,10 @@ TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
 TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
   OpBuilder builder(&context);
   auto loc = builder.getUnknownLoc();
-  ModuleOp module = ModuleOp::create(builder, loc);
-  builder.setInsertionPointToStart(module.getBody());
-
   auto funcType = builder.getFunctionType({}, {});
-  // Container has two arguments: procedure pointers () -> ()
   auto containerType = builder.getFunctionType({funcType, funcType}, {});
-  auto func = func::FuncOp::create(builder, loc, "container", containerType);
-  func.setPrivate();
-  Block *block = func.addEntryBlock();
-  builder.setInsertionPointToStart(block);
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
 
   Value callTarget0 = block->getArgument(0);
   Value callTarget1 = block->getArgument(1);
@@ -168,3 +161,92 @@ TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
   EXPECT_EQ(callOp.getOperand(0), callTarget1);
   EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
 }
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect_withArgs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Direct call with one argument: fir.call @target(%arg)
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+
+  // Switch to indirect; callee must be inserted at 0, arg preserved.
+  callOp.setCalleeFromCallable(calleeVal);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), calleeVal);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect_withArgs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Indirect call with one argument: fir.call %callee(%arg)
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{calleeVal, argVal});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), calleeVal);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+
+  // Switch to direct; callee operand must be removed, arg preserved.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect_withArgs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value callee0 = block->getArgument(0);
+  Value callee1 = block->getArgument(1);
+  Value argVal = block->getArgument(2);
+
+  // Indirect call with one argument: fir.call %callee0(%arg)
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callee0, argVal});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), callee0);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+
+  // Switch to other indirect callee; operand 0 updated, arg preserved.
+  callOp.setCalleeFromCallable(callee1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), callee1);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}

>From d2741c81ec5352dece57d8d7b16d422310c459b8 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 18 Mar 2026 09:39:03 -0700
Subject: [PATCH 3/3] Add support to fix attributes

---
 .../include/flang/Optimizer/Dialect/FIROps.td |  22 +--
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  51 ++++++
 .../Optimizer/FIRCallInterfaceTest.cpp        | 162 ++++++++++++++++++
 3 files changed, 214 insertions(+), 21 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 7292acdca538f..776914bb9bbe8 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2700,27 +2700,7 @@ def fir_CallOp : fir_Op<"call",
     }
 
     /// Set the callee for this operation.
-    void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
-      if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
-        // Switching to direct call: set attribute and remove callee operand
-        // if the op was in indirect form (operand 0 was the callable value).
-        bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
-        (*this)->setAttr(getCalleeAttrName(), symbolRef);
-        if (wasIndirect && getNumOperands() > 0)
-          (*this)->eraseOperand(0);
-        return;
-      }
-      // Switching to indirect call: unset attribute, then either insert
-      // operand 0 (was direct: rest of operands are arguments) or set it
-      // (was already indirect: operand 0 is the callee).
-      bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
-      (*this)->removeAttr(getCalleeAttrNameStr());
-      mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
-      if (wasIndirect)
-        setOperand(0, calleeVal);
-      else
-        (*this)->insertOperands(0, calleeVal);
-    }
+    void setCalleeFromCallable(mlir::CallInterfaceCallable callee);
   }];
 }
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 17aa042bda60b..b82a1bfb6b17e 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1350,6 +1350,57 @@ void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
   result.addTypes(results);
 }
 
+void fir::CallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+  if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
+    // Handling a direct call.
+    bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+    (*this)->setAttr(getCalleeAttrName(), symbolRef);
+    // If it was indirect before, the operand list and associated attributes
+    // needs to be fixed up.
+    if (wasIndirect) {
+      assert(getNumOperands() > 0 && "indirect call must have callee operand");
+      (*this)->eraseOperand(0);
+      // Fix arg_attrs to remove the first (callee) operand if needed.
+      if (auto argAttrs = getArgAttrsAttr()) {
+        // Since we already removed the first operand, check that number
+        // of attributes is one more than number of operands.
+        assert(argAttrs.size() == getNumOperands() + 1 &&
+               "arg_attrs must be one-per-operand");
+        llvm::SmallVector<mlir::Attribute> newAttrs;
+        for (const mlir::Attribute *it = argAttrs.begin() + 1;
+             it != argAttrs.end(); ++it)
+          newAttrs.push_back(*it);
+        if (newAttrs.empty())
+          (*this)->removeAttr(getArgAttrsAttrName());
+        else
+          (*this)->setAttr(getArgAttrsAttrName(),
+                           mlir::ArrayAttr::get(getContext(), newAttrs));
+      }
+    }
+    return;
+  }
+  // The provided callee makes this an indirect call now.
+  bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+  (*this)->removeAttr(getCalleeAttrNameStr());
+  mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
+  if (wasIndirect) {
+    setOperand(0, calleeVal);
+  } else {
+    (*this)->insertOperands(0, calleeVal);
+    // Make arg_attrs by adding an empty dict for the callee.
+    if (auto argAttrs = getArgAttrsAttr()) {
+      assert(argAttrs.size() == getNumOperands() - 1 &&
+             "arg_attrs must be one-per-operand");
+      llvm::SmallVector<mlir::Attribute> newAttrs;
+      newAttrs.push_back(mlir::DictionaryAttr::get(getContext(), {}));
+      for (auto a : argAttrs)
+        newAttrs.push_back(a);
+      (*this)->setAttr(getArgAttrsAttrName(),
+                       mlir::ArrayAttr::get(getContext(), newAttrs));
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // CharConvertOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
index ea785e7126338..9151976d70875 100644
--- a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
+++ b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
@@ -250,3 +250,165 @@ TEST_F(
   EXPECT_EQ(callOp.getOperand(1), argVal);
   EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
 }
+
+static ArrayAttr makeArgAttrs(
+    MLIRContext *ctx, llvm::ArrayRef<DictionaryAttr> dicts) {
+  llvm::SmallVector<Attribute> attrs(dicts.begin(), dicts.end());
+  return ArrayAttr::get(ctx, attrs);
+}
+
+static DictionaryAttr makeTestArgDict(MLIRContext *ctx, StringRef value) {
+  return DictionaryAttr::get(ctx,
+      {NamedAttribute(
+          StringAttr::get(ctx, "test.attr"), StringAttr::get(ctx, value))});
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_directToDirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value argVal = block->getArgument(1);
+
+  // Direct call with one argument and arg_attrs.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context, {makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+
+  // Switch to another direct callee
+  auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other_target");
+  callOp.setCalleeFromCallable(newCallTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "other_target");
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 1u);
+  EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[0]).get("test.attr"),
+      StringAttr::get(&context, "arg0"));
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Direct call with one argument and arg_attrs for that argument.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context, {makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+
+  // Switch to indirect
+  callOp.setCalleeFromCallable(calleeVal);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), calleeVal);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 2u);
+  // First entry is empty dict for callee.
+  EXPECT_TRUE(llvm::cast<DictionaryAttr>(argAttrs[0]).empty());
+  // Second entry preserves the argument's attribute.
+  auto argDict = llvm::cast<DictionaryAttr>(argAttrs[1]);
+  EXPECT_EQ(argDict.get("test.attr"), StringAttr::get(&context, "arg0"));
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Indirect call with callee + one argument
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{calleeVal, argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context,
+          {DictionaryAttr::get(&context, {}),
+              makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+
+  // Switch to direct
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 1u);
+  EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[0]).get("test.attr"),
+      StringAttr::get(&context, "arg0"));
+}
+
+TEST_F(FIRCallInterfaceTest,
+    setCalleeFromCallable_indirectToIndirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value callee0 = block->getArgument(0);
+  Value callee1 = block->getArgument(1);
+  Value argVal = block->getArgument(2);
+
+  // Indirect call with one argument and arg_attrs
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callee0, argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context,
+          {DictionaryAttr::get(&context, {}),
+              makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+
+  // Switch to other indirect callee
+  callOp.setCalleeFromCallable(callee1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), callee1);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 2u);
+  EXPECT_TRUE(llvm::cast<DictionaryAttr>(argAttrs[0]).empty());
+  EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[1]).get("test.attr"),
+      StringAttr::get(&context, "arg0"));
+}



More information about the flang-commits mailing list