[clang] [CIR] cir.call with scalar return type (PR #135552)

Sirui Mu via cfe-commits cfe-commits at lists.llvm.org
Sun Apr 13 08:40:46 PDT 2025


https://github.com/Lancern created https://github.com/llvm/llvm-project/pull/135552

This PR introduces support for calling functions with a scalar return type to the upstream. This PR also includes an initial version of `CIRGenTargetInfo` and related definitions which are essential for the CIRGen of call ops.

Related to #132487 .

>From fcd100485e1a589be20ddd6b9050cdd5e5281fa6 Mon Sep 17 00:00:00 2001
From: Sirui Mu <msrlancern at gmail.com>
Date: Sun, 13 Apr 2025 23:34:21 +0800
Subject: [PATCH] [CIR] cir.call with scalar return type

This patch introduces support for calling functions with a scalar return type to
the upstream. This patch also includes an initial version of CIRGenTargetInfo
and related definitions which are essential for the CIRGen of call ops.
---
 clang/include/clang/CIR/ABIArgInfo.h          | 89 +++++++++++++++++++
 .../CIR/Dialect/Builder/CIRBaseBuilder.h      |  8 +-
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |  6 +-
 clang/include/clang/CIR/MissingFeatures.h     |  4 +
 clang/lib/CIR/CodeGen/ABIInfo.h               | 32 +++++++
 clang/lib/CIR/CodeGen/CIRGenCall.cpp          | 82 ++++++++++++++---
 clang/lib/CIR/CodeGen/CIRGenCall.h            |  4 +
 clang/lib/CIR/CodeGen/CIRGenExpr.cpp          | 24 +++--
 clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp    |  7 +-
 clang/lib/CIR/CodeGen/CIRGenFunction.h        | 15 +++-
 clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h    | 39 +++++++-
 clang/lib/CIR/CodeGen/CIRGenModule.cpp        | 30 +++++++
 clang/lib/CIR/CodeGen/CIRGenModule.h          |  6 ++
 clang/lib/CIR/CodeGen/CIRGenTypes.cpp         | 19 +++-
 clang/lib/CIR/CodeGen/CIRGenTypes.h           |  9 +-
 clang/lib/CIR/CodeGen/CMakeLists.txt          |  1 +
 clang/lib/CIR/CodeGen/TargetInfo.cpp          | 50 +++++++++++
 clang/lib/CIR/CodeGen/TargetInfo.h            | 41 +++++++++
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       | 29 +++++-
 clang/test/CIR/CodeGen/call.cpp               | 10 +++
 clang/test/CIR/IR/call.cir                    | 14 +++
 21 files changed, 480 insertions(+), 39 deletions(-)
 create mode 100644 clang/include/clang/CIR/ABIArgInfo.h
 create mode 100644 clang/lib/CIR/CodeGen/ABIInfo.h
 create mode 100644 clang/lib/CIR/CodeGen/TargetInfo.cpp
 create mode 100644 clang/lib/CIR/CodeGen/TargetInfo.h

diff --git a/clang/include/clang/CIR/ABIArgInfo.h b/clang/include/clang/CIR/ABIArgInfo.h
new file mode 100644
index 0000000000000..0c2cd85915aa7
--- /dev/null
+++ b/clang/include/clang/CIR/ABIArgInfo.h
@@ -0,0 +1,89 @@
+//==-- ABIArgInfo.h - Abstract info regarding ABI-specific arguments -------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines ABIArgInfo and associated types used by CIR to track information
+// regarding ABI-coerced types for function arguments and return values. This
+// was moved to the common library as it might be used by both CIRGen and
+// passes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_CIR_ABIARGINFO_H
+#define LLVM_CLANG_CIR_ABIARGINFO_H
+
+#include "MissingFeatures.h"
+#include "mlir/IR/Types.h"
+
+namespace cir {
+
+class ABIArgInfo {
+public:
+  enum Kind : uint8_t {
+    /// Pass the argument directly using the normal converted CIR type,
+    /// or by coercing to another specified type stored in 'CoerceToType'). If
+    /// an offset is specified (in UIntData), then the argument passed is offset
+    /// by some number of bytes in the memory representation. A dummy argument
+    /// is emitted before the real argument if the specified type stored in
+    /// "PaddingType" is not zero.
+    Direct,
+
+    /// Ignore the argument (treat as void). Useful for void and empty
+    /// structs.
+    Ignore,
+  };
+
+private:
+  mlir::Type typeData;
+  struct DirectAttrInfo {
+    unsigned offset;
+    unsigned align;
+  };
+  union {
+    DirectAttrInfo directAttr;
+  };
+  Kind theKind;
+
+public:
+  ABIArgInfo(Kind k = Direct) : directAttr{0, 0}, theKind(k) {}
+
+  static ABIArgInfo getDirect(mlir::Type ty = nullptr) {
+    ABIArgInfo info(Direct);
+    info.setCoerceToType(ty);
+    return info;
+  }
+
+  static ABIArgInfo getIgnore() { return ABIArgInfo(Ignore); }
+
+  Kind getKind() const { return theKind; }
+  bool isDirect() const { return theKind == Direct; }
+  bool isIgnore() const { return theKind == Ignore; }
+
+  bool canHaveCoerceToType() const {
+    assert(!cir::MissingFeatures::abiArgInfo());
+    return isDirect();
+  }
+
+  unsigned getDirectOffset() const {
+    assert(!cir::MissingFeatures::abiArgInfo());
+    return directAttr.offset;
+  }
+
+  mlir::Type getCoerceToType() const {
+    assert(canHaveCoerceToType() && "invalid kind!");
+    return typeData;
+  }
+
+  void setCoerceToType(mlir::Type ty) {
+    assert(canHaveCoerceToType() && "invalid kind!");
+    typeData = ty;
+  }
+};
+
+} // namespace cir
+
+#endif // LLVM_CLANG_CIR_ABIARGINFO_H
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 68a4505ca7a5a..a24006810c1f5 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -205,13 +205,15 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
   // Call operators
   //===--------------------------------------------------------------------===//
 
-  cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee) {
-    auto op = create<cir::CallOp>(loc, callee);
+  cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee,
+                           mlir::Type returnType = cir::VoidType()) {
+    auto op = create<cir::CallOp>(loc, callee, /*resType=*/returnType);
     return op;
   }
 
   cir::CallOp createCallOp(mlir::Location loc, cir::FuncOp callee) {
-    return createCallOp(loc, mlir::SymbolRefAttr::get(callee));
+    return createCallOp(loc, mlir::SymbolRefAttr::get(callee),
+                        callee.getFunctionType().getReturnType());
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 0d3c2065cd58c..5ba4b33dc1a12 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1408,10 +1408,14 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
     ```
   }];
 
+  let results = (outs Optional<CIR_AnyType>:$result);
   let arguments = commonArgs;
 
-  let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee), [{
+  let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
+                                 "mlir::Type":$resType), [{
       $_state.addAttribute("callee", callee);
+      if (resType && !isa<VoidType>(resType))
+        $_state.addTypes(resType);
     }]>];
 }
 
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index d6a28d4324b32..f692dc661e9d5 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -103,6 +103,8 @@ struct MissingFeatures {
 
   // Misc
   static bool cxxABI() { return false; }
+  static bool cirgenABIInfo() { return false; }
+  static bool abiArgInfo() { return false; }
   static bool tryEmitAsConstant() { return false; }
   static bool constructABIArgDirectExtend() { return false; }
   static bool opGlobalViewAttr() { return false; }
@@ -121,6 +123,8 @@ struct MissingFeatures {
   static bool fpConstraints() { return false; }
   static bool sanitizers() { return false; }
   static bool addHeapAllocSiteMetadata() { return false; }
+  static bool targetCIRGenInfoArch() { return false; }
+  static bool targetCIRGenInfoOS() { return false; }
   static bool targetCodeGenInfoGetNullPointer() { return false; }
   static bool loopInfoStack() { return false; }
   static bool requiresCleanups() { return false; }
diff --git a/clang/lib/CIR/CodeGen/ABIInfo.h b/clang/lib/CIR/CodeGen/ABIInfo.h
new file mode 100644
index 0000000000000..157e80f67a67c
--- /dev/null
+++ b/clang/lib/CIR/CodeGen/ABIInfo.h
@@ -0,0 +1,32 @@
+//===----- ABIInfo.h - ABI information access & encapsulation ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_CIR_ABIINFO_H
+#define LLVM_CLANG_LIB_CIR_ABIINFO_H
+
+namespace clang::CIRGen {
+
+class CIRGenFunctionInfo;
+class CIRGenTypes;
+
+class ABIInfo {
+  ABIInfo() = delete;
+
+public:
+  CIRGenTypes &cgt;
+
+  ABIInfo(CIRGenTypes &cgt) : cgt(cgt) {}
+
+  virtual ~ABIInfo();
+
+  virtual void computeInfo(CIRGenFunctionInfo &funcInfo) const = 0;
+};
+
+} // namespace clang::CIRGen
+
+#endif // LLVM_CLANG_LIB_CIR_ABIINFO_H
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 1a936458782ea..811750ebfc8b4 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -18,9 +18,12 @@
 using namespace clang;
 using namespace clang::CIRGen;
 
-CIRGenFunctionInfo *CIRGenFunctionInfo::create() {
-  // For now we just create an empty CIRGenFunctionInfo.
-  CIRGenFunctionInfo *fi = new CIRGenFunctionInfo();
+CIRGenFunctionInfo *CIRGenFunctionInfo::create(CanQualType resultType) {
+  void *buffer = operator new(totalSizeToAlloc<ArgInfo>(1));
+
+  CIRGenFunctionInfo *fi = new (buffer) CIRGenFunctionInfo();
+  fi->getArgsBuffer()[0].type = resultType;
+
   return fi;
 }
 
@@ -29,13 +32,29 @@ CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const {
   return *this;
 }
 
-static const CIRGenFunctionInfo &arrangeFreeFunctionLikeCall(CIRGenTypes &cgt) {
+static const CIRGenFunctionInfo &
+arrangeFreeFunctionLikeCall(CIRGenTypes &cgt, CIRGenModule &cgm,
+                            const FunctionType *fnType) {
+  if (const auto *proto = dyn_cast<FunctionProtoType>(fnType)) {
+    if (proto->isVariadic())
+      cgm.errorNYI("call to variadic function");
+    if (proto->hasExtParameterInfos())
+      cgm.errorNYI("call to functions with extra parameter info");
+  } else if (isa<FunctionNoProtoType>(fnType)) {
+    cgm.errorNYI("call to function without a prototype");
+  }
+
   assert(!cir::MissingFeatures::opCallArgs());
-  return cgt.arrangeCIRFunctionInfo();
+
+  CanQualType retType = fnType->getReturnType()
+                            ->getCanonicalTypeUnqualified()
+                            .getUnqualifiedType();
+  return cgt.arrangeCIRFunctionInfo(retType);
 }
 
-const CIRGenFunctionInfo &CIRGenTypes::arrangeFreeFunctionCall() {
-  return arrangeFreeFunctionLikeCall(*this);
+const CIRGenFunctionInfo &
+CIRGenTypes::arrangeFreeFunctionCall(const FunctionType *fnType) {
+  return arrangeFreeFunctionLikeCall(*this, cgm, fnType);
 }
 
 static cir::CIRCallOpInterface emitCallLikeOp(CIRGenFunction &cgf,
@@ -54,8 +73,12 @@ static cir::CIRCallOpInterface emitCallLikeOp(CIRGenFunction &cgf,
 
 RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
                                 const CIRGenCallee &callee,
+                                ReturnValueSlot returnValue,
                                 cir::CIRCallOpInterface *callOp,
                                 mlir::Location loc) {
+  QualType retTy = funcInfo.getReturnType();
+  const cir::ABIArgInfo &retInfo = funcInfo.getReturnInfo();
+
   assert(!cir::MissingFeatures::opCallArgs());
   assert(!cir::MissingFeatures::emitLifetimeMarkers());
 
@@ -87,9 +110,48 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
   assert(!cir::MissingFeatures::opCallMustTail());
   assert(!cir::MissingFeatures::opCallReturn());
 
-  // For now we just return nothing because we don't have support for return
-  // values yet.
-  RValue ret = RValue::get(nullptr);
+  RValue ret;
+  switch (retInfo.getKind()) {
+  case cir::ABIArgInfo::Direct: {
+    mlir::Type retCIRTy = convertType(retTy);
+    if (retInfo.getCoerceToType() == retCIRTy &&
+        retInfo.getDirectOffset() == 0) {
+      switch (getEvaluationKind(retTy)) {
+      case cir::TEK_Scalar: {
+        mlir::ResultRange results = theCall->getOpResults();
+        assert(results.size() == 1 && "unexpected number of returns");
+
+        // If the argument doesn't match, perform a bitcast to coerce it. This
+        // can happen due to trivial type mismatches.
+        if (results[0].getType() != retCIRTy) {
+          cgm.errorNYI(loc, "bitcast on function return value");
+        }
+
+        mlir::Region *region = builder.getBlock()->getParent();
+        if (region != theCall->getParentRegion()) {
+          cgm.errorNYI(loc, "function calls with cleanup");
+        }
+
+        return RValue::get(results[0]);
+      }
+      default:
+        cgm.errorNYI(loc,
+                     "unsupported evaluation kind of function call result");
+      }
+    } else {
+      cgm.errorNYI(loc, "unsupported function call form");
+    }
+
+    break;
+  }
+  case cir::ABIArgInfo::Ignore:
+    // If we are ignoring an argument that had a result, make sure to construct
+    // the appropriate return value for our caller.
+    ret = getUndefRValue(retTy);
+    break;
+  default:
+    cgm.errorNYI(loc, "unsupported return value information");
+  }
 
   return ret;
 }
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h
index 76fefdca9e45e..4427fda863d7e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.h
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.h
@@ -81,6 +81,10 @@ struct CallArg {};
 
 class CallArgList : public llvm::SmallVector<CallArg, 8> {};
 
+/// Contains the address where the return value of a function can be stored, and
+/// whether the address is volatile or not.
+class ReturnValueSlot {};
+
 } // namespace clang::CIRGen
 
 #endif // CLANG_LIB_CODEGEN_CIRGENCALL_H
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index f0732a8ea60af..550231132ab53 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -662,23 +662,36 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &cgm, GlobalDecl gd) {
   return CIRGenCallee::forDirect(callee, gd);
 }
 
+RValue CIRGenFunction::getUndefRValue(QualType ty) {
+  if (ty->isVoidType())
+    return RValue::get(nullptr);
+
+  cgm.errorNYI("unsupported type for undef rvalue");
+  return RValue::get(nullptr);
+}
+
 RValue CIRGenFunction::emitCall(clang::QualType calleeTy,
                                 const CIRGenCallee &callee,
-                                const clang::CallExpr *e) {
+                                const clang::CallExpr *e,
+                                ReturnValueSlot returnValue) {
   // Get the actual function type. The callee type will always be a pointer to
   // function type or a block pointer type.
   assert(calleeTy->isFunctionPointerType() &&
          "Callee must have function pointer type!");
 
   calleeTy = getContext().getCanonicalType(calleeTy);
+  auto pointeeTy = cast<PointerType>(calleeTy)->getPointeeType();
 
   if (getLangOpts().CPlusPlus)
     assert(!cir::MissingFeatures::sanitizers());
 
+  const auto *fnType = cast<FunctionType>(pointeeTy);
+
   assert(!cir::MissingFeatures::sanitizers());
   assert(!cir::MissingFeatures::opCallArgs());
 
-  const CIRGenFunctionInfo &funcInfo = cgm.getTypes().arrangeFreeFunctionCall();
+  const CIRGenFunctionInfo &funcInfo =
+      cgm.getTypes().arrangeFreeFunctionCall(fnType);
 
   assert(!cir::MissingFeatures::opCallNoPrototypeFunc());
   assert(!cir::MissingFeatures::opCallChainCall());
@@ -687,7 +700,7 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy,
 
   cir::CIRCallOpInterface callOp;
   RValue callResult =
-      emitCall(funcInfo, callee, &callOp, getLoc(e->getExprLoc()));
+      emitCall(funcInfo, callee, returnValue, &callOp, getLoc(e->getExprLoc()));
 
   assert(!cir::MissingFeatures::generateDebugInfo());
 
@@ -713,7 +726,8 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
   return {};
 }
 
-RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e) {
+RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,
+                                    ReturnValueSlot returnValue) {
   assert(!cir::MissingFeatures::objCBlocks());
 
   if (isa<CXXMemberCallExpr>(e)) {
@@ -745,7 +759,7 @@ RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e) {
   }
   assert(!cir::MissingFeatures::opCallPseudoDtor());
 
-  return emitCall(e->getCallee()->getType(), callee, e);
+  return emitCall(e->getCallee()->getType(), callee, e, returnValue);
 }
 
 /// Emit code to compute the specified expression, ignoring the result.
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 38104f8533c7d..3dae26dc86f85 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -1519,11 +1519,8 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
 }
 
 mlir::Value ScalarExprEmitter::VisitCallExpr(const CallExpr *e) {
-  if (e->getCallReturnType(cgf.getContext())->isReferenceType()) {
-    cgf.getCIRGenModule().errorNYI(
-        e->getSourceRange(), "call to function with non-void return type");
-    return {};
-  }
+  if (e->getCallReturnType(cgf.getContext())->isReferenceType())
+    return emitLoadOfLValue(e);
 
   auto v = cgf.emitCallExpr(e).getScalarVal();
   assert(!cir::MissingFeatures::emitLValueAlignmentAssumption());
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index a96d277d0bc0b..01abd84ce1c85 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -269,6 +269,12 @@ class CIRGenFunction : public CIRGenTypeCache {
     return LValue::makeAddr(addr, ty, baseInfo);
   }
 
+  /// Get an appropriate 'undef' rvalue for the given type.
+  /// TODO: What's the equivalent for MLIR? Currently we're only using this for
+  /// void types so it just returns RValue::get(nullptr) but it'll need
+  /// addressed later.
+  RValue getUndefRValue(clang::QualType ty);
+
   cir::FuncOp generateCode(clang::GlobalDecl gd, cir::FuncOp fn,
                            cir::FuncType funcType);
 
@@ -451,11 +457,12 @@ class CIRGenFunction : public CIRGenTypeCache {
   mlir::LogicalResult emitBreakStmt(const clang::BreakStmt &s);
 
   RValue emitCall(const CIRGenFunctionInfo &funcInfo,
-                  const CIRGenCallee &callee, cir::CIRCallOpInterface *callOp,
-                  mlir::Location loc);
+                  const CIRGenCallee &callee, ReturnValueSlot returnValue,
+                  cir::CIRCallOpInterface *callOp, mlir::Location loc);
   RValue emitCall(clang::QualType calleeTy, const CIRGenCallee &callee,
-                  const clang::CallExpr *e);
-  RValue emitCallExpr(const clang::CallExpr *e);
+                  const clang::CallExpr *e, ReturnValueSlot returnValue);
+  RValue emitCallExpr(const clang::CallExpr *e,
+                      ReturnValueSlot returnValue = ReturnValueSlot());
   CIRGenCallee emitCallee(const clang::Expr *e);
 
   mlir::LogicalResult emitContinueStmt(const clang::ContinueStmt &s);
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
index da73e7a7a9059..c4a2b238c96ae 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h
@@ -15,18 +15,49 @@
 #ifndef LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
 #define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
 
+#include "clang/AST/CanonicalType.h"
+#include "clang/CIR/ABIArgInfo.h"
 #include "llvm/ADT/FoldingSet.h"
+#include "llvm/Support/TrailingObjects.h"
 
 namespace clang::CIRGen {
 
-class CIRGenFunctionInfo final : public llvm::FoldingSetNode {
+struct CIRGenFunctionInfoArgInfo {
+  CanQualType type;
+  cir::ABIArgInfo info;
+};
+
+class CIRGenFunctionInfo final
+    : public llvm::FoldingSetNode,
+      private llvm::TrailingObjects<CIRGenFunctionInfo,
+                                    CIRGenFunctionInfoArgInfo> {
+  using ArgInfo = CIRGenFunctionInfoArgInfo;
+
+  ArgInfo *getArgsBuffer() { return getTrailingObjects<ArgInfo>(); }
+  const ArgInfo *getArgsBuffer() const { return getTrailingObjects<ArgInfo>(); }
+
 public:
-  static CIRGenFunctionInfo *create();
+  static CIRGenFunctionInfo *create(CanQualType resultType);
+
+  void operator delete(void *p) { ::operator delete(p); }
+
+  // Friending class TrailingObjects is apparantly not good enough for MSVC, so
+  // these have to be public.
+  friend class TrailingObjects;
 
   // This function has to be CamelCase because llvm::FoldingSet requires so.
   // NOLINTNEXTLINE(readability-identifier-naming)
-  static void Profile(llvm::FoldingSetNodeID &id) {
-    // We don't have anything to profile yet.
+  static void Profile(llvm::FoldingSetNodeID &id, CanQualType resultType) {
+    resultType.Profile(id);
+  }
+
+  void Profile(llvm::FoldingSetNodeID &id) { getReturnType().Profile(id); }
+
+  CanQualType getReturnType() const { return getArgsBuffer()[0].type; }
+
+  cir::ABIArgInfo &getReturnInfo() { return getArgsBuffer()[0].info; }
+  const cir::ABIArgInfo &getReturnInfo() const {
+    return getArgsBuffer()[0].info;
   }
 };
 
diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
index fd11523ebba61..cbfaa3d89836b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
@@ -129,6 +129,36 @@ CharUnits CIRGenModule::getNaturalTypeAlignment(QualType t,
   return alignment;
 }
 
+const TargetCIRGenInfo &CIRGenModule::getTargetCIRGenInfo() {
+  if (theTargetCIRGenInfo)
+    return *theTargetCIRGenInfo;
+
+  const llvm::Triple &triple = getTarget().getTriple();
+  switch (triple.getArch()) {
+  default:
+    assert(!cir::MissingFeatures::targetCIRGenInfoArch());
+    errorNYI("unsupported target arch");
+
+    // Currently we just fall through to x86_64.
+    [[fallthrough]];
+
+  case llvm::Triple::x86_64: {
+    switch (triple.getOS()) {
+    default:
+      assert(!cir::MissingFeatures::targetCIRGenInfoOS());
+      errorNYI("unsupported target OS");
+
+      // Currently we just fall through to x86_64.
+      [[fallthrough]];
+
+    case llvm::Triple::Linux:
+      theTargetCIRGenInfo = createX8664TargetCIRGenInfo(genTypes);
+      return *theTargetCIRGenInfo;
+    }
+  }
+  }
+}
+
 mlir::Location CIRGenModule::getLoc(SourceLocation cLoc) {
   assert(cLoc.isValid() && "expected valid source location");
   const SourceManager &sm = astContext.getSourceManager();
diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.h b/clang/lib/CIR/CodeGen/CIRGenModule.h
index 764ad1d7592aa..1e0d6623c4f40 100644
--- a/clang/lib/CIR/CodeGen/CIRGenModule.h
+++ b/clang/lib/CIR/CodeGen/CIRGenModule.h
@@ -21,6 +21,7 @@
 #include "clang/AST/CharUnits.h"
 #include "clang/CIR/Dialect/IR/CIRDialect.h"
 
+#include "TargetInfo.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/MLIRContext.h"
@@ -60,6 +61,8 @@ class CIRGenModule : public CIRGenTypeCache {
   ~CIRGenModule() = default;
 
 private:
+  mutable std::unique_ptr<TargetCIRGenInfo> theTargetCIRGenInfo;
+
   CIRGenBuilderTy builder;
 
   /// Hold Clang AST information.
@@ -86,6 +89,7 @@ class CIRGenModule : public CIRGenTypeCache {
   mlir::ModuleOp getModule() const { return theModule; }
   CIRGenBuilderTy &getBuilder() { return builder; }
   clang::ASTContext &getASTContext() const { return astContext; }
+  const clang::TargetInfo &getTarget() const { return target; }
   const clang::CodeGenOptions &getCodeGenOpts() const { return codeGenOpts; }
   CIRGenTypes &getTypes() { return genTypes; }
   const clang::LangOptions &getLangOpts() const { return langOpts; }
@@ -116,6 +120,8 @@ class CIRGenModule : public CIRGenTypeCache {
   getAddrOfGlobalVar(const VarDecl *d, mlir::Type ty = {},
                      ForDefinition_t isForDefinition = NotForDefinition);
 
+  const TargetCIRGenInfo &getTargetCIRGenInfo();
+
   /// Helpers to convert the presumed location of Clang's SourceLocation to an
   /// MLIR Location.
   mlir::Location getLoc(clang::SourceLocation cLoc);
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
index a5978a4ad9085..ccc5b5fc070d2 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
@@ -14,7 +14,8 @@ using namespace clang::CIRGen;
 
 CIRGenTypes::CIRGenTypes(CIRGenModule &genModule)
     : cgm(genModule), astContext(genModule.getASTContext()),
-      builder(cgm.getBuilder()) {}
+      builder(cgm.getBuilder()),
+      theABIInfo(cgm.getTargetCIRGenInfo().getABIInfo()) {}
 
 CIRGenTypes::~CIRGenTypes() {}
 
@@ -290,10 +291,11 @@ bool CIRGenTypes::isZeroInitializable(clang::QualType t) {
   return true;
 }
 
-const CIRGenFunctionInfo &CIRGenTypes::arrangeCIRFunctionInfo() {
+const CIRGenFunctionInfo &
+CIRGenTypes::arrangeCIRFunctionInfo(CanQualType returnType) {
   // Lookup or create unique function info.
   llvm::FoldingSetNodeID id;
-  CIRGenFunctionInfo::Profile(id);
+  CIRGenFunctionInfo::Profile(id, returnType);
 
   void *insertPos = nullptr;
   CIRGenFunctionInfo *fi = functionInfos.FindNodeOrInsertPos(id, insertPos);
@@ -303,7 +305,7 @@ const CIRGenFunctionInfo &CIRGenTypes::arrangeCIRFunctionInfo() {
   assert(!cir::MissingFeatures::opCallCallConv());
 
   // Construction the function info. We co-allocate the ArgInfos.
-  fi = CIRGenFunctionInfo::create();
+  fi = CIRGenFunctionInfo::create(returnType);
   functionInfos.InsertNode(fi, insertPos);
 
   bool inserted = functionsBeingProcessed.insert(fi).second;
@@ -311,6 +313,15 @@ const CIRGenFunctionInfo &CIRGenTypes::arrangeCIRFunctionInfo() {
   assert(inserted && "Are functions being processed recursively?");
 
   assert(!cir::MissingFeatures::opCallCallConv());
+  getABIInfo().computeInfo(*fi);
+
+  // Loop over all of the computed argument and return value info. If any of
+  // them are direct or extend without a specified coerce type, specify the
+  // default now.
+  cir::ABIArgInfo &retInfo = fi->getReturnInfo();
+  if (retInfo.canHaveCoerceToType() && retInfo.getCoerceToType() == nullptr)
+    retInfo.setCoerceToType(convertType(fi->getReturnType()));
+
   assert(!cir::MissingFeatures::opCallArgs());
 
   bool erased = functionsBeingProcessed.erase(fi);
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.h b/clang/lib/CIR/CodeGen/CIRGenTypes.h
index 60661ba0a3beb..59548f5c1f0b3 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.h
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CLANG_LIB_CODEGEN_CODEGENTYPES_H
 #define LLVM_CLANG_LIB_CODEGEN_CODEGENTYPES_H
 
+#include "ABIInfo.h"
 #include "CIRGenFunctionInfo.h"
 #include "clang/CIR/Dialect/IR/CIRTypes.h"
 
@@ -45,6 +46,8 @@ class CIRGenTypes {
   clang::ASTContext &astContext;
   CIRGenBuilderTy &builder;
 
+  const ABIInfo &theABIInfo;
+
   /// Hold memoized CIRGenFunctionInfo results
   llvm::FoldingSet<CIRGenFunctionInfo> functionInfos;
 
@@ -69,6 +72,8 @@ class CIRGenTypes {
 
   mlir::MLIRContext &getMLIRContext() const;
 
+  const ABIInfo &getABIInfo() const { return theABIInfo; }
+
   /// Convert a Clang type into a mlir::Type.
   mlir::Type convertType(clang::QualType type);
 
@@ -83,9 +88,9 @@ class CIRGenTypes {
   /// LLVM zeroinitializer.
   bool isZeroInitializable(clang::QualType ty);
 
-  const CIRGenFunctionInfo &arrangeFreeFunctionCall();
+  const CIRGenFunctionInfo &arrangeFreeFunctionCall(const FunctionType *fnType);
 
-  const CIRGenFunctionInfo &arrangeCIRFunctionInfo();
+  const CIRGenFunctionInfo &arrangeCIRFunctionInfo(CanQualType returnType);
 };
 
 } // namespace clang::CIRGen
diff --git a/clang/lib/CIR/CodeGen/CMakeLists.txt b/clang/lib/CIR/CodeGen/CMakeLists.txt
index dc18f7f2af160..59834eac3049f 100644
--- a/clang/lib/CIR/CodeGen/CMakeLists.txt
+++ b/clang/lib/CIR/CodeGen/CMakeLists.txt
@@ -21,6 +21,7 @@ add_clang_library(clangCIR
   CIRGenStmt.cpp
   CIRGenStmtOpenACC.cpp
   CIRGenTypes.cpp
+  TargetInfo.cpp
 
   DEPENDS
   MLIRCIR
diff --git a/clang/lib/CIR/CodeGen/TargetInfo.cpp b/clang/lib/CIR/CodeGen/TargetInfo.cpp
new file mode 100644
index 0000000000000..8b89f8dc8b431
--- /dev/null
+++ b/clang/lib/CIR/CodeGen/TargetInfo.cpp
@@ -0,0 +1,50 @@
+#include "TargetInfo.h"
+#include "ABIInfo.h"
+#include "CIRGenFunctionInfo.h"
+#include "clang/CIR/MissingFeatures.h"
+
+using namespace clang;
+using namespace clang::CIRGen;
+
+static bool testIfIsVoidTy(QualType ty) {
+  const auto *builtinTy = ty->getAs<BuiltinType>();
+  return builtinTy && builtinTy->getKind() == BuiltinType::Void;
+}
+
+namespace {
+
+class X8664ABIInfo : public ABIInfo {
+public:
+  X8664ABIInfo(CIRGenTypes &cgt) : ABIInfo(cgt) {}
+
+  void computeInfo(CIRGenFunctionInfo &funcInfo) const override;
+};
+
+class X8664TargetCIRGenInfo : public TargetCIRGenInfo {
+public:
+  X8664TargetCIRGenInfo(CIRGenTypes &cgt)
+      : TargetCIRGenInfo(std::make_unique<X8664ABIInfo>(cgt)) {}
+};
+
+} // namespace
+
+void X8664ABIInfo::computeInfo(CIRGenFunctionInfo &funcInfo) const {
+  // Top level CIR has unlimited arguments and return types. Lowering for ABI
+  // specific concerns should happen during a lowering phase. Assume everything
+  // is direct for now.
+  assert(!cir::MissingFeatures::opCallArgs());
+
+  CanQualType retTy = funcInfo.getReturnType();
+  if (testIfIsVoidTy(retTy))
+    funcInfo.getReturnInfo() = cir::ABIArgInfo::getIgnore();
+  else
+    funcInfo.getReturnInfo() =
+        cir::ABIArgInfo::getDirect(cgt.convertType(retTy));
+}
+
+std::unique_ptr<TargetCIRGenInfo>
+clang::CIRGen::createX8664TargetCIRGenInfo(CIRGenTypes &cgt) {
+  return std::make_unique<X8664TargetCIRGenInfo>(cgt);
+}
+
+ABIInfo::~ABIInfo() noexcept = default;
diff --git a/clang/lib/CIR/CodeGen/TargetInfo.h b/clang/lib/CIR/CodeGen/TargetInfo.h
new file mode 100644
index 0000000000000..70590c3c65ebb
--- /dev/null
+++ b/clang/lib/CIR/CodeGen/TargetInfo.h
@@ -0,0 +1,41 @@
+//===---- TargetInfo.h - Encapsulate target details -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These classes wrap the information about a call or function definition used
+// to handle ABI compliancy.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_CIR_TARGETINFO_H
+#define LLVM_CLANG_LIB_CIR_TARGETINFO_H
+
+#include "ABIInfo.h"
+#include "CIRGenTypes.h"
+
+#include <memory>
+#include <utility>
+
+namespace clang::CIRGen {
+
+class TargetCIRGenInfo {
+  std::unique_ptr<ABIInfo> info;
+
+public:
+  TargetCIRGenInfo(std::unique_ptr<ABIInfo> info) : info(std::move(info)) {}
+
+  virtual ~TargetCIRGenInfo() = default;
+
+  /// Returns ABI info helper for the target.
+  const ABIInfo &getABIInfo() const { return *info; }
+};
+
+std::unique_ptr<TargetCIRGenInfo> createX8664TargetCIRGenInfo(CIRGenTypes &cgt);
+
+} // namespace clang::CIRGen
+
+#endif // LLVM_CLANG_LIB_CIR_TARGETINFO_H
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f3e5e572653da..5e910f92002d0 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -449,6 +449,7 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
                                          mlir::OperationState &result) {
   mlir::FlatSymbolRefAttr calleeAttr;
+  llvm::ArrayRef<mlir::Type> allResultTypes;
 
   if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
            .has_value())
@@ -473,6 +474,9 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.parseType(opsFnTy))
     return mlir::failure();
 
+  allResultTypes = opsFnTy.getResults();
+  result.addTypes(allResultTypes);
+
   return mlir::success();
 }
 
@@ -515,9 +519,32 @@ verifyCallCommInSymbolUses(mlir::Operation *op,
     return op->emitOpError() << "'" << fnAttr.getValue()
                              << "' does not reference a valid function";
 
-  // TODO(cir): verify function arguments and return type
+  auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
+  assert(callIf && "expected CIR call interface to be always available");
+
+  // Verify that the operand and result types match the callee. Note that
+  // argument-checking is disabled for functions without a prototype.
+  auto fnType = fn.getFunctionType();
+
+  // TODO(cir): verify function arguments
   assert(!cir::MissingFeatures::opCallArgs());
 
+  // Void function must not return any results.
+  if (fnType.hasVoidReturn() && op->getNumResults() != 0)
+    return op->emitOpError("callee returns void but call has results");
+
+  // Non-void function calls must return exactly one result.
+  if (!fnType.hasVoidReturn() && op->getNumResults() != 1)
+    return op->emitOpError("incorrect number of results for callee");
+
+  // Parent function and return value types must match.
+  if (!fnType.hasVoidReturn() &&
+      op->getResultTypes().front() != fnType.getReturnType()) {
+    return op->emitOpError("result type mismatch: expected ")
+           << fnType.getReturnType() << ", but provided "
+           << op->getResult(0).getType();
+  }
+
   return mlir::success();
 }
 
diff --git a/clang/test/CIR/CodeGen/call.cpp b/clang/test/CIR/CodeGen/call.cpp
index e69b347c2ca99..9082fbc9f6860 100644
--- a/clang/test/CIR/CodeGen/call.cpp
+++ b/clang/test/CIR/CodeGen/call.cpp
@@ -7,3 +7,13 @@ void f2() {
 
 // CHECK-LABEL: cir.func @f2
 // CHECK:         cir.call @f1() : () -> ()
+
+int f3();
+int f4() {
+  int x = f3();
+  return x;
+}
+
+// CHECK-LABEL: cir.func @f4() -> !s32i
+// CHECK:         %[[#x:]] = cir.call @f3() : () -> !s32i
+// CHECK-NEXT:    cir.store %[[#x]], %{{.+}} : !s32i, !cir.ptr<!s32i>
diff --git a/clang/test/CIR/IR/call.cir b/clang/test/CIR/IR/call.cir
index 8630bb80eb14a..3c3fbf3d4d987 100644
--- a/clang/test/CIR/IR/call.cir
+++ b/clang/test/CIR/IR/call.cir
@@ -1,5 +1,7 @@
 // RUN: cir-opt %s | FileCheck %s
 
+!s32i = !cir.int<s, 32>
+
 module {
 
 cir.func @f1()
@@ -14,4 +16,16 @@ cir.func @f2() {
 // CHECK-NEXT:   cir.return
 // CHECK-NEXT: }
 
+cir.func @f3() -> !s32i
+
+cir.func @f4() -> !s32i {
+  %0 = cir.call @f3() : () -> !s32i
+  cir.return %0 : !s32i
+}
+
+// CHECK:      cir.func @f4() -> !s32i {
+// CHECK-NEXT:   %[[#x:]] = cir.call @f3() : () -> !s32i
+// CHECK-NEXT:   cir.return %[[#x]] : !s32i
+// CHECK-NEXT: }
+
 }



More information about the cfe-commits mailing list