[llvm] [WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic, option 2 (PR #147486)
Hood Chatham via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 8 07:09:59 PDT 2025
https://github.com/hoodmane updated https://github.com/llvm/llvm-project/pull/147486
>From 5eb65fd57bfeb6cfca7f59936d786ec3f01917fd Mon Sep 17 00:00:00 2001
From: Hood Chatham <roberthoodchatham at gmail.com>
Date: Wed, 2 Jul 2025 20:53:56 +0200
Subject: [PATCH 1/3] [WebAssembly,llvm] Add llvm.wasm.ref.test.func intrinsic,
option 2
To test whether or not a function pointer has the expected signature.
Intended for adding a future clang builtin
` __builtin_wasm_test_function_pointer_signature` so we can test whether
calling a function pointer will fail with function signature mismatch.
This is an alternative to #147076, where instead of using a
ref.test.pseudo instruction with a custom inserter, we teach SelectionDag
a type of TargetConstantAP nodes that get converted to a CImm in the
MCInst layer.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 1 +
llvm/include/llvm/CodeGen/SelectionDAG.h | 10 ++-
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 12 +--
llvm/include/llvm/IR/IntrinsicsWebAssembly.td | 4 +
.../lib/CodeGen/SelectionDAG/InstrEmitter.cpp | 7 +-
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 6 +-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 ++--
.../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 1 +
.../WebAssembly/WebAssemblyISelLowering.cpp | 67 +++++++++++++++++
.../WebAssembly/WebAssemblyMCInstLower.cpp | 74 +++++++++++++++++++
.../test/CodeGen/WebAssembly/ref-test-func.ll | 42 +++++++++++
11 files changed, 221 insertions(+), 16 deletions(-)
create mode 100644 llvm/test/CodeGen/WebAssembly/ref-test-func.ll
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 465e4a0a9d0d8..a9d4e5ade0ba8 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -173,6 +173,7 @@ enum NodeType {
/// materialized in registers.
TargetConstant,
TargetConstantFP,
+ TargetConstantAP,
/// TargetGlobalAddress - Like GlobalAddress, but the DAG does no folding or
/// anything else with this node, and this is valid in the target-specific
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 7d8a0c4ce8e45..0a7e3ec24be23 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -683,7 +683,8 @@ class SelectionDAG {
LLVM_ABI SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
bool isTarget = false, bool isOpaque = false);
LLVM_ABI SDValue getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
- bool isTarget = false, bool isOpaque = false);
+ bool isTarget = false, bool isOpaque = false,
+ bool isArbitraryPrecision = false);
LLVM_ABI SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT,
bool isTarget = false,
@@ -694,7 +695,8 @@ class SelectionDAG {
bool IsOpaque = false);
LLVM_ABI SDValue getConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
- bool isTarget = false, bool isOpaque = false);
+ bool isTarget = false, bool isOpaque = false,
+ bool isArbitraryPrecision = false);
LLVM_ABI SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL,
bool isTarget = false);
LLVM_ABI SDValue getShiftAmountConstant(uint64_t Val, EVT VT,
@@ -712,6 +714,10 @@ class SelectionDAG {
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque);
}
+ SDValue getTargetConstantAP(const APInt &Val, const SDLoc &DL, EVT VT,
+ bool isOpaque = false) {
+ return getConstant(Val, DL, VT, true, isOpaque, true);
+ }
SDValue getTargetConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque);
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 5d9937f832396..45e57c491181b 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1742,10 +1742,11 @@ class ConstantSDNode : public SDNode {
const ConstantInt *Value;
- ConstantSDNode(bool isTarget, bool isOpaque, const ConstantInt *val,
- SDVTList VTs)
- : SDNode(isTarget ? ISD::TargetConstant : ISD::Constant, 0, DebugLoc(),
- VTs),
+ ConstantSDNode(bool isTarget, bool isOpaque, bool isAPTarget,
+ const ConstantInt *val, SDVTList VTs)
+ : SDNode(isAPTarget ? ISD::TargetConstantAP
+ : (isTarget ? ISD::TargetConstant : ISD::Constant),
+ 0, DebugLoc(), VTs),
Value(val) {
assert(!isa<VectorType>(val->getType()) && "Unexpected vector type!");
ConstantSDNodeBits.IsOpaque = isOpaque;
@@ -1772,7 +1773,8 @@ class ConstantSDNode : public SDNode {
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::Constant ||
- N->getOpcode() == ISD::TargetConstant;
+ N->getOpcode() == ISD::TargetConstant ||
+ N->getOpcode() == ISD::TargetConstantAP;
}
};
diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
index f592ff287a0e3..fb61d8a11e5c0 100644
--- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
+++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;
+def int_wasm_ref_test_func
+ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
+ [IntrNoMem], "llvm.wasm.ref.test.func">;
+
//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
index 03d3e8eab35d0..95a93d0cefff9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
@@ -402,7 +402,12 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
- MIB.addImm(C->getSExtValue());
+ if (C->getOpcode() == ISD::TargetConstantAP) {
+ MIB.addCImm(
+ ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
+ } else {
+ MIB.addImm(C->getSExtValue());
+ }
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
} else if (RegisterSDNode *R = dyn_cast<RegisterSDNode>(Op)) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index f5f4d71236fee..ef5c74610f887 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -968,6 +968,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
// Allow illegal target nodes and illegal registers.
if (Node->getOpcode() == ISD::TargetConstant ||
+ Node->getOpcode() == ISD::TargetConstantAP ||
Node->getOpcode() == ISD::Register)
return;
@@ -979,10 +980,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
for (const SDValue &Op : Node->op_values())
assert((TLI.getTypeAction(*DAG.getContext(), Op.getValueType()) ==
- TargetLowering::TypeLegal ||
+ TargetLowering::TypeLegal ||
Op.getOpcode() == ISD::TargetConstant ||
+ Op->getOpcode() == ISD::TargetConstantAP ||
Op.getOpcode() == ISD::Register) &&
- "Unexpected illegal type!");
+ "Unexpected illegal type!");
#endif
// Figure out the correct action; the way to query this varies by opcode
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a8bda55fef04..2c1bd246c147a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1664,14 +1664,14 @@ SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
}
SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
- bool isT, bool isO) {
- return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO);
+ bool isT, bool isO, bool isAP) {
+ return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO, isAP);
}
SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
- EVT VT, bool isT, bool isO) {
+ EVT VT, bool isT, bool isO, bool isAP) {
assert(VT.isInteger() && "Cannot create FP integer constant!");
-
+ isT |= isAP;
EVT EltVT = VT.getScalarType();
const ConstantInt *Elt = &Val;
@@ -1760,7 +1760,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
assert(Elt->getBitWidth() == EltVT.getSizeInBits() &&
"APInt size does not match type size!");
- unsigned Opc = isT ? ISD::TargetConstant : ISD::Constant;
+ unsigned Opc = isAP ? ISD::TargetConstantAP
+ : (isT ? ISD::TargetConstant : ISD::Constant);
SDVTList VTs = getVTList(EltVT);
FoldingSetNodeID ID;
AddNodeIDNode(ID, Opc, VTs, {});
@@ -1773,7 +1774,7 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
return SDValue(N, 0);
if (!N) {
- N = newSDNode<ConstantSDNode>(isT, isO, Elt, VTs);
+ N = newSDNode<ConstantSDNode>(isT, isO, isAP, Elt, VTs);
CSEMap.InsertNode(N, IP);
InsertNode(N);
NewSDValueDbgMsg(SDValue(N, 0), "Creating constant: ", this);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index d9b9cf6bcc772..5a3b96743b0ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -3255,6 +3255,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
case ISD::HANDLENODE:
case ISD::MDNODE_SDNODE:
case ISD::TargetConstant:
+ case ISD::TargetConstantAP:
case ISD::TargetConstantFP:
case ISD::TargetConstantPool:
case ISD::TargetFrameIndex:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ec369eaeae0a5 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
+#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -794,6 +795,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB,
if (IsIndirect) {
// Placeholder for the type index.
+ // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
MIB.addImm(0);
// The table into which this call_indirect indexes.
MCSymbolWasm *Table = IsFuncrefCall
@@ -2253,6 +2255,71 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
+ case Intrinsic::wasm_ref_test_func: {
+ // First emit the TABLE_GET instruction to convert function pointer ==>
+ // funcref
+ MachineFunction &MF = DAG.getMachineFunction();
+ auto PtrVT = getPointerTy(MF.getDataLayout());
+ MCSymbol *Table =
+ WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
+ SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
+ SDValue FuncRef =
+ SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
+ MVT::funcref, TableSym, Op.getOperand(1)),
+ 0);
+
+ // Encode the signature information into the type index placeholder.
+ // This gets decoded and converted into the actual type signature in
+ // WebAssemblyMCInstLower.cpp.
+ auto NParams = Op.getNumOperands() - 2;
+ auto Sig = APInt(NParams * 64, 0);
+ // The return type has to be a BlockType since it can be void.
+ {
+ SDValue Operand = Op.getOperand(2);
+ MVT VT = Operand.getValueType().getSimpleVT();
+ WebAssembly::BlockType V;
+ if (VT == MVT::Untyped) {
+ V = WebAssembly::BlockType::Void;
+ } else if (VT == MVT::i32) {
+ V = WebAssembly::BlockType::I32;
+ } else if (VT == MVT::i64) {
+ V = WebAssembly::BlockType::I64;
+ } else if (VT == MVT::f32) {
+ V = WebAssembly::BlockType::F32;
+ } else if (VT == MVT::f64) {
+ V = WebAssembly::BlockType::F64;
+ } else {
+ llvm_unreachable("Unhandled type!");
+ }
+ Sig |= (int64_t)V;
+ }
+ for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
+ SDValue Operand = Op.getOperand(i);
+ MVT VT = Operand.getValueType().getSimpleVT();
+ wasm::ValType V;
+ if (VT == MVT::i32) {
+ V = wasm::ValType::I32;
+ } else if (VT == MVT::i64) {
+ V = wasm::ValType::I64;
+ } else if (VT == MVT::f32) {
+ V = wasm::ValType::F32;
+ } else if (VT == MVT::f64) {
+ V = wasm::ValType::F64;
+ } else {
+ llvm_unreachable("Unhandled type!");
+ }
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+
+ SmallVector<SDValue, 4> Ops;
+ Ops.push_back(DAG.getTargetConstantAP(
+ Sig, DL, EVT::getIntegerVT(*DAG.getContext(), NParams * 64)));
+ Ops.push_back(FuncRef);
+ return SDValue(
+ DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
+ 0);
+ }
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index cc36244e63ff5..f725ec344d922 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -15,13 +15,17 @@
#include "WebAssemblyMCInstLower.h"
#include "MCTargetDesc/WebAssemblyMCAsmInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
#include "TargetInfo/WebAssemblyTargetInfo.h"
#include "Utils/WebAssemblyTypeUtilities.h"
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/IR/Constants.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
@@ -196,11 +200,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
MCOp = MCOperand::createReg(WAReg);
break;
}
+ case llvm::MachineOperand::MO_CImmediate: {
+ // Lower type index placeholder for ref.test
+ // Currently this is the only way that CImmediates show up so panic if we
+ // get confused.
+ unsigned DescIndex = I - NumVariadicDefs;
+ if (DescIndex >= Desc.NumOperands) {
+ llvm_unreachable("unexpected CImmediate operand");
+ }
+ const MCOperandInfo &Info = Desc.operands()[DescIndex];
+ if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
+ llvm_unreachable("unexpected CImmediate operand");
+ }
+ auto CImm = MO.getCImm()->getValue();
+ auto NumWords = CImm.getNumWords();
+ // Extract the type data we packed into the CImm in LowerRefTestFuncRef.
+ // We need to load the words from most significant to least significant
+ // order because of the way we bitshifted them in from the right.
+ // The return type needs special handling because it could be void.
+ auto ReturnType = static_cast<WebAssembly::BlockType>(
+ CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
+ SmallVector<wasm::ValType, 2> Returns;
+ switch (ReturnType) {
+ case WebAssembly::BlockType::Invalid:
+ llvm_unreachable("Invalid return type");
+ case WebAssembly::BlockType::I32:
+ Returns = {wasm::ValType::I32};
+ break;
+ case WebAssembly::BlockType::I64:
+ Returns = {wasm::ValType::I64};
+ break;
+ case WebAssembly::BlockType::F32:
+ Returns = {wasm::ValType::F32};
+ break;
+ case WebAssembly::BlockType::F64:
+ Returns = {wasm::ValType::F64};
+ break;
+ case WebAssembly::BlockType::Void:
+ Returns = {};
+ break;
+ case WebAssembly::BlockType::Exnref:
+ Returns = {wasm::ValType::EXNREF};
+ break;
+ case WebAssembly::BlockType::Externref:
+ Returns = {wasm::ValType::EXTERNREF};
+ break;
+ case WebAssembly::BlockType::Funcref:
+ Returns = {wasm::ValType::FUNCREF};
+ break;
+ case WebAssembly::BlockType::V128:
+ Returns = {wasm::ValType::V128};
+ break;
+ case WebAssembly::BlockType::Multivalue: {
+ llvm_unreachable("Invalid return type");
+ }
+ }
+ SmallVector<wasm::ValType, 4> Params;
+
+ for (int I = NumWords - 2; I >= 0; I--) {
+ auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
+ auto ParamType = static_cast<wasm::ValType>(Val);
+ Params.push_back(ParamType);
+ }
+ MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
+ break;
+ }
case MachineOperand::MO_Immediate: {
unsigned DescIndex = I - NumVariadicDefs;
if (DescIndex < Desc.NumOperands) {
const MCOperandInfo &Info = Desc.operands()[DescIndex];
+ // Replace type index placeholder with actual type index. The type index
+ // placeholders are Immediates and have an operand type of
+ // OPERAND_TYPEINDEX or OPERAND_SIGNATURE.
if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) {
+ // Lower type index placeholder for a CALL_INDIRECT instruction
SmallVector<wasm::ValType, 4> Returns;
SmallVector<wasm::ValType, 4> Params;
@@ -228,6 +301,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
break;
}
if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) {
+ // Lower type index placeholder for blocks
auto BT = static_cast<WebAssembly::BlockType>(MO.getImm());
assert(BT != WebAssembly::BlockType::Invalid);
if (BT == WebAssembly::BlockType::Multivalue) {
diff --git a/llvm/test/CodeGen/WebAssembly/ref-test-func.ll b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll
new file mode 100644
index 0000000000000..3fc848cd167f9
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll
@@ -0,0 +1,42 @@
+; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+
+; CHECK-LABEL: test_function_pointer_signature_void:
+; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> ()
+; CHECK-NEXT: .local funcref
+; CHECK: local.get 0
+; CHECK-NEXT: table.get __indirect_function_table
+; CHECK-NEXT: local.tee 1
+; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32)
+; CHECK-NEXT: call use
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32)
+; CHECK-NEXT: call use
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32)
+; CHECK-NEXT: call use
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: ref.test (i32, i32, i32) -> ()
+; CHECK-NEXT: call use
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: ref.test () -> ()
+; CHECK-NEXT: call use
+
+; Function Attrs: nounwind
+define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 {
+entry:
+ %0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
+ tail call void @use(i32 noundef %0) #3
+ %1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
+ tail call void @use(i32 noundef %1) #3
+ %2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
+ tail call void @use(i32 noundef %2) #3
+ %3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
+ tail call void @use(i32 noundef %3) #3
+ %4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
+ tail call void @use(i32 noundef %4) #3
+ ret void
+}
+
+declare void @use(i32 noundef) local_unnamed_addr #1
>From ce05ea5e2efff83a1ab0aafc0d294909e36f4169 Mon Sep 17 00:00:00 2001
From: Hood Chatham <roberthoodchatham at gmail.com>
Date: Tue, 8 Jul 2025 12:10:58 +0200
Subject: [PATCH 2/3] Use bit width to decide whether to emit CImm or Imm
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 1 -
llvm/include/llvm/CodeGen/SelectionDAG.h | 10 ++--------
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 12 +++++-------
llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp | 6 +++---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 6 ++----
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 13 ++++++-------
llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp | 1 -
.../Target/WebAssembly/WebAssemblyISelLowering.cpp | 7 ++++---
.../Target/WebAssembly/WebAssemblyMCInstLower.cpp | 2 +-
9 files changed, 23 insertions(+), 35 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index a9d4e5ade0ba8..465e4a0a9d0d8 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -173,7 +173,6 @@ enum NodeType {
/// materialized in registers.
TargetConstant,
TargetConstantFP,
- TargetConstantAP,
/// TargetGlobalAddress - Like GlobalAddress, but the DAG does no folding or
/// anything else with this node, and this is valid in the target-specific
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 0a7e3ec24be23..7d8a0c4ce8e45 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -683,8 +683,7 @@ class SelectionDAG {
LLVM_ABI SDValue getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
bool isTarget = false, bool isOpaque = false);
LLVM_ABI SDValue getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
- bool isTarget = false, bool isOpaque = false,
- bool isArbitraryPrecision = false);
+ bool isTarget = false, bool isOpaque = false);
LLVM_ABI SDValue getSignedConstant(int64_t Val, const SDLoc &DL, EVT VT,
bool isTarget = false,
@@ -695,8 +694,7 @@ class SelectionDAG {
bool IsOpaque = false);
LLVM_ABI SDValue getConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
- bool isTarget = false, bool isOpaque = false,
- bool isArbitraryPrecision = false);
+ bool isTarget = false, bool isOpaque = false);
LLVM_ABI SDValue getIntPtrConstant(uint64_t Val, const SDLoc &DL,
bool isTarget = false);
LLVM_ABI SDValue getShiftAmountConstant(uint64_t Val, EVT VT,
@@ -714,10 +712,6 @@ class SelectionDAG {
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque);
}
- SDValue getTargetConstantAP(const APInt &Val, const SDLoc &DL, EVT VT,
- bool isOpaque = false) {
- return getConstant(Val, DL, VT, true, isOpaque, true);
- }
SDValue getTargetConstant(const ConstantInt &Val, const SDLoc &DL, EVT VT,
bool isOpaque = false) {
return getConstant(Val, DL, VT, true, isOpaque);
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 45e57c491181b..5d9937f832396 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1742,11 +1742,10 @@ class ConstantSDNode : public SDNode {
const ConstantInt *Value;
- ConstantSDNode(bool isTarget, bool isOpaque, bool isAPTarget,
- const ConstantInt *val, SDVTList VTs)
- : SDNode(isAPTarget ? ISD::TargetConstantAP
- : (isTarget ? ISD::TargetConstant : ISD::Constant),
- 0, DebugLoc(), VTs),
+ ConstantSDNode(bool isTarget, bool isOpaque, const ConstantInt *val,
+ SDVTList VTs)
+ : SDNode(isTarget ? ISD::TargetConstant : ISD::Constant, 0, DebugLoc(),
+ VTs),
Value(val) {
assert(!isa<VectorType>(val->getType()) && "Unexpected vector type!");
ConstantSDNodeBits.IsOpaque = isOpaque;
@@ -1773,8 +1772,7 @@ class ConstantSDNode : public SDNode {
static bool classof(const SDNode *N) {
return N->getOpcode() == ISD::Constant ||
- N->getOpcode() == ISD::TargetConstant ||
- N->getOpcode() == ISD::TargetConstantAP;
+ N->getOpcode() == ISD::TargetConstant;
}
};
diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
index 95a93d0cefff9..4d56803ba492a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
@@ -402,11 +402,11 @@ void InstrEmitter::AddOperand(MachineInstrBuilder &MIB, SDValue Op,
AddRegisterOperand(MIB, Op, IIOpNum, II, VRBaseMap,
IsDebug, IsClone, IsCloned);
} else if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op)) {
- if (C->getOpcode() == ISD::TargetConstantAP) {
+ if (C->getAPIntValue().getBitWidth() <= 64) {
+ MIB.addImm(C->getSExtValue());
+ } else {
MIB.addCImm(
ConstantInt::get(MF->getFunction().getContext(), C->getAPIntValue()));
- } else {
- MIB.addImm(C->getSExtValue());
}
} else if (ConstantFPSDNode *F = dyn_cast<ConstantFPSDNode>(Op)) {
MIB.addFPImm(F->getConstantFPValue());
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index ef5c74610f887..f5f4d71236fee 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -968,7 +968,6 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
// Allow illegal target nodes and illegal registers.
if (Node->getOpcode() == ISD::TargetConstant ||
- Node->getOpcode() == ISD::TargetConstantAP ||
Node->getOpcode() == ISD::Register)
return;
@@ -980,11 +979,10 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
for (const SDValue &Op : Node->op_values())
assert((TLI.getTypeAction(*DAG.getContext(), Op.getValueType()) ==
- TargetLowering::TypeLegal ||
+ TargetLowering::TypeLegal ||
Op.getOpcode() == ISD::TargetConstant ||
- Op->getOpcode() == ISD::TargetConstantAP ||
Op.getOpcode() == ISD::Register) &&
- "Unexpected illegal type!");
+ "Unexpected illegal type!");
#endif
// Figure out the correct action; the way to query this varies by opcode
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2c1bd246c147a..2a8bda55fef04 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1664,14 +1664,14 @@ SDValue SelectionDAG::getConstant(uint64_t Val, const SDLoc &DL, EVT VT,
}
SDValue SelectionDAG::getConstant(const APInt &Val, const SDLoc &DL, EVT VT,
- bool isT, bool isO, bool isAP) {
- return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO, isAP);
+ bool isT, bool isO) {
+ return getConstant(*ConstantInt::get(*Context, Val), DL, VT, isT, isO);
}
SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
- EVT VT, bool isT, bool isO, bool isAP) {
+ EVT VT, bool isT, bool isO) {
assert(VT.isInteger() && "Cannot create FP integer constant!");
- isT |= isAP;
+
EVT EltVT = VT.getScalarType();
const ConstantInt *Elt = &Val;
@@ -1760,8 +1760,7 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
assert(Elt->getBitWidth() == EltVT.getSizeInBits() &&
"APInt size does not match type size!");
- unsigned Opc = isAP ? ISD::TargetConstantAP
- : (isT ? ISD::TargetConstant : ISD::Constant);
+ unsigned Opc = isT ? ISD::TargetConstant : ISD::Constant;
SDVTList VTs = getVTList(EltVT);
FoldingSetNodeID ID;
AddNodeIDNode(ID, Opc, VTs, {});
@@ -1774,7 +1773,7 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
return SDValue(N, 0);
if (!N) {
- N = newSDNode<ConstantSDNode>(isT, isO, isAP, Elt, VTs);
+ N = newSDNode<ConstantSDNode>(isT, isO, Elt, VTs);
CSEMap.InsertNode(N, IP);
InsertNode(N);
NewSDValueDbgMsg(SDValue(N, 0), "Creating constant: ", this);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 5a3b96743b0ef..d9b9cf6bcc772 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -3255,7 +3255,6 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch,
case ISD::HANDLENODE:
case ISD::MDNODE_SDNODE:
case ISD::TargetConstant:
- case ISD::TargetConstantAP:
case ISD::TargetConstantFP:
case ISD::TargetConstantPool:
case ISD::TargetFrameIndex:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index ec369eaeae0a5..2359d6b04aa97 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -2272,7 +2272,8 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
// This gets decoded and converted into the actual type signature in
// WebAssemblyMCInstLower.cpp.
auto NParams = Op.getNumOperands() - 2;
- auto Sig = APInt(NParams * 64, 0);
+ auto BitWidth = (NParams + 1) * 64;
+ auto Sig = APInt(BitWidth, 0);
// The return type has to be a BlockType since it can be void.
{
SDValue Operand = Op.getOperand(2);
@@ -2313,8 +2314,8 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
}
SmallVector<SDValue, 4> Ops;
- Ops.push_back(DAG.getTargetConstantAP(
- Sig, DL, EVT::getIntegerVT(*DAG.getContext(), NParams * 64)));
+ Ops.push_back(DAG.getTargetConstant(
+ Sig, DL, EVT::getIntegerVT(*DAG.getContext(), BitWidth)));
Ops.push_back(FuncRef);
return SDValue(
DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index f725ec344d922..4e224c0766146 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -213,7 +213,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
llvm_unreachable("unexpected CImmediate operand");
}
auto CImm = MO.getCImm()->getValue();
- auto NumWords = CImm.getNumWords();
+ auto NumWords = CImm.getNumWords() - 1;
// Extract the type data we packed into the CImm in LowerRefTestFuncRef.
// We need to load the words from most significant to least significant
// order because of the way we bitshifted them in from the right.
>From f049ff60419c909200e033b5eeb4f62bd1a5c4dc Mon Sep 17 00:00:00 2001
From: Hood Chatham <roberthoodchatham at gmail.com>
Date: Tue, 8 Jul 2025 16:09:32 +0200
Subject: [PATCH 3/3] Try moving to WebAssemblyISelDAGToDAG
---
llvm/include/llvm/MC/MCSymbolWasm.h | 1 +
.../WebAssembly/WebAssemblyISelDAGToDAG.cpp | 38 ++++++++++
.../WebAssembly/WebAssemblyISelLowering.cpp | 66 ----------------
.../WebAssembly/WebAssemblyMCInstLower.cpp | 76 ++++++-------------
.../WebAssembly/WebAssemblyMCInstLower.h | 1 +
.../WebAssembly/WebAssemblyUtilities.cpp | 37 +++++++++
.../Target/WebAssembly/WebAssemblyUtilities.h | 7 ++
7 files changed, 109 insertions(+), 117 deletions(-)
diff --git a/llvm/include/llvm/MC/MCSymbolWasm.h b/llvm/include/llvm/MC/MCSymbolWasm.h
index beb6b975a4cc3..b58d6138cd045 100644
--- a/llvm/include/llvm/MC/MCSymbolWasm.h
+++ b/llvm/include/llvm/MC/MCSymbolWasm.h
@@ -12,6 +12,7 @@
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/MCSymbolTableEntry.h"
+
namespace llvm {
class MCSymbolWasm : public MCSymbol {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
index ac819cf5c1801..c54787f1f71f7 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelDAGToDAG.cpp
@@ -15,12 +15,14 @@
#include "WebAssembly.h"
#include "WebAssemblyISelLowering.h"
#include "WebAssemblyTargetMachine.h"
+#include "WebAssemblyUtilities.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/SelectionDAGISel.h"
#include "llvm/CodeGen/WasmEHFuncInfo.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h" // To access function attributes.
#include "llvm/IR/IntrinsicsWebAssembly.h"
+#include "llvm/MC/MCSymbolWasm.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/raw_ostream.h"
@@ -189,6 +191,42 @@ void WebAssemblyDAGToDAGISel::Select(SDNode *Node) {
ReplaceNode(Node, TLSAlign);
return;
}
+ case Intrinsic::wasm_ref_test_func: {
+ // First emit the TABLE_GET instruction to convert function pointer ==>
+ // funcref
+ MachineFunction &MF = CurDAG->getMachineFunction();
+ auto PtrVT = MVT::getIntegerVT(MF.getDataLayout().getPointerSizeInBits());
+ MCSymbol *Table = WebAssembly::getOrCreateFunctionTableSymbol(
+ MF.getContext(), Subtarget);
+ SDValue TableSym = CurDAG->getMCSymbol(Table, PtrVT);
+ SDValue FuncRef = SDValue(
+ CurDAG->getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
+ MVT::funcref, TableSym, Node->getOperand(1)),
+ 0);
+
+ // Encode the signature information into the type index placeholder.
+ // This gets decoded and converted into the actual type signature in
+ // WebAssemblyMCInstLower.cpp.
+ SmallVector<MVT, 4> Params;
+ SmallVector<MVT, 1> Results;
+
+ MVT VT = Node->getOperand(2).getValueType().getSimpleVT();
+ if (VT != MVT::Untyped) {
+ Params.push_back(VT);
+ }
+ for (unsigned I = 3; I < Node->getNumOperands(); ++I) {
+ MVT VT = Node->getOperand(I).getValueType().getSimpleVT();
+ Results.push_back(VT);
+ }
+ auto Sig = WebAssembly::encodeFunctionSignature(Params, Results);
+
+ SmallVector<SDValue, 4> Ops;
+ auto SigOp = CurDAG->getTargetConstant(
+ Sig, DL, EVT::getIntegerVT(*CurDAG->getContext(), Sig.getBitWidth()));
+ MachineSDNode *RefTestNode = CurDAG->getMachineNode(
+ WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, {SigOp, FuncRef});
+ ReplaceNode(Node, RefTestNode);
+ }
}
break;
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 2359d6b04aa97..d664bb0d5b17a 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -2255,72 +2255,6 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
- case Intrinsic::wasm_ref_test_func: {
- // First emit the TABLE_GET instruction to convert function pointer ==>
- // funcref
- MachineFunction &MF = DAG.getMachineFunction();
- auto PtrVT = getPointerTy(MF.getDataLayout());
- MCSymbol *Table =
- WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
- SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
- SDValue FuncRef =
- SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
- MVT::funcref, TableSym, Op.getOperand(1)),
- 0);
-
- // Encode the signature information into the type index placeholder.
- // This gets decoded and converted into the actual type signature in
- // WebAssemblyMCInstLower.cpp.
- auto NParams = Op.getNumOperands() - 2;
- auto BitWidth = (NParams + 1) * 64;
- auto Sig = APInt(BitWidth, 0);
- // The return type has to be a BlockType since it can be void.
- {
- SDValue Operand = Op.getOperand(2);
- MVT VT = Operand.getValueType().getSimpleVT();
- WebAssembly::BlockType V;
- if (VT == MVT::Untyped) {
- V = WebAssembly::BlockType::Void;
- } else if (VT == MVT::i32) {
- V = WebAssembly::BlockType::I32;
- } else if (VT == MVT::i64) {
- V = WebAssembly::BlockType::I64;
- } else if (VT == MVT::f32) {
- V = WebAssembly::BlockType::F32;
- } else if (VT == MVT::f64) {
- V = WebAssembly::BlockType::F64;
- } else {
- llvm_unreachable("Unhandled type!");
- }
- Sig |= (int64_t)V;
- }
- for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
- SDValue Operand = Op.getOperand(i);
- MVT VT = Operand.getValueType().getSimpleVT();
- wasm::ValType V;
- if (VT == MVT::i32) {
- V = wasm::ValType::I32;
- } else if (VT == MVT::i64) {
- V = wasm::ValType::I64;
- } else if (VT == MVT::f32) {
- V = wasm::ValType::F32;
- } else if (VT == MVT::f64) {
- V = wasm::ValType::F64;
- } else {
- llvm_unreachable("Unhandled type!");
- }
- Sig <<= 64;
- Sig |= (int64_t)V;
- }
-
- SmallVector<SDValue, 4> Ops;
- Ops.push_back(DAG.getTargetConstant(
- Sig, DL, EVT::getIntegerVT(*DAG.getContext(), BitWidth)));
- Ops.push_back(FuncRef);
- return SDValue(
- DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF, DL, MVT::i32, Ops),
- 0);
- }
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index 4e224c0766146..6ca046b22f503 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -21,6 +21,7 @@
#include "WebAssemblyAsmPrinter.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblyUtilities.h"
+#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/AsmPrinter.h"
@@ -156,6 +157,29 @@ MCOperand WebAssemblyMCInstLower::lowerTypeIndexOperand(
return MCOperand::createExpr(Expr);
}
+MCOperand
+WebAssemblyMCInstLower::lowerEncodedFunctionSignature(const APInt &Sig) const {
+ auto NumWords = Sig.getNumWords();
+ SmallVector<wasm::ValType, 4> Params;
+ SmallVector<wasm::ValType, 2> Returns;
+
+ int Idx = NumWords;
+
+ auto GetWord = [&Idx, &Sig]() {
+ Idx--;
+ return Sig.extractBitsAsZExtValue(64, 64 * Idx);
+ };
+ int NParams = GetWord();
+ for (int I = 0; I < NParams; I++) {
+ Params.push_back(static_cast<wasm::ValType>(GetWord()));
+ }
+ int NReturns = GetWord();
+ for (int I = 0; I < NReturns; I++) {
+ Returns.push_back(static_cast<wasm::ValType>(GetWord()));
+ }
+ return lowerTypeIndexOperand(std::move(Params), std::move(Returns));
+}
+
static void getFunctionReturns(const MachineInstr *MI,
SmallVectorImpl<wasm::ValType> &Returns) {
const Function &F = MI->getMF()->getFunction();
@@ -212,57 +236,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
llvm_unreachable("unexpected CImmediate operand");
}
- auto CImm = MO.getCImm()->getValue();
- auto NumWords = CImm.getNumWords() - 1;
- // Extract the type data we packed into the CImm in LowerRefTestFuncRef.
- // We need to load the words from most significant to least significant
- // order because of the way we bitshifted them in from the right.
- // The return type needs special handling because it could be void.
- auto ReturnType = static_cast<WebAssembly::BlockType>(
- CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
- SmallVector<wasm::ValType, 2> Returns;
- switch (ReturnType) {
- case WebAssembly::BlockType::Invalid:
- llvm_unreachable("Invalid return type");
- case WebAssembly::BlockType::I32:
- Returns = {wasm::ValType::I32};
- break;
- case WebAssembly::BlockType::I64:
- Returns = {wasm::ValType::I64};
- break;
- case WebAssembly::BlockType::F32:
- Returns = {wasm::ValType::F32};
- break;
- case WebAssembly::BlockType::F64:
- Returns = {wasm::ValType::F64};
- break;
- case WebAssembly::BlockType::Void:
- Returns = {};
- break;
- case WebAssembly::BlockType::Exnref:
- Returns = {wasm::ValType::EXNREF};
- break;
- case WebAssembly::BlockType::Externref:
- Returns = {wasm::ValType::EXTERNREF};
- break;
- case WebAssembly::BlockType::Funcref:
- Returns = {wasm::ValType::FUNCREF};
- break;
- case WebAssembly::BlockType::V128:
- Returns = {wasm::ValType::V128};
- break;
- case WebAssembly::BlockType::Multivalue: {
- llvm_unreachable("Invalid return type");
- }
- }
- SmallVector<wasm::ValType, 4> Params;
-
- for (int I = NumWords - 2; I >= 0; I--) {
- auto Val = CImm.extractBitsAsZExtValue(64, 64 * I);
- auto ParamType = static_cast<wasm::ValType>(Val);
- Params.push_back(ParamType);
- }
- MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params));
+ MCOp = lowerEncodedFunctionSignature(MO.getCImm()->getValue());
break;
}
case MachineOperand::MO_Immediate: {
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h
index 9f08499e5cde1..34404d93434bb 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.h
@@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyMCInstLower {
MCOperand lowerSymbolOperand(const MachineOperand &MO, MCSymbol *Sym) const;
MCOperand lowerTypeIndexOperand(SmallVectorImpl<wasm::ValType> &&,
SmallVectorImpl<wasm::ValType> &&) const;
+ MCOperand lowerEncodedFunctionSignature(const APInt &Sig) const;
public:
WebAssemblyMCInstLower(MCContext &ctx, WebAssemblyAsmPrinter &printer)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.cpp
index 747ef18df8d65..991c80df57bf3 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.cpp
@@ -195,3 +195,40 @@ bool WebAssembly::canLowerReturn(size_t ResultSize,
const WebAssemblySubtarget *Subtarget) {
return ResultSize <= 1 || canLowerMultivalueReturn(Subtarget);
}
+
+APInt WebAssembly::encodeFunctionSignature(SmallVector<MVT, 4> &Params,
+ SmallVector<MVT, 1> &Returns) {
+ auto toWasmValType = [](MVT VT) {
+ if (VT == MVT::i32) {
+ return wasm::ValType::I32;
+ }
+ if (VT == MVT::i64) {
+ return wasm::ValType::I64;
+ }
+ if (VT == MVT::f32) {
+ return wasm::ValType::F32;
+ }
+ if (VT == MVT::f64) {
+ return wasm::ValType::F64;
+ }
+ llvm_unreachable("Unhandled type!");
+ };
+ auto NParams = Params.size();
+ auto NReturns = Params.size();
+ auto BitWidth = (NParams + NReturns + 2) * 64;
+ auto Sig = APInt(BitWidth, 0);
+
+ Sig |= NParams;
+ for (auto &Param : Params) {
+ auto V = toWasmValType(Param);
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+ Sig |= NReturns;
+ for (auto &Return : Returns) {
+ auto V = toWasmValType(Return);
+ Sig <<= 64;
+ Sig |= (int64_t)V;
+ }
+ return Sig;
+}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h b/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h
index 046b1b5db2a79..d1e696ff59a69 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyUtilities.h
@@ -15,6 +15,10 @@
#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_UTILS_WEBASSEMBLYUTILITIES_H
#define LLVM_LIB_TARGET_WEBASSEMBLY_UTILS_WEBASSEMBLYUTILITIES_H
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/Wasm.h"
+#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/CommandLine.h"
namespace llvm {
@@ -73,6 +77,9 @@ bool canLowerMultivalueReturn(const WebAssemblySubtarget *Subtarget);
/// memory.
bool canLowerReturn(size_t ResultSize, const WebAssemblySubtarget *Subtarget);
+APInt encodeFunctionSignature(SmallVector<MVT, 4> &Params,
+ SmallVector<MVT, 1> &Returns);
+
} // end namespace WebAssembly
} // end namespace llvm
More information about the llvm-commits
mailing list