[llvm-branch-commits] [clang] [mlir] [CIR] Honor Direct coercion offset in callconv (PR #203640)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jun 26 07:27:09 PDT 2026


https://github.com/adams381 updated https://github.com/llvm/llvm-project/pull/203640

>From c1be624a7f13bf3750c0a57280e3b34d63dc0e9b Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Fri, 12 Jun 2026 12:45:57 -0700
Subject: [PATCH] [CIR] Honor Direct coercion offset in callconv

A Direct classification with a coerced type assumed the coerced value
started at byte 0 of the original aggregate.  On x86-64 SysV a 16-byte
record whose low eightbyte is NO_CLASS carries its live value in the high
eightbyte and is classified as getDirect(coerceType, offset=8); the
coercion path read and wrote the wrong eightbyte for that shape.

Add a directOffset to ArgClassification (with a getDirect(coerced, offset)
overload).  emitCoercionToMemory now applies the offset to the coerced
(scalar) side of the slot via a u8 ptr_stride before the typed view, so the
aggregate side stays at offset 0 while the scalar is read from / written to
the right bytes.  The offset is threaded through both emitCoercion overloads,
insertReturnCoercion, and the call-site and entry-block Direct arms.  Offset
0 takes the original plain-bitcast path and is byte-identical to before.

The Test target parser gains an optional direct_offset key so cir-opt can
inject this classification; coerce-direct-offset.cir covers the offset-8
return and argument plus an offset-0 negative case.
---
 .../TargetLowering/CIRABIRewriteContext.cpp   | 100 ++++++++++++------
 .../abi-lowering/coerce-direct-offset.cir     |  75 +++++++++++++
 mlir/include/mlir/ABI/ABIRewriteContext.h     |  13 +++
 .../mlir/ABI/Targets/Test/TestTarget.h        |   7 +-
 mlir/lib/ABI/Targets/Test/TestTarget.cpp      |  18 +++-
 5 files changed, 174 insertions(+), 39 deletions(-)
 create mode 100644 clang/test/CIR/Transforms/abi-lowering/coerce-direct-offset.cir

diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
index 2fe54e6d6c0e2..af4d58862bf92 100644
--- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRABIRewriteContext.cpp
@@ -299,12 +299,12 @@ mlir::ArrayAttr updateResAttrs(mlir::MLIRContext *ctx,
 /// Any operations the helper creates are appended to \p createdOps so the
 /// caller can pass them to replaceAllUsesExcept and avoid clobbering the
 /// store's value operand when later rewiring the source value.
-mlir::Value
-emitCoercionToMemory(mlir::OpBuilder &builder, mlir::Location loc,
-                     mlir::Type dstTy, mlir::Value src,
-                     mlir::FunctionOpInterface funcOp,
-                     const mlir::DataLayout &dl,
-                     SmallPtrSetImpl<mlir::Operation *> &createdOps) {
+mlir::Value emitCoercionToMemory(mlir::OpBuilder &builder, mlir::Location loc,
+                                 mlir::Type dstTy, mlir::Value src,
+                                 mlir::FunctionOpInterface funcOp,
+                                 const mlir::DataLayout &dl,
+                                 SmallPtrSetImpl<mlir::Operation *> &createdOps,
+                                 unsigned offset = 0) {
   mlir::Type srcTy = src.getType();
   assert(srcTy != dstTy &&
          "emitCoercion callers must pre-check that the types differ");
@@ -315,6 +315,15 @@ emitCoercionToMemory(mlir::OpBuilder &builder, mlir::Location loc,
   mlir::Type slotTy =
       dl.getTypeSize(srcTy) >= dl.getTypeSize(dstTy) ? srcTy : dstTy;
 
+  // The offset applies to the coerced/scalar side -- the operand whose type
+  // is not the slot type.  The aggregate side sits at slot offset 0.  The
+  // slot (the larger of the two types) must be large enough to hold the
+  // coerced value at the offset.
+  [[maybe_unused]] mlir::Type scalarTy = slotTy == srcTy ? dstTy : srcTy;
+  assert((offset == 0 ||
+          offset + dl.getTypeSize(scalarTy) <= dl.getTypeSize(slotTy)) &&
+         "coerce slot too small for offset access");
+
   auto slotPtrTy = cir::PointerType::get(slotTy);
   auto srcPtrTy = cir::PointerType::get(srcTy);
   auto dstPtrTy = cir::PointerType::get(dstTy);
@@ -330,25 +339,46 @@ emitCoercionToMemory(mlir::OpBuilder &builder, mlir::Location loc,
   }
   createdOps.insert(alloca);
 
+  // Retype the slot to \p wantPtrTy.  When \p applyOffset (the scalar side)
+  // and the offset is non-zero, point at byte \p offset via a u8 ptr_stride
+  // before the bitcast; otherwise a plain bitcast (byte-identical to the
+  // offset-0 path).
+  auto slotView = [&](mlir::Type wantTy,
+                      cir::PointerType wantPtrTy) -> mlir::Value {
+    bool applyOffset = wantTy != slotTy;
+    mlir::Value base = alloca;
+    if (applyOffset && offset != 0) {
+      auto u8Ty =
+          cir::IntType::get(builder.getContext(), 8, /*isSigned=*/false);
+      auto u8PtrTy = cir::PointerType::get(u8Ty);
+      auto u8Base = cir::CastOp::create(builder, loc, u8PtrTy,
+                                        cir::CastKind::bitcast, alloca);
+      createdOps.insert(u8Base);
+      auto strideTy =
+          cir::IntType::get(builder.getContext(), 64, /*isSigned=*/true);
+      auto strideVal = cir::ConstantOp::create(
+          builder, loc, cir::IntAttr::get(strideTy, offset));
+      createdOps.insert(strideVal);
+      auto gep =
+          cir::PtrStrideOp::create(builder, loc, u8PtrTy, u8Base, strideVal);
+      createdOps.insert(gep);
+      base = gep;
+    } else if (wantTy == slotTy) {
+      return alloca;
+    }
+    auto cast = cir::CastOp::create(builder, loc, wantPtrTy,
+                                    cir::CastKind::bitcast, base);
+    createdOps.insert(cast);
+    return cast;
+  };
+
   // Store through a source-typed view of the slot.
-  mlir::Value srcSlot = alloca;
-  if (slotTy != srcTy) {
-    auto srcCast = cir::CastOp::create(builder, loc, srcPtrTy,
-                                       cir::CastKind::bitcast, alloca);
-    createdOps.insert(srcCast);
-    srcSlot = srcCast;
-  }
+  mlir::Value srcSlot = slotView(srcTy, srcPtrTy);
   auto store = cir::StoreOp::create(builder, loc, src, srcSlot);
   createdOps.insert(store);
 
   // Return a destination-typed view of the slot.
-  if (slotTy != dstTy) {
-    auto dstCast = cir::CastOp::create(builder, loc, dstPtrTy,
-                                       cir::CastKind::bitcast, alloca);
-    createdOps.insert(dstCast);
-    return dstCast;
-  }
-  return alloca;
+  return slotView(dstTy, dstPtrTy);
 }
 
 /// Coerce \p src to type \p dstTy by going through memory and load the whole
@@ -358,9 +388,10 @@ mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
                          mlir::Type dstTy, mlir::Value src,
                          mlir::FunctionOpInterface funcOp,
                          const mlir::DataLayout &dl,
-                         SmallPtrSetImpl<mlir::Operation *> &createdOps) {
-  mlir::Value dstSlot =
-      emitCoercionToMemory(builder, loc, dstTy, src, funcOp, dl, createdOps);
+                         SmallPtrSetImpl<mlir::Operation *> &createdOps,
+                         unsigned offset = 0) {
+  mlir::Value dstSlot = emitCoercionToMemory(builder, loc, dstTy, src, funcOp,
+                                             dl, createdOps, offset);
   auto load = cir::LoadOp::create(builder, loc, dstSlot);
   createdOps.insert(load);
   return load;
@@ -371,17 +402,17 @@ mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
 mlir::Value emitCoercion(mlir::OpBuilder &builder, mlir::Location loc,
                          mlir::Type dstTy, mlir::Value src,
                          mlir::FunctionOpInterface funcOp,
-                         const mlir::DataLayout &dl) {
+                         const mlir::DataLayout &dl, unsigned offset = 0) {
   SmallPtrSet<mlir::Operation *, 4> ignored;
-  return emitCoercion(builder, loc, dstTy, src, funcOp, dl, ignored);
+  return emitCoercion(builder, loc, dstTy, src, funcOp, dl, ignored, offset);
 }
 
 /// Insert coercion before each cir.return so the returned value matches the
 /// new (coerced) return type.
 void insertReturnCoercion(mlir::FunctionOpInterface funcOp,
                           mlir::Type origRetTy, mlir::Type coercedRetTy,
-                          mlir::OpBuilder &builder,
-                          const mlir::DataLayout &dl) {
+                          mlir::OpBuilder &builder, const mlir::DataLayout &dl,
+                          unsigned offset) {
   SmallVector<cir::ReturnOp> returns;
   funcOp.walk([&](cir::ReturnOp r) { returns.push_back(r); });
   for (cir::ReturnOp r : returns) {
@@ -391,8 +422,8 @@ void insertReturnCoercion(mlir::FunctionOpInterface funcOp,
     if (origVal.getType() == coercedRetTy)
       continue;
     builder.setInsertionPoint(r);
-    mlir::Value coerced =
-        emitCoercion(builder, r.getLoc(), coercedRetTy, origVal, funcOp, dl);
+    mlir::Value coerced = emitCoercion(builder, r.getLoc(), coercedRetTy,
+                                       origVal, funcOp, dl, offset);
     r->setOperand(0, coerced);
   }
 }
@@ -572,8 +603,9 @@ void insertArgCoercion(mlir::FunctionOpInterface funcOp,
 
       builder.setInsertionPointToStart(&entry);
       SmallPtrSet<mlir::Operation *, 4> coercionOps;
-      mlir::Value adapted = emitCoercion(builder, funcOp.getLoc(), oldArgTy,
-                                         blockArg, funcOp, dl, coercionOps);
+      mlir::Value adapted =
+          emitCoercion(builder, funcOp.getLoc(), oldArgTy, blockArg, funcOp, dl,
+                       coercionOps, ac.directOffset);
 
       // Replace blockArg uses with the adapted value, except inside the
       // helper ops we just created.  This is critical: the StoreOp's value
@@ -889,7 +921,7 @@ mlir::LogicalResult CIRABIRewriteContext::rewriteFunctionDefinition(
       if (fc.returnInfo.kind == ArgKind::Direct && fc.returnInfo.coercedType &&
           !oldResultTypes.empty() && fc.returnInfo.coercedType != origRetTy)
         insertReturnCoercion(funcOp, origRetTy, fc.returnInfo.coercedType,
-                             builder, dl);
+                             builder, dl, fc.returnInfo.directOffset);
 
       mlir::Block &entry = body.front();
 
@@ -1097,7 +1129,7 @@ CIRABIRewriteContext::rewriteCallSite(mlir::Operation *callOp,
     } else if (ac.kind == ArgKind::Direct && ac.coercedType &&
                arg.getType() != ac.coercedType) {
       arg = emitCoercion(builder, call.getLoc(), ac.coercedType, arg,
-                         enclosingFunc, dl);
+                         enclosingFunc, dl, ac.directOffset);
       newArgs.push_back(arg);
     } else if (ac.kind == ArgKind::Indirect) {
       // byval and byref: allocate a stack slot, copy the value in, and pass
@@ -1156,7 +1188,7 @@ CIRABIRewriteContext::rewriteCallSite(mlir::Operation *callOp,
     builder.setInsertionPointAfter(newCall);
     mlir::Value coercedBack =
         emitCoercion(builder, call.getLoc(), origRetTy, newCall.getResult(),
-                     enclosingFunc, dl);
+                     enclosingFunc, dl, fc.returnInfo.directOffset);
     call.getResult().replaceAllUsesWith(coercedBack);
   }
 
diff --git a/clang/test/CIR/Transforms/abi-lowering/coerce-direct-offset.cir b/clang/test/CIR/Transforms/abi-lowering/coerce-direct-offset.cir
new file mode 100644
index 0000000000000..b51209fa241b6
--- /dev/null
+++ b/clang/test/CIR/Transforms/abi-lowering/coerce-direct-offset.cir
@@ -0,0 +1,75 @@
+// RUN: cir-opt %s -cir-call-conv-lowering="classification-attr=test_classify" \
+// RUN:   | FileCheck %s
+
+!u8i = !cir.int<u, 8>
+!s32i = !cir.int<s, 32>
+!s64i = !cir.int<s, 64>
+!rec_HiPair = !cir.struct<"HiPair" {!cir.array<!u8i x 8>, !s64i}>
+
+// 16-byte record whose live value is the high eightbyte: coerced to !s64i
+// taken at byte offset 8.
+#hi_return = {
+  return = { kind = "direct", coerced_type = !s64i, direct_offset = 8 : i64 },
+  args   = [ ]
+}
+
+#hi_arg = {
+  return = { kind = "direct" },
+  args   = [ { kind = "direct", coerced_type = !s64i, direct_offset = 8 : i64 } ]
+}
+
+// Offset 0 (no direct_offset): must remain byte-identical to today - no GEP.
+#zero_return = {
+  return = { kind = "direct", coerced_type = !s64i },
+  args   = [ ]
+}
+
+module attributes {
+  dlti.dl_spec = #dlti.dl_spec<
+    #dlti.dl_entry<i32, dense<32>: vector<2xi64>>,
+    #dlti.dl_entry<i64, dense<64>: vector<2xi64>>>
+} {
+
+  cir.func @returns_hi() -> !rec_HiPair
+      attributes { test_classify = #hi_return } {
+    %0 = cir.const #cir.zero : !rec_HiPair
+    cir.return %0 : !rec_HiPair
+  }
+
+  // Return coerced to !s64i read from byte offset 8 of the record slot.
+  // CHECK:      cir.func{{.*}} @returns_hi() -> !s64i
+  // CHECK:        %[[SLOT:.*]] = cir.alloca "coerce" {{.*}} : !cir.ptr<!rec_HiPair>
+  // CHECK:        cir.store %{{.*}}, %[[SLOT]] : !rec_HiPair, !cir.ptr<!rec_HiPair>
+  // CHECK:        %[[U8:.*]] = cir.cast bitcast %[[SLOT]] : !cir.ptr<!rec_HiPair> -> !cir.ptr<!u8i>
+  // CHECK:        %[[OFF:.*]] = cir.const #cir.int<8> : !s64i
+  // CHECK:        %[[GEP:.*]] = cir.ptr_stride %[[U8]], %[[OFF]] : (!cir.ptr<!u8i>, !s64i) -> !cir.ptr<!u8i>
+  // CHECK:        %[[VCAST:.*]] = cir.cast bitcast %[[GEP]] : !cir.ptr<!u8i> -> !cir.ptr<!s64i>
+  // CHECK:        %[[COERCED:.*]] = cir.load %[[VCAST]] : !cir.ptr<!s64i>, !s64i
+  // CHECK:        cir.return %[[COERCED]] : !s64i
+
+  cir.func @takes_hi(%arg0: !rec_HiPair) -> !s32i
+      attributes { test_classify = #hi_arg } {
+    %0 = cir.const #cir.int<0> : !s32i
+    cir.return %0 : !s32i
+  }
+
+  // Arg arrives as !s64i; reconstructed into the record slot at byte offset 8.
+  // CHECK:      cir.func{{.*}} @takes_hi(%[[A:.*]]: !s64i) -> !s32i
+  // CHECK:        %[[ASLOT:.*]] = cir.alloca "coerce" {{.*}} : !cir.ptr<!rec_HiPair>
+  // CHECK:        %[[AU8:.*]] = cir.cast bitcast %[[ASLOT]] : !cir.ptr<!rec_HiPair> -> !cir.ptr<!u8i>
+  // CHECK:        %[[AOFF:.*]] = cir.const #cir.int<8> : !s64i
+  // CHECK:        %[[AGEP:.*]] = cir.ptr_stride %[[AU8]], %[[AOFF]] : (!cir.ptr<!u8i>, !s64i) -> !cir.ptr<!u8i>
+  // CHECK:        %[[ACAST:.*]] = cir.cast bitcast %[[AGEP]] : !cir.ptr<!u8i> -> !cir.ptr<!s64i>
+  // CHECK:        cir.store %[[A]], %[[ACAST]] : !s64i, !cir.ptr<!s64i>
+
+  cir.func @returns_zero() -> !rec_HiPair
+      attributes { test_classify = #zero_return } {
+    %0 = cir.const #cir.zero : !rec_HiPair
+    cir.return %0 : !rec_HiPair
+  }
+
+  // Offset 0: coercion goes straight through the bitcast view, no ptr_stride.
+  // CHECK:      cir.func{{.*}} @returns_zero() -> !s64i
+  // CHECK-NOT:    cir.ptr_stride
+  // CHECK:        cir.return %{{.*}} : !s64i
+}
diff --git a/mlir/include/mlir/ABI/ABIRewriteContext.h b/mlir/include/mlir/ABI/ABIRewriteContext.h
index 1982110c5ad6e..44d0022cd0459 100644
--- a/mlir/include/mlir/ABI/ABIRewriteContext.h
+++ b/mlir/include/mlir/ABI/ABIRewriteContext.h
@@ -74,6 +74,11 @@ struct ArgClassification {
   /// For Indirect: whether the callee gets ownership (byval).
   bool byVal = false;
 
+  /// For Direct with coercion: the byte offset within the original aggregate
+  /// at which the coerced value lives.  Non-zero when the low eightbyte is
+  /// NO_CLASS and the value is carried in a later eightbyte (x86-64 SysV).
+  unsigned directOffset = 0;
+
   static ArgClassification getDirect(Type coerced = nullptr) {
     ArgClassification c;
     c.kind = ArgKind::Direct;
@@ -81,6 +86,14 @@ struct ArgClassification {
     return c;
   }
 
+  static ArgClassification getDirect(Type coerced, unsigned offset) {
+    ArgClassification c;
+    c.kind = ArgKind::Direct;
+    c.coercedType = coerced;
+    c.directOffset = offset;
+    return c;
+  }
+
   static ArgClassification getIgnore() {
     ArgClassification c;
     c.kind = ArgKind::Ignore;
diff --git a/mlir/include/mlir/ABI/Targets/Test/TestTarget.h b/mlir/include/mlir/ABI/Targets/Test/TestTarget.h
index 4404d47f8df45..4717c843422e7 100644
--- a/mlir/include/mlir/ABI/Targets/Test/TestTarget.h
+++ b/mlir/include/mlir/ABI/Targets/Test/TestTarget.h
@@ -64,6 +64,9 @@ FunctionClassification classify(ArrayRef<Type> argTypes, Type returnType,
 ///   coerced_type:  TypeAttr.  ABI-coerced type, if different from the
 ///                  original.
 ///   can_flatten:   BoolAttr.  Defaults to true.
+///   direct_offset: IntegerAttr.  Byte offset within the original aggregate
+///                  at which the coerced value lives.  Requires coerced_type.
+///                  Defaults to 0.
 ///
 /// For kind = "extend" (coerced_type required, sign_extend optional):
 ///   coerced_type:  TypeAttr.  Required; the extended integer type.
@@ -78,8 +81,8 @@ FunctionClassification classify(ArrayRef<Type> argTypes, Type returnType,
 ///
 /// Future schema additions tracked in projects/daily_log.md (Step 0c
 /// field-mapping table).  When we add new fields to ArgClassification
-/// (e.g. direct_offset, extend_kind tristate, indirect_addr_space,
-/// indirect_realign), the corresponding optional keys go here.
+/// (e.g. extend_kind tristate, indirect_addr_space, indirect_realign),
+/// the corresponding optional keys go here.
 ///
 /// Unknown keys cause a parse error (no silent ignore — keeps schema
 /// honest as it grows).
diff --git a/mlir/lib/ABI/Targets/Test/TestTarget.cpp b/mlir/lib/ABI/Targets/Test/TestTarget.cpp
index 51510b0c18009..1c84c2ece696f 100644
--- a/mlir/lib/ABI/Targets/Test/TestTarget.cpp
+++ b/mlir/lib/ABI/Targets/Test/TestTarget.cpp
@@ -125,8 +125,8 @@ namespace {
 /// set causes a parse error (no silent ignore).  Updated when new
 /// optional keys are added to the schema.
 constexpr StringRef knownArgKeys[] = {
-    "kind",        "coerced_type",   "sign_extend",
-    "can_flatten", "indirect_align", "byval",
+    "kind",           "coerced_type", "sign_extend",   "can_flatten",
+    "indirect_align", "byval",        "direct_offset",
 };
 
 bool isKnownArgKey(StringRef key) {
@@ -151,7 +151,7 @@ parseOne(DictionaryAttr argDict, function_ref<InFlightDiagnostic()> emitError) {
       emitError() << "unknown key '" << na.getName().getValue()
                   << "' in classification dictionary; allowed keys are "
                   << "kind, coerced_type, sign_extend, can_flatten, "
-                  << "indirect_align, byval";
+                  << "indirect_align, byval, direct_offset";
       return std::nullopt;
     }
 
@@ -164,6 +164,18 @@ parseOne(DictionaryAttr argDict, function_ref<InFlightDiagnostic()> emitError) {
     auto c = ArgClassification::getDirect(coerced);
     if (auto cf = argDict.getAs<BoolAttr>("can_flatten"))
       c.canFlatten = cf.getValue();
+    if (auto off = argDict.getAs<IntegerAttr>("direct_offset")) {
+      if (!coerced) {
+        emitError() << "'direct_offset' requires 'coerced_type'";
+        return std::nullopt;
+      }
+      if (off.getInt() < 0) {
+        emitError() << "'direct_offset' must be non-negative; got "
+                    << off.getInt();
+        return std::nullopt;
+      }
+      c.directOffset = off.getInt();
+    }
     return c;
   }
 



More information about the llvm-branch-commits mailing list