[Mlir-commits] [clang] [mlir] [CIR] Add Direct coerce-in-registers + cir.reinterpret_cast op (PR #195879)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 5 09:29:14 PDT 2026
https://github.com/adams381 created https://github.com/llvm/llvm-project/pull/195879
Fourth PR in the split of #192119/#192124. Implements the
Direct-with-coercion path in CallConvLowering and picks off
andykaylor's five inline review comments from the original PR.
The new cir.reinterpret_cast op is for same-bit-width in-register
reinterpretation (vector<2 x float> <-> complex<float>).
emitCoercion uses it when source and destination differ only in
vector-vs-non-vector shape and have identical bit width, instead
of going through memory. For everything else (records, or shape
doesn't match) the helper still does alloca/store/ptr-cast/load.
Andy's comments, in order:
- Temporary alloca alignment is now max(srcAlign, dstAlign) from
DataLayout instead of hardcoded.
- The alloca lives in the entry block via InsertionGuard so it
composes with HoistAllocas regardless of pipeline order.
- isVolatile kept as UnitAttr-absence with an inline comment.
- vector<->complex now uses cir.reinterpret_cast.
- Memory path has three new .cir tests covering it.
CallConvLowering needed splitting into three phases
(function-def coercion / call-site rewriting / Ignore cleanup)
because block-arg type changes from Direct-with-coerce confused
the earlier ordering: Ignore'd args were getting alloca/load
chains synthesized for call-site uses that were about to be
dropped anyway.
LowerToLLVM gets a stub for the new op: bitcast for same-shape
converted types, error-with-message for aggregates. We don't
produce aggregates from CallConvLowering today, so the error
path is only reachable from hand-written IR; follow-up patch can
add an extract/insert lowering if needed.
Co-authored-by: Cursor <cursoragent at cursor.com>
>From f8757c25ba8c8b7370fa42a66500abfb7bc45b6f Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 4 May 2026 12:32:10 -0700
Subject: [PATCH 1/4] [mlir][ABI] Add Test target + classification injection
helper
Adds a dialect-agnostic test ABI target and a DictionaryAttr-based
classification injection helper for testing the MLIR ABIRewriteContext
infrastructure without depending on the in-progress LLVM ABI library
targets (e.g., the x86_64 classifier currently being upstreamed).
- mlir::abi::test::classify(): predictable rules approximating
x86_64 SysV thresholds (Direct / Extend / Indirect / Ignore) for
reviewer familiarity, but explicitly NOT a real ABI target.
Types that don't implement DataLayoutTypeInterface (e.g.
dialect-specific void / unit-style sentinel return types) are
treated as Ignore so the classifier degrades gracefully rather
than crashing on unknown types.
- mlir::abi::test::parseClassificationAttr(): reads a plain
DictionaryAttr from a function and returns a FunctionClassification.
Lets tests inject any classification (including shapes the test
target itself does not produce) so the rewriter can be validated
against arbitrary ABI shapes.
- 20 unit tests covering classifier rules and parser behavior
(well-formed inputs, parse errors, unknown-key rejection,
graceful handling of types without DataLayoutTypeInterface).
This is the foundation for upcoming CIR CallConvLowering pass PRs, which
will use either driver mode (test target or classification injection)
to write .cir tests of ABI rewriting.
The future schema additions (direct_offset, extend_kind tristate,
indirect_addr_space, indirect_realign) tracked in the parser source
comment will be added by subsequent PRs as each ArgClassification field
is needed.
Co-authored-by: Cursor <cursoragent at cursor.com>
---
.../mlir/ABI/Targets/Test/TestTarget.h | 98 ++++++
mlir/lib/ABI/CMakeLists.txt | 1 +
mlir/lib/ABI/Targets/Test/TestTarget.cpp | 251 ++++++++++++++
mlir/unittests/ABI/CMakeLists.txt | 1 +
mlir/unittests/ABI/TestTargetTest.cpp | 317 ++++++++++++++++++
5 files changed, 668 insertions(+)
create mode 100644 mlir/include/mlir/ABI/Targets/Test/TestTarget.h
create mode 100644 mlir/lib/ABI/Targets/Test/TestTarget.cpp
create mode 100644 mlir/unittests/ABI/TestTargetTest.cpp
diff --git a/mlir/include/mlir/ABI/Targets/Test/TestTarget.h b/mlir/include/mlir/ABI/Targets/Test/TestTarget.h
new file mode 100644
index 0000000000000..4404d47f8df45
--- /dev/null
+++ b/mlir/include/mlir/ABI/Targets/Test/TestTarget.h
@@ -0,0 +1,98 @@
+//===- TestTarget.h - Predictable test ABI target --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the test ABI target, a predictable, dialect-agnostic
+// classifier used to exercise the MLIR ABIRewriteContext infrastructure
+// without depending on any real ABI. See TestTarget.cpp for the rules
+// and the rationale.
+//
+// It also declares parseClassificationAttr, the helper used by the
+// classification-injection driver: tests can attach an arbitrary
+// FunctionClassification to a function via a plain mlir::DictionaryAttr,
+// and the rewriter pass reads it back through this parser. This lets
+// tests verify rewriter output against any classification (including
+// shapes the test target itself doesn't produce) without needing a real
+// ABIInfo.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_TARGETS_TEST_TESTTARGET_H
+#define MLIR_ABI_TARGETS_TEST_TESTTARGET_H
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/Support/Error.h"
+
+namespace mlir {
+namespace abi {
+namespace test {
+
+/// Classify a function signature using the test target's predictable rules.
+///
+/// The rules approximate x86_64 SysV thresholds for reviewer familiarity
+/// (see TestTarget.cpp for the full list) but are not a substitute for
+/// testing against a real ABIInfo. Real-ABI-shaped tests should use the
+/// classification-injection driver via `parseClassificationAttr` below.
+///
+/// \param argTypes Argument types of the function.
+/// \param returnType Return type of the function.
+/// \param dl DataLayout used for size and alignment queries.
+FunctionClassification classify(ArrayRef<Type> argTypes, Type returnType,
+ const DataLayout &dl);
+
+/// Parse a `FunctionClassification` from a plain MLIR DictionaryAttr.
+///
+/// Schema (all keys are required unless marked optional):
+///
+/// {
+/// return = { kind = "<kind>", ...per-kind keys... },
+/// args = [ { kind = "<kind>", ...per-kind keys... }, ... ]
+/// }
+///
+/// Per-arg/return dictionary keys:
+/// kind: StringAttr. One of "direct", "extend", "indirect",
+/// "ignore", "expand".
+///
+/// For kind = "direct" (all optional):
+/// coerced_type: TypeAttr. ABI-coerced type, if different from the
+/// original.
+/// can_flatten: BoolAttr. Defaults to true.
+///
+/// For kind = "extend" (coerced_type required, sign_extend optional):
+/// coerced_type: TypeAttr. Required; the extended integer type.
+/// sign_extend: BoolAttr. Defaults to false (zero-extend).
+///
+/// For kind = "indirect" (indirect_align required, byval optional):
+/// indirect_align: IntegerAttr. Required; alignment of the pointed-to
+/// object in bytes.
+/// byval: BoolAttr. Defaults to true.
+///
+/// For kind = "ignore" / "expand": no extra keys.
+///
+/// Future schema additions tracked in projects/daily_log.md (Step 0c
+/// field-mapping table). When we add new fields to ArgClassification
+/// (e.g. direct_offset, extend_kind tristate, indirect_addr_space,
+/// indirect_realign), the corresponding optional keys go here.
+///
+/// Unknown keys cause a parse error (no silent ignore — keeps schema
+/// honest as it grows).
+///
+/// \param attr The dictionary attribute to parse.
+/// \param emitError Diagnostic sink for parse errors.
+/// \returns The parsed classification, or std::nullopt on error.
+std::optional<FunctionClassification>
+parseClassificationAttr(DictionaryAttr attr,
+ function_ref<InFlightDiagnostic()> emitError);
+
+} // namespace test
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_TARGETS_TEST_TESTTARGET_H
diff --git a/mlir/lib/ABI/CMakeLists.txt b/mlir/lib/ABI/CMakeLists.txt
index eb434d25dd390..caa353d26ece3 100644
--- a/mlir/lib/ABI/CMakeLists.txt
+++ b/mlir/lib/ABI/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_library(MLIRABI
ABITypeMapper.cpp
+ Targets/Test/TestTarget.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/ABI
diff --git a/mlir/lib/ABI/Targets/Test/TestTarget.cpp b/mlir/lib/ABI/Targets/Test/TestTarget.cpp
new file mode 100644
index 0000000000000..faee9ecc59882
--- /dev/null
+++ b/mlir/lib/ABI/Targets/Test/TestTarget.cpp
@@ -0,0 +1,251 @@
+//===- TestTarget.cpp - Predictable test ABI target ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// **NOT A REAL ABI TARGET.**
+//
+// This file implements a predictable, dialect-agnostic ABI classifier for
+// testing the MLIR ABIRewriteContext infrastructure. The rules approximate
+// x86_64 SysV thresholds (Direct / Extend / Indirect / Ignore / Expand) so
+// the generated classifications are familiar to reviewers, but they are
+// NOT a substitute for testing against the real x86_64 ABIInfo. Real
+// ABI targets live alongside the LLVM ABI library in `llvm/lib/ABI/Targets/`.
+//
+// Real-ABI-shaped tests use the classification-injection driver via
+// `parseClassificationAttr`, which lets tests construct any
+// FunctionClassification (including shapes the test target itself does
+// not produce) by attaching a DictionaryAttr to the function.
+//
+// Rules:
+// - mlir::NoneType → Ignore
+// - IntegerType with width < 32 → Extend (zero-extend by
+// default; tests using the
+// injection driver can
+// override to signed)
+// - IntegerType with width >= 32 → Direct
+// - FloatType, VectorType, MemRefType → Direct
+// - Anything else with DataLayout size 0 → Ignore
+// - Anything else with DataLayout size <= 16 → Direct (coerced to the
+// same type — no actual
+// coercion in the test
+// target; PR C handles
+// non-trivial coercion)
+// - Anything else with DataLayout size > 16 → Indirect with byval=true
+// (sret on returns) and
+// alignment from
+// DataLayout
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/Targets/Test/TestTarget.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "llvm/Support/Alignment.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+using namespace mlir::abi::test;
+
+namespace {
+
+/// Indirect-vs-direct cutoff in bytes. Chosen to match x86_64 SysV's
+/// 16-byte register-passing window for reviewer familiarity.
+constexpr uint64_t IndirectCutoffBytes = 16;
+
+/// Below this width (in bits) integers get an extension attribute.
+/// Chosen to match x86_64 SysV (32-bit register width) for reviewer
+/// familiarity.
+constexpr unsigned ExtendBelowBits = 32;
+
+ArgClassification classifyOne(Type type, const DataLayout &dl) {
+ if (isa<NoneType>(type))
+ return ArgClassification::getIgnore();
+
+ if (auto intTy = dyn_cast<IntegerType>(type)) {
+ if (intTy.getWidth() < ExtendBelowBits) {
+ Type i32 = IntegerType::get(type.getContext(), ExtendBelowBits);
+ return ArgClassification::getExtend(i32, /*signExt=*/intTy.isSigned());
+ }
+ return ArgClassification::getDirect();
+ }
+
+ if (isa<FloatType, VectorType, MemRefType>(type))
+ return ArgClassification::getDirect();
+
+ // For dialect-specific types: query DataLayout via
+ // DataLayoutTypeInterface. Types that don't implement the interface
+ // (e.g. dialect-specific void / unit-style sentinel types used as a
+ // function's "no return value" marker) are treated as Ignore so that
+ // the test target degrades gracefully rather than crashing on unknown
+ // types.
+ if (!isa<DataLayoutTypeInterface>(type))
+ return ArgClassification::getIgnore();
+
+ llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+ if (sizeInBits.isZero())
+ return ArgClassification::getIgnore();
+
+ uint64_t sizeInBytes = (sizeInBits.getFixedValue() + 7) / 8;
+ if (sizeInBytes <= IndirectCutoffBytes)
+ return ArgClassification::getDirect();
+
+ uint64_t alignBytes = dl.getTypeABIAlignment(type);
+ return ArgClassification::getIndirect(llvm::Align(alignBytes),
+ /*byVal=*/true);
+}
+
+} // namespace
+
+FunctionClassification mlir::abi::test::classify(ArrayRef<Type> argTypes,
+ Type returnType,
+ const DataLayout &dl) {
+ FunctionClassification fc;
+ fc.returnInfo = classifyOne(returnType, dl);
+ fc.argInfos.reserve(argTypes.size());
+ for (Type t : argTypes)
+ fc.argInfos.push_back(classifyOne(t, dl));
+ return fc;
+}
+
+namespace {
+
+/// Set of dictionary keys this parser knows about. Any key not in this
+/// set causes a parse error (no silent ignore). Updated when new
+/// optional keys are added to the schema.
+constexpr StringRef KnownArgKeys[] = {
+ "kind", "coerced_type", "sign_extend",
+ "can_flatten", "indirect_align", "byval",
+};
+
+bool isKnownArgKey(StringRef key) {
+ for (StringRef k : KnownArgKeys)
+ if (k == key)
+ return true;
+ return false;
+}
+
+/// Parse a single ArgClassification dictionary. Returns std::nullopt on
+/// any error (with the diagnostic emitted via \p emitError).
+std::optional<ArgClassification>
+parseOne(DictionaryAttr argDict, function_ref<InFlightDiagnostic()> emitError) {
+ StringAttr kindAttr = argDict.getAs<StringAttr>("kind");
+ if (!kindAttr) {
+ emitError() << "missing required 'kind' StringAttr";
+ return std::nullopt;
+ }
+
+ for (NamedAttribute na : argDict)
+ if (!isKnownArgKey(na.getName().getValue())) {
+ emitError() << "unknown key '" << na.getName().getValue()
+ << "' in classification dictionary; allowed keys are "
+ << "kind, coerced_type, sign_extend, can_flatten, "
+ << "indirect_align, byval";
+ return std::nullopt;
+ }
+
+ StringRef kind = kindAttr.getValue();
+
+ if (kind == "direct") {
+ Type coerced;
+ if (auto t = argDict.getAs<TypeAttr>("coerced_type"))
+ coerced = t.getValue();
+ auto c = ArgClassification::getDirect(coerced);
+ if (auto cf = argDict.getAs<BoolAttr>("can_flatten"))
+ c.canFlatten = cf.getValue();
+ return c;
+ }
+
+ if (kind == "extend") {
+ auto coerced = argDict.getAs<TypeAttr>("coerced_type");
+ if (!coerced) {
+ emitError() << "kind='extend' requires 'coerced_type' TypeAttr";
+ return std::nullopt;
+ }
+ bool signExt = false;
+ if (auto se = argDict.getAs<BoolAttr>("sign_extend"))
+ signExt = se.getValue();
+ return ArgClassification::getExtend(coerced.getValue(), signExt);
+ }
+
+ if (kind == "indirect") {
+ auto align = argDict.getAs<IntegerAttr>("indirect_align");
+ if (!align) {
+ emitError() << "kind='indirect' requires 'indirect_align' IntegerAttr";
+ return std::nullopt;
+ }
+ if (align.getInt() <= 0 || !llvm::isPowerOf2_64(align.getInt())) {
+ emitError() << "'indirect_align' must be a positive power of 2; got "
+ << align.getInt();
+ return std::nullopt;
+ }
+ bool byVal = true;
+ if (auto bv = argDict.getAs<BoolAttr>("byval"))
+ byVal = bv.getValue();
+ return ArgClassification::getIndirect(llvm::Align(align.getInt()), byVal);
+ }
+
+ if (kind == "ignore") {
+ return ArgClassification::getIgnore();
+ }
+
+ if (kind == "expand") {
+ ArgClassification c;
+ c.kind = ArgKind::Expand;
+ return c;
+ }
+
+ emitError() << "unknown kind='" << kind
+ << "'; expected one of direct, extend, indirect, ignore, expand";
+ return std::nullopt;
+}
+
+} // namespace
+
+std::optional<FunctionClassification> mlir::abi::test::parseClassificationAttr(
+ DictionaryAttr attr, function_ref<InFlightDiagnostic()> emitError) {
+ auto returnDict = attr.getAs<DictionaryAttr>("return");
+ if (!returnDict) {
+ emitError() << "missing required 'return' DictionaryAttr";
+ return std::nullopt;
+ }
+
+ auto argsArr = attr.getAs<ArrayAttr>("args");
+ if (!argsArr) {
+ emitError() << "missing required 'args' ArrayAttr";
+ return std::nullopt;
+ }
+
+ for (NamedAttribute na : attr) {
+ StringRef k = na.getName().getValue();
+ if (k != "return" && k != "args") {
+ emitError() << "unknown top-level key '" << k
+ << "'; only 'return' and 'args' are allowed";
+ return std::nullopt;
+ }
+ }
+
+ FunctionClassification fc;
+
+ std::optional<ArgClassification> ret = parseOne(returnDict, emitError);
+ if (!ret)
+ return std::nullopt;
+ fc.returnInfo = *ret;
+
+ fc.argInfos.reserve(argsArr.size());
+ for (Attribute a : argsArr) {
+ auto d = dyn_cast<DictionaryAttr>(a);
+ if (!d) {
+ emitError() << "'args' entries must be DictionaryAttrs";
+ return std::nullopt;
+ }
+ std::optional<ArgClassification> ac = parseOne(d, emitError);
+ if (!ac)
+ return std::nullopt;
+ fc.argInfos.push_back(*ac);
+ }
+
+ return fc;
+}
diff --git a/mlir/unittests/ABI/CMakeLists.txt b/mlir/unittests/ABI/CMakeLists.txt
index 39f955a8efea6..1113ed9516f9c 100644
--- a/mlir/unittests/ABI/CMakeLists.txt
+++ b/mlir/unittests/ABI/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_unittest(MLIRABITests
ABIRewriteContextTest.cpp
ABITypeMapperTest.cpp
+ TestTargetTest.cpp
)
mlir_target_link_libraries(MLIRABITests
diff --git a/mlir/unittests/ABI/TestTargetTest.cpp b/mlir/unittests/ABI/TestTargetTest.cpp
new file mode 100644
index 0000000000000..64d5bbf0866bc
--- /dev/null
+++ b/mlir/unittests/ABI/TestTargetTest.cpp
@@ -0,0 +1,317 @@
+//===- TestTargetTest.cpp - Unit tests for the test ABI target -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/Targets/Test/TestTarget.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+using namespace mlir::abi::test;
+
+namespace {
+
+class TestTargetClassifyTest : public ::testing::Test {
+protected:
+ TestTargetClassifyTest()
+ : module(ModuleOp::create(UnknownLoc::get(&context))), dl(module) {
+ context.loadDialect<DLTIDialect>();
+ }
+
+ MLIRContext context;
+ ModuleOp module;
+ DataLayout dl;
+};
+
+TEST_F(TestTargetClassifyTest, IgnoresNoneType) {
+ auto noneTy = NoneType::get(&context);
+ FunctionClassification fc = classify({}, noneTy, dl);
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Ignore);
+}
+
+TEST_F(TestTargetClassifyTest, ExtendsNarrowSignedInteger) {
+ auto i8 = IntegerType::get(&context, 8, IntegerType::Signed);
+ FunctionClassification fc = classify({i8}, NoneType::get(&context), dl);
+ ASSERT_EQ(fc.argInfos.size(), 1u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Extend);
+ EXPECT_TRUE(fc.argInfos[0].signExtend);
+ auto coerced = dyn_cast<IntegerType>(fc.argInfos[0].coercedType);
+ ASSERT_TRUE(coerced);
+ EXPECT_EQ(coerced.getWidth(), 32u);
+}
+
+TEST_F(TestTargetClassifyTest, ExtendsNarrowSignlessIntegerAsZeroExt) {
+ auto i8 = IntegerType::get(&context, 8);
+ FunctionClassification fc = classify({i8}, NoneType::get(&context), dl);
+ ASSERT_EQ(fc.argInfos.size(), 1u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Extend);
+ EXPECT_FALSE(fc.argInfos[0].signExtend);
+}
+
+TEST_F(TestTargetClassifyTest, RegisterSizedIntegerIsDirect) {
+ auto i32 = IntegerType::get(&context, 32);
+ auto i64 = IntegerType::get(&context, 64);
+ FunctionClassification fc = classify({i32, i64}, NoneType::get(&context), dl);
+ ASSERT_EQ(fc.argInfos.size(), 2u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct);
+ EXPECT_EQ(fc.argInfos[1].kind, ArgKind::Direct);
+}
+
+TEST_F(TestTargetClassifyTest, FloatIsDirect) {
+ auto f32 = Float32Type::get(&context);
+ FunctionClassification fc = classify({f32}, f32, dl);
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Direct);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct);
+}
+
+TEST_F(TestTargetClassifyTest, FunctionLevelReturnAndArgsClassifiedTogether) {
+ auto i32 = IntegerType::get(&context, 32);
+ auto f64 = Float64Type::get(&context);
+ FunctionClassification fc = classify({i32, f64}, i32, dl);
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Direct);
+ ASSERT_EQ(fc.argInfos.size(), 2u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct);
+ EXPECT_EQ(fc.argInfos[1].kind, ArgKind::Direct);
+}
+
+TEST_F(TestTargetClassifyTest,
+ TypeWithoutDataLayoutInterfaceClassifiedAsIgnore) {
+ // FunctionType does not implement DataLayoutTypeInterface. The classifier
+ // must treat it as Ignore rather than crashing in dl.getTypeSizeInBits().
+ // This guards against the same crash for dialect-specific void / sentinel
+ // types (e.g. cir::VoidType) used as a function's "no return value" marker.
+ auto i32 = IntegerType::get(&context, 32);
+ auto fnTy = FunctionType::get(&context, {i32}, {i32});
+ FunctionClassification fc = classify({}, fnTy, dl);
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Ignore);
+}
+
+class TestTargetParseTest : public ::testing::Test {
+protected:
+ TestTargetParseTest() : builder(&context) {
+ // Suppress diagnostic printing during tests; capture into lastError
+ // for assertions instead.
+ context.getDiagEngine().registerHandler([this](Diagnostic &diag) {
+ lastError = diag.str();
+ return success();
+ });
+ }
+
+ /// Convenience: parse and assert success, returning the result.
+ FunctionClassification parseOk(DictionaryAttr attr) {
+ auto loc = UnknownLoc::get(&context);
+ auto result =
+ parseClassificationAttr(attr, [&]() { return mlir::emitError(loc); });
+ EXPECT_TRUE(result.has_value())
+ << "parseClassificationAttr failed: " << lastError;
+ return result.value_or(FunctionClassification{});
+ }
+
+ /// Convenience: parse and assert failure with a substring match.
+ void parseError(DictionaryAttr attr, StringRef expectedSubstr) {
+ auto loc = UnknownLoc::get(&context);
+ lastError.clear();
+ auto result =
+ parseClassificationAttr(attr, [&]() { return mlir::emitError(loc); });
+ EXPECT_FALSE(result.has_value());
+ EXPECT_NE(lastError.find(expectedSubstr.str()), std::string::npos)
+ << "expected error containing '" << expectedSubstr << "' but got '"
+ << lastError << "'";
+ }
+
+ DictionaryAttr makeArg(ArrayRef<NamedAttribute> entries) {
+ return DictionaryAttr::get(&context, entries);
+ }
+
+ MLIRContext context;
+ OpBuilder builder;
+ std::string lastError;
+};
+
+TEST_F(TestTargetParseTest, ParsesDirectReturnAndOneDirectArg) {
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({direct})),
+ });
+
+ auto fc = parseOk(attr);
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Direct);
+ ASSERT_EQ(fc.argInfos.size(), 1u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct);
+}
+
+TEST_F(TestTargetParseTest, ParsesExtendWithCoercedTypeAndSignExtend) {
+ auto i32 = IntegerType::get(&context, 32);
+ auto extend = makeArg({
+ builder.getNamedAttr("kind", builder.getStringAttr("extend")),
+ builder.getNamedAttr("coerced_type", TypeAttr::get(i32)),
+ builder.getNamedAttr("sign_extend", builder.getBoolAttr(true)),
+ });
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({extend})),
+ });
+
+ auto fc = parseOk(attr);
+ ASSERT_EQ(fc.argInfos.size(), 1u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Extend);
+ EXPECT_TRUE(fc.argInfos[0].signExtend);
+ EXPECT_EQ(fc.argInfos[0].coercedType, i32);
+}
+
+TEST_F(TestTargetParseTest, ParsesIndirectWithAlignAndByval) {
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto indirect = makeArg({
+ builder.getNamedAttr("kind", builder.getStringAttr("indirect")),
+ builder.getNamedAttr("indirect_align", builder.getI64IntegerAttr(16)),
+ builder.getNamedAttr("byval", builder.getBoolAttr(false)),
+ });
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({indirect})),
+ });
+
+ auto fc = parseOk(attr);
+ ASSERT_EQ(fc.argInfos.size(), 1u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Indirect);
+ EXPECT_EQ(fc.argInfos[0].indirectAlign, llvm::Align(16));
+ EXPECT_FALSE(fc.argInfos[0].byVal);
+}
+
+TEST_F(TestTargetParseTest, ParsesIgnoreAndExpand) {
+ auto ignore =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("ignore"))});
+ auto expand =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("expand"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", ignore),
+ builder.getNamedAttr("args", builder.getArrayAttr({expand, ignore})),
+ });
+
+ auto fc = parseOk(attr);
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Ignore);
+ ASSERT_EQ(fc.argInfos.size(), 2u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Expand);
+ EXPECT_EQ(fc.argInfos[1].kind, ArgKind::Ignore);
+}
+
+TEST_F(TestTargetParseTest, RejectsMissingReturn) {
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("args", builder.getArrayAttr({direct})),
+ });
+ parseError(attr, "missing required 'return'");
+}
+
+TEST_F(TestTargetParseTest, RejectsMissingArgs) {
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ });
+ parseError(attr, "missing required 'args'");
+}
+
+TEST_F(TestTargetParseTest, RejectsUnknownTopLevelKey) {
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({})),
+ builder.getNamedAttr("future_field", builder.getStringAttr("hello")),
+ });
+ parseError(attr, "unknown top-level key 'future_field'");
+}
+
+TEST_F(TestTargetParseTest, RejectsUnknownArgKey) {
+ auto badArg = makeArg({
+ builder.getNamedAttr("kind", builder.getStringAttr("direct")),
+ builder.getNamedAttr("future_field", builder.getBoolAttr(true)),
+ });
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({badArg})),
+ });
+ parseError(attr, "unknown key 'future_field'");
+}
+
+TEST_F(TestTargetParseTest, RejectsExtendWithoutCoercedType) {
+ auto badExtend =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("extend"))});
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({badExtend})),
+ });
+ parseError(attr, "kind='extend' requires 'coerced_type'");
+}
+
+TEST_F(TestTargetParseTest, RejectsIndirectWithoutAlign) {
+ auto badIndirect = makeArg(
+ {builder.getNamedAttr("kind", builder.getStringAttr("indirect"))});
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({badIndirect})),
+ });
+ parseError(attr, "kind='indirect' requires 'indirect_align'");
+}
+
+TEST_F(TestTargetParseTest, RejectsIndirectWithNonPowerOfTwoAlign) {
+ auto badIndirect = makeArg({
+ builder.getNamedAttr("kind", builder.getStringAttr("indirect")),
+ builder.getNamedAttr("indirect_align", builder.getI64IntegerAttr(7)),
+ });
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({badIndirect})),
+ });
+ parseError(attr, "must be a positive power of 2");
+}
+
+TEST_F(TestTargetParseTest, RejectsUnknownKind) {
+ auto bad = makeArg(
+ {builder.getNamedAttr("kind", builder.getStringAttr("invalid_kind"))});
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({bad})),
+ });
+ parseError(attr, "unknown kind='invalid_kind'");
+}
+
+TEST_F(TestTargetParseTest, RejectsMissingKind) {
+ auto bad = makeArg({});
+ auto direct =
+ makeArg({builder.getNamedAttr("kind", builder.getStringAttr("direct"))});
+ auto attr = builder.getDictionaryAttr({
+ builder.getNamedAttr("return", direct),
+ builder.getNamedAttr("args", builder.getArrayAttr({bad})),
+ });
+ parseError(attr, "missing required 'kind'");
+}
+
+} // namespace
>From 7f34f379aa0871474bf092efca5f2815caa1ad1b Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 4 May 2026 13:34:39 -0700
Subject: [PATCH 2/4] [CIR] Add CallConvLowering pass + Direct/Ignore ABI
rewriting
Implements the cir-call-conv-lowering pass scaffolding and the
Direct (pass-through) + Ignore handlers for cir.func signatures
and cir.call sites. This is the second PR in the series splitting
#192119 / #192124 per andykaylor's review feedback.
Two driver modes select how each function's classification is
computed:
- target=test selects the MLIR test ABI target
(mlir/lib/ABI/Targets/Test/) added in PR A1.
- classification-attr=<name> reads a DictionaryAttr from each
cir.func and parses it via mlir::abi::test::parseClassification
Attr (also from PR A1).
Exactly one of the two options must be set. The pass requires a
dlti.dl_spec attribute on the module and emits a clear diagnostic
otherwise.
Subsequent PRs in the series add the remaining ArgKind handlers:
PR B: Extend (signext / zeroext)
PR C: Direct with coercion (in-register coercion + cir.reinterpret_cast op)
PR D: Indirect return (sret)
PR E: Indirect byval argument
PR F: Expand (struct flattening)
CIRABIRewriteContext currently dispatches on argClass.kind via a
switch with explicit "not yet implemented" diagnostics for the
unhandled kinds, so subsequent PRs are purely additive.
Updates clang/docs/CIR/ABILowering.rst with a "CIR Pass Pipeline
Position" section describing where cir-call-conv-lowering sits
in the pipeline, the DataLayout requirement, the alloca-placement
invariant, and the two driver modes.
6 new .cir tests using both driver modes for Direct passthrough,
Ignore arg, Ignore return, declaration rewriting, and DataLayout-
missing error reporting. check-clang-cir / check-clang-cir-codegen
both pass with no regressions.
Co-authored-by: Cursor <cursoragent at cursor.com>
---
clang/docs/CIR/ABILowering.rst | 40 +++
clang/include/clang/CIR/Dialect/Passes.h | 1 +
clang/include/clang/CIR/Dialect/Passes.td | 52 +++-
.../lib/CIR/Dialect/Transforms/CMakeLists.txt | 3 +
.../Transforms/CallConvLoweringPass.cpp | 175 ++++++++++++
.../TargetLowering/CIRABIRewriteContext.cpp | 257 ++++++++++++++++++
.../TargetLowering/CIRABIRewriteContext.h | 56 ++++
.../Transforms/TargetLowering/CMakeLists.txt | 2 +
.../abi-lowering/Inputs/test-datalayout.cir | 17 ++
.../abi-lowering/datalayout-missing-error.cir | 17 ++
.../abi-lowering/declaration-rewrite.cir | 34 +++
.../direct-passthrough-injection.cir | 42 +++
.../direct-passthrough-test-target.cir | 35 +++
.../Transforms/abi-lowering/ignore-arg.cir | 39 +++
.../Transforms/abi-lowering/ignore-return.cir | 48 ++++
clang/tools/cir-opt/cir-opt.cpp | 4 +
16 files changed, 816 insertions(+), 6 deletions(-)
create mode 100644 clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
create mode 100644 clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
create mode 100644 clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
create mode 100644 clang/test/CIR/Transforms/abi-lowering/Inputs/test-datalayout.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/datalayout-missing-error.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/declaration-rewrite.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/direct-passthrough-injection.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/direct-passthrough-test-target.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/ignore-arg.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/ignore-return.cir
diff --git a/clang/docs/CIR/ABILowering.rst b/clang/docs/CIR/ABILowering.rst
index 59f0bb4f9a646..74457be5c8e6d 100644
--- a/clang/docs/CIR/ABILowering.rst
+++ b/clang/docs/CIR/ABILowering.rst
@@ -531,6 +531,46 @@ options or configuration. The dependency direction is: the MLIR ABI pass
depends on ``llvm/lib/ABI``; there is no reverse dependency from the ABI library
to MLIR dialects.
+CIR Pass Pipeline Position
+--------------------------
+
+For the CIR dialect, the calling-convention lowering pass is named
+``cir-call-conv-lowering`` and runs late in the CIR-to-LLVM pipeline:
+
+1. ``cir-target-lowering`` legalizes target-specific operations (e.g.
+ atomic synchronization scopes).
+2. ``cir-cxxabi-lowering`` lowers C++-specific high-level types (member
+ pointers, vtable lookups, etc.) to ABI-specific representations.
+3. ``cir-call-conv-lowering`` rewrites function signatures and call sites
+ to match the target ABI's calling convention rules.
+
+``cir-call-conv-lowering`` requires a ``dlti.dl_spec`` attribute on the
+module so it can query type sizes and alignments through MLIR's
+``DataLayout``. When the attribute is missing, the pass emits a
+diagnostic and fails rather than silently using a default layout.
+
+The pass places any temporary allocas it needs (for argument coercion,
+sret slots, etc.) directly in the function entry block, so it does not
+rely on a subsequent ``cir-hoist-allocas`` run to position them
+correctly. This invariant means ``cir-hoist-allocas`` may run either
+before or after ``cir-call-conv-lowering`` without changing observable
+behavior.
+
+The pass takes one of two driver modes via pass options:
+
+- ``target=<name>`` selects a real ABI target. The first supported value
+ is ``test`` (the MLIR test target in ``mlir/lib/ABI/Targets/Test/``,
+ used for testing the rewriter without depending on the in-progress
+ LLVM ABI library targets). Real targets (``x86_64``, ``aarch64``,
+ ...) will be added as the LLVM ABI library ships them.
+- ``classification-attr=<name>`` reads a pre-built ``FunctionClassifica
+ tion`` from a ``DictionaryAttr`` named ``<name>`` on each ``cir.func``
+ and rewrites accordingly. This driver is for tests that need to
+ exercise rewriter behavior against arbitrary classifications without
+ routing through any real classifier.
+
+Exactly one of the two options must be set.
+
Open Questions
==============
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index d441dfcbc6c14..a68f7b621f5d8 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -27,6 +27,7 @@ std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createCIREHABILoweringPass();
std::unique_ptr<Pass> createCXXABILoweringPass();
std::unique_ptr<Pass> createTargetLoweringPass();
+std::unique_ptr<Pass> createCallConvLoweringPass();
std::unique_ptr<Pass> createHoistAllocasPass();
std::unique_ptr<Pass> createLoweringPreparePass();
std::unique_ptr<Pass> createLoweringPreparePass(clang::ASTContext *astCtx);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 32cd182aacec7..cb3c78d590a42 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -108,7 +108,7 @@ def TargetLowering : Pass<"cir-target-lowering", "mlir::ModuleOp"> {
1. The `TargetLowering` pass.
2. The `CXXABILowering` pass.
- 3. The `CallConvLowering` pass (not implemented yet).
+ 3. The `CallConvLowering` pass.
The `TargetLowering` pass acts more like a legalization pass. It ensures
every operation in CIR conforms to the target's constraints. An example
@@ -117,11 +117,11 @@ def TargetLowering : Pass<"cir-target-lowering", "mlir::ModuleOp"> {
any atomic operations with a different synchronization scope would be
transformed to use the system-wide scope in this pass.
- The `CXXABILowering` pass and the (not yet implemented) `CallConvLowering`
- pass transform the CIR according to the target's ABI requirements. The
- former handles all ABI-related lowering except for calling convention
- handling, which is handled specifically in the latter. Example
- transformations that the `CXXABILowering` pass could make include:
+ The `CXXABILowering` pass and the `CallConvLowering` pass transform the
+ CIR according to the target's ABI requirements. The former handles all
+ ABI-related lowering except for calling convention handling, which is
+ handled specifically in the latter. Example transformations that the
+ `CXXABILowering` pass could make include:
- Replace C/C++ types that have an ABI-defined layout with more
fundamental types corresponding to the ABI requirements. For example,
@@ -195,4 +195,44 @@ def IdiomRecognizer : Pass<"cir-idiom-recognizer", "mlir::ModuleOp"> {
let dependentDialects = ["cir::CIRDialect"];
}
+def CallConvLowering : Pass<"cir-call-conv-lowering", "mlir::ModuleOp"> {
+ let summary = "Lower CIR function signatures and call sites to match target ABI";
+ let description = [{
+ This pass rewrites `cir.func` signatures and `cir.call` sites to match the
+ target ABI's calling convention requirements (extension attributes, struct
+ coercion, sret/byval indirect passing, struct flattening, etc.).
+
+ The pass requires a `dlti.dl_spec` attribute on the module so it can query
+ the data layout for sizes and alignments.
+
+ Two driver modes select how each function's classification is computed:
+
+ - `target=<name>` selects an ABI target. Currently only `"test"` (the
+ MLIR test target in `mlir/lib/ABI/Targets/Test/`) is supported. Real
+ targets (x86_64, AArch64, ...) will be added once the LLVM ABI library
+ ships them.
+ - `classification-attr=<name>` reads a `DictionaryAttr` named `<name>`
+ from each `cir.func` and parses it via the test-target injection
+ helper. Used by tests to inject arbitrary classifications without
+ depending on a real ABI target.
+
+ Exactly one of the two options must be set.
+
+ Pipeline position: this pass runs after `CXXABILowering` (so member
+ pointers and similar C++ types are already lowered) and after
+ `HoistAllocas` is unnecessary because this pass places its own temporary
+ allocas in the entry block.
+ }];
+ let constructor = "mlir::createCallConvLoweringPass()";
+ let dependentDialects = ["cir::CIRDialect"];
+ let options = [
+ Option<"target", "target", "std::string", /*default=*/"\"\"",
+ "Target whose ABI rules drive classification (currently: test)">,
+ Option<"classificationAttr", "classification-attr", "std::string",
+ /*default=*/"\"\"",
+ "Function attribute name carrying a pre-built FunctionClassification "
+ "DictionaryAttr (alternative to target=, used by tests)">,
+ ];
+}
+
#endif // CLANG_CIR_DIALECT_PASSES_TD
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 092ccfac7ddb7..b3fa7375eaafa 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_subdirectory(TargetLowering)
add_clang_library(MLIRCIRTransforms
+ CallConvLoweringPass.cpp
CIRCanonicalize.cpp
CIRSimplify.cpp
CXXABILowering.cpp
@@ -19,7 +20,9 @@ add_clang_library(MLIRCIRTransforms
clangAST
clangBasic
+ MLIRABI
MLIRAnalysis
+ MLIRDLTIDialect
MLIRIR
MLIRPass
MLIRTransformUtils
diff --git a/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp b/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
new file mode 100644
index 0000000000000..e50aeca1791e9
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
@@ -0,0 +1,175 @@
+//===- CallConvLoweringPass.cpp - Lower CIR to ABI calling convention ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass walks every cir.func and cir.call in the module, computes a
+// FunctionClassification for it (via either an ABI target or a pre-built
+// classification injected as a function attribute), and dispatches to
+// CIRABIRewriteContext to perform the actual IR rewriting.
+//
+// Two driver modes (mutually exclusive):
+//
+// target=test
+// Use the MLIR test ABI target (mlir/lib/ABI/Targets/Test/) to classify
+// each function. Predictable rules that approximate x86_64 SysV. Real
+// targets (x86_64, AArch64) will be added once the LLVM ABI library
+// ships them.
+//
+// classification-attr=<name>
+// Read a DictionaryAttr named <name> from each cir.func and parse it via
+// mlir::abi::test::parseClassificationAttr. Used by tests to inject any
+// classification (including shapes the test target itself does not
+// produce) without depending on a real ABI target.
+//
+// The pass requires a `dlti.dl_spec` attribute on the module so the
+// classifier can query type sizes and alignments.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "TargetLowering/CIRABIRewriteContext.h"
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/ABI/Targets/Test/TestTarget.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+using namespace cir;
+
+namespace mlir {
+#define GEN_PASS_DEF_CALLCONVLOWERING
+#include "clang/CIR/Dialect/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+
+struct CallConvLoweringPass
+ : public impl::CallConvLoweringBase<CallConvLoweringPass> {
+ using CallConvLoweringBase::CallConvLoweringBase;
+ void runOnOperation() override;
+};
+
+/// Classify \p func using whichever driver mode is configured. Returns
+/// std::nullopt and emits an error on the function if classification fails
+/// (e.g. injection-driver mode but the function is missing the attribute,
+/// or the attribute is malformed).
+std::optional<FunctionClassification>
+classifyFunction(cir::FuncOp func, const DataLayout &dl, StringRef target,
+ StringRef classificationAttrName) {
+ ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
+ Type returnType = func.getFunctionType().getReturnType();
+
+ if (!classificationAttrName.empty()) {
+ auto attr = func->getAttrOfType<DictionaryAttr>(classificationAttrName);
+ if (!attr) {
+ func.emitOpError()
+ << "missing classification attribute '" << classificationAttrName
+ << "' (CallConvLowering driver mode 'classification-attr')";
+ return std::nullopt;
+ }
+ return mlir::abi::test::parseClassificationAttr(
+ attr, [&]() { return func.emitOpError(); });
+ }
+
+ if (target == "test")
+ return mlir::abi::test::classify(argTypes, returnType, dl);
+
+ func.emitOpError() << "unknown target '" << target << "' (supported: test)";
+ return std::nullopt;
+}
+
+/// Find the cir.func declaration matching a cir.call's callee, if any.
+/// Returns nullptr if the callee is indirect or the symbol cannot be
+/// resolved (in which case the call is left alone).
+cir::FuncOp lookupCallee(cir::CallOp call, ModuleOp module) {
+ FlatSymbolRefAttr callee = call.getCalleeAttr();
+ if (!callee)
+ return nullptr;
+ return module.lookupSymbol<cir::FuncOp>(callee.getValue());
+}
+
+void CallConvLoweringPass::runOnOperation() {
+ ModuleOp module = getOperation();
+ MLIRContext *ctx = &getContext();
+
+ if (target.empty() == classificationAttr.empty()) {
+ module.emitOpError() << "CallConvLowering requires exactly one of "
+ "'target' or 'classification-attr' pass options";
+ signalPassFailure();
+ return;
+ }
+
+ if (!module->hasAttr(DLTIDialect::kDataLayoutAttrName)) {
+ module.emitOpError()
+ << "CallConvLowering requires a DataLayout (dlti.dl_spec attribute "
+ "on the module)";
+ signalPassFailure();
+ return;
+ }
+
+ DataLayout dl(module);
+ CIRABIRewriteContext rewriteCtx(module);
+
+ // Pre-compute classifications for every cir.func so that call-site
+ // rewriting can find them (call site uses callee's classification).
+ llvm::MapVector<cir::FuncOp, FunctionClassification> classifications;
+ bool anyFailed = false;
+ module.walk([&](cir::FuncOp f) {
+ auto fc = classifyFunction(f, dl, target, classificationAttr);
+ if (!fc) {
+ anyFailed = true;
+ return;
+ }
+ classifications.insert({f, std::move(*fc)});
+ });
+ if (anyFailed) {
+ signalPassFailure();
+ return;
+ }
+
+ OpBuilder rewriter(ctx);
+
+ // Rewrite call sites first, while functions still have their original
+ // signatures. This avoids any chance of us reading a partially-rewritten
+ // signature and matching args against the wrong classification.
+ SmallVector<cir::CallOp> calls;
+ module.walk([&](cir::CallOp c) { calls.push_back(c); });
+ for (cir::CallOp call : calls) {
+ cir::FuncOp callee = lookupCallee(call, module);
+ if (!callee)
+ continue;
+ auto it = classifications.find(callee);
+ if (it == classifications.end())
+ continue;
+ if (failed(rewriteCtx.rewriteCallSite(call, it->second, rewriter))) {
+ signalPassFailure();
+ return;
+ }
+ }
+
+ // Now rewrite each function definition.
+ for (auto &kv : classifications) {
+ if (failed(rewriteCtx.rewriteFunctionDefinition(kv.first, kv.second,
+ rewriter))) {
+ signalPassFailure();
+ return;
+ }
+ }
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCallConvLoweringPass() {
+ return std::make_unique<CallConvLoweringPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
new file mode 100644
index 0000000000000..43d0b7aeca386
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -0,0 +1,257 @@
+//===- CIRABIRewriteContext.cpp - CIR ABI rewrite context ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CIRABIRewriteContext.h"
+#include "mlir/IR/Builders.h"
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+
+using namespace cir;
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+bool needsRewrite(const FunctionClassification &fc) {
+ if (fc.returnInfo.kind != ArgKind::Direct || fc.returnInfo.coercedType)
+ return true;
+ for (const ArgClassification &ac : fc.argInfos)
+ if (ac.kind != ArgKind::Direct || ac.coercedType)
+ return true;
+ return false;
+}
+
+SmallVector<unsigned> ignoredArgIndices(const FunctionClassification &fc) {
+ SmallVector<unsigned> v;
+ for (auto [idx, ac] : llvm::enumerate(fc.argInfos))
+ if (ac.kind == ArgKind::Ignore)
+ v.push_back(idx);
+ return v;
+}
+
+LogicalResult buildNewArgTypes(ArrayRef<Type> oldArgTypes,
+ const FunctionClassification &fc,
+ SmallVectorImpl<Type> &newArgTypes,
+ function_ref<InFlightDiagnostic()> emitError) {
+ newArgTypes.reserve(oldArgTypes.size());
+ for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
+ Type origTy = oldArgTypes[idx];
+ switch (ac.kind) {
+ case ArgKind::Direct:
+ if (ac.coercedType) {
+ emitError() << "Direct with coerced type at arg " << idx
+ << " not yet implemented in CallConvLowering";
+ return failure();
+ }
+ newArgTypes.push_back(origTy);
+ break;
+ case ArgKind::Ignore:
+ break;
+ case ArgKind::Expand:
+ newArgTypes.push_back(origTy);
+ break;
+ case ArgKind::Extend:
+ emitError() << "Extend at arg " << idx
+ << " not yet implemented in CallConvLowering";
+ return failure();
+ case ArgKind::Indirect:
+ emitError() << "Indirect at arg " << idx
+ << " not yet implemented in CallConvLowering";
+ return failure();
+ }
+ }
+ return success();
+}
+
+Type computeNewReturnType(Type origRetTy, const ArgClassification &retInfo,
+ MLIRContext *ctx,
+ function_ref<InFlightDiagnostic()> emitError) {
+ switch (retInfo.kind) {
+ case ArgKind::Direct:
+ if (retInfo.coercedType) {
+ emitError() << "Direct return with coerced type not yet implemented "
+ << "in CallConvLowering";
+ return nullptr;
+ }
+ return origRetTy;
+ case ArgKind::Ignore:
+ return cir::VoidType::get(ctx);
+ case ArgKind::Expand:
+ return origRetTy;
+ case ArgKind::Extend:
+ emitError() << "Extend return not yet implemented in CallConvLowering";
+ return nullptr;
+ case ArgKind::Indirect:
+ emitError() << "Indirect return (sret) not yet implemented in "
+ << "CallConvLowering";
+ return nullptr;
+ }
+ llvm_unreachable("all ArgKind cases handled");
+}
+
+} // namespace
+
+LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
+ FunctionOpInterface funcOp, const FunctionClassification &fc,
+ OpBuilder &rewriter) {
+ if (!needsRewrite(fc))
+ return success();
+
+ ArrayRef<Type> oldArgTypes = funcOp.getArgumentTypes();
+ ArrayRef<Type> oldResultTypes = funcOp.getResultTypes();
+ MLIRContext *ctx = funcOp->getContext();
+
+ SmallVector<Type> newArgTypes;
+ if (failed(buildNewArgTypes(oldArgTypes, fc, newArgTypes,
+ [&]() { return funcOp.emitOpError(); })))
+ return failure();
+
+ Type voidTy = cir::VoidType::get(ctx);
+ Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
+ Type newRetTy = computeNewReturnType(origRetTy, fc.returnInfo, ctx,
+ [&]() { return funcOp.emitOpError(); });
+ if (!newRetTy)
+ return failure();
+ SmallVector<Type> newResultTypes = {newRetTy};
+
+ if (!funcOp.isDeclaration()) {
+ Region &body = funcOp->getRegion(0);
+ if (!body.empty()) {
+ Block &entry = body.front();
+
+ SmallVector<unsigned> ignored = ignoredArgIndices(fc);
+ for (int i = static_cast<int>(ignored.size()) - 1; i >= 0; --i) {
+ unsigned blockIdx = ignored[i];
+ if (blockIdx >= entry.getNumArguments())
+ continue;
+ BlockArgument arg = entry.getArgument(blockIdx);
+ if (!arg.use_empty()) {
+ rewriter.setInsertionPointToStart(&entry);
+ auto ptrTy = cir::PointerType::get(arg.getType());
+ auto alloca = cir::AllocaOp::create(
+ rewriter, funcOp.getLoc(), ptrTy, arg.getType(),
+ rewriter.getStringAttr("ignored"), rewriter.getI64IntegerAttr(1));
+ auto load = cir::LoadOp::create(
+ rewriter, funcOp.getLoc(), arg.getType(), alloca, UnitAttr(),
+ UnitAttr(), IntegerAttr(), cir::SyncScopeKindAttr(),
+ cir::MemOrderAttr());
+ arg.replaceAllUsesWith(load);
+ }
+ entry.eraseArgument(blockIdx);
+ }
+ }
+
+ if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) {
+ SmallVector<cir::ReturnOp> returns;
+ funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
+ for (cir::ReturnOp r : returns) {
+ if (r.getNumOperands() == 0)
+ continue;
+ rewriter.setInsertionPoint(r);
+ cir::ReturnOp::create(rewriter, r.getLoc());
+ r.erase();
+ }
+ }
+ }
+
+ Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
+ funcOp.setFunctionTypeAttr(TypeAttr::get(newFnTy));
+
+ SmallVector<unsigned> ignored = ignoredArgIndices(fc);
+ if (!ignored.empty())
+ if (auto existing = funcOp->getAttrOfType<ArrayAttr>("arg_attrs")) {
+ SmallVector<Attribute> kept;
+ kept.reserve(newArgTypes.size());
+ for (auto [oldIdx, attr] : llvm::enumerate(existing.getValue()))
+ if (oldIdx >= fc.argInfos.size() ||
+ fc.argInfos[oldIdx].kind != ArgKind::Ignore)
+ kept.push_back(attr);
+ funcOp->setAttr("arg_attrs", ArrayAttr::get(ctx, kept));
+ }
+
+ return success();
+}
+
+LogicalResult CIRABIRewriteContext::rewriteCallSite(
+ Operation *callOp, const FunctionClassification &fc, OpBuilder &rewriter) {
+ if (!needsRewrite(fc))
+ return success();
+
+ auto call = cast<cir::CallOp>(callOp);
+
+ for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
+ switch (ac.kind) {
+ case ArgKind::Direct:
+ if (ac.coercedType)
+ return call.emitOpError()
+ << "Direct with coerced type at call-site arg " << idx
+ << " not yet implemented in CallConvLowering";
+ break;
+ case ArgKind::Ignore:
+ case ArgKind::Expand:
+ break;
+ case ArgKind::Extend:
+ return call.emitOpError() << "Extend at call-site arg " << idx
+ << " not yet implemented in CallConvLowering";
+ case ArgKind::Indirect:
+ return call.emitOpError() << "Indirect at call-site arg " << idx
+ << " not yet implemented in CallConvLowering";
+ }
+ }
+
+ SmallVector<Value> newArgs;
+ ValueRange argOperands = call.getArgOperands();
+ newArgs.reserve(argOperands.size());
+ for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
+ if (idx >= argOperands.size())
+ break;
+ if (ac.kind == ArgKind::Ignore)
+ continue;
+ newArgs.push_back(argOperands[idx]);
+ }
+ for (unsigned i = fc.argInfos.size(); i < argOperands.size(); ++i)
+ newArgs.push_back(argOperands[i]);
+
+ bool hasResult = call.getNumResults() > 0;
+ Type origRetTy = hasResult ? call.getResult().getType()
+ : cir::VoidType::get(callOp->getContext());
+ Type callRetTy = origRetTy;
+ if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
+ callRetTy = cir::VoidType::get(callOp->getContext());
+ if ((fc.returnInfo.kind == ArgKind::Direct ||
+ fc.returnInfo.kind == ArgKind::Extend) &&
+ fc.returnInfo.coercedType)
+ return call.emitOpError() << "Direct/Extend return with coerced type at "
+ << "call-site not yet implemented in "
+ << "CallConvLowering";
+
+ rewriter.setInsertionPoint(call);
+ auto newCall = cir::CallOp::create(rewriter, call.getLoc(),
+ call.getCalleeAttr(), callRetTy, newArgs);
+ for (NamedAttribute attr : call->getAttrs())
+ if (!newCall->hasAttr(attr.getName()))
+ newCall->setAttr(attr.getName(), attr.getValue());
+
+ if (hasResult && fc.returnInfo.kind == ArgKind::Ignore) {
+ if (!call.getResult().use_empty()) {
+ rewriter.setInsertionPointAfter(newCall);
+ auto ptrTy = cir::PointerType::get(origRetTy);
+ auto alloca = cir::AllocaOp::create(
+ rewriter, call.getLoc(), ptrTy, origRetTy,
+ rewriter.getStringAttr("ignored"), rewriter.getI64IntegerAttr(1));
+ auto load = cir::LoadOp::create(
+ rewriter, call.getLoc(), origRetTy, alloca, UnitAttr(), UnitAttr(),
+ IntegerAttr(), cir::SyncScopeKindAttr(), cir::MemOrderAttr());
+ call.getResult().replaceAllUsesWith(load);
+ }
+ } else if (hasResult) {
+ call.getResult().replaceAllUsesWith(newCall.getResult());
+ }
+
+ call->erase();
+ return success();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
new file mode 100644
index 0000000000000..cf8635e9afdd6
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
@@ -0,0 +1,56 @@
+//===- CIRABIRewriteContext.h - CIR ABI rewrite context ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines CIRABIRewriteContext, the CIR dialect's implementation of the
+// generic mlir::abi::ABIRewriteContext. Given a FunctionClassification it
+// rewrites a cir.func signature, the function body, and call sites to match
+// the ABI-lowered shape.
+//
+// This file currently handles only Direct (pass-through) and Ignore. Other
+// ArgKind handlers (Extend, Direct-with-coercion, Indirect, Expand) are
+// added by subsequent PRs in the calling-convention-lowering split series.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
+#define CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+
+namespace cir {
+
+/// CIR-specific implementation of mlir::abi::ABIRewriteContext.
+///
+/// The driver pass (CallConvLoweringPass) computes a FunctionClassification
+/// for each cir.func / cir.call and dispatches to this class to perform the
+/// actual IR rewriting using cir dialect operations.
+class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext {
+public:
+ explicit CIRABIRewriteContext(mlir::ModuleOp module) : module(module) {}
+
+ mlir::LogicalResult
+ rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp,
+ const mlir::abi::FunctionClassification &fc,
+ mlir::OpBuilder &rewriter) override;
+
+ mlir::LogicalResult
+ rewriteCallSite(mlir::Operation *callOp,
+ const mlir::abi::FunctionClassification &fc,
+ mlir::OpBuilder &rewriter) override;
+
+ mlir::StringRef getDialectNamespace() const override { return "cir"; }
+
+private:
+ mlir::ModuleOp module;
+};
+
+} // namespace cir
+
+#endif // CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
index 86502b7f5dd4e..9833952623708 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
@@ -1,4 +1,5 @@
add_clang_library(MLIRCIRTargetLowering
+ CIRABIRewriteContext.cpp
CIRCXXABI.cpp
LowerModule.cpp
LowerItaniumCXXABI.cpp
@@ -15,6 +16,7 @@ add_clang_library(MLIRCIRTargetLowering
LINK_LIBS PUBLIC
clangBasic
+ MLIRABI
MLIRIR
MLIRPass
MLIRDLTIDialect
diff --git a/clang/test/CIR/Transforms/abi-lowering/Inputs/test-datalayout.cir b/clang/test/CIR/Transforms/abi-lowering/Inputs/test-datalayout.cir
new file mode 100644
index 0000000000000..17db09584e873
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/Inputs/test-datalayout.cir
@@ -0,0 +1,17 @@
+// Shared fixture: minimal x86_64 dlti.dl_spec used by every abi-lowering
+// test below. Test files include this verbatim inside their `module
+// attributes { ... }` declaration so that CallConvLowering's DataLayout
+// queries return sensible sizes and alignments.
+//
+// Copy this attribute into a test like so:
+//
+// module attributes {
+// dlti.dl_spec = #dlti.dl_spec<
+// #dlti.dl_entry<i1, dense<8>: vector<2xi64>>,
+// #dlti.dl_entry<i8, dense<8>: vector<2xi64>>,
+// #dlti.dl_entry<i16, dense<16>: vector<2xi64>>,
+// #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+// #dlti.dl_entry<i64, dense<64>: vector<2xi64>>,
+// #dlti.dl_entry<f32, dense<32>: vector<2xi64>>,
+// #dlti.dl_entry<f64, dense<64>: vector<2xi64>>>
+// } { ... }
diff --git a/clang/test/CIR/Transforms/abi-lowering/datalayout-missing-error.cir b/clang/test/CIR/Transforms/abi-lowering/datalayout-missing-error.cir
new file mode 100644
index 0000000000000..644f9ec714e49
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/datalayout-missing-error.cir
@@ -0,0 +1,17 @@
+// RUN: not cir-opt %s -cir-call-conv-lowering=target=test 2>&1 | FileCheck %s
+
+// The pass requires a `dlti.dl_spec` attribute on the module. Without it,
+// classification cannot query type sizes / alignments, so the pass emits a
+// diagnostic and fails.
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+ cir.func @passthrough(%arg0: !s32i) -> !s32i {
+ cir.return %arg0 : !s32i
+ }
+
+}
+
+// CHECK: error: 'builtin.module' op CallConvLowering requires a DataLayout (dlti.dl_spec attribute on the module)
diff --git a/clang/test/CIR/Transforms/abi-lowering/declaration-rewrite.cir b/clang/test/CIR/Transforms/abi-lowering/declaration-rewrite.cir
new file mode 100644
index 0000000000000..2242ed65d74bc
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/declaration-rewrite.cir
@@ -0,0 +1,34 @@
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+// Function declarations (no body) get their signature rewritten just like
+// definitions: arg list is updated, return type is updated, but no body
+// adaptation runs.
+
+!s32i = !cir.int<s, 32>
+
+#ignore_first_arg = {
+ return = { kind = "direct" },
+ args = [ { kind = "ignore" }, { kind = "direct" } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func private @ext_decl(!s32i, !s32i) -> !s32i
+ attributes { test_classify = #ignore_first_arg }
+
+ // The first argument type is dropped from the declaration's signature.
+ // CHECK: cir.func{{.*}} @ext_decl(!s32i) -> !s32i
+
+ cir.func @caller(%arg0: !s32i, %arg1: !s32i) -> !s32i
+ attributes { test_classify = #ignore_first_arg } {
+ %0 = cir.call @ext_decl(%arg0, %arg1) : (!s32i, !s32i) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // CHECK: cir.func{{.*}} @caller(%arg0: !s32i) -> !s32i
+ // CHECK: %[[R:.*]] = cir.call @ext_decl(%arg0) : (!s32i) -> !s32i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/direct-passthrough-injection.cir b/clang/test/CIR/Transforms/abi-lowering/direct-passthrough-injection.cir
new file mode 100644
index 0000000000000..859398eff7d25
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/direct-passthrough-injection.cir
@@ -0,0 +1,42 @@
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+// Same passthrough behavior as direct-passthrough-test-target.cir, but the
+// classification is injected via a function attribute rather than computed
+// by the test target. This is the driver mode that lets tests verify
+// rewriter behavior against arbitrary classifications without depending on
+// any real ABI target.
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+
+#all_direct_two_args = {
+ return = { kind = "direct" },
+ args = [ { kind = "direct" }, { kind = "direct" } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+ #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+ cir.func @passthrough(%arg0: !s32i, %arg1: !s64i) -> !s32i
+ attributes { test_classify = #all_direct_two_args } {
+ cir.return %arg0 : !s32i
+ }
+
+ // CHECK: cir.func{{.*}} @passthrough(%arg0: !s32i, %arg1: !s64i) -> !s32i
+ // CHECK: cir.return %arg0 : !s32i
+
+ cir.func @caller(%arg0: !s32i, %arg1: !s64i) -> !s32i
+ attributes { test_classify = #all_direct_two_args } {
+ %0 = cir.call @passthrough(%arg0, %arg1) : (!s32i, !s64i) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // CHECK: cir.func{{.*}} @caller(%arg0: !s32i, %arg1: !s64i) -> !s32i
+ // CHECK: %[[RES:.*]] = cir.call @passthrough(%arg0, %arg1) : (!s32i, !s64i) -> !s32i
+ // CHECK: cir.return %[[RES]] : !s32i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/direct-passthrough-test-target.cir b/clang/test/CIR/Transforms/abi-lowering/direct-passthrough-test-target.cir
new file mode 100644
index 0000000000000..7b363cd84f9fd
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/direct-passthrough-test-target.cir
@@ -0,0 +1,35 @@
+// RUN: cir-opt %s -cir-call-conv-lowering=target=test | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i1, dense<8>: vector<2xi64>>,
+ #dlti.dl_entry<i8, dense<8>: vector<2xi64>>,
+ #dlti.dl_entry<i16, dense<16>: vector<2xi64>>,
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+ #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+ // Register-sized integer args/returns are Direct under the test target.
+ // Direct (without coercion) is a true pass-through — function signature
+ // and call sites are unchanged.
+
+ cir.func @passthrough(%arg0: !s32i, %arg1: !s64i) -> !s32i {
+ cir.return %arg0 : !s32i
+ }
+
+ // CHECK: cir.func{{.*}} @passthrough(%arg0: !s32i, %arg1: !s64i) -> !s32i
+ // CHECK-NEXT: cir.return %arg0 : !s32i
+
+ cir.func @caller(%arg0: !s32i, %arg1: !s64i) -> !s32i {
+ %0 = cir.call @passthrough(%arg0, %arg1) : (!s32i, !s64i) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // CHECK: cir.func{{.*}} @caller(%arg0: !s32i, %arg1: !s64i) -> !s32i
+ // CHECK-NEXT: %[[RES:.*]] = cir.call @passthrough(%arg0, %arg1) : (!s32i, !s64i) -> !s32i
+ // CHECK-NEXT: cir.return %[[RES]] : !s32i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/ignore-arg.cir b/clang/test/CIR/Transforms/abi-lowering/ignore-arg.cir
new file mode 100644
index 0000000000000..548cf2940ac9f
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/ignore-arg.cir
@@ -0,0 +1,39 @@
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+// Ignore'd argument: the rewriter drops it from both the function signature
+// and call sites. Body uses of the arg get rewritten to load from a fresh
+// alloca so the IR stays well-typed.
+
+!s32i = !cir.int<s, 32>
+
+#ignore_first_arg = {
+ return = { kind = "direct" },
+ args = [ { kind = "ignore" }, { kind = "direct" } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func @callee(%arg0: !s32i, %arg1: !s32i) -> !s32i
+ attributes { test_classify = #ignore_first_arg } {
+ cir.return %arg1 : !s32i
+ }
+
+ // The first arg is dropped; second arg is preserved.
+ // CHECK: cir.func{{.*}} @callee(%arg0: !s32i) -> !s32i
+ // CHECK-NEXT: cir.return %arg0 : !s32i
+
+ cir.func @caller(%arg0: !s32i, %arg1: !s32i) -> !s32i
+ attributes { test_classify = #ignore_first_arg } {
+ %0 = cir.call @callee(%arg0, %arg1) : (!s32i, !s32i) -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // The call site drops the first arg too.
+ // CHECK: cir.func{{.*}} @caller(%arg0: !s32i) -> !s32i
+ // CHECK-NEXT: %[[R:.*]] = cir.call @callee(%arg0) : (!s32i) -> !s32i
+ // CHECK-NEXT: cir.return %[[R]] : !s32i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/ignore-return.cir b/clang/test/CIR/Transforms/abi-lowering/ignore-return.cir
new file mode 100644
index 0000000000000..5f04e3ecbc6fd
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/ignore-return.cir
@@ -0,0 +1,48 @@
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+// Ignore'd return: the rewriter changes the function's return type to void,
+// drops the operand from every cir.return, and at call sites issues a void
+// call. Any remaining uses of the original (non-void) call result get
+// rewritten to load from a fresh alloca so the IR stays well-typed.
+
+!s32i = !cir.int<s, 32>
+
+#ignore_return = {
+ return = { kind = "ignore" },
+ args = [ ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func @callee() -> !s32i
+ attributes { test_classify = #ignore_return } {
+ %0 = cir.const #cir.int<42> : !s32i
+ cir.return %0 : !s32i
+ }
+
+ // Return type becomes void (dropped from the func type), and the cir.return
+ // becomes operand-free.
+ // CHECK: cir.func{{.*}} @callee()
+ // CHECK-NOT: -> !s32i
+ // CHECK: cir.return
+ // CHECK-NOT: : !s32i
+
+ cir.func @caller_uses_result() -> !s32i
+ attributes { test_classify = #ignore_return } {
+ %0 = cir.call @callee() : () -> !s32i
+ cir.return %0 : !s32i
+ }
+
+ // Caller's own return is also classified Ignore, so its return type is
+ // also void. The use of @callee's result becomes a load from a freshly
+ // allocated slot (the returned-but-ignored value is never observable).
+ // CHECK: cir.func{{.*}} @caller_uses_result()
+ // CHECK: cir.call @callee() : () -> ()
+ // CHECK: cir.alloca !s32i
+ // CHECK: cir.load
+ // CHECK: cir.return
+
+}
diff --git a/clang/tools/cir-opt/cir-opt.cpp b/clang/tools/cir-opt/cir-opt.cpp
index 05e3b9ec7e964..b6d2438812070 100644
--- a/clang/tools/cir-opt/cir-opt.cpp
+++ b/clang/tools/cir-opt/cir-opt.cpp
@@ -74,6 +74,10 @@ int main(int argc, char **argv) {
return mlir::createCXXABILoweringPass();
});
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
+ return mlir::createCallConvLoweringPass();
+ });
+
mlir::omp::registerOpenMPPasses();
mlir::registerTransformsPasses();
>From e3484746b4ba649d62bff209b4e61b4e540396b6 Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 4 May 2026 14:09:42 -0700
Subject: [PATCH 3/4] [CIR] Add Extend (signext/zeroext) handling to
CallConvLowering
Third PR in the series splitting #192119 / #192124. Adds handlers
for the Extend ArgKind in both rewriteFunctionDefinition and
rewriteCallSite.
The CIR signature keeps the original (narrow) integer type;
sign- or zero-extension is communicated to LLVM via the
llvm.signext / llvm.zeroext arg_attrs and res_attrs entries
attached by the rewriter. This matches Classic Clang's LLVM IR
convention (e.g. `define void @f(i8 signext %x)`, not `define
void @f(i32 signext %x)` with an entry-block truncation).
The coercedType field on an Extend ArgClassification is treated
as informational only -- it carries the LLVM ABI library's
register-width type for downstream consumers, but the rewriter
does not use it to change the CIR signature.
Three .cir tests covering the three Extend code paths
(narrow-signed-arg, narrow-unsigned-arg, narrow-signed-return).
The test target's narrow-int Extend rule fires only on MLIR
builtin IntegerType -- since CIR functions take cir::IntType,
these tests exercise the rewriter via the classification-
injection driver added in PR A1.
check-clang-cir-codegen / check-clang-cir both pass with no
regressions.
Co-authored-by: Cursor <cursoragent at cursor.com>
---
.../TargetLowering/CIRABIRewriteContext.cpp | 149 +++++++++++++++---
.../Transforms/abi-lowering/extend-return.cir | 41 +++++
.../abi-lowering/extend-signed-arg.cir | 40 +++++
.../abi-lowering/extend-unsigned-arg.cir | 36 +++++
4 files changed, 240 insertions(+), 26 deletions(-)
create mode 100644 clang/test/CIR/Transforms/abi-lowering/extend-return.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/extend-signed-arg.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/extend-unsigned-arg.cir
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
index 43d0b7aeca386..db0a9bcfcbb39 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -17,6 +17,10 @@ using namespace mlir::abi;
namespace {
bool needsRewrite(const FunctionClassification &fc) {
+ // Direct without coercion is a true pass-through; any other kind (or a
+ // coerced Direct) means the rewriter must touch the IR. Extend is
+ // technically attribute-only at the IR level but still counts because the
+ // attribute attachment changes observable behavior.
if (fc.returnInfo.kind != ArgKind::Direct || fc.returnInfo.coercedType)
return true;
for (const ArgClassification &ac : fc.argInfos)
@@ -55,9 +59,14 @@ LogicalResult buildNewArgTypes(ArrayRef<Type> oldArgTypes,
newArgTypes.push_back(origTy);
break;
case ArgKind::Extend:
- emitError() << "Extend at arg " << idx
- << " not yet implemented in CallConvLowering";
- return failure();
+ // Extend keeps the original (narrow) type in the signature; the
+ // sign/zero extension is communicated to LLVM via the llvm.signext /
+ // llvm.zeroext arg attribute, attached separately below. Any
+ // coercedType the classifier set on the Extend ArgClassification is
+ // informational (typically the register-width type the value gets
+ // extended to in registers) but does not change the CIR signature.
+ newArgTypes.push_back(origTy);
+ break;
case ArgKind::Indirect:
emitError() << "Indirect at arg " << idx
<< " not yet implemented in CallConvLowering";
@@ -83,8 +92,10 @@ Type computeNewReturnType(Type origRetTy, const ArgClassification &retInfo,
case ArgKind::Expand:
return origRetTy;
case ArgKind::Extend:
- emitError() << "Extend return not yet implemented in CallConvLowering";
- return nullptr;
+ // Same convention as Extend args: keep the original return type in the
+ // signature; the sign/zero extension is communicated via the
+ // llvm.signext / llvm.zeroext res attribute attached separately below.
+ return origRetTy;
case ArgKind::Indirect:
emitError() << "Indirect return (sret) not yet implemented in "
<< "CallConvLowering";
@@ -93,6 +104,67 @@ Type computeNewReturnType(Type origRetTy, const ArgClassification &retInfo,
llvm_unreachable("all ArgKind cases handled");
}
+/// Build an updated arg_attrs ArrayAttr that drops Ignore'd args and adds
+/// llvm.signext / llvm.zeroext on Extend args. Preserves any existing arg
+/// attributes on retained arg slots.
+ArrayAttr updateArgAttrs(MLIRContext *ctx, unsigned numNewArgs,
+ ArrayAttr existingArgAttrs,
+ const FunctionClassification &fc) {
+ SmallVector<Attribute> newArgAttrs(numNewArgs, DictionaryAttr::get(ctx));
+
+ // Step 1: copy existing arg attrs over to their new positions, skipping
+ // Ignore'd args.
+ if (existingArgAttrs) {
+ unsigned newIdx = 0;
+ for (unsigned oldIdx = 0; oldIdx < existingArgAttrs.size(); ++oldIdx) {
+ if (oldIdx < fc.argInfos.size() &&
+ fc.argInfos[oldIdx].kind == ArgKind::Ignore)
+ continue;
+ if (newIdx < numNewArgs)
+ newArgAttrs[newIdx] = existingArgAttrs[oldIdx];
+ ++newIdx;
+ }
+ }
+
+ // Step 2: layer llvm.signext / llvm.zeroext onto each Extend arg. The new
+ // arg index for Extend at original index `oldIdx` is `oldIdx` minus the
+ // number of Ignore'd args that came before it.
+ unsigned newIdx = 0;
+ for (auto [oldIdx, ac] : llvm::enumerate(fc.argInfos)) {
+ if (ac.kind == ArgKind::Ignore)
+ continue;
+ if (ac.kind == ArgKind::Extend && newIdx < numNewArgs) {
+ auto existing = cast<DictionaryAttr>(newArgAttrs[newIdx]);
+ SmallVector<NamedAttribute> attrs(existing.begin(), existing.end());
+ StringRef attrName = ac.signExtend ? "llvm.signext" : "llvm.zeroext";
+ attrs.push_back(
+ NamedAttribute(StringAttr::get(ctx, attrName), UnitAttr::get(ctx)));
+ newArgAttrs[newIdx] = DictionaryAttr::get(ctx, attrs);
+ }
+ ++newIdx;
+ }
+
+ return ArrayAttr::get(ctx, newArgAttrs);
+}
+
+/// Build an updated res_attrs ArrayAttr (single entry, since CIR funcs have
+/// at most one result) that adds llvm.signext / llvm.zeroext on an Extend
+/// return. Preserves any existing res attributes.
+ArrayAttr updateResAttrs(MLIRContext *ctx, ArrayAttr existingResAttrs,
+ const ArgClassification &retInfo) {
+ if (retInfo.kind != ArgKind::Extend)
+ return existingResAttrs;
+
+ SmallVector<NamedAttribute> attrs;
+ if (existingResAttrs && existingResAttrs.size() > 0)
+ for (NamedAttribute na : cast<DictionaryAttr>(existingResAttrs[0]))
+ attrs.push_back(na);
+ StringRef attrName = retInfo.signExtend ? "llvm.signext" : "llvm.zeroext";
+ attrs.push_back(
+ NamedAttribute(StringAttr::get(ctx, attrName), UnitAttr::get(ctx)));
+ return ArrayAttr::get(ctx, {DictionaryAttr::get(ctx, attrs)});
+}
+
} // namespace
LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
@@ -161,17 +233,24 @@ LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
funcOp.setFunctionTypeAttr(TypeAttr::get(newFnTy));
- SmallVector<unsigned> ignored = ignoredArgIndices(fc);
- if (!ignored.empty())
- if (auto existing = funcOp->getAttrOfType<ArrayAttr>("arg_attrs")) {
- SmallVector<Attribute> kept;
- kept.reserve(newArgTypes.size());
- for (auto [oldIdx, attr] : llvm::enumerate(existing.getValue()))
- if (oldIdx >= fc.argInfos.size() ||
- fc.argInfos[oldIdx].kind != ArgKind::Ignore)
- kept.push_back(attr);
- funcOp->setAttr("arg_attrs", ArrayAttr::get(ctx, kept));
- }
+ // Rebuild arg_attrs: drop entries for Ignore'd args, layer
+ // llvm.signext / llvm.zeroext onto Extend args, preserve everything else.
+ bool needsArgAttrUpdate = !ignoredArgIndices(fc).empty();
+ for (const ArgClassification &ac : fc.argInfos)
+ if (ac.kind == ArgKind::Extend)
+ needsArgAttrUpdate = true;
+ if (needsArgAttrUpdate) {
+ auto existing = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
+ funcOp->setAttr("arg_attrs",
+ updateArgAttrs(ctx, newArgTypes.size(), existing, fc));
+ }
+
+ // Rebuild res_attrs: layer llvm.signext / llvm.zeroext onto an Extend
+ // return.
+ if (fc.returnInfo.kind == ArgKind::Extend) {
+ auto existing = funcOp->getAttrOfType<ArrayAttr>("res_attrs");
+ funcOp->setAttr("res_attrs", updateResAttrs(ctx, existing, fc.returnInfo));
+ }
return success();
}
@@ -183,6 +262,8 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
auto call = cast<cir::CallOp>(callOp);
+ MLIRContext *ctx = callOp->getContext();
+
for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
switch (ac.kind) {
case ArgKind::Direct:
@@ -193,10 +274,10 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
break;
case ArgKind::Ignore:
case ArgKind::Expand:
- break;
case ArgKind::Extend:
- return call.emitOpError() << "Extend at call-site arg " << idx
- << " not yet implemented in CallConvLowering";
+ // Extend at the call site is just an attribute change (llvm.signext /
+ // llvm.zeroext on the call's arg_attrs); no IR-level cast.
+ break;
case ArgKind::Indirect:
return call.emitOpError() << "Indirect at call-site arg " << idx
<< " not yet implemented in CallConvLowering";
@@ -217,15 +298,13 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
newArgs.push_back(argOperands[i]);
bool hasResult = call.getNumResults() > 0;
- Type origRetTy = hasResult ? call.getResult().getType()
- : cir::VoidType::get(callOp->getContext());
+ Type origRetTy =
+ hasResult ? call.getResult().getType() : cir::VoidType::get(ctx);
Type callRetTy = origRetTy;
if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
- callRetTy = cir::VoidType::get(callOp->getContext());
- if ((fc.returnInfo.kind == ArgKind::Direct ||
- fc.returnInfo.kind == ArgKind::Extend) &&
- fc.returnInfo.coercedType)
- return call.emitOpError() << "Direct/Extend return with coerced type at "
+ callRetTy = cir::VoidType::get(ctx);
+ if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType)
+ return call.emitOpError() << "Direct return with coerced type at "
<< "call-site not yet implemented in "
<< "CallConvLowering";
@@ -236,6 +315,24 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
if (!newCall->hasAttr(attr.getName()))
newCall->setAttr(attr.getName(), attr.getValue());
+ // Layer llvm.signext / llvm.zeroext onto the new call's arg_attrs and
+ // res_attrs for Extend args/return.
+ bool needsArgAttrUpdate = false;
+ for (const ArgClassification &ac : fc.argInfos)
+ if (ac.kind == ArgKind::Extend || ac.kind == ArgKind::Ignore) {
+ needsArgAttrUpdate = true;
+ break;
+ }
+ if (needsArgAttrUpdate) {
+ auto existing = call->getAttrOfType<ArrayAttr>("arg_attrs");
+ newCall->setAttr("arg_attrs",
+ updateArgAttrs(ctx, newArgs.size(), existing, fc));
+ }
+ if (fc.returnInfo.kind == ArgKind::Extend) {
+ auto existing = call->getAttrOfType<ArrayAttr>("res_attrs");
+ newCall->setAttr("res_attrs", updateResAttrs(ctx, existing, fc.returnInfo));
+ }
+
if (hasResult && fc.returnInfo.kind == ArgKind::Ignore) {
if (!call.getResult().use_empty()) {
rewriter.setInsertionPointAfter(newCall);
diff --git a/clang/test/CIR/Transforms/abi-lowering/extend-return.cir b/clang/test/CIR/Transforms/abi-lowering/extend-return.cir
new file mode 100644
index 0000000000000..dd2c40fc5305e
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/extend-return.cir
@@ -0,0 +1,41 @@
+// Extend on a narrow return attaches llvm.signext (or llvm.zeroext) to
+// res_attrs while keeping the narrow type in the signature. Body and
+// call-site IR are unchanged otherwise.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+!s8i = !cir.int<s, 8>
+
+#extend_signed_return = {
+ return = { kind = "extend",
+ coerced_type = !cir.int<s, 32>,
+ sign_extend = true },
+ args = [ ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i8, dense<8>: vector<2xi64>>,
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func @returns_s8() -> !s8i
+ attributes { test_classify = #extend_signed_return } {
+ %0 = cir.const #cir.int<7> : !s8i
+ cir.return %0 : !s8i
+ }
+
+ // CHECK: cir.func{{.*}} @returns_s8() -> (!s8i {llvm.signext})
+ // CHECK: cir.return %{{.*}} : !s8i
+
+ cir.func @caller() -> !s8i
+ attributes { test_classify = #extend_signed_return } {
+ %0 = cir.call @returns_s8() : () -> !s8i
+ cir.return %0 : !s8i
+ }
+
+ // CHECK: cir.func{{.*}} @caller() -> (!s8i {llvm.signext})
+ // CHECK: %[[R:.*]] = cir.call @returns_s8() : () -> (!s8i {llvm.signext})
+ // CHECK: cir.return %[[R]] : !s8i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/extend-signed-arg.cir b/clang/test/CIR/Transforms/abi-lowering/extend-signed-arg.cir
new file mode 100644
index 0000000000000..600acdc94d6b3
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/extend-signed-arg.cir
@@ -0,0 +1,40 @@
+// Extend with sign_extend = true on a narrow signed integer arg. The pass
+// keeps the narrow type in the signature (matching Classic Clang's LLVM IR
+// convention) and attaches llvm.signext to the corresponding arg_attrs.
+//
+// The test target's narrow-int Extend rule fires only on MLIR builtin
+// IntegerType, not cir::IntType, so this test uses the injection driver.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+!s8i = !cir.int<s, 8>
+
+#extend_signed_arg = {
+ return = { kind = "direct" },
+ args = [ { kind = "extend",
+ coerced_type = !cir.int<s, 32>,
+ sign_extend = true } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i8, dense<8>: vector<2xi64>>,
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func @takes_s8(%arg0: !s8i)
+ attributes { test_classify = #extend_signed_arg } {
+ cir.return
+ }
+
+ // CHECK: cir.func{{.*}} @takes_s8(%arg0: !s8i {llvm.signext})
+
+ cir.func @caller(%arg0: !s8i)
+ attributes { test_classify = #extend_signed_arg } {
+ cir.call @takes_s8(%arg0) : (!s8i) -> ()
+ cir.return
+ }
+
+ // CHECK: cir.call @takes_s8(%arg0) : (!s8i {llvm.signext}) -> ()
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/extend-unsigned-arg.cir b/clang/test/CIR/Transforms/abi-lowering/extend-unsigned-arg.cir
new file mode 100644
index 0000000000000..6c5c8207bb93b
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/extend-unsigned-arg.cir
@@ -0,0 +1,36 @@
+// Extend with sign_extend = false on a narrow unsigned integer arg attaches
+// llvm.zeroext (instead of llvm.signext) to the corresponding arg_attrs.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+!u16i = !cir.int<u, 16>
+
+#extend_unsigned_arg = {
+ return = { kind = "direct" },
+ args = [ { kind = "extend",
+ coerced_type = !cir.int<u, 32>,
+ sign_extend = false } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i16, dense<16>: vector<2xi64>>,
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func @takes_u16(%arg0: !u16i)
+ attributes { test_classify = #extend_unsigned_arg } {
+ cir.return
+ }
+
+ // CHECK: cir.func{{.*}} @takes_u16(%arg0: !u16i {llvm.zeroext})
+
+ cir.func @caller(%arg0: !u16i)
+ attributes { test_classify = #extend_unsigned_arg } {
+ cir.call @takes_u16(%arg0) : (!u16i) -> ()
+ cir.return
+ }
+
+ // CHECK: cir.call @takes_u16(%arg0) : (!u16i {llvm.zeroext}) -> ()
+
+}
>From 70674177f7fff56884f7cef34f78da3069e69033 Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Tue, 5 May 2026 09:26:10 -0700
Subject: [PATCH 4/4] [CIR] Add Direct coerce-in-registers +
cir.reinterpret_cast op
Fourth PR in the split of #192119/#192124. Implements the
Direct-with-coercion path in CallConvLowering and picks off
andykaylor's five inline review comments from the original PR.
The new cir.reinterpret_cast op is for same-bit-width in-register
reinterpretation (vector<2 x float> <-> complex<float>).
emitCoercion uses it when source and destination differ only in
vector-vs-non-vector shape and have identical bit width, instead
of going through memory. For everything else (records, or shape
doesn't match) the helper still does alloca/store/ptr-cast/load.
Andy's comments, in order:
- Temporary alloca alignment is now max(srcAlign, dstAlign) from
DataLayout instead of hardcoded.
- The alloca lives in the entry block via InsertionGuard so it
composes with HoistAllocas regardless of pipeline order.
- isVolatile kept as UnitAttr-absence with an inline comment.
- vector<->complex now uses cir.reinterpret_cast.
- Memory path has three new .cir tests covering it.
CallConvLowering needed splitting into three phases
(function-def coercion / call-site rewriting / Ignore cleanup)
because block-arg type changes from Direct-with-coerce confused
the earlier ordering: Ignore'd args were getting alloca/load
chains synthesized for call-site uses that were about to be
dropped anyway.
LowerToLLVM gets a stub for the new op: bitcast for same-shape
converted types, error-with-message for aggregates. We don't
produce aggregates from CallConvLowering today, so the error
path is only reachable from hand-written IR; follow-up patch can
add an extract/insert lowering if needed.
Co-authored-by: Cursor <cursoragent at cursor.com>
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 48 ++++
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 22 ++
.../Transforms/CallConvLoweringPass.cpp | 37 ++-
.../TargetLowering/CIRABIRewriteContext.cpp | 268 ++++++++++++++++--
.../TargetLowering/CIRABIRewriteContext.h | 29 +-
.../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 39 +++
clang/test/CIR/IR/reinterpret-cast.cir | 28 ++
.../abi-lowering/coerce-int-to-record.cir | 59 ++++
.../abi-lowering/coerce-record-to-int.cir | 50 ++++
.../coerce-record-to-record-via-memory.cir | 34 +++
.../coerce-vector-to-complex-reinterpret.cir | 42 +++
11 files changed, 619 insertions(+), 37 deletions(-)
create mode 100644 clang/test/CIR/IR/reinterpret-cast.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir
create mode 100644 clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 97d623ba5e6d9..1f0cb759864f8 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -266,6 +266,54 @@ def CIR_CastOp : CIR_Op<"cast", [
}];
}
+//===----------------------------------------------------------------------===//
+// ReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+def CIR_ReinterpretCastOp : CIR_Op<"reinterpret_cast", [Pure]> {
+ let summary = "Reinterpret a value as a different same-bit-width type";
+ let description = [{
+ The `cir.reinterpret_cast` operation reinterprets the bits of its source
+ value as a different type, with no IR-level cost. It is used by the
+ calling-convention lowering pass to coerce between same-bit-width types
+ that have an LLVM-IR-level shape mismatch but identical in-register
+ representation -- for example, between `!cir.vector<2 x !cir.float>` and
+ `!cir.complex<!cir.float>`, both of which lower to the same LLVM IR
+ representation but have distinct CIR types.
+
+ Unlike `cir.cast bitcast`, which is overloaded for pointer-to-pointer
+ bitcasts and several other use cases, `cir.reinterpret_cast` is reserved
+ for in-register value reinterpretation only. The result type must
+ differ from the source type; otherwise the op is meaningless and the
+ folder removes it.
+
+ **Invariant** (not currently enforced by the verifier): the source and
+ destination types must have the same bit width per the module's
+ DataLayout, and they must use the same in-register lane order on the
+ target. Producers (e.g. CallConvLowering's coerce-in-registers path)
+ are responsible for ensuring this; a follow-up patch will move the
+ bit-width check into the verifier once the design question of
+ DataLayout-aware op verifiers is resolved.
+
+ Example:
+
+ ```
+ %c = cir.reinterpret_cast %v
+ : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+ ```
+ }];
+
+ let arguments = (ins CIR_AnyType:$src);
+ let results = (outs CIR_AnyType:$result);
+
+ let assemblyFormat = [{
+ $src `:` type($src) `->` type($result) attr-dict
+ }];
+
+ let hasVerifier = 1;
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// DynamicCastOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 7386819d8fce9..80c3a3ecaea4a 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -821,6 +821,28 @@ static Value tryFoldCastChain(cir::CastOp op) {
return {};
}
+//===----------------------------------------------------------------------===//
+// ReinterpretCastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult cir::ReinterpretCastOp::verify() {
+ // The op is meaningless for identical types -- the folder is the right
+ // way to remove it -- but we accept it at the verifier level so that
+ // peephole code (e.g. pattern rewriters that round-trip values) doesn't
+ // need a type-equality guard. Producers should still avoid emitting
+ // it for matching types.
+ //
+ // The same-bit-width invariant is documented on the op but not yet
+ // checked here; see the op description for the rationale.
+ return success();
+}
+
+OpFoldResult cir::ReinterpretCastOp::fold(FoldAdaptor adaptor) {
+ if (getSrc().getType() == getType())
+ return getSrc();
+ return {};
+}
+
OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
if (mlir::isa_and_present<cir::PoisonAttr>(adaptor.getSrc())) {
// Propagate poison value
diff --git a/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp b/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
index e50aeca1791e9..5131bcfa4316a 100644
--- a/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CallConvLoweringPass.cpp
@@ -119,7 +119,7 @@ void CallConvLoweringPass::runOnOperation() {
}
DataLayout dl(module);
- CIRABIRewriteContext rewriteCtx(module);
+ CIRABIRewriteContext rewriteCtx(module, dl);
// Pre-compute classifications for every cir.func so that call-site
// rewriting can find them (call site uses callee's classification).
@@ -140,9 +140,33 @@ void CallConvLoweringPass::runOnOperation() {
OpBuilder rewriter(ctx);
- // Rewrite call sites first, while functions still have their original
- // signatures. This avoids any chance of us reading a partially-rewritten
- // signature and matching args against the wrong classification.
+ // Three-phase rewrite. Each phase needs the previous one to be complete
+ // across every function before it can run, so they're three separate
+ // sweeps over the module:
+ //
+ // 1. rewriteFunctionDefinition: in-body coercion only. Block-arg
+ // types for Direct-with-coerce / Extend args change here, and
+ // replaceAllUsesExcept routes existing uses (including in-body
+ // cir.call operands) over to the adapted (original-type) value.
+ // Ignore handling and signature finalization are deferred.
+ // 2. rewriteCallSite: each call site coerces args, drops Ignore'd
+ // args, and swaps the call to the lowered signature. Now Ignore'd
+ // block args have no remaining uses.
+ // 3. finalizeFunctionDefinition: erase the now-use-empty Ignore'd
+ // block args, drop Ignore'd return operands, finalize the function
+ // signature, and attach the llvm.signext / llvm.zeroext attrs.
+ //
+ // Splitting (1) and (3) avoids synthesizing dead alloca/load chains
+ // for Ignore'd args whose uses were going to be dropped by (2) anyway.
+
+ for (auto &kv : classifications) {
+ if (failed(rewriteCtx.rewriteFunctionDefinition(kv.first, kv.second,
+ rewriter))) {
+ signalPassFailure();
+ return;
+ }
+ }
+
SmallVector<cir::CallOp> calls;
module.walk([&](cir::CallOp c) { calls.push_back(c); });
for (cir::CallOp call : calls) {
@@ -158,10 +182,9 @@ void CallConvLoweringPass::runOnOperation() {
}
}
- // Now rewrite each function definition.
for (auto &kv : classifications) {
- if (failed(rewriteCtx.rewriteFunctionDefinition(kv.first, kv.second,
- rewriter))) {
+ if (failed(rewriteCtx.finalizeFunctionDefinition(kv.first, kv.second,
+ rewriter))) {
signalPassFailure();
return;
}
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
index db0a9bcfcbb39..18937d694260b 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -46,12 +46,11 @@ LogicalResult buildNewArgTypes(ArrayRef<Type> oldArgTypes,
Type origTy = oldArgTypes[idx];
switch (ac.kind) {
case ArgKind::Direct:
- if (ac.coercedType) {
- emitError() << "Direct with coerced type at arg " << idx
- << " not yet implemented in CallConvLowering";
- return failure();
- }
- newArgTypes.push_back(origTy);
+ // Direct with a coerced type means the wire signature uses the
+ // coerced type; the body still expects origTy and we'll insert a
+ // reinterpret/coercion at the entry block. Direct without a
+ // coerced type is a true pass-through.
+ newArgTypes.push_back(ac.coercedType ? ac.coercedType : origTy);
break;
case ArgKind::Ignore:
break;
@@ -81,12 +80,9 @@ Type computeNewReturnType(Type origRetTy, const ArgClassification &retInfo,
function_ref<InFlightDiagnostic()> emitError) {
switch (retInfo.kind) {
case ArgKind::Direct:
- if (retInfo.coercedType) {
- emitError() << "Direct return with coerced type not yet implemented "
- << "in CallConvLowering";
- return nullptr;
- }
- return origRetTy;
+ // Direct return with a coerced type uses the coerced type on the wire;
+ // the rewriter inserts a coercion before each cir.return.
+ return retInfo.coercedType ? retInfo.coercedType : origRetTy;
case ArgKind::Ignore:
return cir::VoidType::get(ctx);
case ArgKind::Expand:
@@ -165,11 +161,212 @@ ArrayAttr updateResAttrs(MLIRContext *ctx, ArrayAttr existingResAttrs,
return ArrayAttr::get(ctx, {DictionaryAttr::get(ctx, attrs)});
}
+/// Coerce \p src to type \p dstTy at the current builder insertion point.
+///
+/// Three strategies, in order of preference:
+/// - If src and dst are the same type, return src unchanged and leave
+/// \p createdOps empty.
+/// - If both are non-aggregate same-bit-width values that just differ in
+/// vector-vs-scalar shape (e.g. !cir.vector<2 x !cir.float> ↔
+/// !cir.complex<!cir.float>), use cir.reinterpret_cast which is free at
+/// the IR level.
+/// - Otherwise go through memory: allocate a slot of the source type
+/// (using max(srcAlign, dstAlign) for the alloca alignment), store
+/// the source, bitcast the pointer to the destination type, load the
+/// destination type back.
+///
+/// The temporary alloca is placed at the start of the enclosing function's
+/// entry block so that it composes correctly with the HoistAllocas pass
+/// regardless of pipeline ordering.
+///
+/// Any operations the helper creates are appended to \p createdOps so the
+/// caller can pass them to replaceAllUsesExcept and avoid clobbering the
+/// store's value operand when later rewiring the source value.
+Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy, Value src,
+ FunctionOpInterface funcOp, const DataLayout &dl,
+ SmallPtrSetImpl<Operation *> &createdOps) {
+ Type srcTy = src.getType();
+ if (srcTy == dstTy)
+ return src;
+
+ // Reinterpret path: same total bit width, neither side is a record, and
+ // the shapes differ only in vector-vs-non-vector. Going through memory
+ // is wasteful for these — they have the same in-register representation.
+ bool isAggregate = isa<cir::RecordType>(srcTy) || isa<cir::RecordType>(dstTy);
+ bool vectorMismatch =
+ isa<cir::VectorType>(srcTy) != isa<cir::VectorType>(dstTy);
+ if (!isAggregate && vectorMismatch &&
+ dl.getTypeSizeInBits(srcTy) == dl.getTypeSizeInBits(dstTy)) {
+ auto reinterpret =
+ cir::ReinterpretCastOp::create(rewriter, loc, dstTy, src);
+ createdOps.insert(reinterpret);
+ return reinterpret;
+ }
+
+ // Memory path: alloca + store + ptr-cast + load. The alloca goes in the
+ // entry block (Andy's review comment #3 on the original PR), with
+ // alignment = max(srcAlign, dstAlign) to satisfy both the store and the
+ // load (review comment #1).
+ uint64_t srcAlign = dl.getTypeABIAlignment(srcTy);
+ uint64_t dstAlign = dl.getTypeABIAlignment(dstTy);
+ uint64_t allocaAlign = std::max(srcAlign, dstAlign);
+
+ auto srcPtrTy = cir::PointerType::get(srcTy);
+ auto dstPtrTy = cir::PointerType::get(dstTy);
+
+ cir::AllocaOp alloca;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Block &entry = funcOp->getRegion(0).front();
+ rewriter.setInsertionPointToStart(&entry);
+ alloca = cir::AllocaOp::create(rewriter, loc, srcPtrTy, srcTy,
+ rewriter.getStringAttr("coerce"),
+ rewriter.getI64IntegerAttr(allocaAlign));
+ }
+ createdOps.insert(alloca);
+
+ auto store = cir::StoreOp::create(rewriter, loc, src, alloca,
+ /*isVolatile=*/UnitAttr(),
+ /*alignment=*/IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+ createdOps.insert(store);
+
+ auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy,
+ cir::CastKind::bitcast, alloca);
+ createdOps.insert(ptrCast);
+
+ auto load = cir::LoadOp::create(rewriter, loc, dstTy, ptrCast,
+ /*isDeref=*/UnitAttr(),
+ /*isVolatile=*/UnitAttr(),
+ /*alignment=*/IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+ createdOps.insert(load);
+ return load;
+}
+
+/// Convenience overload for callers that don't need the createdOps set
+/// (e.g. call-site coercion where we don't replaceAllUsesExcept).
+Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy, Value src,
+ FunctionOpInterface funcOp, const DataLayout &dl) {
+ SmallPtrSet<Operation *, 4> ignored;
+ return emitCoercion(rewriter, loc, dstTy, src, funcOp, dl, ignored);
+}
+
+/// Insert coercion before each cir.return so the returned value matches the
+/// new (coerced) return type.
+void insertReturnCoercion(FunctionOpInterface funcOp, Type origRetTy,
+ Type coercedRetTy, OpBuilder &rewriter,
+ const DataLayout &dl) {
+ SmallVector<cir::ReturnOp> returns;
+ funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
+ for (cir::ReturnOp r : returns) {
+ if (r.getInput().empty())
+ continue;
+ Value origVal = r.getInput()[0];
+ if (origVal.getType() == coercedRetTy)
+ continue;
+ rewriter.setInsertionPoint(r);
+ Value coerced =
+ emitCoercion(rewriter, r.getLoc(), coercedRetTy, origVal, funcOp, dl);
+ r->setOperand(0, coerced);
+ }
+}
+
+/// For each Direct arg with a coerced type, change the block argument's type
+/// to the coerced type and insert a coercion at function entry that maps it
+/// back to the original type for body uses.
+void insertArgCoercion(FunctionOpInterface funcOp,
+ const FunctionClassification &fc, OpBuilder &rewriter,
+ const DataLayout &dl) {
+ Region &body = funcOp->getRegion(0);
+ if (body.empty())
+ return;
+ Block &entry = body.front();
+
+ for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
+ if (ac.kind != ArgKind::Direct || !ac.coercedType)
+ continue;
+ if (idx >= entry.getNumArguments())
+ continue;
+
+ BlockArgument blockArg = entry.getArgument(idx);
+ Type oldArgTy = blockArg.getType();
+ Type newArgTy = ac.coercedType;
+ if (oldArgTy == newArgTy)
+ continue;
+
+ blockArg.setType(newArgTy);
+
+ rewriter.setInsertionPointToStart(&entry);
+ SmallPtrSet<Operation *, 4> coercionOps;
+ Value adapted = emitCoercion(rewriter, funcOp.getLoc(), oldArgTy, blockArg,
+ funcOp, dl, coercionOps);
+
+ // Replace blockArg uses with the adapted value, except inside the helper
+ // ops we just created. This is critical: the StoreOp's value operand is
+ // blockArg, and if we naively replaceAllUses it gets swapped to adapted
+ // (now of the original type != the alloca's pointee type).
+ blockArg.replaceAllUsesExcept(adapted, coercionOps);
+ }
+}
+
} // namespace
LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
FunctionOpInterface funcOp, const FunctionClassification &fc,
OpBuilder &rewriter) {
+ // Phase 1: in-body coercion only. Argument-list / return-list shrinking
+ // for Ignore'd entries and final signature update happen in
+ // finalizeFunctionDefinition AFTER all call sites are rewritten, so that
+ // call sites have already stopped passing the to-be-dropped arguments
+ // (otherwise we would synthesize dead alloca/load chains for soon-to-be-
+ // dropped uses).
+ if (!needsRewrite(fc))
+ return success();
+
+ // We still need to detect the no-coercion path early so we don't error on
+ // unimplemented kinds when there's nothing for us to do. If
+ // computeNewReturnType emits a diagnostic for an unsupported kind, fail.
+ ArrayRef<Type> oldArgTypes = funcOp.getArgumentTypes();
+ ArrayRef<Type> oldResultTypes = funcOp.getResultTypes();
+ MLIRContext *ctx = funcOp->getContext();
+
+ // Validate arg classifications (errors out for unimplemented kinds like
+ // Indirect; the resulting newArgTypes is unused in this phase).
+ SmallVector<Type> newArgTypes;
+ if (failed(buildNewArgTypes(oldArgTypes, fc, newArgTypes,
+ [&]() { return funcOp.emitOpError(); })))
+ return failure();
+
+ Type voidTy = cir::VoidType::get(ctx);
+ Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
+ Type newRetTy = computeNewReturnType(origRetTy, fc.returnInfo, ctx,
+ [&]() { return funcOp.emitOpError(); });
+ if (!newRetTy)
+ return failure();
+
+ if (!funcOp.isDeclaration()) {
+ Region &body = funcOp->getRegion(0);
+ if (!body.empty())
+ insertArgCoercion(funcOp, fc, rewriter, dl);
+
+ // Direct return with coerced type: insert coercion at every cir.return
+ // so the returned value matches the new return type. Done in phase 1
+ // because the operand swap doesn't depend on call sites.
+ if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType &&
+ !oldResultTypes.empty() && fc.returnInfo.coercedType != origRetTy)
+ insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
+ rewriter, dl);
+ }
+
+ return success();
+}
+
+LogicalResult CIRABIRewriteContext::finalizeFunctionDefinition(
+ FunctionOpInterface funcOp, const FunctionClassification &fc,
+ OpBuilder &rewriter) {
if (!needsRewrite(fc))
return success();
@@ -194,7 +391,6 @@ LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
Region &body = funcOp->getRegion(0);
if (!body.empty()) {
Block &entry = body.front();
-
SmallVector<unsigned> ignored = ignoredArgIndices(fc);
for (int i = static_cast<int>(ignored.size()) - 1; i >= 0; --i) {
unsigned blockIdx = ignored[i];
@@ -202,6 +398,12 @@ LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
continue;
BlockArgument arg = entry.getArgument(blockIdx);
if (!arg.use_empty()) {
+ // Defensive: any non-call use of an Ignore'd arg gets a stub
+ // alloca/load chain so the IR stays well-typed. In practice
+ // call sites should be the only users (cir-call-conv-lowering
+ // is the only producer of Ignore classifications today) and
+ // they'll have already dropped the operand by the time we get
+ // here.
rewriter.setInsertionPointToStart(&entry);
auto ptrTy = cir::PointerType::get(arg.getType());
auto alloca = cir::AllocaOp::create(
@@ -263,20 +465,16 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
auto call = cast<cir::CallOp>(callOp);
MLIRContext *ctx = callOp->getContext();
+ auto enclosingFunc = call->getParentOfType<FunctionOpInterface>();
for (auto [idx, ac] : llvm::enumerate(fc.argInfos)) {
switch (ac.kind) {
case ArgKind::Direct:
- if (ac.coercedType)
- return call.emitOpError()
- << "Direct with coerced type at call-site arg " << idx
- << " not yet implemented in CallConvLowering";
- break;
case ArgKind::Ignore:
case ArgKind::Expand:
case ArgKind::Extend:
- // Extend at the call site is just an attribute change (llvm.signext /
- // llvm.zeroext on the call's arg_attrs); no IR-level cast.
+ // Direct (with or without coercion), Ignore, Expand, and Extend are
+ // all handled below. Extend is attribute-only at the IR level.
break;
case ArgKind::Indirect:
return call.emitOpError() << "Indirect at call-site arg " << idx
@@ -284,6 +482,8 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
}
}
+ rewriter.setInsertionPoint(call);
+
SmallVector<Value> newArgs;
ValueRange argOperands = call.getArgOperands();
newArgs.reserve(argOperands.size());
@@ -292,7 +492,12 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
break;
if (ac.kind == ArgKind::Ignore)
continue;
- newArgs.push_back(argOperands[idx]);
+ Value arg = argOperands[idx];
+ if (ac.kind == ArgKind::Direct && ac.coercedType &&
+ arg.getType() != ac.coercedType)
+ arg = emitCoercion(rewriter, call.getLoc(), ac.coercedType, arg,
+ enclosingFunc, dl);
+ newArgs.push_back(arg);
}
for (unsigned i = fc.argInfos.size(); i < argOperands.size(); ++i)
newArgs.push_back(argOperands[i]);
@@ -303,10 +508,11 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
Type callRetTy = origRetTy;
if (fc.returnInfo.kind == ArgKind::Ignore && hasResult)
callRetTy = cir::VoidType::get(ctx);
- if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType)
- return call.emitOpError() << "Direct return with coerced type at "
- << "call-site not yet implemented in "
- << "CallConvLowering";
+ bool returnNeedsCoercion =
+ hasResult && fc.returnInfo.kind == ArgKind::Direct &&
+ fc.returnInfo.coercedType && fc.returnInfo.coercedType != origRetTy;
+ if (returnNeedsCoercion)
+ callRetTy = fc.returnInfo.coercedType;
rewriter.setInsertionPoint(call);
auto newCall = cir::CallOp::create(rewriter, call.getLoc(),
@@ -315,6 +521,15 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
if (!newCall->hasAttr(attr.getName()))
newCall->setAttr(attr.getName(), attr.getValue());
+ // Direct return with coercion: the new call returns the coerced type;
+ // emit a coercion back to the original type for the call's existing uses.
+ if (returnNeedsCoercion) {
+ rewriter.setInsertionPointAfter(newCall);
+ Value coercedBack = emitCoercion(rewriter, call.getLoc(), origRetTy,
+ newCall.getResult(), enclosingFunc, dl);
+ call.getResult().replaceAllUsesWith(coercedBack);
+ }
+
// Layer llvm.signext / llvm.zeroext onto the new call's arg_attrs and
// res_attrs for Extend args/return.
bool needsArgAttrUpdate = false;
@@ -345,7 +560,8 @@ LogicalResult CIRABIRewriteContext::rewriteCallSite(
IntegerAttr(), cir::SyncScopeKindAttr(), cir::MemOrderAttr());
call.getResult().replaceAllUsesWith(load);
}
- } else if (hasResult) {
+ } else if (hasResult && !returnNeedsCoercion) {
+ // returnNeedsCoercion already wired up the coerced result above.
call.getResult().replaceAllUsesWith(newCall.getResult());
}
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
index cf8635e9afdd6..ba5300719a345 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
@@ -11,9 +11,10 @@
// rewrites a cir.func signature, the function body, and call sites to match
// the ABI-lowered shape.
//
-// This file currently handles only Direct (pass-through) and Ignore. Other
-// ArgKind handlers (Extend, Direct-with-coercion, Indirect, Expand) are
-// added by subsequent PRs in the calling-convention-lowering split series.
+// This file currently handles Direct (pass-through and coerce-in-registers),
+// Extend, and Ignore. The remaining ArgKind handlers (Indirect, Expand)
+// are added by subsequent PRs in the calling-convention-lowering split
+// series.
//
//===----------------------------------------------------------------------===//
@@ -22,6 +23,7 @@
#include "mlir/ABI/ABIRewriteContext.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
namespace cir {
@@ -31,9 +33,13 @@ namespace cir {
/// The driver pass (CallConvLoweringPass) computes a FunctionClassification
/// for each cir.func / cir.call and dispatches to this class to perform the
/// actual IR rewriting using cir dialect operations.
+///
+/// Holds a reference to the module's DataLayout for coercion alignment
+/// queries. The DataLayout must outlive the rewrite context.
class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext {
public:
- explicit CIRABIRewriteContext(mlir::ModuleOp module) : module(module) {}
+ CIRABIRewriteContext(mlir::ModuleOp module, const mlir::DataLayout &dl)
+ : module(module), dl(dl) {}
mlir::LogicalResult
rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp,
@@ -45,10 +51,25 @@ class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext {
const mlir::abi::FunctionClassification &fc,
mlir::OpBuilder &rewriter) override;
+ /// Phase 2 of function-definition rewriting, called by the pass driver
+ /// AFTER all call sites have been rewritten. Drops Ignore'd arguments
+ /// from the block argument list (now use-empty since call sites no
+ /// longer pass them), drops Ignore'd return operands, and finalizes the
+ /// function signature with the new (possibly shorter) argument list.
+ ///
+ /// This must run after rewriteCallSite so that the body's calls have
+ /// already stopped passing the to-be-dropped arguments. Otherwise the
+ /// drops here would leave dangling uses.
+ mlir::LogicalResult
+ finalizeFunctionDefinition(mlir::FunctionOpInterface funcOp,
+ const mlir::abi::FunctionClassification &fc,
+ mlir::OpBuilder &rewriter);
+
mlir::StringRef getDialectNamespace() const override { return "cir"; }
private:
mlir::ModuleOp module;
+ const mlir::DataLayout &dl;
};
} // namespace cir
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index e17c7a209db6b..e072e9ba41c1e 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1684,6 +1684,45 @@ mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite(
return mlir::LogicalResult::success();
}
+mlir::LogicalResult CIRToLLVMReinterpretCastOpLowering::matchAndRewrite(
+ cir::ReinterpretCastOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ // After type conversion, source and destination LLVM types may be:
+ // (a) Identical: trivially replace uses with the source value (the
+ // op was a CIR-level type rename only; LLVM sees no change).
+ // (b) Same scalar / vector category, same bit width: emit
+ // LLVM::BitcastOp.
+ // (c) Aggregate vs scalar / aggregate vs vector: LLVM::BitcastOp
+ // does not allow aggregate types. We currently emit an error
+ // directing the producer to go through memory. A future patch
+ // will add an extract/insert lowering for the aggregate case so
+ // the LLVM IR avoids the memory roundtrip too.
+ mlir::Type llvmDstTy = getTypeConverter()->convertType(op.getType());
+ mlir::Value llvmSrc = adaptor.getSrc();
+ mlir::Type llvmSrcTy = llvmSrc.getType();
+
+ if (llvmSrcTy == llvmDstTy) {
+ rewriter.replaceOp(op, llvmSrc);
+ return mlir::success();
+ }
+
+ bool srcIsAggregate =
+ mlir::isa<mlir::LLVM::LLVMStructType, mlir::LLVM::LLVMArrayType>(
+ llvmSrcTy);
+ bool dstIsAggregate =
+ mlir::isa<mlir::LLVM::LLVMStructType, mlir::LLVM::LLVMArrayType>(
+ llvmDstTy);
+ if (srcIsAggregate || dstIsAggregate)
+ return op.emitOpError()
+ << "lowering cir.reinterpret_cast to LLVM with aggregate type "
+ << "not yet implemented; producer should fall back to memory "
+ << "coercion until a follow-up patch adds extract/insert "
+ << "lowering";
+
+ rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llvmDstTy, llvmSrc);
+ return mlir::success();
+}
+
mlir::LogicalResult CIRToLLVMRotateOpLowering::matchAndRewrite(
cir::RotateOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/reinterpret-cast.cir b/clang/test/CIR/IR/reinterpret-cast.cir
new file mode 100644
index 0000000000000..94742e15cda42
--- /dev/null
+++ b/clang/test/CIR/IR/reinterpret-cast.cir
@@ -0,0 +1,28 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ // Vector ↔ complex same-bit-width reinterpret (the canonical use case
+ // from cir-call-conv-lowering's coerce-in-registers path).
+ cir.func @vec_to_complex(%v : !cir.vector<2 x !cir.float>)
+ -> !cir.complex<!cir.float> {
+ %c = cir.reinterpret_cast %v
+ : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+ cir.return %c : !cir.complex<!cir.float>
+ }
+
+ // Reverse direction.
+ cir.func @complex_to_vec(%c : !cir.complex<!cir.float>)
+ -> !cir.vector<2 x !cir.float> {
+ %v = cir.reinterpret_cast %c
+ : !cir.complex<!cir.float> -> !cir.vector<2 x !cir.float>
+ cir.return %v : !cir.vector<2 x !cir.float>
+ }
+}
+
+// CHECK: cir.func{{.*}} @vec_to_complex
+// CHECK: cir.reinterpret_cast %{{.*}} : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+
+// CHECK: cir.func{{.*}} @complex_to_vec
+// CHECK: cir.reinterpret_cast %{{.*}} : !cir.complex<!cir.float> -> !cir.vector<2 x !cir.float>
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir b/clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir
new file mode 100644
index 0000000000000..f90427bf68b4c
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-int-to-record.cir
@@ -0,0 +1,59 @@
+// Direct return with coerced type going from a small record to a same-bit-
+// width integer. Mirror of coerce-record-to-int.cir but exercising the
+// return-side coercion code path: every cir.return gets the original
+// record value coerced to the integer type before being returned.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_Pair = !cir.record<struct "Pair" {!s32i, !s32i}>
+
+#coerce_pair_return_to_i64 = {
+ return = { kind = "direct", coerced_type = !s64i },
+ args = [ ]
+}
+
+#all_direct_no_args = {
+ return = { kind = "direct" },
+ args = [ ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+ #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+ cir.func @returns_pair() -> !rec_Pair
+ attributes { test_classify = #coerce_pair_return_to_i64 } {
+ %0 = cir.const #cir.zero : !rec_Pair
+ cir.return %0 : !rec_Pair
+ }
+
+ // Signature changes to !s64i return; the cir.return's record operand
+ // gets coerced via memory roundtrip before being returned. The alloca
+ // is hoisted to the entry-block start (Andy's review comment #3 from the
+ // original PR) so it sits ahead of the const that produces the value.
+ // CHECK: cir.func{{.*}} @returns_pair() -> !s64i
+ // CHECK: %[[SLOT:.*]] = cir.alloca !rec_Pair, !cir.ptr<!rec_Pair>, ["coerce"]
+ // CHECK: %[[VAL:.*]] = cir.const #cir.zero : !rec_Pair
+ // CHECK: cir.store %[[VAL]], %[[SLOT]] : !rec_Pair, !cir.ptr<!rec_Pair>
+ // CHECK: %[[CAST:.*]] = cir.cast bitcast %[[SLOT]] : !cir.ptr<!rec_Pair> -> !cir.ptr<!s64i>
+ // CHECK: %[[COERCED:.*]] = cir.load %[[CAST]] : !cir.ptr<!s64i>, !s64i
+ // CHECK: cir.return %[[COERCED]] : !s64i
+
+ cir.func @caller() -> !rec_Pair
+ attributes { test_classify = #coerce_pair_return_to_i64 } {
+ %0 = cir.call @returns_pair() : () -> !rec_Pair
+ cir.return %0 : !rec_Pair
+ }
+
+ // At the call site the lowered call returns !s64i; the rewriter coerces
+ // it back to !rec_Pair for downstream uses (the caller's own return
+ // also needs the coerce-back-then-coerce-forward chain since caller's
+ // return is also Direct-with-coerce).
+ // CHECK: cir.func{{.*}} @caller() -> !s64i
+ // CHECK: %{{.*}} = cir.call @returns_pair() : () -> !s64i
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir
new file mode 100644
index 0000000000000..f31f09181710e
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-int.cir
@@ -0,0 +1,50 @@
+// Direct with coerced type going from a small record to a same-bit-width
+// integer. The shapes don't match (record vs scalar) so the rewriter
+// emits a memory roundtrip: alloca in the entry block + store + ptr-cast +
+// load.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_Pair = !cir.record<struct "Pair" {!s32i, !s32i}>
+
+#coerce_pair_to_i64 = {
+ return = { kind = "direct" },
+ args = [ { kind = "direct", coerced_type = !s64i } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+ #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+ cir.func @takes_pair(%arg0: !rec_Pair)
+ attributes { test_classify = #coerce_pair_to_i64 } {
+ cir.return
+ }
+
+ // Signature changes to !s64i; entry block grows an alloca + store + cast
+ // + load chain that recovers the original !rec_Pair value. The alloca
+ // lands at the very start of the entry block so this composes correctly
+ // with cir-hoist-allocas regardless of pipeline ordering.
+ // CHECK: cir.func{{.*}} @takes_pair(%[[ARG:.*]]: !s64i)
+ // CHECK: %[[SLOT:.*]] = cir.alloca !s64i, !cir.ptr<!s64i>, ["coerce"]
+ // CHECK: cir.store %[[ARG]], %[[SLOT]] : !s64i, !cir.ptr<!s64i>
+ // CHECK: %[[CAST:.*]] = cir.cast bitcast %[[SLOT]] : !cir.ptr<!s64i> -> !cir.ptr<!rec_Pair>
+ // CHECK: %{{.*}} = cir.load %[[CAST]] : !cir.ptr<!rec_Pair>, !rec_Pair
+
+ cir.func @caller(%arg0: !rec_Pair)
+ attributes { test_classify = #coerce_pair_to_i64 } {
+ cir.call @takes_pair(%arg0) : (!rec_Pair) -> ()
+ cir.return
+ }
+
+ // At the call site, the original !rec_Pair gets coerced to !s64i via the
+ // same memory roundtrip before being passed. Caller's own arg coercion
+ // chain runs first (it shares the pattern), then the call.
+ // CHECK: cir.func{{.*}} @caller(%[[ARG:.*]]: !s64i)
+ // CHECK: cir.call @takes_pair(%{{.*}}) : (!s64i) -> ()
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir
new file mode 100644
index 0000000000000..1669bf1232d28
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-record-to-record-via-memory.cir
@@ -0,0 +1,34 @@
+// Direct with a coerced type that's a different record (record-to-record):
+// neither side is a vector and at least one is a record, so the rewriter
+// uses the memory-roundtrip path even though both types are aggregates.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_Pair = !cir.record<struct "Pair" {!s32i, !s32i}>
+!rec_Single = !cir.record<struct "Single" {!s64i}>
+
+#coerce_pair_to_single = {
+ return = { kind = "direct" },
+ args = [ { kind = "direct", coerced_type = !rec_Single } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+ #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+ cir.func @takes_pair(%arg0: !rec_Pair)
+ attributes { test_classify = #coerce_pair_to_single } {
+ cir.return
+ }
+
+ // CHECK: cir.func{{.*}} @takes_pair(%[[ARG:.*]]: !rec_Single)
+ // CHECK: %[[SLOT:.*]] = cir.alloca !rec_Single, !cir.ptr<!rec_Single>, ["coerce"]
+ // CHECK: cir.store %[[ARG]], %[[SLOT]] : !rec_Single, !cir.ptr<!rec_Single>
+ // CHECK: %[[CAST:.*]] = cir.cast bitcast %[[SLOT]] : !cir.ptr<!rec_Single> -> !cir.ptr<!rec_Pair>
+ // CHECK: %{{.*}} = cir.load %[[CAST]] : !cir.ptr<!rec_Pair>, !rec_Pair
+
+}
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir b/clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir
new file mode 100644
index 0000000000000..ceb1f9e364466
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-vector-to-complex-reinterpret.cir
@@ -0,0 +1,42 @@
+// Direct with coerced type that differs from the original only in
+// vector-vs-non-vector shape (same total bit width, neither side a record):
+// the rewriter emits cir.reinterpret_cast instead of going through memory.
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN: | FileCheck %s
+
+#coerce_complex_to_vec2 = {
+ return = { kind = "direct" },
+ args = [ { kind = "direct",
+ coerced_type = !cir.vector<2 x !cir.float> } ]
+}
+
+module attributes {
+ dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<f32, dense<32>: vector<2xi64>>>
+} {
+
+ cir.func @takes_complex(%arg0: !cir.complex<!cir.float>)
+ attributes { test_classify = #coerce_complex_to_vec2 } {
+ cir.return
+ }
+
+ // The signature changes to the coerced (vector) type; the body still
+ // expects the complex, so a reinterpret_cast lands at function entry to
+ // adapt the new block argument back to the original type.
+ // CHECK: cir.func{{.*}} @takes_complex(%[[ARG:.*]]: !cir.vector<2 x !cir.float>)
+ // CHECK: %{{.*}} = cir.reinterpret_cast %[[ARG]] : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+
+ cir.func @caller(%arg0: !cir.complex<!cir.float>)
+ attributes { test_classify = #coerce_complex_to_vec2 } {
+ cir.call @takes_complex(%arg0) : (!cir.complex<!cir.float>) -> ()
+ cir.return
+ }
+
+ // At the call site the rewriter coerces the original (complex) value to
+ // the vector type before passing it through.
+ // CHECK: cir.func{{.*}} @caller(%[[ARG:.*]]: !cir.vector<2 x !cir.float>)
+ // CHECK: %[[COMPLEX:.*]] = cir.reinterpret_cast %[[ARG]] : !cir.vector<2 x !cir.float> -> !cir.complex<!cir.float>
+ // CHECK: %[[COERCED:.*]] = cir.reinterpret_cast %[[COMPLEX]] : !cir.complex<!cir.float> -> !cir.vector<2 x !cir.float>
+ // CHECK: cir.call @takes_complex(%[[COERCED]]) : (!cir.vector<2 x !cir.float>) -> ()
+
+}
More information about the Mlir-commits
mailing list