[flang-commits] [flang] f8db5db - [flang] Fix fir.call setCalleeFromCallable (#187124)

via flang-commits flang-commits at lists.llvm.org
Wed Mar 18 10:49:12 PDT 2026


Author: Razvan Lupusoru
Date: 2026-03-18T17:49:06Z
New Revision: f8db5db9586c399d6c134468475632fa7f726b27

URL: https://github.com/llvm/llvm-project/commit/f8db5db9586c399d6c134468475632fa7f726b27
DIFF: https://github.com/llvm/llvm-project/commit/f8db5db9586c399d6c134468475632fa7f726b27.diff

LOG: [flang] Fix fir.call setCalleeFromCallable (#187124)

The CallOpInterface setCalleeFromCallable allows either value or
SymbolRef to be passed in. However, the implementation showed an issue
because while it was able to set attribute, it would fall-through and
also try to set value.

This PR improves the implementation to handle updating the callee even
when switching modes (direct vs indirect) and adds testing for these
APIs.

Added: 
    flang/unittests/Optimizer/FIRCallInterfaceTest.cpp

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/Dialect/FIROps.cpp
    flang/unittests/Optimizer/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 2a849a98903e6..776914bb9bbe8 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2700,13 +2700,7 @@ def fir_CallOp : fir_Op<"call",
     }
 
     /// Set the callee for this operation.
-    void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
-      if (auto calling =
-          (*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
-        (*this)->setAttr(getCalleeAttrName(),
-                         llvm::cast<mlir::SymbolRefAttr>(callee));
-      setOperand(0, llvm::cast<mlir::Value>(callee));
-    }
+    void setCalleeFromCallable(mlir::CallInterfaceCallable callee);
   }];
 }
 

diff  --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 17aa042bda60b..0543a624034ca 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1350,6 +1350,56 @@ void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
   result.addTypes(results);
 }
 
