[Mlir-commits] [clang] [mlir] [CIR] Add CIRABIRewriteContext for ABI function/call rewriting (PR #192119)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 14 12:55:07 PDT 2026
https://github.com/adams381 created https://github.com/llvm/llvm-project/pull/192119
CIR-specific concrete subclass of the `ABIRewriteContext` interface introduced in #190661. Rewrites CIR FuncOps and CallOps to match ABI-lowered signatures.
This first PR handles the scalar cases:
- Direct passthrough and scalar coercion (bitcast)
- Extend (integer widening with signext/zeroext attrs)
- Ignore (void returns, empty-struct arg erasure)
- Call-site rewrites for all of the above
Struct coercion (sret, byval, multi-register flattening) comes next.
11 C++ unit tests — each constructs a `FunctionClassification` by hand and verifies the rewritten IR, so no ABI classifier dependency.
Depends on #190661.
Made with [Cursor](https://cursor.com)
>From 23ef4674ee8ac04d7777d5442071a781b8a09b7b Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 6 Apr 2026 12:14:06 -0700
Subject: [PATCH 1/3] [mlir][ABI] Add ABITypeMapper and ABIRewriteContext
Dialect-agnostic layer bridging MLIR and the LLVM ABI library.
ABITypeMapper handles built-in MLIR types; ABIRewriteContext is
the interface dialects implement for ABI rewrites (see
clang/docs/ClangIRABILowering.md Section 4).
Made-with: Cursor
---
mlir/include/mlir/ABI/ABIRewriteContext.h | 159 +++++++++++++++++
mlir/include/mlir/ABI/ABITypeMapper.h | 66 +++++++
mlir/lib/ABI/ABITypeMapper.cpp | 102 +++++++++++
mlir/lib/ABI/CMakeLists.txt | 14 ++
mlir/lib/CMakeLists.txt | 1 +
mlir/unittests/ABI/ABIRewriteContextTest.cpp | 99 +++++++++++
mlir/unittests/ABI/ABITypeMapperTest.cpp | 173 +++++++++++++++++++
mlir/unittests/ABI/CMakeLists.txt | 12 ++
mlir/unittests/CMakeLists.txt | 1 +
9 files changed, 627 insertions(+)
create mode 100644 mlir/include/mlir/ABI/ABIRewriteContext.h
create mode 100644 mlir/include/mlir/ABI/ABITypeMapper.h
create mode 100644 mlir/lib/ABI/ABITypeMapper.cpp
create mode 100644 mlir/lib/ABI/CMakeLists.txt
create mode 100644 mlir/unittests/ABI/ABIRewriteContextTest.cpp
create mode 100644 mlir/unittests/ABI/ABITypeMapperTest.cpp
create mode 100644 mlir/unittests/ABI/CMakeLists.txt
diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h b/mlir/include/mlir/ABI/ABIRewriteContext.h
new file mode 100644
index 0000000000000..71d5a56b599b7
--- /dev/null
+++ b/mlir/include/mlir/ABI/ABIRewriteContext.h
@@ -0,0 +1,159 @@
+//===- ABIRewriteContext.h - Dialect-specific ABI rewriting -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ABIRewriteContext, the abstract interface for dialect-
+// specific ABI lowering rewrites. Each MLIR dialect that wants ABI lowering
+// (CIR, FIR, etc.) provides a concrete subclass.
+//
+// ABIRewriteContext consumes ABI classification results and drives the
+// creation of lowered function signatures, argument coercions, and call
+// site rewrites using dialect-specific operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_ABIREWRITECONTEXT_H
+#define MLIR_ABI_ABIREWRITECONTEXT_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/Support/Alignment.h"
+
+namespace mlir {
+namespace abi {
+
+/// Classification of how a single argument or return value should be
+/// passed at the ABI level.
+///
+/// This is a dialect-agnostic representation. It mirrors the kinds
+/// found in the LLVM ABI library and in CIR's ABIArgInfo, but does
+/// not depend on either.
+enum class ArgKind : uint8_t {
+ /// Pass directly in registers, possibly coerced to a different type.
+ Direct,
+
+ /// Like Direct, but with a sign/zero extension attribute.
+ Extend,
+
+ /// Pass indirectly via a pointer (sret for returns, byval for args).
+ Indirect,
+
+ /// Ignore (void return, empty struct).
+ Ignore,
+
+ /// Expand an aggregate into its constituent scalar fields.
+ Expand,
+};
+
+/// Describes how a single argument or return value is passed after ABI
+/// lowering.
+struct ArgClassification {
+ ArgKind Kind = ArgKind::Direct;
+
+ /// The ABI-coerced type, if different from the original. Null means
+ /// use the original type.
+ Type CoercedType = nullptr;
+
+ /// For Indirect: alignment of the pointed-to object.
+ llvm::Align IndirectAlign = llvm::Align(1);
+
+ /// For Extend: whether to sign-extend (true) or zero-extend (false).
+ bool SignExtend = false;
+
+ /// For Direct: whether a struct coercion can be flattened into
+ /// individual register-width arguments.
+ bool CanFlatten = true;
+
+ /// For Indirect: whether the callee gets ownership (byval).
+ bool ByVal = false;
+
+ static ArgClassification getDirect(Type coerced = nullptr) {
+ ArgClassification c;
+ c.Kind = ArgKind::Direct;
+ c.CoercedType = coerced;
+ return c;
+ }
+
+ static ArgClassification getIgnore() {
+ ArgClassification c;
+ c.Kind = ArgKind::Ignore;
+ return c;
+ }
+
+ static ArgClassification getIndirect(llvm::Align align, bool byVal = true) {
+ ArgClassification c;
+ c.Kind = ArgKind::Indirect;
+ c.IndirectAlign = align;
+ c.ByVal = byVal;
+ return c;
+ }
+
+ static ArgClassification getExtend(Type coerced, bool signExt) {
+ ArgClassification c;
+ c.Kind = ArgKind::Extend;
+ c.CoercedType = coerced;
+ c.SignExtend = signExt;
+ return c;
+ }
+};
+
+/// Holds the full ABI classification for a function: return type and
+/// all arguments.
+struct FunctionClassification {
+ ArgClassification ReturnInfo;
+ SmallVector<ArgClassification> ArgInfos;
+};
+
+/// ABIRewriteContext is the abstract interface that each dialect
+/// implements to perform ABI-specific rewrites on its operations.
+///
+/// The pass orchestrator calls these methods after ABI classification
+/// to rewrite function definitions and call sites.
+class ABIRewriteContext {
+public:
+ virtual ~ABIRewriteContext() = default;
+
+ /// Rewrite a function definition to use ABI-lowered types.
+ ///
+ /// This creates a new function with the lowered signature, rewrites
+ /// the function body to adapt between the ABI types and the
+ /// original high-level types, and replaces the original function.
+ ///
+ /// \param funcOp The function to rewrite (via FunctionOpInterface).
+ /// \param fc The ABI classification for this function.
+ /// \param rewriter The pattern rewriter to use for modifications.
+ /// \returns success() if the function was rewritten.
+ virtual LogicalResult
+ rewriteFunctionDefinition(FunctionOpInterface funcOp,
+ const FunctionClassification &fc,
+ OpBuilder &rewriter) = 0;
+
+ /// Rewrite a call operation to match the callee's ABI-lowered
+ /// signature.
+ ///
+ /// This coerces arguments, handles indirect returns (sret), and
+ /// adapts the call result back to the original high-level type.
+ ///
+ /// \param callOp The call operation to rewrite.
+ /// \param fc The ABI classification for the callee.
+ /// \param rewriter The pattern rewriter to use for modifications.
+ /// \returns success() if the call was rewritten.
+ virtual LogicalResult rewriteCallSite(Operation *callOp,
+ const FunctionClassification &fc,
+ OpBuilder &rewriter) = 0;
+
+ /// Return the dialect namespace this context handles (e.g. "cir").
+ virtual StringRef getDialectNamespace() const = 0;
+};
+
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_ABIREWRITECONTEXT_H
diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h b/mlir/include/mlir/ABI/ABITypeMapper.h
new file mode 100644
index 0000000000000..2180c9c8a918d
--- /dev/null
+++ b/mlir/include/mlir/ABI/ABITypeMapper.h
@@ -0,0 +1,66 @@
+//===- ABITypeMapper.h - Map MLIR types to ABI types -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ABITypeMapper, which translates mlir::Type instances into
+// the llvm::abi::Type hierarchy defined in llvm/ABI/Types.h. Dialect-specific
+// types are handled via MLIR's DataLayoutTypeInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_ABITYPEMAPPER_H
+#define MLIR_ABI_ABITYPEMAPPER_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/ABI/Types.h"
+#include "llvm/Support/Allocator.h"
+
+namespace mlir {
+namespace abi {
+
+/// ABITypeMapper translates mlir::Type values into the llvm::abi::Type
+/// hierarchy used by the LLVM ABI Lowering Library.
+///
+/// Standard MLIR types (IntegerType, FloatType, IndexType, VectorType,
+/// MemRefType) are mapped directly. Dialect-specific types are mapped
+/// by querying the MLIR DataLayout for size and alignment.
+///
+/// Callers must supply a DataLayout (typically from the enclosing module)
+/// so the mapper can determine sizes and alignments.
+///
+/// The mapper owns a BumpPtrAllocator; all returned abi::Type pointers
+/// are valid for the lifetime of the mapper.
+class ABITypeMapper {
+public:
+ explicit ABITypeMapper(const DataLayout &dl);
+
+ /// Map an MLIR type to its ABI type representation. Returns nullptr
+ /// if the type cannot be mapped.
+ const llvm::abi::Type *map(mlir::Type type);
+
+ /// Access the underlying TypeBuilder for advanced use.
+ llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; }
+
+private:
+ const llvm::abi::Type *mapIntegerType(mlir::IntegerType type);
+ const llvm::abi::Type *mapFloatType(mlir::FloatType type);
+ const llvm::abi::Type *mapIndexType(mlir::IndexType type);
+ const llvm::abi::Type *mapVectorType(mlir::VectorType type);
+ const llvm::abi::Type *mapMemRefType(mlir::MemRefType type);
+ const llvm::abi::Type *mapNoneType(mlir::NoneType type);
+
+ const DataLayout &DL;
+ llvm::BumpPtrAllocator Allocator;
+ llvm::abi::TypeBuilder Builder;
+};
+
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_ABITYPEMAPPER_H
diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp
new file mode 100644
index 0000000000000..c7a69780bbe64
--- /dev/null
+++ b/mlir/lib/ABI/ABITypeMapper.cpp
@@ -0,0 +1,102 @@
+//===- ABITypeMapper.cpp - Map MLIR types to ABI types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABITypeMapper.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/Alignment.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+
+ABITypeMapper::ABITypeMapper(const DataLayout &dl)
+ : DL(dl), Builder(Allocator) {}
+
+const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
+ if (auto intTy = dyn_cast<mlir::IntegerType>(type))
+ return mapIntegerType(intTy);
+
+ if (auto floatTy = dyn_cast<mlir::FloatType>(type))
+ return mapFloatType(floatTy);
+
+ if (auto indexTy = dyn_cast<mlir::IndexType>(type))
+ return mapIndexType(indexTy);
+
+ if (auto vecTy = dyn_cast<mlir::VectorType>(type))
+ return mapVectorType(vecTy);
+
+ if (auto memRefTy = dyn_cast<mlir::MemRefType>(type))
+ return mapMemRefType(memRefTy);
+
+ if (auto noneTy = dyn_cast<mlir::NoneType>(type))
+ return mapNoneType(noneTy);
+
+ // For dialect-specific types, fall back to DataLayout queries.
+ // The type must implement DataLayoutTypeInterface for this to work.
+ llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ return Builder.getIntegerType(sizeInBits.getFixedValue(),
+ llvm::Align(abiAlign),
+ /*Signed=*/false);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapIntegerType(mlir::IntegerType type) {
+ uint64_t width = type.getWidth();
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ // MLIR signless integers are treated as signed for ABI purposes.
+ // Most C/C++ integer types are signless in MLIR but behave as
+ // signed for ABI classification (sign extension, etc.).
+ bool isSigned = type.isSigned() || type.isSignless();
+ return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) {
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ const llvm::fltSemantics &semantics = type.getFloatSemantics();
+ return Builder.getFloatType(semantics, llvm::Align(abiAlign));
+}
+
+const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) {
+ llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ return Builder.getIntegerType(sizeInBits.getFixedValue(),
+ llvm::Align(abiAlign),
+ /*Signed=*/false);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapVectorType(mlir::VectorType type) {
+ const llvm::abi::Type *elementTy = map(type.getElementType());
+ if (!elementTy)
+ return nullptr;
+
+ auto shape = type.getShape();
+ // MLIR VectorType is always fixed-length and can be multi-dimensional.
+ // Flatten to a single dimension for ABI purposes.
+ uint64_t totalElements = 1;
+ for (int64_t dim : shape)
+ totalElements *= dim;
+
+ llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
+}
+
+const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) {
+ // MemRef is pointer-like for ABI purposes.
+ llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ unsigned addrSpace = 0;
+ if (auto as = type.getMemorySpace())
+ if (auto intAttr = dyn_cast<IntegerAttr>(as))
+ addrSpace = intAttr.getInt();
+ return Builder.getPointerType(sizeInBits.getFixedValue(),
+ llvm::Align(abiAlign), addrSpace);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) {
+ return Builder.getVoidType();
+}
diff --git a/mlir/lib/ABI/CMakeLists.txt b/mlir/lib/ABI/CMakeLists.txt
new file mode 100644
index 0000000000000..eb434d25dd390
--- /dev/null
+++ b/mlir/lib/ABI/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_library(MLIRABI
+ ABITypeMapper.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/ABI
+
+ LINK_COMPONENTS
+ ABI
+ Support
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRDataLayoutInterfaces
+ )
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index 91ed05f6548d7..d7a6e28d98586 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -1,6 +1,7 @@
# Enable errors for any global constructors.
add_flag_if_supported("-Werror=global-constructors" WERROR_GLOBAL_CONSTRUCTOR)
+add_subdirectory(ABI)
add_subdirectory(Analysis)
add_subdirectory(AsmParser)
add_subdirectory(Bytecode)
diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
new file mode 100644
index 0000000000000..04c28991cc752
--- /dev/null
+++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
@@ -0,0 +1,99 @@
+//===- ABIRewriteContextTest.cpp - Unit tests for ABIRewriteContext -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class MockRewriteContext : public ABIRewriteContext {
+public:
+ LogicalResult rewriteFunctionDefinition(FunctionOpInterface,
+ const FunctionClassification &,
+ OpBuilder &) override {
+ return success();
+ }
+
+ LogicalResult rewriteCallSite(Operation *, const FunctionClassification &,
+ OpBuilder &) override {
+ return success();
+ }
+
+ StringRef getDialectNamespace() const override { return "mock"; }
+};
+
+TEST(ABIRewriteContextTest, MockCanBeConstructedAndDestroyed) {
+ MockRewriteContext ctx;
+ EXPECT_EQ(ctx.getDialectNamespace(), "mock");
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationDirect) {
+ auto c = ArgClassification::getDirect();
+ EXPECT_EQ(c.Kind, ArgKind::Direct);
+ EXPECT_EQ(c.CoercedType, nullptr);
+ EXPECT_TRUE(c.CanFlatten);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) {
+ MLIRContext mlirCtx;
+ auto i32 = IntegerType::get(&mlirCtx, 32);
+ auto c = ArgClassification::getDirect(i32);
+ EXPECT_EQ(c.Kind, ArgKind::Direct);
+ EXPECT_EQ(c.CoercedType, i32);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIgnore) {
+ auto c = ArgClassification::getIgnore();
+ EXPECT_EQ(c.Kind, ArgKind::Ignore);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIndirect) {
+ auto c = ArgClassification::getIndirect(llvm::Align(8), true);
+ EXPECT_EQ(c.Kind, ArgKind::Indirect);
+ EXPECT_EQ(c.IndirectAlign, llvm::Align(8));
+ EXPECT_TRUE(c.ByVal);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) {
+ auto c = ArgClassification::getIndirect(llvm::Align(16), false);
+ EXPECT_EQ(c.Kind, ArgKind::Indirect);
+ EXPECT_EQ(c.IndirectAlign, llvm::Align(16));
+ EXPECT_FALSE(c.ByVal);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationExtend) {
+ MLIRContext mlirCtx;
+ auto i8 = IntegerType::get(&mlirCtx, 8);
+
+ auto signExt = ArgClassification::getExtend(i8, true);
+ EXPECT_EQ(signExt.Kind, ArgKind::Extend);
+ EXPECT_TRUE(signExt.SignExtend);
+
+ auto zeroExt = ArgClassification::getExtend(i8, false);
+ EXPECT_EQ(zeroExt.Kind, ArgKind::Extend);
+ EXPECT_FALSE(zeroExt.SignExtend);
+}
+
+TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) {
+ FunctionClassification fc;
+ fc.ReturnInfo = ArgClassification::getDirect();
+ fc.ArgInfos.push_back(ArgClassification::getDirect());
+ fc.ArgInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true));
+ fc.ArgInfos.push_back(ArgClassification::getIgnore());
+
+ EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct);
+ EXPECT_EQ(fc.ArgInfos.size(), 3u);
+ EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct);
+ EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect);
+ EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore);
+}
+
+} // namespace
diff --git a/mlir/unittests/ABI/ABITypeMapperTest.cpp b/mlir/unittests/ABI/ABITypeMapperTest.cpp
new file mode 100644
index 0000000000000..4a7989298a149
--- /dev/null
+++ b/mlir/unittests/ABI/ABITypeMapperTest.cpp
@@ -0,0 +1,173 @@
+//===- ABITypeMapperTest.cpp - Unit tests for ABITypeMapper ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABITypeMapper.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/ABI/Types.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class ABITypeMapperTest : public ::testing::Test {
+protected:
+ void SetUp() override {
+ ctx.loadDialect<DLTIDialect>();
+ module = ModuleOp::create(UnknownLoc::get(&ctx));
+ }
+
+ void TearDown() override { module->destroy(); }
+
+ MLIRContext ctx;
+ ModuleOp module;
+};
+
+TEST_F(ABITypeMapperTest, MapI32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto i32 = IntegerType::get(&ctx, 32);
+ const llvm::abi::Type *result = mapper.map(i32);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isInteger());
+
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 32u);
+}
+
+TEST_F(ABITypeMapperTest, MapI1) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto i1 = IntegerType::get(&ctx, 1);
+ const llvm::abi::Type *result = mapper.map(i1);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isInteger());
+
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 1u);
+}
+
+TEST_F(ABITypeMapperTest, MapI64) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto i64 = IntegerType::get(&ctx, 64);
+ const llvm::abi::Type *result = mapper.map(i64);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isInteger());
+
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 64u);
+}
+
+TEST_F(ABITypeMapperTest, MapF32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f32 = Float32Type::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(f32);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isFloat());
+
+ auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+ EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 32u);
+}
+
+TEST_F(ABITypeMapperTest, MapF64) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f64 = Float64Type::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(f64);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isFloat());
+
+ auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+ EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 64u);
+}
+
+TEST_F(ABITypeMapperTest, MapF16) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f16 = Float16Type::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(f16);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isFloat());
+
+ auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+ EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 16u);
+}
+
+TEST_F(ABITypeMapperTest, MapNoneType) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto none = NoneType::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(none);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isVoid());
+}
+
+TEST_F(ABITypeMapperTest, MapVectorOf4xF32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f32 = Float32Type::get(&ctx);
+ auto vec = VectorType::get({4}, f32);
+ const llvm::abi::Type *result = mapper.map(vec);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isVector());
+
+ auto *vecTy = llvm::cast<llvm::abi::VectorType>(result);
+ EXPECT_EQ(vecTy->getNumElements().getFixedValue(), 4u);
+ EXPECT_TRUE(vecTy->getElementType()->isFloat());
+}
+
+TEST_F(ABITypeMapperTest, MapSignedI32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto si32 = IntegerType::get(&ctx, 32, IntegerType::Signed);
+ const llvm::abi::Type *result = mapper.map(si32);
+
+ ASSERT_NE(result, nullptr);
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_TRUE(intTy->isSigned());
+}
+
+TEST_F(ABITypeMapperTest, MapUnsignedI32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto ui32 = IntegerType::get(&ctx, 32, IntegerType::Unsigned);
+ const llvm::abi::Type *result = mapper.map(ui32);
+
+ ASSERT_NE(result, nullptr);
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_FALSE(intTy->isSigned());
+}
+
+} // namespace
diff --git a/mlir/unittests/ABI/CMakeLists.txt b/mlir/unittests/ABI/CMakeLists.txt
new file mode 100644
index 0000000000000..39f955a8efea6
--- /dev/null
+++ b/mlir/unittests/ABI/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_unittest(MLIRABITests
+ ABIRewriteContextTest.cpp
+ ABITypeMapperTest.cpp
+)
+
+mlir_target_link_libraries(MLIRABITests
+ PRIVATE
+ MLIRABI
+ MLIRDataLayoutInterfaces
+ MLIRDLTIDialect
+ MLIRIR
+)
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 89332bce5fe05..654ec44d90b04 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
add_unittest(MLIRUnitTests ${test_dirname} ${ARGN})
endfunction()
+add_subdirectory(ABI)
add_subdirectory(Analysis)
add_subdirectory(Bytecode)
add_subdirectory(Conversion)
>From ca2c452c7fd71bda991e72c00d7248f17e3a7a62 Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 13 Apr 2026 15:07:34 -0700
Subject: [PATCH 2/3] [mlir][ABI][NFC] Rename member variables to camelCase
Rename struct and class members from PascalCase to camelCase
per the MLIR style guide. Applies to ArgClassification,
FunctionClassification, and ABITypeMapper members.
Addresses review feedback from @andykaylor.
Made-with: Cursor
---
mlir/include/mlir/ABI/ABIRewriteContext.h | 34 ++++++-------
mlir/include/mlir/ABI/ABITypeMapper.h | 8 +--
mlir/lib/ABI/ABITypeMapper.cpp | 42 +++++++---------
mlir/unittests/ABI/ABIRewriteContextTest.cpp | 52 ++++++++++----------
4 files changed, 65 insertions(+), 71 deletions(-)
diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h b/mlir/include/mlir/ABI/ABIRewriteContext.h
index 71d5a56b599b7..7c48c626207bb 100644
--- a/mlir/include/mlir/ABI/ABIRewriteContext.h
+++ b/mlir/include/mlir/ABI/ABIRewriteContext.h
@@ -55,51 +55,51 @@ enum class ArgKind : uint8_t {
/// Describes how a single argument or return value is passed after ABI
/// lowering.
struct ArgClassification {
- ArgKind Kind = ArgKind::Direct;
+ ArgKind kind = ArgKind::Direct;
/// The ABI-coerced type, if different from the original. Null means
/// use the original type.
- Type CoercedType = nullptr;
+ Type coercedType = nullptr;
/// For Indirect: alignment of the pointed-to object.
- llvm::Align IndirectAlign = llvm::Align(1);
+ llvm::Align indirectAlign = llvm::Align(1);
/// For Extend: whether to sign-extend (true) or zero-extend (false).
- bool SignExtend = false;
+ bool signExtend = false;
/// For Direct: whether a struct coercion can be flattened into
/// individual register-width arguments.
- bool CanFlatten = true;
+ bool canFlatten = true;
/// For Indirect: whether the callee gets ownership (byval).
- bool ByVal = false;
+ bool byVal = false;
static ArgClassification getDirect(Type coerced = nullptr) {
ArgClassification c;
- c.Kind = ArgKind::Direct;
- c.CoercedType = coerced;
+ c.kind = ArgKind::Direct;
+ c.coercedType = coerced;
return c;
}
static ArgClassification getIgnore() {
ArgClassification c;
- c.Kind = ArgKind::Ignore;
+ c.kind = ArgKind::Ignore;
return c;
}
static ArgClassification getIndirect(llvm::Align align, bool byVal = true) {
ArgClassification c;
- c.Kind = ArgKind::Indirect;
- c.IndirectAlign = align;
- c.ByVal = byVal;
+ c.kind = ArgKind::Indirect;
+ c.indirectAlign = align;
+ c.byVal = byVal;
return c;
}
static ArgClassification getExtend(Type coerced, bool signExt) {
ArgClassification c;
- c.Kind = ArgKind::Extend;
- c.CoercedType = coerced;
- c.SignExtend = signExt;
+ c.kind = ArgKind::Extend;
+ c.coercedType = coerced;
+ c.signExtend = signExt;
return c;
}
};
@@ -107,8 +107,8 @@ struct ArgClassification {
/// Holds the full ABI classification for a function: return type and
/// all arguments.
struct FunctionClassification {
- ArgClassification ReturnInfo;
- SmallVector<ArgClassification> ArgInfos;
+ ArgClassification returnInfo;
+ SmallVector<ArgClassification> argInfos;
};
/// ABIRewriteContext is the abstract interface that each dialect
diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h b/mlir/include/mlir/ABI/ABITypeMapper.h
index 2180c9c8a918d..2499e910ca797 100644
--- a/mlir/include/mlir/ABI/ABITypeMapper.h
+++ b/mlir/include/mlir/ABI/ABITypeMapper.h
@@ -45,7 +45,7 @@ class ABITypeMapper {
const llvm::abi::Type *map(mlir::Type type);
/// Access the underlying TypeBuilder for advanced use.
- llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; }
+ llvm::abi::TypeBuilder &getTypeBuilder() { return builder; }
private:
const llvm::abi::Type *mapIntegerType(mlir::IntegerType type);
@@ -55,9 +55,9 @@ class ABITypeMapper {
const llvm::abi::Type *mapMemRefType(mlir::MemRefType type);
const llvm::abi::Type *mapNoneType(mlir::NoneType type);
- const DataLayout &DL;
- llvm::BumpPtrAllocator Allocator;
- llvm::abi::TypeBuilder Builder;
+ const DataLayout &dl;
+ llvm::BumpPtrAllocator allocator;
+ llvm::abi::TypeBuilder builder;
};
} // namespace abi
diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp
index c7a69780bbe64..83dc6990ec5bb 100644
--- a/mlir/lib/ABI/ABITypeMapper.cpp
+++ b/mlir/lib/ABI/ABITypeMapper.cpp
@@ -13,8 +13,8 @@
using namespace mlir;
using namespace mlir::abi;
-ABITypeMapper::ABITypeMapper(const DataLayout &dl)
- : DL(dl), Builder(Allocator) {}
+ABITypeMapper::ABITypeMapper(const DataLayout &dataLayout)
+ : dl(dataLayout), builder(allocator) {}
const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
if (auto intTy = dyn_cast<mlir::IntegerType>(type))
@@ -37,33 +37,30 @@ const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
// For dialect-specific types, fall back to DataLayout queries.
// The type must implement DataLayoutTypeInterface for this to work.
- llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
- uint64_t abiAlign = DL.getTypeABIAlignment(type);
- return Builder.getIntegerType(sizeInBits.getFixedValue(),
+ llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+ uint64_t abiAlign = dl.getTypeABIAlignment(type);
+ return builder.getIntegerType(sizeInBits.getFixedValue(),
llvm::Align(abiAlign),
/*Signed=*/false);
}
const llvm::abi::Type *ABITypeMapper::mapIntegerType(mlir::IntegerType type) {
uint64_t width = type.getWidth();
- uint64_t abiAlign = DL.getTypeABIAlignment(type);
- // MLIR signless integers are treated as signed for ABI purposes.
- // Most C/C++ integer types are signless in MLIR but behave as
- // signed for ABI classification (sign extension, etc.).
+ uint64_t abiAlign = dl.getTypeABIAlignment(type);
bool isSigned = type.isSigned() || type.isSignless();
- return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
+ return builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
}
const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) {
- uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ uint64_t abiAlign = dl.getTypeABIAlignment(type);
const llvm::fltSemantics &semantics = type.getFloatSemantics();
- return Builder.getFloatType(semantics, llvm::Align(abiAlign));
+ return builder.getFloatType(semantics, llvm::Align(abiAlign));
}
const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) {
- llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
- uint64_t abiAlign = DL.getTypeABIAlignment(type);
- return Builder.getIntegerType(sizeInBits.getFixedValue(),
+ llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+ uint64_t abiAlign = dl.getTypeABIAlignment(type);
+ return builder.getIntegerType(sizeInBits.getFixedValue(),
llvm::Align(abiAlign),
/*Signed=*/false);
}
@@ -74,29 +71,26 @@ const llvm::abi::Type *ABITypeMapper::mapVectorType(mlir::VectorType type) {
return nullptr;
auto shape = type.getShape();
- // MLIR VectorType is always fixed-length and can be multi-dimensional.
- // Flatten to a single dimension for ABI purposes.
uint64_t totalElements = 1;
for (int64_t dim : shape)
totalElements *= dim;
llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements);
- uint64_t abiAlign = DL.getTypeABIAlignment(type);
- return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
+ uint64_t abiAlign = dl.getTypeABIAlignment(type);
+ return builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
}
const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) {
- // MemRef is pointer-like for ABI purposes.
- llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
- uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ llvm::TypeSize sizeInBits = dl.getTypeSizeInBits(type);
+ uint64_t abiAlign = dl.getTypeABIAlignment(type);
unsigned addrSpace = 0;
if (auto as = type.getMemorySpace())
if (auto intAttr = dyn_cast<IntegerAttr>(as))
addrSpace = intAttr.getInt();
- return Builder.getPointerType(sizeInBits.getFixedValue(),
+ return builder.getPointerType(sizeInBits.getFixedValue(),
llvm::Align(abiAlign), addrSpace);
}
const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) {
- return Builder.getVoidType();
+ return builder.getVoidType();
}
diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
index 04c28991cc752..59a5307225e0d 100644
--- a/mlir/unittests/ABI/ABIRewriteContextTest.cpp
+++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
@@ -37,36 +37,36 @@ TEST(ABIRewriteContextTest, MockCanBeConstructedAndDestroyed) {
TEST(ABIRewriteContextTest, ArgClassificationDirect) {
auto c = ArgClassification::getDirect();
- EXPECT_EQ(c.Kind, ArgKind::Direct);
- EXPECT_EQ(c.CoercedType, nullptr);
- EXPECT_TRUE(c.CanFlatten);
+ EXPECT_EQ(c.kind, ArgKind::Direct);
+ EXPECT_EQ(c.coercedType, nullptr);
+ EXPECT_TRUE(c.canFlatten);
}
TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) {
MLIRContext mlirCtx;
auto i32 = IntegerType::get(&mlirCtx, 32);
auto c = ArgClassification::getDirect(i32);
- EXPECT_EQ(c.Kind, ArgKind::Direct);
- EXPECT_EQ(c.CoercedType, i32);
+ EXPECT_EQ(c.kind, ArgKind::Direct);
+ EXPECT_EQ(c.coercedType, i32);
}
TEST(ABIRewriteContextTest, ArgClassificationIgnore) {
auto c = ArgClassification::getIgnore();
- EXPECT_EQ(c.Kind, ArgKind::Ignore);
+ EXPECT_EQ(c.kind, ArgKind::Ignore);
}
TEST(ABIRewriteContextTest, ArgClassificationIndirect) {
auto c = ArgClassification::getIndirect(llvm::Align(8), true);
- EXPECT_EQ(c.Kind, ArgKind::Indirect);
- EXPECT_EQ(c.IndirectAlign, llvm::Align(8));
- EXPECT_TRUE(c.ByVal);
+ EXPECT_EQ(c.kind, ArgKind::Indirect);
+ EXPECT_EQ(c.indirectAlign, llvm::Align(8));
+ EXPECT_TRUE(c.byVal);
}
TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) {
auto c = ArgClassification::getIndirect(llvm::Align(16), false);
- EXPECT_EQ(c.Kind, ArgKind::Indirect);
- EXPECT_EQ(c.IndirectAlign, llvm::Align(16));
- EXPECT_FALSE(c.ByVal);
+ EXPECT_EQ(c.kind, ArgKind::Indirect);
+ EXPECT_EQ(c.indirectAlign, llvm::Align(16));
+ EXPECT_FALSE(c.byVal);
}
TEST(ABIRewriteContextTest, ArgClassificationExtend) {
@@ -74,26 +74,26 @@ TEST(ABIRewriteContextTest, ArgClassificationExtend) {
auto i8 = IntegerType::get(&mlirCtx, 8);
auto signExt = ArgClassification::getExtend(i8, true);
- EXPECT_EQ(signExt.Kind, ArgKind::Extend);
- EXPECT_TRUE(signExt.SignExtend);
+ EXPECT_EQ(signExt.kind, ArgKind::Extend);
+ EXPECT_TRUE(signExt.signExtend);
auto zeroExt = ArgClassification::getExtend(i8, false);
- EXPECT_EQ(zeroExt.Kind, ArgKind::Extend);
- EXPECT_FALSE(zeroExt.SignExtend);
+ EXPECT_EQ(zeroExt.kind, ArgKind::Extend);
+ EXPECT_FALSE(zeroExt.signExtend);
}
TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) {
FunctionClassification fc;
- fc.ReturnInfo = ArgClassification::getDirect();
- fc.ArgInfos.push_back(ArgClassification::getDirect());
- fc.ArgInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true));
- fc.ArgInfos.push_back(ArgClassification::getIgnore());
-
- EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct);
- EXPECT_EQ(fc.ArgInfos.size(), 3u);
- EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct);
- EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect);
- EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore);
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getDirect());
+ fc.argInfos.push_back(ArgClassification::getIndirect(llvm::Align(8), true));
+ fc.argInfos.push_back(ArgClassification::getIgnore());
+
+ EXPECT_EQ(fc.returnInfo.kind, ArgKind::Direct);
+ EXPECT_EQ(fc.argInfos.size(), 3u);
+ EXPECT_EQ(fc.argInfos[0].kind, ArgKind::Direct);
+ EXPECT_EQ(fc.argInfos[1].kind, ArgKind::Indirect);
+ EXPECT_EQ(fc.argInfos[2].kind, ArgKind::Ignore);
}
} // namespace
>From 8ce99db064c122d3b5026e540d4345bb8ea932b6 Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Tue, 14 Apr 2026 12:50:29 -0700
Subject: [PATCH 3/3] [CIR] Add CIRABIRewriteContext for ABI function/call
rewriting
Add CIRABIRewriteContext, the CIR dialect's concrete implementation
of the shared ABIRewriteContext interface from #190661. This class
rewrites CIR FuncOps and CallOps to match ABI-lowered signatures.
This initial PR covers Direct, Extend, and Ignore argument/return
kinds with 11 unit tests. Struct coercion (sret, byval, multi-
register flattening) will follow in a subsequent PR.
Depends on #190661.
Made-with: Cursor
---
.../TargetLowering/CIRABIRewriteContext.cpp | 469 ++++++++++++++++++
.../TargetLowering/CIRABIRewriteContext.h | 50 ++
.../Transforms/TargetLowering/CMakeLists.txt | 2 +
.../CIR/CIRABIRewriteContextTest.cpp | 406 +++++++++++++++
clang/unittests/CIR/CMakeLists.txt | 5 +
5 files changed, 932 insertions(+)
create mode 100644 clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
create mode 100644 clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
create mode 100644 clang/unittests/CIR/CIRABIRewriteContextTest.cpp
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
new file mode 100644
index 0000000000000..cab6faf44eda4
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -0,0 +1,469 @@
+//===- CIRABIRewriteContext.cpp - CIR-specific ABI rewriting --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CIRABIRewriteContext.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Types.h"
+#include "clang/CIR/Dialect/IR/CIRAttrs.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
+
+using namespace cir;
+using namespace mlir;
+using namespace mlir::abi;
+
+/// Emit a value coercion between two types. For scalar-to-scalar
+/// (e.g. integer sign extension), a direct cir.cast is sufficient.
+/// When one of the types is a record (struct), LLVM IR's bitcast
+/// cannot reinterpret between aggregate and scalar types, so we go
+/// through memory: alloca srcTy -> store src -> bitcast ptr -> load
+/// dstTy.
+static Value emitCoercion(OpBuilder &rewriter, Location loc, Type dstTy,
+ Value src) {
+ Type srcTy = src.getType();
+ if (srcTy == dstTy)
+ return src;
+
+ bool needsMemory =
+ mlir::isa<cir::RecordType, cir::ComplexType>(srcTy) ||
+ mlir::isa<cir::RecordType, cir::ComplexType>(dstTy) ||
+ (mlir::isa<cir::VectorType>(srcTy) != mlir::isa<cir::VectorType>(dstTy));
+
+ if (!needsMemory)
+ return cir::CastOp::create(rewriter, loc, dstTy, cir::CastKind::bitcast,
+ src);
+
+ auto srcPtrTy = cir::PointerType::get(srcTy);
+ auto dstPtrTy = cir::PointerType::get(dstTy);
+
+ auto alloca =
+ cir::AllocaOp::create(rewriter, loc, srcPtrTy, srcTy,
+ /*name=*/rewriter.getStringAttr("coerce"),
+ /*alignment=*/rewriter.getI64IntegerAttr(8));
+
+ cir::StoreOp::create(rewriter, loc, src, alloca,
+ /*isVolatile=*/mlir::UnitAttr(),
+ /*alignment=*/mlir::IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+
+ auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy,
+ cir::CastKind::bitcast, alloca);
+
+ return cir::LoadOp::create(rewriter, loc, dstTy, ptrCast,
+ /*isDeref=*/mlir::UnitAttr(),
+ /*isVolatile=*/mlir::UnitAttr(),
+ /*alignment=*/mlir::IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+}
+
+/// Insert coercion before each cir.return to coerce the return value
+/// from the original type to the ABI type.
+static void insertReturnCoercion(FunctionOpInterface funcOp, Type origRetTy,
+ Type coercedRetTy, OpBuilder &rewriter) {
+ SmallVector<cir::ReturnOp> returnOps;
+ funcOp->walk([&](cir::ReturnOp retOp) { returnOps.push_back(retOp); });
+
+ for (cir::ReturnOp retOp : returnOps) {
+ if (retOp.getInput().empty())
+ continue;
+
+ Value origVal = retOp.getInput()[0];
+ if (origVal.getType() == coercedRetTy)
+ continue;
+
+ rewriter.setInsertionPoint(retOp);
+ Value coerced =
+ emitCoercion(rewriter, retOp.getLoc(), coercedRetTy, origVal);
+ retOp->setOperand(0, coerced);
+ }
+}
+
+/// For each argument that requires ABI coercion (Extend or Direct
+/// with a coerced type), insert a cast at the function entry and
+/// replace all uses of the block argument with the cast result.
+static void insertArgAdaptation(FunctionOpInterface funcOp,
+ const FunctionClassification &fc,
+ OpBuilder &rewriter) {
+ Region &body = funcOp->getRegion(0);
+ if (body.empty())
+ return;
+
+ Block &entryBlock = body.front();
+ Operation *lastInserted = nullptr;
+
+ for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+ if (!argClass.coercedType)
+ continue;
+
+ if (argClass.kind != ArgKind::Extend && argClass.kind != ArgKind::Direct)
+ continue;
+
+ BlockArgument blockArg = entryBlock.getArgument(idx);
+ Type oldArgTy = blockArg.getType();
+ Type newArgTy = argClass.coercedType;
+
+ if (oldArgTy == newArgTy)
+ continue;
+
+ blockArg.setType(newArgTy);
+
+ if (lastInserted)
+ rewriter.setInsertionPointAfter(lastInserted);
+ else
+ rewriter.setInsertionPointToStart(&entryBlock);
+
+ Value adapted;
+ SmallPtrSet<Operation *, 4> coercionOps;
+
+ if (argClass.kind == ArgKind::Extend) {
+ auto cast = cir::CastOp::create(rewriter, funcOp.getLoc(), oldArgTy,
+ cir::CastKind::integral, blockArg);
+ adapted = cast;
+ coercionOps.insert(cast.getOperation());
+ } else {
+ auto srcPtrTy = cir::PointerType::get(newArgTy);
+ auto dstPtrTy = cir::PointerType::get(oldArgTy);
+ Location loc = funcOp.getLoc();
+
+ auto alloca =
+ cir::AllocaOp::create(rewriter, loc, srcPtrTy, newArgTy,
+ /*name=*/rewriter.getStringAttr("coerce"),
+ /*alignment=*/rewriter.getI64IntegerAttr(8));
+
+ auto store = cir::StoreOp::create(rewriter, loc, blockArg, alloca,
+ /*isVolatile=*/mlir::UnitAttr(),
+ /*alignment=*/mlir::IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+
+ auto ptrCast = cir::CastOp::create(rewriter, loc, dstPtrTy,
+ cir::CastKind::bitcast, alloca);
+
+ auto load = cir::LoadOp::create(rewriter, loc, oldArgTy, ptrCast,
+ /*isDeref=*/mlir::UnitAttr(),
+ /*isVolatile=*/mlir::UnitAttr(),
+ /*alignment=*/mlir::IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+
+ adapted = load;
+ coercionOps.insert(alloca.getOperation());
+ coercionOps.insert(store.getOperation());
+ coercionOps.insert(ptrCast.getOperation());
+ coercionOps.insert(load.getOperation());
+ }
+ lastInserted = adapted.getDefiningOp();
+
+ blockArg.replaceAllUsesExcept(adapted, coercionOps);
+ }
+}
+
+LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
+ FunctionOpInterface funcOp, const FunctionClassification &fc,
+ OpBuilder &rewriter) {
+ ArrayRef<Type> oldArgTypes = funcOp.getArgumentTypes();
+ ArrayRef<Type> oldResultTypes = funcOp.getResultTypes();
+ bool isDecl = funcOp.isDeclaration();
+
+ bool returnCoerced = false;
+ bool hasArgChanges = false;
+ SmallVector<unsigned> ignoredArgIndices;
+
+ // Compute new argument types.
+ SmallVector<Type> newArgTypes;
+
+ for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+ Type origTy = oldArgTypes[idx];
+ switch (argClass.kind) {
+ case ArgKind::Direct:
+ case ArgKind::Extend:
+ newArgTypes.push_back(argClass.coercedType ? argClass.coercedType
+ : origTy);
+ if (argClass.coercedType && argClass.coercedType != origTy)
+ hasArgChanges = true;
+ break;
+ case ArgKind::Ignore:
+ ignoredArgIndices.push_back(idx);
+ hasArgChanges = true;
+ break;
+ case ArgKind::Indirect:
+ case ArgKind::Expand:
+ newArgTypes.push_back(origTy);
+ break;
+ }
+ }
+
+ // Compute new result type. CIR's FuncType::clone expects exactly
+ // one result type (VoidType for void-returning functions).
+ auto voidTy = cir::VoidType::get(funcOp->getContext());
+ Type origRetTy = oldResultTypes.empty() ? voidTy : oldResultTypes[0];
+ Type newRetTy = origRetTy;
+
+ if (fc.returnInfo.kind == ArgKind::Direct ||
+ fc.returnInfo.kind == ArgKind::Extend) {
+ if (fc.returnInfo.coercedType && !oldResultTypes.empty() &&
+ fc.returnInfo.coercedType != oldResultTypes[0]) {
+ newRetTy = fc.returnInfo.coercedType;
+ returnCoerced = true;
+ }
+ } else if (fc.returnInfo.kind == ArgKind::Ignore) {
+ newRetTy = voidTy;
+ }
+
+ SmallVector<Type> newResultTypes = {newRetTy};
+
+ // If nothing changed, skip the rewrite -- unless we have
+ // Extend args/returns that need signext/zeroext attrs.
+ bool hasExtend = fc.returnInfo.kind == ArgKind::Extend;
+ for (auto &argClass : fc.argInfos)
+ if (argClass.kind == ArgKind::Extend)
+ hasExtend = true;
+ if (!hasArgChanges && !hasExtend && !returnCoerced && newRetTy == origRetTy &&
+ newArgTypes == SmallVector<Type>(oldArgTypes))
+ return success();
+
+ // Body modifications only apply to definitions.
+ if (!isDecl) {
+ if (hasArgChanges)
+ insertArgAdaptation(funcOp, fc, rewriter);
+
+ // Erase block arguments for Ignore'd args (in reverse to keep
+ // indices valid). Replace any remaining uses with undef first.
+ if (!ignoredArgIndices.empty()) {
+ Region &body = funcOp->getRegion(0);
+ if (!body.empty()) {
+ Block &entry = body.front();
+ for (int i = ignoredArgIndices.size() - 1; i >= 0; --i) {
+ unsigned blockIdx = ignoredArgIndices[i];
+ if (blockIdx < entry.getNumArguments()) {
+ BlockArgument arg = entry.getArgument(blockIdx);
+ if (!arg.use_empty()) {
+ rewriter.setInsertionPointToStart(&entry);
+ auto ptrTy = cir::PointerType::get(arg.getType());
+ auto alloca = cir::AllocaOp::create(
+ rewriter, funcOp.getLoc(), ptrTy, arg.getType(),
+ /*name=*/rewriter.getStringAttr("ignored"),
+ /*alignment=*/rewriter.getI64IntegerAttr(1));
+ auto load = cir::LoadOp::create(
+ rewriter, funcOp.getLoc(), arg.getType(), alloca,
+ /*isDeref=*/mlir::UnitAttr(),
+ /*isVolatile=*/mlir::UnitAttr(),
+ /*alignment=*/mlir::IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+ arg.replaceAllUsesWith(load);
+ }
+ entry.eraseArgument(blockIdx);
+ }
+ }
+ }
+ }
+
+ if (returnCoerced)
+ insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
+ rewriter);
+
+ // When the return type is Ignore (empty struct), rewrite all
+ // return ops to drop their operand so they return void.
+ if (fc.returnInfo.kind == ArgKind::Ignore && !oldResultTypes.empty()) {
+ funcOp.walk([&](cir::ReturnOp retOp) {
+ if (retOp.getNumOperands() > 0) {
+ rewriter.setInsertionPoint(retOp);
+ cir::ReturnOp::create(rewriter, retOp.getLoc());
+ retOp->erase();
+ }
+ });
+ }
+ }
+
+ Type newFnTy = funcOp.cloneTypeWith(newArgTypes, newResultTypes);
+ funcOp.setFunctionTypeAttr(TypeAttr::get(newFnTy));
+
+ // Attach signext/zeroext attributes for Extend args and returns.
+ {
+ MLIRContext *ctx = funcOp->getContext();
+ unsigned numArgs = newArgTypes.size();
+ bool needsArgAttrs = false;
+ bool hasIgnoredArgs = !ignoredArgIndices.empty();
+ for (auto &argClass : fc.argInfos)
+ if (argClass.kind == ArgKind::Extend)
+ needsArgAttrs = true;
+ if (hasIgnoredArgs && funcOp->hasAttr("arg_attrs"))
+ needsArgAttrs = true;
+
+ if (needsArgAttrs) {
+ SmallVector<Attribute> argAttrDicts(numArgs, DictionaryAttr::get(ctx));
+
+ // Preserve existing arg_attrs, skipping Ignore'd args.
+ if (auto existingAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs")) {
+ unsigned newIdx = 0;
+ for (unsigned oldIdx = 0; oldIdx < existingAttrs.size(); ++oldIdx) {
+ if (oldIdx < fc.argInfos.size() &&
+ fc.argInfos[oldIdx].kind == ArgKind::Ignore)
+ continue;
+ if (newIdx < numArgs)
+ argAttrDicts[newIdx] = existingAttrs[oldIdx];
+ ++newIdx;
+ }
+ }
+
+ for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+ if (argClass.kind != ArgKind::Extend)
+ continue;
+ if (idx >= numArgs)
+ continue;
+ auto existing = mlir::cast<DictionaryAttr>(argAttrDicts[idx]);
+ SmallVector<NamedAttribute> attrs(existing.begin(), existing.end());
+ StringRef attrName =
+ argClass.signExtend ? "llvm.signext" : "llvm.zeroext";
+ attrs.push_back(
+ rewriter.getNamedAttr(attrName, rewriter.getUnitAttr()));
+ argAttrDicts[idx] = DictionaryAttr::get(ctx, attrs);
+ }
+
+ funcOp->setAttr("arg_attrs", ArrayAttr::get(ctx, argAttrDicts));
+ }
+
+ // Add signext/zeroext to return value for Extend returns.
+ if (fc.returnInfo.kind == ArgKind::Extend) {
+ SmallVector<NamedAttribute> retAttrs;
+ if (auto existing = funcOp->getAttrOfType<ArrayAttr>("res_attrs"))
+ if (existing.size() > 0)
+ for (auto attr : mlir::cast<DictionaryAttr>(existing[0]))
+ retAttrs.push_back(attr);
+ StringRef attrName =
+ fc.returnInfo.signExtend ? "llvm.signext" : "llvm.zeroext";
+ retAttrs.push_back(
+ rewriter.getNamedAttr(attrName, rewriter.getUnitAttr()));
+ SmallVector<Attribute> resAttrDicts;
+ resAttrDicts.push_back(DictionaryAttr::get(ctx, retAttrs));
+ funcOp->setAttr("res_attrs", ArrayAttr::get(ctx, resAttrDicts));
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CIRABIRewriteContext::rewriteCallSite(
+ Operation *callOp, const FunctionClassification &fc, OpBuilder &rewriter) {
+ auto call = cast<cir::CallOp>(callOp);
+
+ SmallVector<Value> newArgs;
+ bool argsChanged = false;
+ auto argOperands = call.getArgOperands();
+
+ for (auto [idx, argClass] : llvm::enumerate(fc.argInfos)) {
+ if (idx >= argOperands.size())
+ break;
+
+ Value arg = argOperands[idx];
+
+ if (argClass.kind == ArgKind::Ignore) {
+ argsChanged = true;
+ continue;
+ }
+
+ if ((argClass.kind == ArgKind::Extend ||
+ argClass.kind == ArgKind::Direct) &&
+ argClass.coercedType && arg.getType() != argClass.coercedType) {
+ rewriter.setInsertionPoint(call);
+ Value coerced;
+ if (argClass.kind == ArgKind::Extend)
+ coerced =
+ cir::CastOp::create(rewriter, call.getLoc(), argClass.coercedType,
+ cir::CastKind::integral, arg);
+ else
+ coerced =
+ emitCoercion(rewriter, call.getLoc(), argClass.coercedType, arg);
+ newArgs.push_back(coerced);
+ argsChanged = true;
+ } else {
+ newArgs.push_back(arg);
+ }
+ }
+
+ // Pass through any extra operands beyond classified args.
+ for (unsigned i = fc.argInfos.size(); i < argOperands.size(); ++i)
+ newArgs.push_back(argOperands[i]);
+
+ // Handle direct return coercion.
+ bool returnCoerced = false;
+ Type coercedRetTy;
+ if ((fc.returnInfo.kind == ArgKind::Direct ||
+ fc.returnInfo.kind == ArgKind::Extend) &&
+ fc.returnInfo.coercedType) {
+ returnCoerced = true;
+ coercedRetTy = fc.returnInfo.coercedType;
+ }
+
+ // Handle Ignore return: replace with void call.
+ if (fc.returnInfo.kind == ArgKind::Ignore && call.getNumResults() > 0) {
+ rewriter.setInsertionPoint(call);
+ auto voidTy = cir::VoidType::get(call.getContext());
+ auto newCall = cir::CallOp::create(rewriter, call.getLoc(),
+ call.getCalleeAttr(), voidTy, newArgs);
+ for (NamedAttribute attr : call->getAttrs())
+ if (!newCall->hasAttr(attr.getName()))
+ newCall->setAttr(attr.getName(), attr.getValue());
+
+ if (!call.getResult().use_empty()) {
+ rewriter.setInsertionPointAfter(newCall);
+ Type origRetTy = call.getResult().getType();
+ auto ptrTy = cir::PointerType::get(origRetTy);
+ auto alloca =
+ cir::AllocaOp::create(rewriter, call.getLoc(), ptrTy, origRetTy,
+ /*name=*/rewriter.getStringAttr("ignored"),
+ /*alignment=*/rewriter.getI64IntegerAttr(1));
+ auto load =
+ cir::LoadOp::create(rewriter, call.getLoc(), origRetTy, alloca,
+ /*isDeref=*/mlir::UnitAttr(),
+ /*isVolatile=*/mlir::UnitAttr(),
+ /*alignment=*/mlir::IntegerAttr(),
+ /*sync_scope=*/cir::SyncScopeKindAttr(),
+ /*mem_order=*/cir::MemOrderAttr());
+ call.getResult().replaceAllUsesWith(load);
+ }
+ call->erase();
+ return success();
+ }
+
+ if (!returnCoerced && !argsChanged)
+ return success();
+
+ Type callRetTy;
+ Type origRetTy;
+ bool hasResult = call.getNumResults() > 0;
+
+ if (hasResult) {
+ origRetTy = call.getResult().getType();
+ callRetTy = returnCoerced ? coercedRetTy : origRetTy;
+ } else {
+ callRetTy = cir::VoidType::get(call.getContext());
+ }
+
+ rewriter.setInsertionPoint(call);
+ auto newCall = cir::CallOp::create(rewriter, call.getLoc(),
+ call.getCalleeAttr(), callRetTy, newArgs);
+ for (NamedAttribute attr : call->getAttrs())
+ if (!newCall->hasAttr(attr.getName()))
+ newCall->setAttr(attr.getName(), attr.getValue());
+
+ if (hasResult && returnCoerced && origRetTy != coercedRetTy) {
+ rewriter.setInsertionPointAfter(newCall);
+ Value castBack =
+ emitCoercion(rewriter, call.getLoc(), origRetTy, newCall.getResult());
+ call.getResult().replaceAllUsesWith(castBack);
+ } else if (hasResult) {
+ call.getResult().replaceAllUsesWith(newCall.getResult());
+ }
+
+ call->erase();
+ return success();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
new file mode 100644
index 0000000000000..93d968c9db123
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h
@@ -0,0 +1,50 @@
+//===- CIRABIRewriteContext.h - CIR-specific ABI rewriting ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines CIRABIRewriteContext, the CIR dialect's implementation of
+// the shared ABIRewriteContext interface. It rewrites CIR function definitions
+// and call sites to match ABI-lowered signatures.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
+#define CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+
+namespace cir {
+
+/// CIR-specific implementation of the ABIRewriteContext interface.
+///
+/// This class knows how to rewrite CIR FuncOps and CallOps to match
+/// ABI-lowered signatures, using CIR operations for coercion (alloca,
+/// load, store, cast, etc.).
+class CIRABIRewriteContext : public mlir::abi::ABIRewriteContext {
+ mlir::ModuleOp module;
+
+public:
+ explicit CIRABIRewriteContext(mlir::ModuleOp module) : module(module) {}
+
+ mlir::LogicalResult
+ rewriteFunctionDefinition(mlir::FunctionOpInterface funcOp,
+ const mlir::abi::FunctionClassification &fc,
+ mlir::OpBuilder &rewriter) override;
+
+ mlir::LogicalResult
+ rewriteCallSite(mlir::Operation *callOp,
+ const mlir::abi::FunctionClassification &fc,
+ mlir::OpBuilder &rewriter) override;
+
+ mlir::StringRef getDialectNamespace() const override { return "cir"; }
+};
+
+} // namespace cir
+
+#endif // CLANG_LIB_CIR_DIALECT_TRANSFORMS_TARGETLOWERING_CIRABIREWRITECONTEXT_H
diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
index 86502b7f5dd4e..9833952623708 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CMakeLists.txt
@@ -1,4 +1,5 @@
add_clang_library(MLIRCIRTargetLowering
+ CIRABIRewriteContext.cpp
CIRCXXABI.cpp
LowerModule.cpp
LowerItaniumCXXABI.cpp
@@ -15,6 +16,7 @@ add_clang_library(MLIRCIRTargetLowering
LINK_LIBS PUBLIC
clangBasic
+ MLIRABI
MLIRIR
MLIRPass
MLIRDLTIDialect
diff --git a/clang/unittests/CIR/CIRABIRewriteContextTest.cpp b/clang/unittests/CIR/CIRABIRewriteContextTest.cpp
new file mode 100644
index 0000000000000..5cf8637346ae3
--- /dev/null
+++ b/clang/unittests/CIR/CIRABIRewriteContextTest.cpp
@@ -0,0 +1,406 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Unit tests for CIRABIRewriteContext, the CIR dialect's concrete
+// implementation of the shared ABIRewriteContext interface. Each test
+// constructs a FunctionClassification manually (no ABI library needed)
+// and verifies the resulting IR after rewriting.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "clang/CIR/Dialect/IR/CIRAttrs.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/IR/CIRTypes.h"
+#include "gtest/gtest.h"
+
+// The header is private to the Transforms library, so we include it
+// via the path relative to the source tree. The CMakeLists arranges
+// the include directories.
+#include "../../lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class CIRABIRewriteTest : public ::testing::Test {
+protected:
+ CIRABIRewriteTest() : builder(&context), loc(UnknownLoc::get(&context)) {
+ context.loadDialect<cir::CIRDialect>();
+ }
+
+ MLIRContext context;
+ OpBuilder builder;
+ Location loc;
+
+ /// Create a ModuleOp containing a single CIR FuncOp with the given
+ /// argument types and return type. If \p addBody is true, the
+ /// function gets an entry block with a cir.return (returning its
+ /// first result-typed block arg if non-void, or void otherwise).
+ std::pair<ModuleOp, cir::FuncOp> createFunc(StringRef name,
+ ArrayRef<Type> argTypes,
+ Type retType,
+ bool addBody = true) {
+ auto module = ModuleOp::create(loc);
+ builder.setInsertionPointToEnd(module.getBody());
+
+ auto funcTy = cir::FuncType::get(argTypes, retType);
+ auto funcOp = cir::FuncOp::create(builder, loc, name, funcTy);
+
+ if (addBody) {
+ Block *entry = funcOp.addEntryBlock();
+ builder.setInsertionPointToEnd(entry);
+ if (isa<cir::VoidType>(retType))
+ cir::ReturnOp::create(builder, loc);
+ else
+ cir::ReturnOp::create(builder, loc,
+ mlir::ValueRange{entry->getArgument(0)});
+ }
+
+ return {module, funcOp};
+ }
+
+ /// Create a ModuleOp containing a caller function that calls a
+ /// callee. The caller passes its own block arguments to the callee.
+ struct CallFixture {
+ ModuleOp module;
+ cir::FuncOp callee;
+ cir::FuncOp caller;
+ cir::CallOp callOp;
+ };
+
+ CallFixture createCallPair(StringRef calleeName, ArrayRef<Type> argTypes,
+ Type retType) {
+ auto module = ModuleOp::create(loc);
+ builder.setInsertionPointToEnd(module.getBody());
+
+ auto funcTy = cir::FuncType::get(argTypes, retType);
+
+ // Callee (declaration only).
+ auto callee = cir::FuncOp::create(builder, loc, calleeName, funcTy);
+
+ // Caller with a body that calls the callee.
+ auto caller = cir::FuncOp::create(builder, loc, "caller", funcTy);
+ Block *entry = caller.addEntryBlock();
+ builder.setInsertionPointToEnd(entry);
+
+ SmallVector<Value> args;
+ for (unsigned i = 0; i < argTypes.size(); ++i)
+ args.push_back(entry->getArgument(i));
+
+ cir::CallOp call;
+ if (isa<cir::VoidType>(retType)) {
+ auto voidTy = cir::VoidType::get(&context);
+ call = cir::CallOp::create(
+ builder, loc, mlir::FlatSymbolRefAttr::get(&context, calleeName),
+ voidTy, args);
+ cir::ReturnOp::create(builder, loc);
+ } else {
+ call = cir::CallOp::create(
+ builder, loc, mlir::FlatSymbolRefAttr::get(&context, calleeName),
+ retType, args);
+ cir::ReturnOp::create(builder, loc, mlir::ValueRange{call.getResult()});
+ }
+
+ return {module, callee, caller, call};
+ }
+};
+
+// ---- rewriteFunctionDefinition tests ----
+
+TEST_F(CIRABIRewriteTest, DirectPassthrough) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getDirect());
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ EXPECT_EQ(fnTy.getInputs().size(), 1u);
+ EXPECT_EQ(fnTy.getInputs()[0], i32Ty);
+ EXPECT_EQ(fnTy.getReturnType(), i32Ty);
+
+ module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, DirectReturnCoercion) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto i64Ty = cir::IntType::get(&context, 64, false);
+ auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect(i64Ty);
+ fc.argInfos.push_back(ArgClassification::getDirect());
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ EXPECT_EQ(fnTy.getReturnType(), i64Ty);
+
+ module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, ExtendArg) {
+ auto i8Ty = cir::IntType::get(&context, 8, true);
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto voidTy = cir::VoidType::get(&context);
+ auto [module, funcOp] = createFunc("f", {i8Ty}, voidTy);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true));
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ EXPECT_EQ(fnTy.getInputs().size(), 1u);
+ EXPECT_EQ(fnTy.getInputs()[0], i32Ty);
+
+ // Verify signext attribute was attached.
+ auto argAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
+ ASSERT_TRUE(argAttrs != nullptr);
+ ASSERT_EQ(argAttrs.size(), 1u);
+ auto dict = cast<DictionaryAttr>(argAttrs[0]);
+ EXPECT_TRUE(dict.get("llvm.signext") != nullptr);
+
+ // Verify the entry block has a cir.cast (integral) to adapt i32
+ // back to i8 for body uses.
+ Block &entry = funcOp->getRegion(0).front();
+ bool foundCast = false;
+ for (Operation &op : entry) {
+ if (auto cast = dyn_cast<cir::CastOp>(op)) {
+ if (cast.getKind() == cir::CastKind::integral) {
+ EXPECT_EQ(cast.getResult().getType(), i8Ty);
+ foundCast = true;
+ }
+ }
+ }
+ EXPECT_TRUE(foundCast);
+
+ module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, ExtendReturn) {
+ auto i8Ty = cir::IntType::get(&context, 8, true);
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto [module, funcOp] = createFunc("f", {i8Ty}, i8Ty);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getExtend(i32Ty, false);
+ fc.argInfos.push_back(ArgClassification::getDirect());
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ EXPECT_EQ(fnTy.getReturnType(), i32Ty);
+
+ // Verify zeroext attribute on return.
+ auto resAttrs = funcOp->getAttrOfType<ArrayAttr>("res_attrs");
+ ASSERT_TRUE(resAttrs != nullptr);
+ ASSERT_EQ(resAttrs.size(), 1u);
+ auto dict = cast<DictionaryAttr>(resAttrs[0]);
+ EXPECT_TRUE(dict.get("llvm.zeroext") != nullptr);
+
+ module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, IgnoreReturn) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto [module, funcOp] = createFunc("f", {i32Ty}, i32Ty);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getIgnore();
+ fc.argInfos.push_back(ArgClassification::getDirect());
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ auto voidTy = cir::VoidType::get(&context);
+ EXPECT_EQ(fnTy.getReturnType(), voidTy);
+
+ module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, IgnoreArg) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto voidTy = cir::VoidType::get(&context);
+ auto [module, funcOp] = createFunc("f", {i32Ty}, voidTy);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getIgnore());
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ EXPECT_EQ(fnTy.getInputs().size(), 0u);
+
+ Block &entry = funcOp->getRegion(0).front();
+ EXPECT_EQ(entry.getNumArguments(), 0u);
+
+ module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, DeclarationRewrite) {
+ auto i8Ty = cir::IntType::get(&context, 8, true);
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto [module, funcOp] = createFunc("f", {i8Ty}, i8Ty, /*addBody=*/false);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getExtend(i32Ty, true);
+ fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true));
+
+ cir::CIRABIRewriteContext rewriteCtx(module);
+ OpBuilder rewriter(funcOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteFunctionDefinition(funcOp, fc, rewriter)));
+
+ auto fnTy = cast<cir::FuncType>(funcOp.getFunctionType());
+ EXPECT_EQ(fnTy.getInputs()[0], i32Ty);
+ EXPECT_EQ(fnTy.getReturnType(), i32Ty);
+
+ // Verify both signext attributes.
+ auto argAttrs = funcOp->getAttrOfType<ArrayAttr>("arg_attrs");
+ ASSERT_TRUE(argAttrs != nullptr);
+ auto dict = cast<DictionaryAttr>(argAttrs[0]);
+ EXPECT_TRUE(dict.get("llvm.signext") != nullptr);
+
+ auto resAttrs = funcOp->getAttrOfType<ArrayAttr>("res_attrs");
+ ASSERT_TRUE(resAttrs != nullptr);
+ auto rdict = cast<DictionaryAttr>(resAttrs[0]);
+ EXPECT_TRUE(rdict.get("llvm.signext") != nullptr);
+
+ module->erase();
+}
+
+// ---- rewriteCallSite tests ----
+
+TEST_F(CIRABIRewriteTest, CallSiteDirectPassthrough) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto fixture = createCallPair("callee", {i32Ty}, i32Ty);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getDirect());
+
+ cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+ OpBuilder rewriter(fixture.callOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+ // The original call should still be there (no changes needed).
+ EXPECT_EQ(fixture.callOp->getNumResults(), 1u);
+ EXPECT_EQ(fixture.callOp->getResult(0).getType(), i32Ty);
+
+ fixture.module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, CallSiteExtendArg) {
+ auto i8Ty = cir::IntType::get(&context, 8, true);
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto voidTy = cir::VoidType::get(&context);
+ auto fixture = createCallPair("callee", {i8Ty}, voidTy);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getExtend(i32Ty, true));
+
+ cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+ OpBuilder rewriter(fixture.callOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+ // The old call was erased and replaced. Look for a CallOp whose
+ // argument is i32 (the extended type).
+ Block &callerEntry = fixture.caller->getRegion(0).front();
+ cir::CallOp newCall;
+ for (Operation &op : callerEntry)
+ if (auto c = dyn_cast<cir::CallOp>(op))
+ newCall = c;
+ ASSERT_TRUE(newCall != nullptr);
+ EXPECT_EQ(newCall.getArgOperands()[0].getType(), i32Ty);
+
+ fixture.module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, CallSiteIgnoreReturn) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto fixture = createCallPair("callee", {i32Ty}, i32Ty);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getIgnore();
+ fc.argInfos.push_back(ArgClassification::getDirect());
+
+ cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+ OpBuilder rewriter(fixture.callOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+ // Find the replacement void call.
+ Block &callerEntry = fixture.caller->getRegion(0).front();
+ cir::CallOp newCall;
+ for (Operation &op : callerEntry)
+ if (auto c = dyn_cast<cir::CallOp>(op))
+ newCall = c;
+ ASSERT_TRUE(newCall != nullptr);
+ EXPECT_EQ(newCall.getNumResults(), 0u);
+
+ fixture.module->erase();
+}
+
+TEST_F(CIRABIRewriteTest, CallSiteIgnoreArg) {
+ auto i32Ty = cir::IntType::get(&context, 32, true);
+ auto voidTy = cir::VoidType::get(&context);
+ auto fixture = createCallPair("callee", {i32Ty}, voidTy);
+
+ FunctionClassification fc;
+ fc.returnInfo = ArgClassification::getDirect();
+ fc.argInfos.push_back(ArgClassification::getIgnore());
+
+ cir::CIRABIRewriteContext rewriteCtx(fixture.module);
+ OpBuilder rewriter(fixture.callOp);
+ ASSERT_TRUE(
+ succeeded(rewriteCtx.rewriteCallSite(fixture.callOp, fc, rewriter)));
+
+ // Find the replacement call -- it should have zero args.
+ Block &callerEntry = fixture.caller->getRegion(0).front();
+ cir::CallOp newCall;
+ for (Operation &op : callerEntry)
+ if (auto c = dyn_cast<cir::CallOp>(op))
+ newCall = c;
+ ASSERT_TRUE(newCall != nullptr);
+ EXPECT_EQ(newCall.getArgOperands().size(), 0u);
+
+ fixture.module->erase();
+}
+
+} // namespace
diff --git a/clang/unittests/CIR/CMakeLists.txt b/clang/unittests/CIR/CMakeLists.txt
index 650fde38c48a9..d318810d33fe5 100644
--- a/clang/unittests/CIR/CMakeLists.txt
+++ b/clang/unittests/CIR/CMakeLists.txt
@@ -1,15 +1,20 @@
set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include )
set(MLIR_TABLEGEN_OUTPUT_DIR ${CMAKE_BINARY_DIR}/tools/mlir/include)
+set(CLANG_CIR_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../lib/CIR)
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(${MLIR_TABLEGEN_OUTPUT_DIR})
+include_directories(${CLANG_CIR_SRC_DIR})
add_distinct_clang_unittest(CIRUnitTests
+ CIRABIRewriteContextTest.cpp
PointerLikeTest.cpp
LLVM_COMPONENTS
Core
LINK_LIBS
+ MLIRABI
MLIRCIR
+ MLIRCIRTargetLowering
CIROpenACCSupport
MLIRIR
MLIROpenACCDialect
More information about the Mlir-commits
mailing list