[llvm] [WebAssembly] Implement lowering calls through funcref to call_ref when available (PR #162227)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 6 23:14:03 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-webassembly
Author: Demetrius Kanios (QuantumSegfault)
<details>
<summary>Changes</summary>
Allows calls through `funcref` (`ptr addrspace(20)`) to be lowered to a sequence of `ref.cast` + `call_ref` when WasmGC is available. This is opposed to the current work around of storing the funcref into a special table, and using `call_indirect`.
Builds upon the framework provided by #<!-- -->147486
__Example__
_Source IR_
```ll
define void @<!-- -->call_ref_void(%funcref %callee) {
call addrspace(20) void %callee()
ret void
}
```
_Result_
Before this PR and/or without GC:
```S
i32.const 0
local.get 0
table.set __funcref_call_table
i32.const 0
call_indirect __funcref_call_table, () -> ()
i32.const 0
ref.null_func
table.set __funcref_call_table
```
After this PR, when compiled with `-mattr=+gc`:
```S
local.get 0
ref.cast () -> ()
call_ref () -> ()
```
---
Full diff: https://github.com/llvm/llvm-project/pull/162227.diff
8 Files Affected:
- (modified) llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h (+12)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp (+2-55)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+128-25)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h (+5)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td (+20)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td (+5)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp (+9-7)
- (added) llvm/test/CodeGen/WebAssembly/call-ref.ll (+51)
``````````diff
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
index fe9a4bada2430..db4d9edb152ce 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCTargetDesc.h
@@ -435,6 +435,18 @@ inline bool isCallIndirect(unsigned Opc) {
}
}
+inline bool isCallRef(unsigned Opc) {
+ switch (Opc) {
+ case WebAssembly::CALL_REF:
+ case WebAssembly::CALL_REF_S:
+ case WebAssembly::RET_CALL_REF:
+ case WebAssembly::RET_CALL_REF_S:
+ return true;
+ default:
+ return false;
+ }
+}
+
inline bool isBrTable(unsigned Opc) {
switch (Opc) {
case WebAssembly::BR_TABLE_I32:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
index 2541b0433ab59..03c90c7160a68 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
@@ -120,60 +120,6 @@ static SDValue getTagSymNode(int Tag, SelectionDAG *DAG) {
return DAG->getTargetExternalSymbol(SymName, PtrVT);
}
-static APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
- SmallVector<MVT, 4> &Returns,
- SmallVector<MVT, 4> &Params) {
- auto toWasmValType = [](MVT VT) {
- if (VT == MVT::i32) {
- return wasm::ValType::I32;
- }
- if (VT == MVT::i64) {
- return wasm::ValType::I64;
- }
- if (VT == MVT::f32) {
- return wasm::ValType::F32;
- }
- if (VT == MVT::f64) {
- return wasm::ValType::F64;
- }
- if (VT == MVT::externref) {
- return wasm::ValType::EXTERNREF;
- }
- if (VT == MVT::funcref) {
- return wasm::ValType::FUNCREF;
- }
- if (VT == MVT::exnref) {
- return wasm::ValType::EXNREF;
- }
- LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT
- << "\n");
- llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func");
- };
- auto NParams = Params.size();
- auto NReturns = Returns.size();
- auto BitWidth = (NParams + NReturns + 2) * 64;
- auto Sig = APInt(BitWidth, 0);
-
- // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
- // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
- // always emit a CImm. So xor NParams with 0x7ffffff to ensure
- // getSignificantBits() > 64
- Sig |= NReturns ^ 0x7ffffff;
- for (auto &Return : Returns) {
- auto V = toWasmValType(Return);
- Sig <<= 64;
- Sig |= (int64_t)V;
- }
- Sig <<= 64;
- Sig |= NParams;
- for (auto &Param : Params) {
- auto V = toWasmValType(Param);
- Sig <<= 64;
- Sig |= (int64_t)V;
- }
- return Sig;
-}
-
void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
// If we have a custom node, we already have selected!
if (Node->isMachineOpcode()) {
@@ -288,7 +234,8 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
Returns.push_back(VT);
}
}
- auto Sig = encodeFunctionSignature(CurDAG, DL, Returns, Params);
+ auto Sig =
+ WebAssembly::encodeFunctionSignature(CurDAG, DL, Returns, Params);
auto SigOp = CurDAG->getTargetConstant(
Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth()));
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 163bf9ba5b089..bd0733c73f7ed 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -723,6 +723,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
bool IsIndirect =
CallParams.getOperand(0).isReg() || CallParams.getOperand(0).isFI();
bool IsRetCall = CallResults.getOpcode() == WebAssembly::RET_CALL_RESULTS;
+ bool IsCallRef = false;
bool IsFuncrefCall = false;
if (IsIndirect && CallParams.getOperand(0).isReg()) {
@@ -732,10 +733,19 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
const TargetRegisterClass *TRC = MRI.getRegClass(Reg);
IsFuncrefCall = (TRC == &WebAssembly::FUNCREFRegClass);
assert(!IsFuncrefCall || Subtarget->hasReferenceTypes());
+
+ if (IsFuncrefCall && Subtarget->hasGC()) {
+ IsIndirect = false;
+ IsCallRef = true;
+ }
}
unsigned CallOp;
- if (IsIndirect && IsRetCall) {
+ if (IsCallRef && IsRetCall) {
+ CallOp = WebAssembly::RET_CALL_REF;
+ } else if (IsCallRef) {
+ CallOp = WebAssembly::CALL_REF;
+ } else if (IsIndirect && IsRetCall) {
CallOp = WebAssembly::RET_CALL_INDIRECT;
} else if (IsIndirect) {
CallOp = WebAssembly::CALL_INDIRECT;
@@ -771,6 +781,14 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
CallParams.addOperand(FnPtr);
}
+ // Move the function pointer to the end of the arguments for funcref calls
+ if (IsCallRef) {
+ auto FnRef = CallParams.getOperand(0);
+ CallParams.removeOperand(0);
+
+ CallParams.addOperand(FnRef);
+ }
+
for (auto Def : CallResults.defs())
MIB.add(Def);
@@ -795,6 +813,12 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
}
}
+ if (IsCallRef) {
+ // Placeholder for the type index.
+ // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
+ MIB.addImm(0);
+ }
+
for (auto Use : CallParams.uses())
MIB.add(Use);
@@ -1173,6 +1197,60 @@ static bool callingConvSupported(CallingConv::ID CallConv) {
CallConv == CallingConv::Swift;
}
+APInt WebAssembly::encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
+ SmallVector<MVT, 4> &Returns,
+ SmallVector<MVT, 4> &Params) {
+ auto toWasmValType = [](MVT VT) {
+ if (VT == MVT::i32) {
+ return wasm::ValType::I32;
+ }
+ if (VT == MVT::i64) {
+ return wasm::ValType::I64;
+ }
+ if (VT == MVT::f32) {
+ return wasm::ValType::F32;
+ }
+ if (VT == MVT::f64) {
+ return wasm::ValType::F64;
+ }
+ if (VT == MVT::externref) {
+ return wasm::ValType::EXTERNREF;
+ }
+ if (VT == MVT::funcref) {
+ return wasm::ValType::FUNCREF;
+ }
+ if (VT == MVT::exnref) {
+ return wasm::ValType::EXNREF;
+ }
+ LLVM_DEBUG(errs() << "Unhandled type for llvm.wasm.ref.test.func: " << VT
+ << "\n");
+ llvm_unreachable("Unhandled type for llvm.wasm.ref.test.func");
+ };
+ auto NParams = Params.size();
+ auto NReturns = Returns.size();
+ auto BitWidth = (NParams + NReturns + 2) * 64;
+ auto Sig = APInt(BitWidth, 0);
+
+ // Annoying special case: if getSignificantBits() <= 64 then InstrEmitter will
+ // emit an Imm instead of a CImm. It simplifies WebAssemblyMCInstLower if we
+ // always emit a CImm. So xor NParams with 0x7ffffff to ensure
+ // getSignificantBits() > 64
+ Sig |= NReturns ^ 0x7ffffff;
+ for (auto &Return : Returns) {
+ auto V = toWasmValType(Return);
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+ Sig <<= 64;
+ Sig |= NParams;
+ for (auto &Param : Params) {
+ auto V = toWasmValType(Param);
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+ return Sig;
+}
+
SDValue
WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
@@ -1412,33 +1490,58 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
InTys.push_back(In.VT);
}
- // Lastly, if this is a call to a funcref we need to add an instruction
- // table.set to the chain and transform the call.
+ // Lastly, if this is a call to a funcref we need to insert an instruction
+ // to either cast the funcref to a typed funcref for call_ref, or place it
+ // into a table for call_indirect
if (CLI.CB && WebAssembly::isWebAssemblyFuncrefType(
CLI.CB->getCalledOperand()->getType())) {
- // In the absence of function references proposal where a funcref call is
- // lowered to call_ref, using reference types we generate a table.set to set
- // the funcref to a special table used solely for this purpose, followed by
- // a call_indirect. Here we just generate the table set, and return the
- // SDValue of the table.set so that LowerCall can finalize the lowering by
- // generating the call_indirect.
- SDValue Chain = Ops[0];
+ if (Subtarget->hasGC()) {
+ // Since LLVM doesn't directly support typed function references, we take
+ // the untyped funcref and ref.cast it into a typed funcref.
+ SmallVector<MVT, 4> Params;
+ SmallVector<MVT, 4> Returns;
+
+ for (const auto &Out : Outs) {
+ Params.push_back(Out.VT);
+ }
+ for (const auto &In : Ins) {
+ Returns.push_back(In.VT);
+ }
- MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
- MF.getContext(), Subtarget);
- SDValue Sym = DAG.getMCSymbol(Table, PtrVT);
- SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32);
- SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee};
- SDValue TableSet = DAG.getMemIntrinsicNode(
- WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps,
- MVT::funcref,
- // Machine Mem Operand args
- MachinePointerInfo(
- WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF),
- CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()),
- MachineMemOperand::MOStore);
-
- Ops[0] = TableSet; // The new chain is the TableSet itself
+ auto Sig =
+ WebAssembly::encodeFunctionSignature(&DAG, DL, Returns, Params);
+
+ auto SigOp = DAG.getTargetConstant(
+ Sig, DL, EVT::getIntegerVT(*DAG.getContext(), Sig.getBitWidth()));
+ MachineSDNode *RefCastNode = DAG.getMachineNode(
+ WebAssembly::REF_CAST_FUNCREF, DL, MVT::funcref, {SigOp, Callee});
+
+ Ops[1] = SDValue(RefCastNode, 0);
+ } else {
+ // In the absence of function references proposal where a funcref call is
+ // lowered to call_ref, using reference types we generate a table.set to
+ // set the funcref to a special table used solely for this purpose,
+ // followed by a call_indirect. Here we just generate the table set, and
+ // return the SDValue of the table.set so that LowerCall can finalize the
+ // lowering by generating the call_indirect.
+ SDValue Chain = Ops[0];
+
+ MCSymbolWasm *Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
+ MF.getContext(), Subtarget);
+ SDValue Sym = DAG.getMCSymbol(Table, PtrVT);
+ SDValue TableSlot = DAG.getConstant(0, DL, MVT::i32);
+ SDValue TableSetOps[] = {Chain, Sym, TableSlot, Callee};
+ SDValue TableSet = DAG.getMemIntrinsicNode(
+ WebAssemblyISD::TABLE_SET, DL, DAG.getVTList(MVT::Other), TableSetOps,
+ MVT::funcref,
+ // Machine Mem Operand args
+ MachinePointerInfo(
+ WebAssembly::WasmAddressSpace::WASM_ADDRESS_SPACE_FUNCREF),
+ CLI.CB->getCalledOperand()->getPointerAlignment(DAG.getDataLayout()),
+ MachineMemOperand::MOStore);
+
+ Ops[0] = TableSet; // The new chain is the TableSet itself
+ }
}
if (CLI.IsTailCall) {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index b33a8530310be..7d2194132f293 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -141,6 +141,11 @@ class WebAssemblyTargetLowering final : public TargetLowering {
namespace WebAssembly {
FastISel *createFastISel(FunctionLoweringInfo &funcInfo,
const TargetLibraryInfo *libInfo);
+
+APInt encodeFunctionSignature(SelectionDAG *DAG, SDLoc &DL,
+ SmallVector<MVT, 4> &Returns,
+ SmallVector<MVT, 4> &Params);
+
} // end namespace WebAssembly
} // end namespace llvm
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
index ca9a5ef9dda1c..81b62f6a682ec 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrCall.td
@@ -66,6 +66,16 @@ defm CALL_INDIRECT :
[],
"call_indirect", "call_indirect\t$type, $table", 0x11>;
+let variadicOpsAreDefs = 1 in
+defm CALL_REF :
+ I<(outs),
+ (ins TypeIndex:$type, variable_ops),
+ (outs),
+ (ins TypeIndex:$type),
+ [],
+ "call_ref", "call_ref\t$type", 0x14>,
+ Requires<[HasGC]>;
+
let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in
defm RET_CALL :
I<(outs), (ins function32_op:$callee, variable_ops),
@@ -81,4 +91,14 @@ defm RET_CALL_INDIRECT :
0x13>,
Requires<[HasTailCall]>;
+let isReturn = 1, isTerminator = 1, hasCtrlDep = 1, isBarrier = 1 in
+defm RET_CALL_REF :
+ I<(outs),
+ (ins TypeIndex:$type, variable_ops),
+ (outs),
+ (ins TypeIndex:$type),
+ [],
+ "return_call_ref", "return_call_ref\t$type", 0x15>,
+ Requires<[HasTailCall, HasGC]>;
+
} // Uses = [SP32,SP64], isCall = 1
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
index fc82e5b4a61da..6fa6ed897d647 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
@@ -41,6 +41,11 @@ defm REF_TEST_FUNCREF : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref),
"ref.test\t$type, $ref", "ref.test $type", 0xfb14>,
Requires<[HasGC]>;
+defm REF_CAST_FUNCREF : I<(outs FUNCREF:$res), (ins TypeIndex:$type, FUNCREF:$ref),
+ (outs), (ins TypeIndex:$type), [],
+ "ref.cast\t$type, $ref", "ref.cast $type", 0xfb16>,
+ Requires<[HasGC]>;
+
defm "" : REF_I<FUNCREF, funcref, "func">;
defm "" : REF_I<EXTERNREF, externref, "extern">;
defm "" : REF_I<EXNREF, exnref, "exn">;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index e48283aadb437..1ed15967c01fe 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -230,7 +230,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
case llvm::MachineOperand::MO_CImmediate: {
- // Lower type index placeholder for ref.test
+ // Lower type index placeholder for ref.test and ref.cast
// Currently this is the only way that CImmediates show up so panic if we
// get confused.
unsigned DescIndex = I - NumVariadicDefs;
@@ -266,14 +266,16 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
Params.push_back(WebAssembly::regClassToValType(
MRI.getRegClass(MO.getReg())->getID()));
- // call_indirect instructions have a callee operand at the end which
- // doesn't count as a param.
- if (WebAssembly::isCallIndirect(MI->getOpcode()))
+ // call_indirect and call_ref instructions have a callee operand at
+ // the end which doesn't count as a param.
+ if (WebAssembly::isCallIndirect(MI->getOpcode()) ||
+ WebAssembly::isCallRef(MI->getOpcode()))
Params.pop_back();
- // return_call_indirect instructions have the return type of the
- // caller
- if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT)
+ // return_call_indirect and return_call_ref instructions have the
+ // return type of the caller
+ if (MI->getOpcode() == WebAssembly::RET_CALL_INDIRECT ||
+ MI->getOpcode() == WebAssembly::RET_CALL_REF)
getFunctionReturns(MI, Returns);
MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
diff --git a/llvm/test/CodeGen/WebAssembly/call-ref.ll b/llvm/test/CodeGen/WebAssembly/call-ref.ll
new file mode 100644
index 0000000000000..25fc7440ac64c
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/call-ref.ll
@@ -0,0 +1,51 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mattr=+reference-types,-gc | FileCheck --check-prefixes=CHECK,NOGC %s
+; RUN: llc < %s -mattr=+reference-types,+gc | FileCheck --check-prefixes=CHECK,GC %s
+
+; Test that calls through funcref lower to call_ref when GC is available
+
+target triple = "wasm32-unknown-unknown"
+
+%funcref = type ptr addrspace(20);
+
+define void @call_ref_void(%funcref %callee) {
+; CHECK-LABEL: call_ref_void:
+; CHECK: .functype call_ref_void (funcref) -> ()
+; CHECK-NEXT: # %bb.0:
+; NOGC-NEXT: i32.const 0
+; CHECK-NEXT: local.get 0
+; NOGC-NEXT: table.set __funcref_call_table
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: call_indirect __funcref_call_table, () -> ()
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: ref.null_func
+; NOGC-NEXT: table.set __funcref_call_table
+; GC-NEXT: ref.cast () -> ()
+; GC-NEXT: call_ref () -> ()
+; CHECK-NEXT: # fallthrough-return
+ call addrspace(20) void %callee()
+ ret void
+}
+
+define void @call_ref_with_args_and_ret(%funcref %callee) {
+; CHECK-LABEL: call_ref_with_args_and_ret:
+; CHECK: .functype call_ref_with_args_and_ret (funcref) -> ()
+; CHECK-NEXT: # %bb.0:
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: local.get 0
+; NOGC-NEXT: table.set __funcref_call_table
+; CHECK-NEXT: i32.const 1
+; CHECK-NEXT: f64.const 0x1p1
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: call_indirect __funcref_call_table, (i32, f64) -> (i32)
+; GC-NEXT: local.get 0
+; GC-NEXT: ref.cast (i32, f64) -> (i32)
+; GC-NEXT: call_ref (i32, f64) -> (i32)
+; CHECK-NEXT: drop
+; NOGC-NEXT: i32.const 0
+; NOGC-NEXT: ref.null_func
+; NOGC-NEXT: table.set __funcref_call_table
+; CHECK-NEXT: # fallthrough-return
+ %result = call addrspace(20) i32 %callee(i32 1, double 2.0)
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/162227
More information about the llvm-commits
mailing list