[Mlir-commits] [mlir] [mlir][ABI] Add ABITypeMapper and ABIRewriteContext (PR #190661)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 6 12:39:17 PDT 2026
https://github.com/adams381 created https://github.com/llvm/llvm-project/pull/190661
Add ABITypeMapper and ABIRewriteContext as the dialect-agnostic
bridge between MLIR dialects and the LLVM ABI Lowering Library.
ABITypeMapper maps MLIR built-in types (integer, float, vector,
index, memref) to abi::Type* using DataLayout for sizes and
alignment. Dialect-specific types fall back to integer mapping
via DataLayoutTypeInterface.
ABIRewriteContext defines the abstract interface that each dialect
(CIR, FIR) implements to rewrite function definitions and call
sites after ABI classification. See the CIR ABI lowering design
document (clang/docs/ClangIRABILowering.md, Section 4) for the
architectural context.
Unit tests for both components (18 test cases).
>From e0ca38e8d5ecdfe0161c0f73af08d65e264953bf Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 6 Apr 2026 12:14:06 -0700
Subject: [PATCH] [mlir][ABI] Add ABITypeMapper and ABIRewriteContext
Dialect-agnostic layer bridging MLIR and the LLVM ABI library.
ABITypeMapper handles built-in MLIR types; ABIRewriteContext is
the interface dialects implement for ABI rewrites (see
clang/docs/ClangIRABILowering.md Section 4).
Made-with: Cursor
---
mlir/include/mlir/ABI/ABIRewriteContext.h | 160 +++++++++++++++++
mlir/include/mlir/ABI/ABITypeMapper.h | 66 +++++++
mlir/lib/ABI/ABITypeMapper.cpp | 103 +++++++++++
mlir/lib/ABI/CMakeLists.txt | 14 ++
mlir/lib/CMakeLists.txt | 1 +
mlir/unittests/ABI/ABIRewriteContextTest.cpp | 101 +++++++++++
mlir/unittests/ABI/ABITypeMapperTest.cpp | 173 +++++++++++++++++++
mlir/unittests/ABI/CMakeLists.txt | 12 ++
mlir/unittests/CMakeLists.txt | 1 +
9 files changed, 631 insertions(+)
create mode 100644 mlir/include/mlir/ABI/ABIRewriteContext.h
create mode 100644 mlir/include/mlir/ABI/ABITypeMapper.h
create mode 100644 mlir/lib/ABI/ABITypeMapper.cpp
create mode 100644 mlir/lib/ABI/CMakeLists.txt
create mode 100644 mlir/unittests/ABI/ABIRewriteContextTest.cpp
create mode 100644 mlir/unittests/ABI/ABITypeMapperTest.cpp
create mode 100644 mlir/unittests/ABI/CMakeLists.txt
diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h b/mlir/include/mlir/ABI/ABIRewriteContext.h
new file mode 100644
index 0000000000000..ff589d0d0bd1b
--- /dev/null
+++ b/mlir/include/mlir/ABI/ABIRewriteContext.h
@@ -0,0 +1,160 @@
+//===- ABIRewriteContext.h - Dialect-specific ABI rewriting -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ABIRewriteContext, the abstract interface for dialect-
+// specific ABI lowering rewrites. Each MLIR dialect that wants ABI lowering
+// (CIR, FIR, etc.) provides a concrete subclass.
+//
+// ABIRewriteContext consumes ABI classification results and drives the
+// creation of lowered function signatures, argument coercions, and call
+// site rewrites using dialect-specific operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_ABIREWRITECONTEXT_H
+#define MLIR_ABI_ABIREWRITECONTEXT_H
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "llvm/Support/Alignment.h"
+
+namespace mlir {
+namespace abi {
+
+/// Classification of how a single argument or return value should be
+/// passed at the ABI level.
+///
+/// This is a dialect-agnostic representation. It mirrors the kinds
+/// found in the LLVM ABI library and in CIR's ABIArgInfo, but does
+/// not depend on either.
+enum class ArgKind : uint8_t {
+ /// Pass directly in registers, possibly coerced to a different type.
+ Direct,
+
+ /// Like Direct, but with a sign/zero extension attribute.
+ Extend,
+
+ /// Pass indirectly via a pointer (sret for returns, byval for args).
+ Indirect,
+
+ /// Ignore (void return, empty struct).
+ Ignore,
+
+ /// Expand an aggregate into its constituent scalar fields.
+ Expand,
+};
+
+/// Describes how a single argument or return value is passed after ABI
+/// lowering.
+struct ArgClassification {
+ ArgKind Kind = ArgKind::Direct;
+
+ /// The ABI-coerced type, if different from the original. Null means
+ /// use the original type.
+ Type CoercedType = nullptr;
+
+ /// For Indirect: alignment of the pointed-to object.
+ llvm::Align IndirectAlign = llvm::Align(1);
+
+ /// For Extend: whether to sign-extend (true) or zero-extend (false).
+ bool SignExtend = false;
+
+ /// For Direct: whether a struct coercion can be flattened into
+ /// individual register-width arguments.
+ bool CanFlatten = true;
+
+ /// For Indirect: whether the callee gets ownership (byval).
+ bool ByVal = false;
+
+ static ArgClassification getDirect(Type coerced = nullptr) {
+ ArgClassification c;
+ c.Kind = ArgKind::Direct;
+ c.CoercedType = coerced;
+ return c;
+ }
+
+ static ArgClassification getIgnore() {
+ ArgClassification c;
+ c.Kind = ArgKind::Ignore;
+ return c;
+ }
+
+ static ArgClassification getIndirect(llvm::Align align,
+ bool byVal = true) {
+ ArgClassification c;
+ c.Kind = ArgKind::Indirect;
+ c.IndirectAlign = align;
+ c.ByVal = byVal;
+ return c;
+ }
+
+ static ArgClassification getExtend(Type coerced, bool signExt) {
+ ArgClassification c;
+ c.Kind = ArgKind::Extend;
+ c.CoercedType = coerced;
+ c.SignExtend = signExt;
+ return c;
+ }
+};
+
+/// Holds the full ABI classification for a function: return type and
+/// all arguments.
+struct FunctionClassification {
+ ArgClassification ReturnInfo;
+ SmallVector<ArgClassification> ArgInfos;
+};
+
+/// ABIRewriteContext is the abstract interface that each dialect
+/// implements to perform ABI-specific rewrites on its operations.
+///
+/// The pass orchestrator calls these methods after ABI classification
+/// to rewrite function definitions and call sites.
+class ABIRewriteContext {
+public:
+ virtual ~ABIRewriteContext() = default;
+
+ /// Rewrite a function definition to use ABI-lowered types.
+ ///
+ /// This creates a new function with the lowered signature, rewrites
+ /// the function body to adapt between the ABI types and the
+ /// original high-level types, and replaces the original function.
+ ///
+ /// \param funcOp The function to rewrite (via FunctionOpInterface).
+ /// \param fc The ABI classification for this function.
+ /// \param rewriter The pattern rewriter to use for modifications.
+ /// \returns success() if the function was rewritten.
+ virtual LogicalResult
+ rewriteFunctionDefinition(FunctionOpInterface funcOp,
+ const FunctionClassification &fc,
+ OpBuilder &rewriter) = 0;
+
+ /// Rewrite a call operation to match the callee's ABI-lowered
+ /// signature.
+ ///
+ /// This coerces arguments, handles indirect returns (sret), and
+ /// adapts the call result back to the original high-level type.
+ ///
+ /// \param callOp The call operation to rewrite.
+ /// \param fc The ABI classification for the callee.
+ /// \param rewriter The pattern rewriter to use for modifications.
+ /// \returns success() if the call was rewritten.
+ virtual LogicalResult rewriteCallSite(Operation *callOp,
+ const FunctionClassification &fc,
+ OpBuilder &rewriter) = 0;
+
+ /// Return the dialect namespace this context handles (e.g. "cir").
+ virtual StringRef getDialectNamespace() const = 0;
+};
+
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_ABIREWRITECONTEXT_H
diff --git a/mlir/include/mlir/ABI/ABITypeMapper.h b/mlir/include/mlir/ABI/ABITypeMapper.h
new file mode 100644
index 0000000000000..2180c9c8a918d
--- /dev/null
+++ b/mlir/include/mlir/ABI/ABITypeMapper.h
@@ -0,0 +1,66 @@
+//===- ABITypeMapper.h - Map MLIR types to ABI types -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines ABITypeMapper, which translates mlir::Type instances into
+// the llvm::abi::Type hierarchy defined in llvm/ABI/Types.h. Dialect-specific
+// types are handled via MLIR's DataLayoutTypeInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ABI_ABITYPEMAPPER_H
+#define MLIR_ABI_ABITYPEMAPPER_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/ABI/Types.h"
+#include "llvm/Support/Allocator.h"
+
+namespace mlir {
+namespace abi {
+
+/// ABITypeMapper translates mlir::Type values into the llvm::abi::Type
+/// hierarchy used by the LLVM ABI Lowering Library.
+///
+/// Standard MLIR types (IntegerType, FloatType, IndexType, VectorType,
+/// MemRefType) are mapped directly. Dialect-specific types are mapped
+/// by querying the MLIR DataLayout for size and alignment.
+///
+/// Callers must supply a DataLayout (typically from the enclosing module)
+/// so the mapper can determine sizes and alignments.
+///
+/// The mapper owns a BumpPtrAllocator; all returned abi::Type pointers
+/// are valid for the lifetime of the mapper.
+class ABITypeMapper {
+public:
+ explicit ABITypeMapper(const DataLayout &dl);
+
+ /// Map an MLIR type to its ABI type representation. Returns nullptr
+ /// if the type cannot be mapped.
+ const llvm::abi::Type *map(mlir::Type type);
+
+ /// Access the underlying TypeBuilder for advanced use.
+ llvm::abi::TypeBuilder &getTypeBuilder() { return Builder; }
+
+private:
+ const llvm::abi::Type *mapIntegerType(mlir::IntegerType type);
+ const llvm::abi::Type *mapFloatType(mlir::FloatType type);
+ const llvm::abi::Type *mapIndexType(mlir::IndexType type);
+ const llvm::abi::Type *mapVectorType(mlir::VectorType type);
+ const llvm::abi::Type *mapMemRefType(mlir::MemRefType type);
+ const llvm::abi::Type *mapNoneType(mlir::NoneType type);
+
+ const DataLayout &DL;
+ llvm::BumpPtrAllocator Allocator;
+ llvm::abi::TypeBuilder Builder;
+};
+
+} // namespace abi
+} // namespace mlir
+
+#endif // MLIR_ABI_ABITYPEMAPPER_H
diff --git a/mlir/lib/ABI/ABITypeMapper.cpp b/mlir/lib/ABI/ABITypeMapper.cpp
new file mode 100644
index 0000000000000..cfffae7b746f8
--- /dev/null
+++ b/mlir/lib/ABI/ABITypeMapper.cpp
@@ -0,0 +1,103 @@
+//===- ABITypeMapper.cpp - Map MLIR types to ABI types --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABITypeMapper.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/Support/Alignment.h"
+
+using namespace mlir;
+using namespace mlir::abi;
+
+ABITypeMapper::ABITypeMapper(const DataLayout &dl)
+ : DL(dl), Builder(Allocator) {}
+
+const llvm::abi::Type *ABITypeMapper::map(mlir::Type type) {
+ if (auto intTy = dyn_cast<mlir::IntegerType>(type))
+ return mapIntegerType(intTy);
+
+ if (auto floatTy = dyn_cast<mlir::FloatType>(type))
+ return mapFloatType(floatTy);
+
+ if (auto indexTy = dyn_cast<mlir::IndexType>(type))
+ return mapIndexType(indexTy);
+
+ if (auto vecTy = dyn_cast<mlir::VectorType>(type))
+ return mapVectorType(vecTy);
+
+ if (auto memRefTy = dyn_cast<mlir::MemRefType>(type))
+ return mapMemRefType(memRefTy);
+
+ if (auto noneTy = dyn_cast<mlir::NoneType>(type))
+ return mapNoneType(noneTy);
+
+ // For dialect-specific types, fall back to DataLayout queries.
+ // The type must implement DataLayoutTypeInterface for this to work.
+ llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ return Builder.getIntegerType(sizeInBits.getFixedValue(),
+ llvm::Align(abiAlign),
+ /*Signed=*/false);
+}
+
+const llvm::abi::Type *
+ABITypeMapper::mapIntegerType(mlir::IntegerType type) {
+ uint64_t width = type.getWidth();
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ // MLIR signless integers are treated as signed for ABI purposes.
+ // Most C/C++ integer types are signless in MLIR but behave as
+ // signed for ABI classification (sign extension, etc.).
+ bool isSigned = type.isSigned() || type.isSignless();
+ return Builder.getIntegerType(width, llvm::Align(abiAlign), isSigned);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapFloatType(mlir::FloatType type) {
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ const llvm::fltSemantics &semantics = type.getFloatSemantics();
+ return Builder.getFloatType(semantics, llvm::Align(abiAlign));
+}
+
+const llvm::abi::Type *ABITypeMapper::mapIndexType(mlir::IndexType type) {
+ llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ return Builder.getIntegerType(sizeInBits.getFixedValue(),
+ llvm::Align(abiAlign),
+ /*Signed=*/false);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapVectorType(mlir::VectorType type) {
+ const llvm::abi::Type *elementTy = map(type.getElementType());
+ if (!elementTy)
+ return nullptr;
+
+ auto shape = type.getShape();
+ // MLIR VectorType is always fixed-length and can be multi-dimensional.
+ // Flatten to a single dimension for ABI purposes.
+ uint64_t totalElements = 1;
+ for (int64_t dim : shape)
+ totalElements *= dim;
+
+ llvm::ElementCount ec = llvm::ElementCount::getFixed(totalElements);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ return Builder.getVectorType(elementTy, ec, llvm::Align(abiAlign));
+}
+
+const llvm::abi::Type *ABITypeMapper::mapMemRefType(mlir::MemRefType type) {
+ // MemRef is pointer-like for ABI purposes.
+ llvm::TypeSize sizeInBits = DL.getTypeSizeInBits(type);
+ uint64_t abiAlign = DL.getTypeABIAlignment(type);
+ unsigned addrSpace = 0;
+ if (auto as = type.getMemorySpace())
+ if (auto intAttr = dyn_cast<IntegerAttr>(as))
+ addrSpace = intAttr.getInt();
+ return Builder.getPointerType(sizeInBits.getFixedValue(),
+ llvm::Align(abiAlign), addrSpace);
+}
+
+const llvm::abi::Type *ABITypeMapper::mapNoneType(mlir::NoneType type) {
+ return Builder.getVoidType();
+}
diff --git a/mlir/lib/ABI/CMakeLists.txt b/mlir/lib/ABI/CMakeLists.txt
new file mode 100644
index 0000000000000..eb434d25dd390
--- /dev/null
+++ b/mlir/lib/ABI/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_library(MLIRABI
+ ABITypeMapper.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/ABI
+
+ LINK_COMPONENTS
+ ABI
+ Support
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRDataLayoutInterfaces
+ )
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index 91ed05f6548d7..d7a6e28d98586 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -1,6 +1,7 @@
# Enable errors for any global constructors.
add_flag_if_supported("-Werror=global-constructors" WERROR_GLOBAL_CONSTRUCTOR)
+add_subdirectory(ABI)
add_subdirectory(Analysis)
add_subdirectory(AsmParser)
add_subdirectory(Bytecode)
diff --git a/mlir/unittests/ABI/ABIRewriteContextTest.cpp b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
new file mode 100644
index 0000000000000..7c035036be108
--- /dev/null
+++ b/mlir/unittests/ABI/ABIRewriteContextTest.cpp
@@ -0,0 +1,101 @@
+//===- ABIRewriteContextTest.cpp - Unit tests for ABIRewriteContext -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABIRewriteContext.h"
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class MockRewriteContext : public ABIRewriteContext {
+public:
+ LogicalResult rewriteFunctionDefinition(FunctionOpInterface,
+ const FunctionClassification &,
+ OpBuilder &) override {
+ return success();
+ }
+
+ LogicalResult rewriteCallSite(Operation *,
+ const FunctionClassification &,
+ OpBuilder &) override {
+ return success();
+ }
+
+ StringRef getDialectNamespace() const override { return "mock"; }
+};
+
+TEST(ABIRewriteContextTest, MockCanBeConstructedAndDestroyed) {
+ MockRewriteContext ctx;
+ EXPECT_EQ(ctx.getDialectNamespace(), "mock");
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationDirect) {
+ auto c = ArgClassification::getDirect();
+ EXPECT_EQ(c.Kind, ArgKind::Direct);
+ EXPECT_EQ(c.CoercedType, nullptr);
+ EXPECT_TRUE(c.CanFlatten);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationDirectWithType) {
+ MLIRContext mlirCtx;
+ auto i32 = IntegerType::get(&mlirCtx, 32);
+ auto c = ArgClassification::getDirect(i32);
+ EXPECT_EQ(c.Kind, ArgKind::Direct);
+ EXPECT_EQ(c.CoercedType, i32);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIgnore) {
+ auto c = ArgClassification::getIgnore();
+ EXPECT_EQ(c.Kind, ArgKind::Ignore);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIndirect) {
+ auto c = ArgClassification::getIndirect(llvm::Align(8), true);
+ EXPECT_EQ(c.Kind, ArgKind::Indirect);
+ EXPECT_EQ(c.IndirectAlign, llvm::Align(8));
+ EXPECT_TRUE(c.ByVal);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationIndirectNoByVal) {
+ auto c = ArgClassification::getIndirect(llvm::Align(16), false);
+ EXPECT_EQ(c.Kind, ArgKind::Indirect);
+ EXPECT_EQ(c.IndirectAlign, llvm::Align(16));
+ EXPECT_FALSE(c.ByVal);
+}
+
+TEST(ABIRewriteContextTest, ArgClassificationExtend) {
+ MLIRContext mlirCtx;
+ auto i8 = IntegerType::get(&mlirCtx, 8);
+
+ auto signExt = ArgClassification::getExtend(i8, true);
+ EXPECT_EQ(signExt.Kind, ArgKind::Extend);
+ EXPECT_TRUE(signExt.SignExtend);
+
+ auto zeroExt = ArgClassification::getExtend(i8, false);
+ EXPECT_EQ(zeroExt.Kind, ArgKind::Extend);
+ EXPECT_FALSE(zeroExt.SignExtend);
+}
+
+TEST(ABIRewriteContextTest, FunctionClassificationHoldsReturnAndArgs) {
+ FunctionClassification fc;
+ fc.ReturnInfo = ArgClassification::getDirect();
+ fc.ArgInfos.push_back(ArgClassification::getDirect());
+ fc.ArgInfos.push_back(
+ ArgClassification::getIndirect(llvm::Align(8), true));
+ fc.ArgInfos.push_back(ArgClassification::getIgnore());
+
+ EXPECT_EQ(fc.ReturnInfo.Kind, ArgKind::Direct);
+ EXPECT_EQ(fc.ArgInfos.size(), 3u);
+ EXPECT_EQ(fc.ArgInfos[0].Kind, ArgKind::Direct);
+ EXPECT_EQ(fc.ArgInfos[1].Kind, ArgKind::Indirect);
+ EXPECT_EQ(fc.ArgInfos[2].Kind, ArgKind::Ignore);
+}
+
+} // namespace
diff --git a/mlir/unittests/ABI/ABITypeMapperTest.cpp b/mlir/unittests/ABI/ABITypeMapperTest.cpp
new file mode 100644
index 0000000000000..4a7989298a149
--- /dev/null
+++ b/mlir/unittests/ABI/ABITypeMapperTest.cpp
@@ -0,0 +1,173 @@
+//===- ABITypeMapperTest.cpp - Unit tests for ABITypeMapper ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/ABI/ABITypeMapper.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include "llvm/ABI/Types.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::abi;
+
+namespace {
+
+class ABITypeMapperTest : public ::testing::Test {
+protected:
+ void SetUp() override {
+ ctx.loadDialect<DLTIDialect>();
+ module = ModuleOp::create(UnknownLoc::get(&ctx));
+ }
+
+ void TearDown() override { module->destroy(); }
+
+ MLIRContext ctx;
+ ModuleOp module;
+};
+
+TEST_F(ABITypeMapperTest, MapI32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto i32 = IntegerType::get(&ctx, 32);
+ const llvm::abi::Type *result = mapper.map(i32);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isInteger());
+
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 32u);
+}
+
+TEST_F(ABITypeMapperTest, MapI1) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto i1 = IntegerType::get(&ctx, 1);
+ const llvm::abi::Type *result = mapper.map(i1);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isInteger());
+
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 1u);
+}
+
+TEST_F(ABITypeMapperTest, MapI64) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto i64 = IntegerType::get(&ctx, 64);
+ const llvm::abi::Type *result = mapper.map(i64);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isInteger());
+
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_EQ(intTy->getSizeInBits().getFixedValue(), 64u);
+}
+
+TEST_F(ABITypeMapperTest, MapF32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f32 = Float32Type::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(f32);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isFloat());
+
+ auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+ EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 32u);
+}
+
+TEST_F(ABITypeMapperTest, MapF64) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f64 = Float64Type::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(f64);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isFloat());
+
+ auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+ EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 64u);
+}
+
+TEST_F(ABITypeMapperTest, MapF16) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f16 = Float16Type::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(f16);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isFloat());
+
+ auto *floatTy = llvm::cast<llvm::abi::FloatType>(result);
+ EXPECT_EQ(floatTy->getSizeInBits().getFixedValue(), 16u);
+}
+
+TEST_F(ABITypeMapperTest, MapNoneType) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto none = NoneType::get(&ctx);
+ const llvm::abi::Type *result = mapper.map(none);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isVoid());
+}
+
+TEST_F(ABITypeMapperTest, MapVectorOf4xF32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto f32 = Float32Type::get(&ctx);
+ auto vec = VectorType::get({4}, f32);
+ const llvm::abi::Type *result = mapper.map(vec);
+
+ ASSERT_NE(result, nullptr);
+ EXPECT_TRUE(result->isVector());
+
+ auto *vecTy = llvm::cast<llvm::abi::VectorType>(result);
+ EXPECT_EQ(vecTy->getNumElements().getFixedValue(), 4u);
+ EXPECT_TRUE(vecTy->getElementType()->isFloat());
+}
+
+TEST_F(ABITypeMapperTest, MapSignedI32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto si32 = IntegerType::get(&ctx, 32, IntegerType::Signed);
+ const llvm::abi::Type *result = mapper.map(si32);
+
+ ASSERT_NE(result, nullptr);
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_TRUE(intTy->isSigned());
+}
+
+TEST_F(ABITypeMapperTest, MapUnsignedI32) {
+ DataLayout dl(module);
+ ABITypeMapper mapper(dl);
+
+ auto ui32 = IntegerType::get(&ctx, 32, IntegerType::Unsigned);
+ const llvm::abi::Type *result = mapper.map(ui32);
+
+ ASSERT_NE(result, nullptr);
+ auto *intTy = llvm::cast<llvm::abi::IntegerType>(result);
+ EXPECT_FALSE(intTy->isSigned());
+}
+
+} // namespace
diff --git a/mlir/unittests/ABI/CMakeLists.txt b/mlir/unittests/ABI/CMakeLists.txt
new file mode 100644
index 0000000000000..39f955a8efea6
--- /dev/null
+++ b/mlir/unittests/ABI/CMakeLists.txt
@@ -0,0 +1,12 @@
+add_mlir_unittest(MLIRABITests
+ ABIRewriteContextTest.cpp
+ ABITypeMapperTest.cpp
+)
+
+mlir_target_link_libraries(MLIRABITests
+ PRIVATE
+ MLIRABI
+ MLIRDataLayoutInterfaces
+ MLIRDLTIDialect
+ MLIRIR
+)
diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt
index 89332bce5fe05..654ec44d90b04 100644
--- a/mlir/unittests/CMakeLists.txt
+++ b/mlir/unittests/CMakeLists.txt
@@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
add_unittest(MLIRUnitTests ${test_dirname} ${ARGN})
endfunction()
+add_subdirectory(ABI)
add_subdirectory(Analysis)
add_subdirectory(Bytecode)
add_subdirectory(Conversion)
More information about the Mlir-commits
mailing list