+void fir::CallOp::setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+  if (auto symbolRef = llvm::dyn_cast<mlir::SymbolRefAttr>(callee)) {
+    // Handling a direct call.
+    bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+    (*this)->setAttr(getCalleeAttrName(), symbolRef);
+    // If it was indirect before, the operand list and associated attributes
+    // needs to be fixed up.
+    if (wasIndirect) {
+      assert(getNumOperands() > 0 && "indirect call must have callee operand");
+      (*this)->eraseOperand(0);
+      // Fix arg_attrs to remove the first (callee) operand if needed.
+      if (auto argAttrs = getArgAttrsAttr()) {
+        // Since we already removed the first operand, check that number
+        // of attributes is one more than number of operands.
+        assert(argAttrs.size() == getNumOperands() + 1 &&
+               "arg_attrs must be one-per-operand");
+        llvm::SmallVector<mlir::Attribute> newAttrs(argAttrs.begin() + 1,
+                                                    argAttrs.end());
+        if (newAttrs.empty())
+          (*this)->removeAttr(getArgAttrsAttrName());
+        else
+          (*this)->setAttr(getArgAttrsAttrName(),
+                           mlir::ArrayAttr::get(getContext(), newAttrs));
+      }
+    }
+    return;
+  }
+  // The provided callee makes this an indirect call now.
+  bool wasIndirect = llvm::isa<mlir::Value>(getCallableForCallee());
+  (*this)->removeAttr(getCalleeAttrNameStr());
+  mlir::Value calleeVal = llvm::cast<mlir::Value>(callee);
+  if (wasIndirect) {
+    setOperand(0, calleeVal);
+  } else {
+    (*this)->insertOperands(0, calleeVal);
+    // Make arg_attrs consistent in size with operands by adding an empty dict
+    // for the callee.
+    if (auto argAttrs = getArgAttrsAttr()) {
+      assert(argAttrs.size() == getNumOperands() - 1 &&
+             "arg_attrs must be one-per-operand");
+      llvm::SmallVector<mlir::Attribute> newAttrs;
+      newAttrs.reserve(1 + argAttrs.size());
+      newAttrs.push_back(mlir::DictionaryAttr::get(getContext(), {}));
+      newAttrs.append(argAttrs.begin(), argAttrs.end());
+      (*this)->setAttr(getArgAttrsAttrName(),
+                       mlir::ArrayAttr::get(getContext(), newAttrs));
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // CharConvertOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/unittests/Optimizer/CMakeLists.txt b/flang/unittests/Optimizer/CMakeLists.txt
index c390add457632..8c2fb1c4dc850 100644
--- a/flang/unittests/Optimizer/CMakeLists.txt
+++ b/flang/unittests/Optimizer/CMakeLists.txt
@@ -34,6 +34,7 @@ add_flang_unittest(FlangOptimizerTests
   Builder/Runtime/ReductionTest.cpp
   Builder/Runtime/StopTest.cpp
   Builder/Runtime/TransformationalTest.cpp
+  FIRCallInterfaceTest.cpp
   FIRContextTest.cpp
   FIRTypesTest.cpp
   FortranVariableTest.cpp

diff  --git a/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
new file mode 100644
index 0000000000000..9151976d70875
--- /dev/null
+++ b/flang/unittests/Optimizer/FIRCallInterfaceTest.cpp
@@ -0,0 +1,414 @@
+//===- FIRCallInterfaceTest.cpp - fir::CallOp setCalleeFromCallable tests -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Tests for CallOpInterface on fir::CallOp.
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Support/InitFIR.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/Casting.h"
+
+using namespace mlir;
+
+namespace {
+
+static bool isSymbolRef(mlir::CallInterfaceCallable callable) {
+  return llvm::isa<SymbolRefAttr>(callable);
+}
+static bool isValue(mlir::CallInterfaceCallable callable) {
+  return llvm::isa<Value>(callable);
+}
+
+/// Creates a module and a function with entry block. Builder insertion point is
+/// set to the block start. Returns (func, block) so tests can create calls in
+/// the block and use block arguments as callee/args.
+std::pair<func::FuncOp, Block *> createModuleWithFunction(
+    OpBuilder &builder, Location loc, StringRef name, FunctionType funcType) {
+  ModuleOp module = ModuleOp::create(builder, loc);
+  builder.setInsertionPointToStart(module.getBody());
+  auto func = func::FuncOp::create(builder, loc, name, funcType);
+  func.setPrivate();
+  Block *block = func.addEntryBlock();
+  builder.setInsertionPointToStart(block);
+  return {func, block};
+}
+
+} // namespace
+
+struct FIRCallInterfaceTest : public testing::Test {
+  void SetUp() override { fir::support::loadDialects(context); }
+
+  MLIRContext context;
+};
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToDirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto funcType = builder.getFunctionType({}, {});
+  (void)createModuleWithFunction(builder, loc, "target", funcType);
+
+  // Direct call: fir.call @target()
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(
+      builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+  // Change to another symbol; should remain direct with no extra operand.
+  auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other");
+  callOp.setCalleeFromCallable(newCallTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "other");
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto funcType = builder.getFunctionType({}, {});
+  auto containerType = builder.getFunctionType({funcType}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+
+  // Indirect call: fir.call %arg0()
+  Value callTargetValue = block->getArgument(0);
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callTargetValue});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+
+  // Switch to direct call; operand 0 must be removed.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "direct_target");
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto funcType = builder.getFunctionType({}, {});
+  auto containerType = builder.getFunctionType({funcType}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+
+  // Direct call first
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(
+      builder, loc, callTargetRef, llvm::ArrayRef<mlir::Type>{}, ValueRange{});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 0u);
+
+  // Switch to indirect; attribute must be unset, operand 0 set.
+  Value callTargetValue = block->getArgument(0);
+  callOp.setCalleeFromCallable(callTargetValue);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTargetValue);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto funcType = builder.getFunctionType({}, {});
+  auto containerType = builder.getFunctionType({funcType, funcType}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+
+  Value callTarget0 = block->getArgument(0);
+  Value callTarget1 = block->getArgument(1);
+
+  // Indirect call: fir.call %arg0()
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callTarget0});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTarget0);
+
+  // Switch to other indirect call target; should remain indirect, operand 0
+  // updated.
+  callOp.setCalleeFromCallable(callTarget1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), callTarget1);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect_withArgs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Direct call with one argument: fir.call @target(%arg)
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+
+  // Switch to indirect; callee must be inserted at 0, arg preserved.
+  callOp.setCalleeFromCallable(calleeVal);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), calleeVal);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect_withArgs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Indirect call with one argument: fir.call %callee(%arg)
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{calleeVal, argVal});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), calleeVal);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+
+  // Switch to direct; callee operand must be removed, arg preserved.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+  EXPECT_TRUE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_indirectToIndirect_withArgs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value callee0 = block->getArgument(0);
+  Value callee1 = block->getArgument(1);
+  Value argVal = block->getArgument(2);
+
+  // Indirect call with one argument: fir.call %callee0(%arg)
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callee0, argVal});
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), callee0);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+
+  // Switch to other indirect callee; operand 0 updated, arg preserved.
+  callOp.setCalleeFromCallable(callee1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), callee1);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  EXPECT_FALSE(callOp->getAttr(fir::CallOp::getCalleeAttrNameStr()));
+}
+
+static ArrayAttr makeArgAttrs(
+    MLIRContext *ctx, llvm::ArrayRef<DictionaryAttr> dicts) {
+  llvm::SmallVector<Attribute> attrs(dicts.begin(), dicts.end());
+  return ArrayAttr::get(ctx, attrs);
+}
+
+static DictionaryAttr makeTestArgDict(MLIRContext *ctx, StringRef value) {
+  return DictionaryAttr::get(ctx,
+      {NamedAttribute(
+          StringAttr::get(ctx, "test.attr"), StringAttr::get(ctx, value))});
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_directToDirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value argVal = block->getArgument(1);
+
+  // Direct call with one argument and arg_attrs.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context, {makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+
+  // Switch to another direct callee
+  auto newCallTargetRef = FlatSymbolRefAttr::get(&context, "other_target");
+  callOp.setCalleeFromCallable(newCallTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
+                .getRootReference()
+                .getValue(),
+      "other_target");
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 1u);
+  EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[0]).get("test.attr"),
+      StringAttr::get(&context, "arg0"));
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_directToIndirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Direct call with one argument and arg_attrs for that argument.
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "target");
+  auto callOp = fir::CallOp::create(builder, loc, callTargetRef,
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context, {makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+
+  // Switch to indirect
+  callOp.setCalleeFromCallable(calleeVal);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), calleeVal);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 2u);
+  // First entry is empty dict for callee.
+  EXPECT_TRUE(llvm::cast<DictionaryAttr>(argAttrs[0]).empty());
+  // Second entry preserves the argument's attribute.
+  auto argDict = llvm::cast<DictionaryAttr>(argAttrs[1]);
+  EXPECT_EQ(argDict.get("test.attr"), StringAttr::get(&context, "arg0"));
+}
+
+TEST_F(
+    FIRCallInterfaceTest, setCalleeFromCallable_indirectToDirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value calleeVal = block->getArgument(0);
+  Value argVal = block->getArgument(1);
+
+  // Indirect call with callee + one argument
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{calleeVal, argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context,
+          {DictionaryAttr::get(&context, {}),
+              makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+
+  // Switch to direct
+  auto callTargetRef = FlatSymbolRefAttr::get(&context, "direct_target");
+  callOp.setCalleeFromCallable(callTargetRef);
+
+  EXPECT_TRUE(isSymbolRef(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 1u);
+  EXPECT_EQ(callOp.getOperand(0), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 1u);
+  EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[0]).get("test.attr"),
+      StringAttr::get(&context, "arg0"));
+}
+
+TEST_F(FIRCallInterfaceTest,
+    setCalleeFromCallable_indirectToIndirect_withArgAttrs) {
+  OpBuilder builder(&context);
+  auto loc = builder.getUnknownLoc();
+  auto i32Ty = builder.getI32Type();
+  auto funcType = builder.getFunctionType({i32Ty}, {});
+  auto containerType = builder.getFunctionType({funcType, funcType, i32Ty}, {});
+  auto [func, block] =
+      createModuleWithFunction(builder, loc, "container", containerType);
+  Value callee0 = block->getArgument(0);
+  Value callee1 = block->getArgument(1);
+  Value argVal = block->getArgument(2);
+
+  // Indirect call with one argument and arg_attrs
+  auto callOp = fir::CallOp::create(builder, loc, SymbolRefAttr{},
+      llvm::ArrayRef<mlir::Type>{}, ValueRange{callee0, argVal});
+  callOp->setAttr(callOp.getArgAttrsAttrName(),
+      makeArgAttrs(&context,
+          {DictionaryAttr::get(&context, {}),
+              makeTestArgDict(&context, "arg0")}));
+  ASSERT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+
+  // Switch to other indirect callee
+  callOp.setCalleeFromCallable(callee1);
+
+  EXPECT_TRUE(isValue(callOp.getCallableForCallee()));
+  EXPECT_EQ(callOp.getNumOperands(), 2u);
+  EXPECT_EQ(callOp.getOperand(0), callee1);
+  EXPECT_EQ(callOp.getOperand(1), argVal);
+  ArrayAttr argAttrs = callOp.getArgAttrsAttr();
+  ASSERT_TRUE(argAttrs);
+  ASSERT_EQ(argAttrs.size(), 2u);
+  EXPECT_TRUE(llvm::cast<DictionaryAttr>(argAttrs[0]).empty());
+  EXPECT_EQ(llvm::cast<DictionaryAttr>(argAttrs[1]).get("test.attr"),
+      StringAttr::get(&context, "arg0"));
+}


        


More information about the flang-commits mailing list