[llvm] [WebAssembly] Implement GlobalISel (PR #157161)
Demetrius Kanios via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 29 11:25:05 PDT 2025
https://github.com/QuantumSegfault updated https://github.com/llvm/llvm-project/pull/157161
>From 44131ba8a7344c1011e4fa5b0cffa3923119bca5 Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:40:31 -0700
Subject: [PATCH 1/9] Prepare basic GlobalISel setup and implement
CallLowering::lowerFormalArguments and CallLowering::lowerReturn
---
llvm/lib/Target/WebAssembly/CMakeLists.txt | 4 +
.../GISel/WebAssemblyCallLowering.cpp | 687 ++++++++++++++++++
.../GISel/WebAssemblyCallLowering.h | 43 ++
.../GISel/WebAssemblyInstructionSelector.cpp | 0
.../GISel/WebAssemblyInstructionSelector.h | 0
.../GISel/WebAssemblyLegalizerInfo.cpp | 23 +
.../GISel/WebAssemblyLegalizerInfo.h | 29 +
.../GISel/WebAssemblyRegisterBankInfo.cpp | 0
.../GISel/WebAssemblyRegisterBankInfo.h | 0
.../WebAssembly/WebAssemblySubtarget.cpp | 30 +-
.../Target/WebAssembly/WebAssemblySubtarget.h | 14 +
.../WebAssembly/WebAssemblyTargetMachine.cpp | 30 +
12 files changed, 859 insertions(+), 1 deletion(-)
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h
diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt
index 1e83cbeac50d6..371d224efc1c5 100644
--- a/llvm/lib/Target/WebAssembly/CMakeLists.txt
+++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt
@@ -15,6 +15,10 @@ tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget)
add_public_tablegen_target(WebAssemblyCommonTableGen)
add_llvm_target(WebAssemblyCodeGen
+ GISel/WebAssemblyCallLowering.cpp
+ GISel/WebAssemblyInstructionSelector.cpp
+ GISel/WebAssemblyRegisterBankInfo.cpp
+ GISel/WebAssemblyLegalizerInfo.cpp
WebAssemblyAddMissingPrototypes.cpp
WebAssemblyArgumentMove.cpp
WebAssemblyAsmPrinter.cpp
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
new file mode 100644
index 0000000000000..5949d26a83840
--- /dev/null
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -0,0 +1,687 @@
+//===-- WebAssemblyCallLowering.cpp - Call lowering for GlobalISel -*- C++ -*-//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the lowering of LLVM calls to machine code calls for
+/// GlobalISel.
+///
+//===----------------------------------------------------------------------===//
+
+#include "WebAssemblyCallLowering.h"
+#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "WebAssemblyISelLowering.h"
+#include "WebAssemblyMachineFunctionInfo.h"
+#include "WebAssemblySubtarget.h"
+#include "WebAssemblyUtilities.h"
+#include "llvm/CodeGen/Analysis.h"
+#include "llvm/CodeGen/FunctionLoweringInfo.h"
+#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/CodeGen/LowLevelTypeUtils.h"
+#include "llvm/CodeGenTypes/LowLevelType.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugLoc.h"
+
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "wasm-call-lowering"
+
+using namespace llvm;
+
+// Several of the following methods are internal utilities defined in
+// CodeGen/GlobalIsel/CallLowering.cpp
+// TODO: Find a better solution?
+
+// Internal utility from CallLowering.cpp
+static unsigned extendOpFromFlags(ISD::ArgFlagsTy Flags) {
+ if (Flags.isSExt())
+ return TargetOpcode::G_SEXT;
+ if (Flags.isZExt())
+ return TargetOpcode::G_ZEXT;
+ return TargetOpcode::G_ANYEXT;
+}
+
+// Internal utility from CallLowering.cpp
+/// Pack values \p SrcRegs to cover the vector type result \p DstRegs.
+static MachineInstrBuilder
+mergeVectorRegsToResultRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
+ ArrayRef<Register> SrcRegs) {
+ MachineRegisterInfo &MRI = *B.getMRI();
+ LLT LLTy = MRI.getType(DstRegs[0]);
+ LLT PartLLT = MRI.getType(SrcRegs[0]);
+
+ // Deal with v3s16 split into v2s16
+ LLT LCMTy = getCoverTy(LLTy, PartLLT);
+ if (LCMTy == LLTy) {
+ // Common case where no padding is needed.
+ assert(DstRegs.size() == 1);
+ return B.buildConcatVectors(DstRegs[0], SrcRegs);
+ }
+
+ // We need to create an unmerge to the result registers, which may require
+ // widening the original value.
+ Register UnmergeSrcReg;
+ if (LCMTy != PartLLT) {
+ assert(DstRegs.size() == 1);
+ return B.buildDeleteTrailingVectorElements(
+ DstRegs[0], B.buildMergeLikeInstr(LCMTy, SrcRegs));
+ } else {
+ // We don't need to widen anything if we're extracting a scalar which was
+ // promoted to a vector e.g. s8 -> v4s8 -> s8
+ assert(SrcRegs.size() == 1);
+ UnmergeSrcReg = SrcRegs[0];
+ }
+
+ int NumDst = LCMTy.getSizeInBits() / LLTy.getSizeInBits();
+
+ SmallVector<Register, 8> PadDstRegs(NumDst);
+ llvm::copy(DstRegs, PadDstRegs.begin());
+
+ // Create the excess dead defs for the unmerge.
+ for (int I = DstRegs.size(); I != NumDst; ++I)
+ PadDstRegs[I] = MRI.createGenericVirtualRegister(LLTy);
+
+ if (PadDstRegs.size() == 1)
+ return B.buildDeleteTrailingVectorElements(DstRegs[0], UnmergeSrcReg);
+ return B.buildUnmerge(PadDstRegs, UnmergeSrcReg);
+}
+
+// Internal utility from CallLowering.cpp
+/// Create a sequence of instructions to combine pieces split into register
+/// typed values to the original IR value. \p OrigRegs contains the destination
+/// value registers of type \p LLTy, and \p Regs contains the legalized pieces
+/// with type \p PartLLT. This is used for incoming values (physregs to vregs).
+static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
+ ArrayRef<Register> Regs, LLT LLTy, LLT PartLLT,
+ const ISD::ArgFlagsTy Flags) {
+ MachineRegisterInfo &MRI = *B.getMRI();
+
+ if (PartLLT == LLTy) {
+ // We should have avoided introducing a new virtual register, and just
+ // directly assigned here.
+ assert(OrigRegs[0] == Regs[0]);
+ return;
+ }
+
+ if (PartLLT.getSizeInBits() == LLTy.getSizeInBits() && OrigRegs.size() == 1 &&
+ Regs.size() == 1) {
+ B.buildBitcast(OrigRegs[0], Regs[0]);
+ return;
+ }
+
+ // A vector PartLLT needs extending to LLTy's element size.
+ // E.g. <2 x s64> = G_SEXT <2 x s32>.
+ if (PartLLT.isVector() == LLTy.isVector() &&
+ PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() &&
+ (!PartLLT.isVector() ||
+ PartLLT.getElementCount() == LLTy.getElementCount()) &&
+ OrigRegs.size() == 1 && Regs.size() == 1) {
+ Register SrcReg = Regs[0];
+
+ LLT LocTy = MRI.getType(SrcReg);
+
+ if (Flags.isSExt()) {
+ SrcReg = B.buildAssertSExt(LocTy, SrcReg, LLTy.getScalarSizeInBits())
+ .getReg(0);
+ } else if (Flags.isZExt()) {
+ SrcReg = B.buildAssertZExt(LocTy, SrcReg, LLTy.getScalarSizeInBits())
+ .getReg(0);
+ }
+
+ // Sometimes pointers are passed zero extended.
+ LLT OrigTy = MRI.getType(OrigRegs[0]);
+ if (OrigTy.isPointer()) {
+ LLT IntPtrTy = LLT::scalar(OrigTy.getSizeInBits());
+ B.buildIntToPtr(OrigRegs[0], B.buildTrunc(IntPtrTy, SrcReg));
+ return;
+ }
+
+ B.buildTrunc(OrigRegs[0], SrcReg);
+ return;
+ }
+
+ if (!LLTy.isVector() && !PartLLT.isVector()) {
+ assert(OrigRegs.size() == 1);
+ LLT OrigTy = MRI.getType(OrigRegs[0]);
+
+ unsigned SrcSize = PartLLT.getSizeInBits().getFixedValue() * Regs.size();
+ if (SrcSize == OrigTy.getSizeInBits())
+ B.buildMergeValues(OrigRegs[0], Regs);
+ else {
+ auto Widened = B.buildMergeLikeInstr(LLT::scalar(SrcSize), Regs);
+ B.buildTrunc(OrigRegs[0], Widened);
+ }
+
+ return;
+ }
+
+ if (PartLLT.isVector()) {
+ assert(OrigRegs.size() == 1);
+ SmallVector<Register> CastRegs(Regs);
+
+ // If PartLLT is a mismatched vector in both number of elements and element
+ // size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to
+ // have the same elt type, i.e. v4s32.
+ // TODO: Extend this coersion to element multiples other than just 2.
+ if (TypeSize::isKnownGT(PartLLT.getSizeInBits(), LLTy.getSizeInBits()) &&
+ PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 &&
+ Regs.size() == 1) {
+ LLT NewTy = PartLLT.changeElementType(LLTy.getElementType())
+ .changeElementCount(PartLLT.getElementCount() * 2);
+ CastRegs[0] = B.buildBitcast(NewTy, Regs[0]).getReg(0);
+ PartLLT = NewTy;
+ }
+
+ if (LLTy.getScalarType() == PartLLT.getElementType()) {
+ mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs);
+ } else {
+ unsigned I = 0;
+ LLT GCDTy = getGCDType(LLTy, PartLLT);
+
+ // We are both splitting a vector, and bitcasting its element types. Cast
+ // the source pieces into the appropriate number of pieces with the result
+ // element type.
+ for (Register SrcReg : CastRegs)
+ CastRegs[I++] = B.buildBitcast(GCDTy, SrcReg).getReg(0);
+ mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs);
+ }
+
+ return;
+ }
+
+ assert(LLTy.isVector() && !PartLLT.isVector());
+
+ LLT DstEltTy = LLTy.getElementType();
+
+ // Pointer information was discarded. We'll need to coerce some register types
+ // to avoid violating type constraints.
+ LLT RealDstEltTy = MRI.getType(OrigRegs[0]).getElementType();
+
+ assert(DstEltTy.getSizeInBits() == RealDstEltTy.getSizeInBits());
+
+ if (DstEltTy == PartLLT) {
+ // Vector was trivially scalarized.
+
+ if (RealDstEltTy.isPointer()) {
+ for (Register Reg : Regs)
+ MRI.setType(Reg, RealDstEltTy);
+ }
+
+ B.buildBuildVector(OrigRegs[0], Regs);
+ } else if (DstEltTy.getSizeInBits() > PartLLT.getSizeInBits()) {
+ // Deal with vector with 64-bit elements decomposed to 32-bit
+ // registers. Need to create intermediate 64-bit elements.
+ SmallVector<Register, 8> EltMerges;
+ int PartsPerElt =
+ divideCeil(DstEltTy.getSizeInBits(), PartLLT.getSizeInBits());
+ LLT ExtendedPartTy = LLT::scalar(PartLLT.getSizeInBits() * PartsPerElt);
+
+ for (int I = 0, NumElts = LLTy.getNumElements(); I != NumElts; ++I) {
+ auto Merge =
+ B.buildMergeLikeInstr(ExtendedPartTy, Regs.take_front(PartsPerElt));
+ if (ExtendedPartTy.getSizeInBits() > RealDstEltTy.getSizeInBits())
+ Merge = B.buildTrunc(RealDstEltTy, Merge);
+ // Fix the type in case this is really a vector of pointers.
+ MRI.setType(Merge.getReg(0), RealDstEltTy);
+ EltMerges.push_back(Merge.getReg(0));
+ Regs = Regs.drop_front(PartsPerElt);
+ }
+
+ B.buildBuildVector(OrigRegs[0], EltMerges);
+ } else {
+ // Vector was split, and elements promoted to a wider type.
+ // FIXME: Should handle floating point promotions.
+ unsigned NumElts = LLTy.getNumElements();
+ LLT BVType = LLT::fixed_vector(NumElts, PartLLT);
+
+ Register BuildVec;
+ if (NumElts == Regs.size())
+ BuildVec = B.buildBuildVector(BVType, Regs).getReg(0);
+ else {
+ // Vector elements are packed in the inputs.
+ // e.g. we have a <4 x s16> but 2 x s32 in regs.
+ assert(NumElts > Regs.size());
+ LLT SrcEltTy = MRI.getType(Regs[0]);
+
+ LLT OriginalEltTy = MRI.getType(OrigRegs[0]).getElementType();
+
+ // Input registers contain packed elements.
+ // Determine how many elements per reg.
+ assert((SrcEltTy.getSizeInBits() % OriginalEltTy.getSizeInBits()) == 0);
+ unsigned EltPerReg =
+ (SrcEltTy.getSizeInBits() / OriginalEltTy.getSizeInBits());
+
+ SmallVector<Register, 0> BVRegs;
+ BVRegs.reserve(Regs.size() * EltPerReg);
+ for (Register R : Regs) {
+ auto Unmerge = B.buildUnmerge(OriginalEltTy, R);
+ for (unsigned K = 0; K < EltPerReg; ++K)
+ BVRegs.push_back(B.buildAnyExt(PartLLT, Unmerge.getReg(K)).getReg(0));
+ }
+
+ // We may have some more elements in BVRegs, e.g. if we have 2 s32 pieces
+ // for a <3 x s16> vector. We should have less than EltPerReg extra items.
+ if (BVRegs.size() > NumElts) {
+ assert((BVRegs.size() - NumElts) < EltPerReg);
+ BVRegs.truncate(NumElts);
+ }
+ BuildVec = B.buildBuildVector(BVType, BVRegs).getReg(0);
+ }
+ B.buildTrunc(OrigRegs[0], BuildVec);
+ }
+}
+
+// Internal utility from CallLowering.cpp
+/// Create a sequence of instructions to expand the value in \p SrcReg (of type
+/// \p SrcTy) to the types in \p DstRegs (of type \p PartTy). \p ExtendOp should
+/// contain the type of scalar value extension if necessary.
+///
+/// This is used for outgoing values (vregs to physregs)
+static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
+ Register SrcReg, LLT SrcTy, LLT PartTy,
+ unsigned ExtendOp = TargetOpcode::G_ANYEXT) {
+ // We could just insert a regular copy, but this is unreachable at the moment.
+ assert(SrcTy != PartTy && "identical part types shouldn't reach here");
+
+ const TypeSize PartSize = PartTy.getSizeInBits();
+
+ if (PartTy.isVector() == SrcTy.isVector() &&
+ PartTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits()) {
+ assert(DstRegs.size() == 1);
+ B.buildInstr(ExtendOp, {DstRegs[0]}, {SrcReg});
+ return;
+ }
+
+ if (SrcTy.isVector() && !PartTy.isVector() &&
+ TypeSize::isKnownGT(PartSize, SrcTy.getElementType().getSizeInBits())) {
+ // Vector was scalarized, and the elements extended.
+ auto UnmergeToEltTy = B.buildUnmerge(SrcTy.getElementType(), SrcReg);
+ for (int i = 0, e = DstRegs.size(); i != e; ++i)
+ B.buildAnyExt(DstRegs[i], UnmergeToEltTy.getReg(i));
+ return;
+ }
+
+ if (SrcTy.isVector() && PartTy.isVector() &&
+ PartTy.getSizeInBits() == SrcTy.getSizeInBits() &&
+ ElementCount::isKnownLT(SrcTy.getElementCount(),
+ PartTy.getElementCount())) {
+ // A coercion like: v2f32 -> v4f32 or nxv2f32 -> nxv4f32
+ Register DstReg = DstRegs.front();
+ B.buildPadVectorWithUndefElements(DstReg, SrcReg);
+ return;
+ }
+
+ LLT GCDTy = getGCDType(SrcTy, PartTy);
+ if (GCDTy == PartTy) {
+ // If this already evenly divisible, we can create a simple unmerge.
+ B.buildUnmerge(DstRegs, SrcReg);
+ return;
+ }
+
+ if (SrcTy.isVector() && !PartTy.isVector() &&
+ SrcTy.getScalarSizeInBits() > PartTy.getSizeInBits()) {
+ LLT ExtTy =
+ LLT::vector(SrcTy.getElementCount(),
+ LLT::scalar(PartTy.getScalarSizeInBits() * DstRegs.size() /
+ SrcTy.getNumElements()));
+ auto Ext = B.buildAnyExt(ExtTy, SrcReg);
+ B.buildUnmerge(DstRegs, Ext);
+ return;
+ }
+
+ MachineRegisterInfo &MRI = *B.getMRI();
+ LLT DstTy = MRI.getType(DstRegs[0]);
+ LLT LCMTy = getCoverTy(SrcTy, PartTy);
+
+ if (PartTy.isVector() && LCMTy == PartTy) {
+ assert(DstRegs.size() == 1);
+ B.buildPadVectorWithUndefElements(DstRegs[0], SrcReg);
+ return;
+ }
+
+ const unsigned DstSize = DstTy.getSizeInBits();
+ const unsigned SrcSize = SrcTy.getSizeInBits();
+ unsigned CoveringSize = LCMTy.getSizeInBits();
+
+ Register UnmergeSrc = SrcReg;
+
+ if (!LCMTy.isVector() && CoveringSize != SrcSize) {
+ // For scalars, it's common to be able to use a simple extension.
+ if (SrcTy.isScalar() && DstTy.isScalar()) {
+ CoveringSize = alignTo(SrcSize, DstSize);
+ LLT CoverTy = LLT::scalar(CoveringSize);
+ UnmergeSrc = B.buildInstr(ExtendOp, {CoverTy}, {SrcReg}).getReg(0);
+ } else {
+ // Widen to the common type.
+ // FIXME: This should respect the extend type
+ Register Undef = B.buildUndef(SrcTy).getReg(0);
+ SmallVector<Register, 8> MergeParts(1, SrcReg);
+ for (unsigned Size = SrcSize; Size != CoveringSize; Size += SrcSize)
+ MergeParts.push_back(Undef);
+ UnmergeSrc = B.buildMergeLikeInstr(LCMTy, MergeParts).getReg(0);
+ }
+ }
+
+ if (LCMTy.isVector() && CoveringSize != SrcSize)
+ UnmergeSrc = B.buildPadVectorWithUndefElements(LCMTy, SrcReg).getReg(0);
+
+ B.buildUnmerge(DstRegs, UnmergeSrc);
+}
+
+// Test whether the given calling convention is supported.
+static bool callingConvSupported(CallingConv::ID CallConv) {
+ // We currently support the language-independent target-independent
+ // conventions. We don't yet have a way to annotate calls with properties like
+ // "cold", and we don't have any call-clobbered registers, so these are mostly
+ // all handled the same.
+ return CallConv == CallingConv::C || CallConv == CallingConv::Fast ||
+ CallConv == CallingConv::Cold ||
+ CallConv == CallingConv::PreserveMost ||
+ CallConv == CallingConv::PreserveAll ||
+ CallConv == CallingConv::CXX_FAST_TLS ||
+ CallConv == CallingConv::WASM_EmscriptenInvoke ||
+ CallConv == CallingConv::Swift;
+}
+
+static void fail(MachineIRBuilder &MIRBuilder, const char *Msg) {
+ MachineFunction &MF = MIRBuilder.getMF();
+ MIRBuilder.getContext().diagnose(
+ DiagnosticInfoUnsupported(MF.getFunction(), Msg, MIRBuilder.getDL()));
+}
+
+WebAssemblyCallLowering::WebAssemblyCallLowering(
+ const WebAssemblyTargetLowering &TLI)
+ : CallLowering(&TLI) {}
+
+bool WebAssemblyCallLowering::canLowerReturn(MachineFunction &MF,
+ CallingConv::ID CallConv,
+ SmallVectorImpl<BaseArgInfo> &Outs,
+ bool IsVarArg) const {
+ return WebAssembly::canLowerReturn(Outs.size(),
+ &MF.getSubtarget<WebAssemblySubtarget>());
+}
+
+bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
+ const Value *Val,
+ ArrayRef<Register> VRegs,
+ FunctionLoweringInfo &FLI,
+ Register SwiftErrorVReg) const {
+ auto MIB = MIRBuilder.buildInstrNoInsert(WebAssembly::RETURN);
+
+ assert(((Val && !VRegs.empty()) || (!Val && VRegs.empty())) &&
+ "Return value without a vreg");
+
+ if (Val && !FLI.CanLowerReturn) {
+ insertSRetStores(MIRBuilder, Val->getType(), VRegs, FLI.DemoteRegister);
+ } else if (!VRegs.empty()) {
+ MachineFunction &MF = MIRBuilder.getMF();
+ const Function &F = MF.getFunction();
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ const WebAssemblyTargetLowering &TLI = *getTLI<WebAssemblyTargetLowering>();
+ auto &DL = F.getDataLayout();
+ LLVMContext &Ctx = Val->getType()->getContext();
+
+ SmallVector<EVT, 4> SplitEVTs;
+ ComputeValueVTs(TLI, DL, Val->getType(), SplitEVTs);
+ assert(VRegs.size() == SplitEVTs.size() &&
+ "For each split Type there should be exactly one VReg.");
+
+ SmallVector<ArgInfo, 8> SplitArgs;
+ CallingConv::ID CallConv = F.getCallingConv();
+
+ unsigned i = 0;
+ for (auto SplitEVT : SplitEVTs) {
+ Register CurVReg = VRegs[i];
+ ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0};
+ setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F);
+
+ splitToValueTypes(CurArgInfo, SplitArgs, DL, CallConv);
+ ++i;
+ }
+
+ for (auto &Arg : SplitArgs) {
+ EVT OrigVT = TLI.getValueType(DL, Arg.Ty);
+ MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ LLT OrigLLT = getLLTForType(*Arg.Ty, DL);
+ LLT NewLLT = getLLTForMVT(NewVT);
+
+ // If we need to split the type over multiple regs, check it's a scenario
+ // we currently support.
+ unsigned NumParts =
+ TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT);
+
+ ISD::ArgFlagsTy OrigFlags = Arg.Flags[0];
+ Arg.Flags.clear();
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ISD::ArgFlagsTy Flags = OrigFlags;
+ if (Part == 0) {
+ Flags.setSplit();
+ } else {
+ Flags.setOrigAlign(Align(1));
+ if (Part == NumParts - 1)
+ Flags.setSplitEnd();
+ }
+
+ Arg.Flags.push_back(Flags);
+ }
+
+ Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end());
+ if (NumParts != 1 || OrigVT != NewVT) {
+ // If we can't directly assign the register, we need one or more
+ // intermediate values.
+ Arg.Regs.resize(NumParts);
+
+ // For each split register, create and assign a vreg that will store
+ // the incoming component of the larger value. These will later be
+ // merged to form the final vreg.
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
+ }
+ buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT,
+ extendOpFromFlags(Arg.Flags[0]));
+ }
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ MIB.addUse(Arg.Regs[Part]);
+ }
+ }
+ }
+
+ if (SwiftErrorVReg) {
+ llvm_unreachable("WASM does not `supportSwiftError`, yet SwiftErrorVReg is "
+ "improperly valid.");
+ }
+
+ MIRBuilder.insertInstr(MIB);
+ return true;
+}
+
+static unsigned getWASMArgOpcode(MVT ArgType) {
+#define MVT_CASE(type) \
+ case MVT::type: \
+ return WebAssembly::ARGUMENT_##type;
+
+ switch (ArgType.SimpleTy) {
+ MVT_CASE(i32)
+ MVT_CASE(i64)
+ MVT_CASE(f32)
+ MVT_CASE(f64)
+ MVT_CASE(funcref)
+ MVT_CASE(externref)
+ MVT_CASE(exnref)
+ MVT_CASE(v16i8)
+ MVT_CASE(v8i16)
+ MVT_CASE(v4i32)
+ MVT_CASE(v2i64)
+ MVT_CASE(v4f32)
+ MVT_CASE(v2f64)
+ MVT_CASE(v8f16)
+ default:
+ break;
+ }
+ llvm_unreachable("Found unexpected type for WASM argument");
+
+#undef MVT_CASE
+}
+
+bool WebAssemblyCallLowering::lowerFormalArguments(
+ MachineIRBuilder &MIRBuilder, const Function &F,
+ ArrayRef<ArrayRef<Register>> VRegs, FunctionLoweringInfo &FLI) const {
+
+ MachineFunction &MF = MIRBuilder.getMF();
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+ WebAssemblyFunctionInfo *MFI = MF.getInfo<WebAssemblyFunctionInfo>();
+ const DataLayout &DL = F.getDataLayout();
+ auto &TLI = *getTLI<WebAssemblyTargetLowering>();
+ LLVMContext &Ctx = MIRBuilder.getContext();
+ const CallingConv::ID CallConv = F.getCallingConv();
+
+ if (!callingConvSupported(F.getCallingConv())) {
+ fail(MIRBuilder, "WebAssembly doesn't support non-C calling conventions");
+ return false;
+ }
+
+ // Set up the live-in for the incoming ARGUMENTS.
+ MF.getRegInfo().addLiveIn(WebAssembly::ARGUMENTS);
+
+ SmallVector<ArgInfo, 8> SplitArgs;
+
+ if (!FLI.CanLowerReturn) {
+ dbgs() << "grath\n";
+ insertSRetIncomingArgument(F, SplitArgs, FLI.DemoteRegister, MRI, DL);
+ }
+ unsigned i = 0;
+
+ bool HasSwiftErrorArg = false;
+ bool HasSwiftSelfArg = false;
+ for (const auto &Arg : F.args()) {
+ ArgInfo OrigArg{VRegs[i], Arg.getType(), i};
+ setArgFlags(OrigArg, i + AttributeList::FirstArgIndex, DL, F);
+
+ HasSwiftSelfArg |= Arg.hasSwiftSelfAttr();
+ HasSwiftErrorArg |= Arg.hasSwiftErrorAttr();
+ if (Arg.hasInAllocaAttr()) {
+ fail(MIRBuilder, "WebAssembly hasn't implemented inalloca arguments");
+ return false;
+ }
+ if (Arg.hasNestAttr()) {
+ fail(MIRBuilder, "WebAssembly hasn't implemented nest arguments");
+ return false;
+ }
+ splitToValueTypes(OrigArg, SplitArgs, DL, F.getCallingConv());
+ ++i;
+ }
+
+ unsigned FinalArgIdx = 0;
+ for (auto &Arg : SplitArgs) {
+ EVT OrigVT = TLI.getValueType(DL, Arg.Ty);
+ MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ LLT OrigLLT = getLLTForType(*Arg.Ty, DL);
+ LLT NewLLT = getLLTForMVT(NewVT);
+
+ // If we need to split the type over multiple regs, check it's a scenario
+ // we currently support.
+ unsigned NumParts =
+ TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT);
+
+ ISD::ArgFlagsTy OrigFlags = Arg.Flags[0];
+ Arg.Flags.clear();
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ISD::ArgFlagsTy Flags = OrigFlags;
+ if (Part == 0) {
+ Flags.setSplit();
+ } else {
+ Flags.setOrigAlign(Align(1));
+ if (Part == NumParts - 1)
+ Flags.setSplitEnd();
+ }
+
+ Arg.Flags.push_back(Flags);
+ }
+
+ Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end());
+ if (NumParts != 1 || OrigVT != NewVT) {
+ // If we can't directly assign the register, we need one or more
+ // intermediate values.
+ Arg.Regs.resize(NumParts);
+
+ // For each split register, create and assign a vreg that will store
+ // the incoming component of the larger value. These will later be
+ // merged to form the final vreg.
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
+ }
+ buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT,
+ Arg.Flags[0]);
+ }
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ MIRBuilder.buildInstr(getWASMArgOpcode(NewVT))
+ .addDef(Arg.Regs[Part])
+ .addImm(FinalArgIdx);
+ MFI->addParam(NewVT);
+ ++FinalArgIdx;
+ }
+ }
+
+ /**/
+
+ // For swiftcc, emit additional swiftself and swifterror arguments
+ // if there aren't. These additional arguments are also added for callee
+ // signature They are necessary to match callee and caller signature for
+ // indirect call.
+ auto PtrVT = TLI.getPointerTy(DL);
+ if (CallConv == CallingConv::Swift) {
+ if (!HasSwiftSelfArg) {
+ MFI->addParam(PtrVT);
+ }
+ if (!HasSwiftErrorArg) {
+ MFI->addParam(PtrVT);
+ }
+ }
+
+ // Varargs are copied into a buffer allocated by the caller, and a pointer to
+ // the buffer is passed as an argument.
+ if (F.isVarArg()) {
+ auto PtrVT = TLI.getPointerTy(DL);
+ Register VarargVreg = MF.getRegInfo().createGenericVirtualRegister(
+ getLLTForType(*PointerType::get(Ctx, 0), DL));
+ MFI->setVarargBufferVreg(VarargVreg);
+
+ MIRBuilder.buildInstr(getWASMArgOpcode(PtrVT))
+ .addDef(VarargVreg)
+ .addImm(FinalArgIdx);
+
+ MFI->addParam(PtrVT);
+ ++FinalArgIdx;
+ }
+
+ // Record the number and types of arguments and results.
+ SmallVector<MVT, 4> Params;
+ SmallVector<MVT, 4> Results;
+ computeSignatureVTs(MF.getFunction().getFunctionType(), &MF.getFunction(),
+ MF.getFunction(), MF.getTarget(), Params, Results);
+ for (MVT VT : Results)
+ MFI->addResult(VT);
+
+ // TODO: Use signatures in WebAssemblyMachineFunctionInfo too and unify
+ // the param logic here with ComputeSignatureVTs
+ assert(MFI->getParams().size() == Params.size() &&
+ std::equal(MFI->getParams().begin(), MFI->getParams().end(),
+ Params.begin()));
+ return true;
+}
+
+bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
+ CallLoweringInfo &Info) const {
+ return false;
+}
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h
new file mode 100644
index 0000000000000..d22f7cbd17eb3
--- /dev/null
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h
@@ -0,0 +1,43 @@
+//===-- WebAssemblyCallLowering.h - Call lowering for GlobalISel -*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file describes how to lower LLVM calls to machine code calls.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYCALLLOWERING_H
+#define LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYCALLLOWERING_H
+
+#include "WebAssemblyISelLowering.h"
+#include "llvm/CodeGen/GlobalISel/CallLowering.h"
+#include "llvm/IR/CallingConv.h"
+
+namespace llvm {
+
+class WebAssemblyTargetLowering;
+
+class WebAssemblyCallLowering : public CallLowering {
+public:
+ WebAssemblyCallLowering(const WebAssemblyTargetLowering &TLI);
+
+ bool canLowerReturn(MachineFunction &MF, CallingConv::ID CallConv,
+ SmallVectorImpl<BaseArgInfo> &Outs,
+ bool IsVarArg) const override;
+ bool lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val,
+ ArrayRef<Register> VRegs, FunctionLoweringInfo &FLI,
+ Register SwiftErrorVReg) const override;
+ bool lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F,
+ ArrayRef<ArrayRef<Register>> VRegs,
+ FunctionLoweringInfo &FLI) const override;
+ bool lowerCall(MachineIRBuilder &MIRBuilder,
+ CallLoweringInfo &Info) const override;
+};
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
new file mode 100644
index 0000000000000..3acdabb5612cc
--- /dev/null
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
@@ -0,0 +1,23 @@
+//===- WebAssemblyLegalizerInfo.h --------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file implements the targeting of the Machinelegalizer class for
+/// WebAssembly
+//===----------------------------------------------------------------------===//
+
+#include "WebAssemblyLegalizerInfo.h"
+
+#define DEBUG_TYPE "wasm-legalinfo"
+
+using namespace llvm;
+using namespace LegalizeActions;
+
+WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
+ const WebAssemblySubtarget &ST) {
+ getLegacyLegalizerInfo().computeTables();
+}
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h
new file mode 100644
index 0000000000000..c02205fc7ae0d
--- /dev/null
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h
@@ -0,0 +1,29 @@
+//===- WebAssemblyLegalizerInfo.h --------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file declares the targeting of the Machinelegalizer class for
+/// WebAssembly
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYMACHINELEGALIZER_H
+#define LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYMACHINELEGALIZER_H
+
+#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
+
+namespace llvm {
+
+class WebAssemblySubtarget;
+
+/// This class provides the information for the BPF target legalizer for
+/// GlobalISel.
+class WebAssemblyLegalizerInfo : public LegalizerInfo {
+public:
+ WebAssemblyLegalizerInfo(const WebAssemblySubtarget &ST);
+};
+} // namespace llvm
+#endif
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp
index a3ce40f0297ec..3ea8b9f85819f 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp
@@ -13,8 +13,12 @@
//===----------------------------------------------------------------------===//
#include "WebAssemblySubtarget.h"
+#include "GISel/WebAssemblyCallLowering.h"
+#include "GISel/WebAssemblyLegalizerInfo.h"
+#include "GISel/WebAssemblyRegisterBankInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "WebAssemblyInstrInfo.h"
+#include "WebAssemblyTargetMachine.h"
#include "llvm/MC/TargetRegistry.h"
using namespace llvm;
@@ -66,7 +70,15 @@ WebAssemblySubtarget::WebAssemblySubtarget(const Triple &TT,
const TargetMachine &TM)
: WebAssemblyGenSubtargetInfo(TT, CPU, /*TuneCPU*/ CPU, FS),
TargetTriple(TT), InstrInfo(initializeSubtargetDependencies(CPU, FS)),
- TLInfo(TM, *this) {}
+ TLInfo(TM, *this) {
+ CallLoweringInfo.reset(new WebAssemblyCallLowering(*getTargetLowering()));
+ Legalizer.reset(new WebAssemblyLegalizerInfo(*this));
+ /*auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo());
+ RegBankInfo.reset(RBI);
+
+ InstSelector.reset(createWebAssemblyInstructionSelector(
+ *static_cast<const WebAssemblyTargetMachine *>(&TM), *this, *RBI));*/
+}
bool WebAssemblySubtarget::enableAtomicExpand() const {
// If atomics are disabled, atomic ops are lowered instead of expanded
@@ -81,3 +93,19 @@ bool WebAssemblySubtarget::enableMachineScheduler() const {
}
bool WebAssemblySubtarget::useAA() const { return true; }
+
+const CallLowering *WebAssemblySubtarget::getCallLowering() const {
+ return CallLoweringInfo.get();
+}
+
+InstructionSelector *WebAssemblySubtarget::getInstructionSelector() const {
+ return InstSelector.get();
+}
+
+const LegalizerInfo *WebAssemblySubtarget::getLegalizerInfo() const {
+ return Legalizer.get();
+}
+
+const RegisterBankInfo *WebAssemblySubtarget::getRegBankInfo() const {
+ return RegBankInfo.get();
+}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
index 2f88bbba05d00..c195f995009b1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
@@ -20,6 +20,10 @@
#include "WebAssemblyISelLowering.h"
#include "WebAssemblyInstrInfo.h"
#include "WebAssemblySelectionDAGInfo.h"
+#include "llvm/CodeGen/GlobalISel/CallLowering.h"
+#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
+#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
+#include "llvm/CodeGen/RegisterBankInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include <string>
@@ -64,6 +68,11 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo {
WebAssemblySelectionDAGInfo TSInfo;
WebAssemblyTargetLowering TLInfo;
+ std::unique_ptr<CallLowering> CallLoweringInfo;
+ std::unique_ptr<InstructionSelector> InstSelector;
+ std::unique_ptr<LegalizerInfo> Legalizer;
+ std::unique_ptr<RegisterBankInfo> RegBankInfo;
+
WebAssemblySubtarget &initializeSubtargetDependencies(StringRef CPU,
StringRef FS);
@@ -118,6 +127,11 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo {
/// Parses features string setting specified subtarget options. Definition of
/// function is auto generated by tblgen.
void ParseSubtargetFeatures(StringRef CPU, StringRef TuneCPU, StringRef FS);
+
+ const CallLowering *getCallLowering() const override;
+ InstructionSelector *getInstructionSelector() const override;
+ const LegalizerInfo *getLegalizerInfo() const override;
+ const RegisterBankInfo *getRegBankInfo() const override;
};
} // end namespace llvm
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
index 6827ee6527947..84d4315ca9fc0 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp
@@ -20,6 +20,10 @@
#include "WebAssemblyTargetObjectFile.h"
#include "WebAssemblyTargetTransformInfo.h"
#include "WebAssemblyUtilities.h"
+#include "llvm/CodeGen/GlobalISel/IRTranslator.h"
+#include "llvm/CodeGen/GlobalISel/InstructionSelect.h"
+#include "llvm/CodeGen/GlobalISel/Legalizer.h"
+#include "llvm/CodeGen/GlobalISel/RegBankSelect.h"
#include "llvm/CodeGen/MIRParser/MIParser.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/RegAllocRegistry.h"
@@ -92,6 +96,7 @@ LLVMInitializeWebAssemblyTarget() {
// Register backend passes
auto &PR = *PassRegistry::getPassRegistry();
+ initializeGlobalISel(PR);
initializeWebAssemblyAddMissingPrototypesPass(PR);
initializeWebAssemblyLowerEmscriptenEHSjLjPass(PR);
initializeLowerGlobalDtorsLegacyPassPass(PR);
@@ -455,6 +460,11 @@ class WebAssemblyPassConfig final : public TargetPassConfig {
// No reg alloc
bool addRegAssignAndRewriteOptimized() override { return false; }
+
+ bool addIRTranslator() override;
+ bool addLegalizeMachineIR() override;
+ bool addRegBankSelect() override;
+ bool addGlobalInstructionSelect() override;
};
} // end anonymous namespace
@@ -675,6 +685,26 @@ bool WebAssemblyPassConfig::addPreISel() {
return false;
}
+bool WebAssemblyPassConfig::addIRTranslator() {
+ addPass(new IRTranslator());
+ return false;
+}
+
+bool WebAssemblyPassConfig::addLegalizeMachineIR() {
+ addPass(new Legalizer());
+ return false;
+}
+
+bool WebAssemblyPassConfig::addRegBankSelect() {
+ addPass(new RegBankSelect());
+ return false;
+}
+
+bool WebAssemblyPassConfig::addGlobalInstructionSelect() {
+ addPass(new InstructionSelect(getOptLevel()));
+ return false;
+}
+
yaml::MachineFunctionInfo *
WebAssemblyTargetMachine::createDefaultFuncInfoYAML() const {
return new yaml::WebAssemblyFunctionInfo();
>From ea32de61794d3d4c37a84d72166dcb66091b86f1 Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:41:02 -0700
Subject: [PATCH 2/9] Implement WebAssemblyCallLowering::lowerCall
---
.../GISel/WebAssemblyCallLowering.cpp | 453 +++++++++++++++++-
1 file changed, 451 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 5949d26a83840..8956932b403ef 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -14,14 +14,21 @@
#include "WebAssemblyCallLowering.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "Utils/WasmAddressSpaces.h"
#include "WebAssemblyISelLowering.h"
#include "WebAssemblyMachineFunctionInfo.h"
#include "WebAssemblySubtarget.h"
#include "WebAssemblyUtilities.h"
+#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/CodeGen/Analysis.h"
+#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/FunctionLoweringInfo.h"
+#include "llvm/CodeGen/GlobalISel/CallLowering.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/LowLevelTypeUtils.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGenTypes/LowLevelType.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/DataLayout.h"
@@ -29,7 +36,10 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/DiagnosticPrinter.h"
+#include "llvm/MC/MCSymbolWasm.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
#define DEBUG_TYPE "wasm-call-lowering"
@@ -555,7 +565,6 @@ bool WebAssemblyCallLowering::lowerFormalArguments(
SmallVector<ArgInfo, 8> SplitArgs;
if (!FLI.CanLowerReturn) {
- dbgs() << "grath\n";
insertSRetIncomingArgument(F, SplitArgs, FLI.DemoteRegister, MRI, DL);
}
unsigned i = 0;
@@ -683,5 +692,445 @@ bool WebAssemblyCallLowering::lowerFormalArguments(
bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallLoweringInfo &Info) const {
- return false;
+ MachineFunction &MF = MIRBuilder.getMF();
+ auto DL = MIRBuilder.getDataLayout();
+ LLVMContext &Ctx = MIRBuilder.getContext();
+ const WebAssemblyTargetLowering &TLI = *getTLI<WebAssemblyTargetLowering>();
+ MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
+ const WebAssemblySubtarget &Subtarget = MF.getSubtarget<WebAssemblySubtarget>();
+
+ CallingConv::ID CallConv = Info.CallConv;
+ if (!callingConvSupported(CallConv)) {
+ fail(MIRBuilder,
+ "WebAssembly doesn't support language-specific or target-specific "
+ "calling conventions yet");
+ return false;
+ }
+
+ // TODO: investigate "PatchPoint"
+ /*
+ if (Info.IsPatchPoint) {
+ fail(MIRBuilder, "WebAssembly doesn't support patch point yet");
+ return false;
+ }
+ */
+
+ if (Info.IsTailCall) {
+ Info.LoweredTailCall = true;
+ auto NoTail = [&](const char *Msg) {
+ if (Info.CB && Info.CB->isMustTailCall())
+ fail(MIRBuilder, Msg);
+ Info.LoweredTailCall = false;
+ };
+
+ if (!Subtarget.hasTailCall())
+ NoTail("WebAssembly 'tail-call' feature not enabled");
+
+ // Varargs calls cannot be tail calls because the buffer is on the stack
+ if (Info.IsVarArg)
+ NoTail("WebAssembly does not support varargs tail calls");
+
+ // Do not tail call unless caller and callee return types match
+ const Function &F = MF.getFunction();
+ const TargetMachine &TM = TLI.getTargetMachine();
+ Type *RetTy = F.getReturnType();
+ SmallVector<MVT, 4> CallerRetTys;
+ SmallVector<MVT, 4> CalleeRetTys;
+ computeLegalValueVTs(F, TM, RetTy, CallerRetTys);
+ computeLegalValueVTs(F, TM, Info.OrigRet.Ty, CalleeRetTys);
+ bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() &&
+ std::equal(CallerRetTys.begin(), CallerRetTys.end(),
+ CalleeRetTys.begin());
+ if (!TypesMatch)
+ NoTail("WebAssembly tail call requires caller and callee return types to "
+ "match");
+
+ // If pointers to local stack values are passed, we cannot tail call
+ if (Info.CB) {
+ for (auto &Arg : Info.CB->args()) {
+ Value *Val = Arg.get();
+ // Trace the value back through pointer operations
+ while (true) {
+ Value *Src = Val->stripPointerCastsAndAliases();
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(Src))
+ Src = GEP->getPointerOperand();
+ if (Val == Src)
+ break;
+ Val = Src;
+ }
+ if (isa<AllocaInst>(Val)) {
+ NoTail(
+ "WebAssembly does not support tail calling with stack arguments");
+ break;
+ }
+ }
+ }
+ }
+
+ MachineInstrBuilder CallInst;
+
+ bool IsIndirect = false;
+ Register IndirectIdx;
+
+ if (Info.Callee.isReg()) {
+ LLT CalleeType = MRI.getType(Info.Callee.getReg());
+ assert(CalleeType.isPointer() && "Trying to lower a call with a Callee other than a pointer???");
+
+ IsIndirect = true;
+ CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT : WebAssembly::CALL_INDIRECT);
+
+ // Placeholder for the type index.
+ // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
+ CallInst.addImm(0);
+
+ MCSymbolWasm *Table;
+ if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) {
+ Table = WebAssembly::getOrCreateFunctionTableSymbol(
+ MF.getContext(), &Subtarget);
+ IndirectIdx = Info.Callee.getReg();
+
+ auto PtrSize = CalleeType.getSizeInBits();
+ auto PtrIntLLT = LLT::scalar(PtrSize);
+
+ IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0);
+ if (PtrSize > 32) {
+ IndirectIdx = MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0);
+ }
+ } else if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) {
+ Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
+ MF.getContext(), &Subtarget);
+
+ auto TableSetInstr = MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF);
+ TableSetInstr.addSym(Table);
+ TableSetInstr.addUse(Info.Callee.getReg());
+ IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0);
+ } else {
+ fail(MIRBuilder, "Invalid address space for indirect call");
+ return false;
+ }
+
+ if (Subtarget.hasCallIndirectOverlong()) {
+ CallInst.addSym(Table);
+ } else {
+ // For the MVP there is at most one table whose number is 0, but we can't
+ // write a table symbol or issue relocations. Instead we just ensure the
+ // table is live and write a zero.
+ Table->setNoStrip();
+ CallInst.addImm(0);
+ }
+ } else {
+ CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL);
+
+ if (Info.Callee.isGlobal()) {
+ CallInst.addGlobalAddress(Info.Callee.getGlobal());
+ } else if (Info.Callee.isSymbol()) {
+ // TODO: figure out how to trigger/test this
+ CallInst.addSym(Info.Callee.getMCSymbol());
+ } else {
+ llvm_unreachable("Trying to lower call with a callee other than reg, global, or a symbol.");
+ }
+ }
+
+
+ SmallVector<ArgInfo, 8> SplitArgs;
+
+ bool HasSwiftErrorArg = false;
+ bool HasSwiftSelfArg = false;
+
+ for (const auto &Arg : Info.OrigArgs) {
+ HasSwiftSelfArg |= Arg.Flags[0].isSwiftSelf();
+ HasSwiftErrorArg |= Arg.Flags[0].isSwiftError();
+ if (Arg.Flags[0].isNest()) {
+ fail(MIRBuilder, "WebAssembly hasn't implemented nest arguments");
+ return false;
+ }
+ if (Arg.Flags[0].isInAlloca()) {
+ fail(MIRBuilder, "WebAssembly hasn't implemented inalloca arguments");
+ return false;
+ }
+ if (Arg.Flags[0].isInConsecutiveRegs()) {
+ fail(MIRBuilder, "WebAssembly hasn't implemented cons regs arguments");
+ return false;
+ }
+ if (Arg.Flags[0].isInConsecutiveRegsLast()) {
+ fail(MIRBuilder,
+ "WebAssembly hasn't implemented cons regs last arguments");
+ return false;
+ }
+
+ if (Arg.Flags[0].isByVal() && Arg.Flags[0].getByValSize() != 0) {
+ MachineFrameInfo &MFI = MF.getFrameInfo();
+
+ unsigned MemSize = Arg.Flags[0].getByValSize();
+ Align MemAlign = Arg.Flags[0].getNonZeroByValAlign();
+ int FI = MFI.CreateStackObject(Arg.Flags[0].getByValSize(), MemAlign,
+ /*isSS=*/false);
+
+ auto StackAddrSpace = DL.getAllocaAddrSpace();
+ auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL);
+ Register StackObjPtrVreg =
+ MF.getRegInfo().createGenericVirtualRegister(PtrLLT);
+
+ MIRBuilder.buildFrameIndex(StackObjPtrVreg, FI);
+
+ MachinePointerInfo DstPtrInfo = MachinePointerInfo::getFixedStack(MF, FI);
+
+ MachinePointerInfo SrcPtrInfo(Arg.OrigValue);
+ if (!Arg.OrigValue) {
+ // We still need to accurately track the stack address space if we
+ // don't know the underlying value.
+ SrcPtrInfo = MachinePointerInfo::getUnknownStack(MF);
+ }
+
+ Align DstAlign =
+ std::max(MemAlign, inferAlignFromPtrInfo(MF, DstPtrInfo));
+
+ Align SrcAlign =
+ std::max(MemAlign, inferAlignFromPtrInfo(MF, SrcPtrInfo));
+
+ MachineMemOperand *SrcMMO = MF.getMachineMemOperand(
+ SrcPtrInfo,
+ MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable,
+ MemSize, SrcAlign);
+
+ MachineMemOperand *DstMMO = MF.getMachineMemOperand(
+ DstPtrInfo,
+ MachineMemOperand::MOStore | MachineMemOperand::MODereferenceable,
+ MemSize, DstAlign);
+
+ const LLT SizeTy = LLT::scalar(PtrLLT.getSizeInBits());
+
+ auto SizeConst = MIRBuilder.buildConstant(SizeTy, MemSize);
+ MIRBuilder.buildMemCpy(StackObjPtrVreg, Arg.Regs[0], SizeConst, *DstMMO,
+ *SrcMMO);
+ }
+
+ splitToValueTypes(Arg, SplitArgs, DL, CallConv);
+ }
+
+ unsigned NumFixedArgs = 0;
+
+ for (auto &Arg : SplitArgs) {
+ EVT OrigVT = TLI.getValueType(DL, Arg.Ty);
+ MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ LLT OrigLLT = getLLTForType(*Arg.Ty, DL);
+ LLT NewLLT = getLLTForMVT(NewVT);
+
+ // If we need to split the type over multiple regs, check it's a scenario
+ // we currently support.
+ unsigned NumParts =
+ TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT);
+
+ ISD::ArgFlagsTy OrigFlags = Arg.Flags[0];
+ Arg.Flags.clear();
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ISD::ArgFlagsTy Flags = OrigFlags;
+ if (Part == 0) {
+ Flags.setSplit();
+ } else {
+ Flags.setOrigAlign(Align(1));
+ if (Part == NumParts - 1)
+ Flags.setSplitEnd();
+ }
+
+ Arg.Flags.push_back(Flags);
+ }
+
+ Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end());
+ if (NumParts != 1 || OrigVT != NewVT) {
+ // If we can't directly assign the register, we need one or more
+ // intermediate values.
+ Arg.Regs.resize(NumParts);
+
+ // For each split register, create and assign a vreg that will store
+ // the incoming component of the larger value. These will later be
+ // merged to form the final vreg.
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
+ }
+
+ buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT,
+ extendOpFromFlags(Arg.Flags[0]));
+ }
+
+ if (!Arg.Flags[0].isVarArg()) {
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ CallInst.addUse(Arg.Regs[Part]);
+ ++NumFixedArgs;
+ }
+ }
+ }
+
+ if (CallConv == CallingConv::Swift) {
+ Type *PtrTy = PointerType::getUnqual(Ctx);
+ LLT PtrLLT = getLLTForType(*PtrTy, DL);
+
+ if (!HasSwiftSelfArg) {
+ CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0));
+ }
+ if (!HasSwiftErrorArg) {
+ CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0));
+ }
+ }
+
+ // Analyze operands of the call, assigning locations to each operand.
+ SmallVector<CCValAssign, 16> ArgLocs;
+ CCState CCInfo(CallConv, Info.IsVarArg, MF, ArgLocs, Ctx);
+
+ if (Info.IsVarArg) {
+ // Outgoing non-fixed arguments are placed in a buffer. First
+ // compute their offsets and the total amount of buffer space needed.
+ for (ArgInfo &Arg : drop_begin(SplitArgs, NumFixedArgs)) {
+ EVT OrigVT = TLI.getValueType(DL, Arg.Ty);
+ MVT PartVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ Type *Ty = EVT(PartVT).getTypeForEVT(Ctx);
+
+ for (unsigned Part = 0; Part < Arg.Regs.size(); ++Part) {
+ Align Alignment = std::max(Arg.Flags[Part].getNonZeroOrigAlign(),
+ DL.getABITypeAlign(Ty));
+ unsigned Offset =
+ CCInfo.AllocateStack(DL.getTypeAllocSize(Ty), Alignment);
+ CCInfo.addLoc(CCValAssign::getMem(ArgLocs.size(), PartVT, Offset,
+ PartVT, CCValAssign::Full));
+ }
+ }
+ }
+
+ unsigned NumBytes = CCInfo.getAlignedCallFrameSize();
+
+ auto StackAddrSpace = DL.getAllocaAddrSpace();
+ auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL);
+ auto SizeLLT = LLT::scalar(PtrLLT.getSizeInBits());
+
+ if (Info.IsVarArg && NumBytes) {
+ Register VarArgStackPtr =
+ MF.getRegInfo().createGenericVirtualRegister(PtrLLT);
+
+ MaybeAlign StackAlign = DL.getStackAlignment();
+ assert(StackAlign && "data layout string is missing stack alignment");
+ int FI = MF.getFrameInfo().CreateStackObject(NumBytes, *StackAlign,
+ /*isSS=*/false);
+
+ MIRBuilder.buildFrameIndex(VarArgStackPtr, FI);
+
+ unsigned ValNo = 0;
+ for (ArgInfo &Arg : drop_begin(SplitArgs, NumFixedArgs)) {
+ EVT OrigVT = TLI.getValueType(DL, Arg.Ty);
+ MVT PartVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ Type *Ty = EVT(PartVT).getTypeForEVT(Ctx);
+
+ for (unsigned Part = 0; Part < Arg.Regs.size(); ++Part) {
+ Align Alignment = std::max(Arg.Flags[Part].getNonZeroOrigAlign(),
+ DL.getABITypeAlign(Ty));
+
+ unsigned Offset = ArgLocs[ValNo++].getLocMemOffset();
+
+ Register DstPtr =
+ MIRBuilder
+ .buildPtrAdd(PtrLLT, VarArgStackPtr,
+ MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0))
+ .getReg(0);
+
+ MachineMemOperand *DstMMO = MF.getMachineMemOperand(
+ MachinePointerInfo::getFixedStack(MF, FI, Offset),
+ MachineMemOperand::MOStore | MachineMemOperand::MODereferenceable,
+ PartVT.getStoreSize(), Alignment);
+
+ MIRBuilder.buildStore(Arg.Regs[Part], DstPtr, *DstMMO);
+ }
+ }
+
+ CallInst.addUse(VarArgStackPtr);
+ } else if (Info.IsVarArg) {
+ CallInst.addUse(MIRBuilder.buildConstant(PtrLLT, 0).getReg(0));
+ }
+
+ if (IsIndirect) {
+ CallInst.addUse(IndirectIdx);
+ }
+
+ MIRBuilder.insertInstr(CallInst);
+
+ if (Info.LoweredTailCall) {
+ return true;
+ }
+
+ if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) {
+ SmallVector<EVT, 4> SplitEVTs;
+ ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs);
+ assert(Info.OrigRet.Regs.size() == SplitEVTs.size() &&
+ "For each split Type there should be exactly one VReg.");
+
+ SmallVector<ArgInfo, 8> SplitReturns;
+
+ unsigned i = 0;
+ for (auto SplitEVT : SplitEVTs) {
+ Register CurVReg = Info.OrigRet.Regs[i];
+ ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0};
+ setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB);
+
+ splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv);
+ ++i;
+ }
+
+ for (auto &Ret : SplitReturns) {
+ EVT OrigVT = TLI.getValueType(DL, Ret.Ty);
+ MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ LLT OrigLLT = getLLTForType(*Ret.Ty, DL);
+ LLT NewLLT = getLLTForMVT(NewVT);
+
+ // If we need to split the type over multiple regs, check it's a scenario
+ // we currently support.
+ unsigned NumParts =
+ TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT);
+
+ ISD::ArgFlagsTy OrigFlags = Ret.Flags[0];
+ Ret.Flags.clear();
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ISD::ArgFlagsTy Flags = OrigFlags;
+ if (Part == 0) {
+ Flags.setSplit();
+ } else {
+ Flags.setOrigAlign(Align(1));
+ if (Part == NumParts - 1)
+ Flags.setSplitEnd();
+ }
+
+ Ret.Flags.push_back(Flags);
+ }
+
+ Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end());
+ if (NumParts != 1 || OrigVT != NewVT) {
+ // If we can't directly assign the register, we need one or more
+ // intermediate values.
+ Ret.Regs.resize(NumParts);
+
+ // For each split register, create and assign a vreg that will store
+ // the incoming component of the larger value. These will later be
+ // merged to form the final vreg.
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
+ }
+ buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT,
+ Ret.Flags[0]);
+ }
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ CallInst.addDef(Ret.Regs[Part]);
+ }
+ }
+ }
+
+ if (!Info.CanLowerReturn) {
+ insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs,
+ Info.DemoteRegister, Info.DemoteStackIndex);
+
+ for (auto Reg : Info.OrigRet.Regs) {
+ CallInst.addDef(Reg);
+ }
+ }
+
+ return true;
}
>From 5a03714459bc0c0a17b16488e5fe864270e05a95 Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:41:11 -0700
Subject: [PATCH 3/9] Fix formatting
---
.../GISel/WebAssemblyCallLowering.cpp | 128 ++++++++++--------
1 file changed, 69 insertions(+), 59 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 8956932b403ef..23a6274e66661 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -697,7 +697,8 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
LLVMContext &Ctx = MIRBuilder.getContext();
const WebAssemblyTargetLowering &TLI = *getTLI<WebAssemblyTargetLowering>();
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
- const WebAssemblySubtarget &Subtarget = MF.getSubtarget<WebAssemblySubtarget>();
+ const WebAssemblySubtarget &Subtarget =
+ MF.getSubtarget<WebAssemblySubtarget>();
CallingConv::ID CallConv = Info.CallConv;
if (!callingConvSupported(CallConv)) {
@@ -716,7 +717,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
*/
if (Info.IsTailCall) {
- Info.LoweredTailCall = true;
+ Info.LoweredTailCall = true;
auto NoTail = [&](const char *Msg) {
if (Info.CB && Info.CB->isMustTailCall())
fail(MIRBuilder, Msg);
@@ -773,65 +774,73 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Register IndirectIdx;
if (Info.Callee.isReg()) {
- LLT CalleeType = MRI.getType(Info.Callee.getReg());
- assert(CalleeType.isPointer() && "Trying to lower a call with a Callee other than a pointer???");
-
- IsIndirect = true;
- CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT : WebAssembly::CALL_INDIRECT);
-
- // Placeholder for the type index.
- // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
- CallInst.addImm(0);
-
- MCSymbolWasm *Table;
- if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) {
- Table = WebAssembly::getOrCreateFunctionTableSymbol(
- MF.getContext(), &Subtarget);
- IndirectIdx = Info.Callee.getReg();
-
- auto PtrSize = CalleeType.getSizeInBits();
- auto PtrIntLLT = LLT::scalar(PtrSize);
-
- IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0);
- if (PtrSize > 32) {
- IndirectIdx = MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0);
- }
- } else if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) {
- Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(
- MF.getContext(), &Subtarget);
-
- auto TableSetInstr = MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF);
- TableSetInstr.addSym(Table);
- TableSetInstr.addUse(Info.Callee.getReg());
- IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0);
- } else {
- fail(MIRBuilder, "Invalid address space for indirect call");
- return false;
+ LLT CalleeType = MRI.getType(Info.Callee.getReg());
+ assert(CalleeType.isPointer() &&
+ "Trying to lower a call with a Callee other than a pointer???");
+
+ IsIndirect = true;
+ CallInst = MIRBuilder.buildInstrNoInsert(
+ Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT
+ : WebAssembly::CALL_INDIRECT);
+
+ // Placeholder for the type index.
+ // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
+ CallInst.addImm(0);
+
+ MCSymbolWasm *Table;
+ if (CalleeType.getAddressSpace() ==
+ WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) {
+ Table = WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(),
+ &Subtarget);
+ IndirectIdx = Info.Callee.getReg();
+
+ auto PtrSize = CalleeType.getSizeInBits();
+ auto PtrIntLLT = LLT::scalar(PtrSize);
+
+ IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0);
+ if (PtrSize > 32) {
+ IndirectIdx =
+ MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0);
}
+ } else if (CalleeType.getAddressSpace() ==
+ WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) {
+ Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(),
+ &Subtarget);
+
+ auto TableSetInstr =
+ MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF);
+ TableSetInstr.addSym(Table);
+ TableSetInstr.addUse(Info.Callee.getReg());
+ IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0);
+ } else {
+ fail(MIRBuilder, "Invalid address space for indirect call");
+ return false;
+ }
- if (Subtarget.hasCallIndirectOverlong()) {
- CallInst.addSym(Table);
- } else {
- // For the MVP there is at most one table whose number is 0, but we can't
- // write a table symbol or issue relocations. Instead we just ensure the
- // table is live and write a zero.
- Table->setNoStrip();
- CallInst.addImm(0);
- }
+ if (Subtarget.hasCallIndirectOverlong()) {
+ CallInst.addSym(Table);
+ } else {
+ // For the MVP there is at most one table whose number is 0, but we can't
+ // write a table symbol or issue relocations. Instead we just ensure the
+ // table is live and write a zero.
+ Table->setNoStrip();
+ CallInst.addImm(0);
+ }
} else {
- CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL);
-
- if (Info.Callee.isGlobal()) {
- CallInst.addGlobalAddress(Info.Callee.getGlobal());
- } else if (Info.Callee.isSymbol()) {
- // TODO: figure out how to trigger/test this
- CallInst.addSym(Info.Callee.getMCSymbol());
- } else {
- llvm_unreachable("Trying to lower call with a callee other than reg, global, or a symbol.");
- }
+ CallInst = MIRBuilder.buildInstrNoInsert(
+ Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL);
+
+ if (Info.Callee.isGlobal()) {
+ CallInst.addGlobalAddress(Info.Callee.getGlobal());
+ } else if (Info.Callee.isSymbol()) {
+ // TODO: figure out how to trigger/test this
+ CallInst.addSym(Info.Callee.getMCSymbol());
+ } else {
+ llvm_unreachable("Trying to lower call with a callee other than reg, "
+ "global, or a symbol.");
+ }
}
-
SmallVector<ArgInfo, 8> SplitArgs;
bool HasSwiftErrorArg = false;
@@ -1028,8 +1037,9 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Register DstPtr =
MIRBuilder
- .buildPtrAdd(PtrLLT, VarArgStackPtr,
- MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0))
+ .buildPtrAdd(
+ PtrLLT, VarArgStackPtr,
+ MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0))
.getReg(0);
MachineMemOperand *DstMMO = MF.getMachineMemOperand(
@@ -1053,7 +1063,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
MIRBuilder.insertInstr(CallInst);
if (Info.LoweredTailCall) {
- return true;
+ return true;
}
if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) {
>From 8c0dd7d9d3c62eae888498c1fee7b2670e26f2ce Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:41:18 -0700
Subject: [PATCH 4/9] Fix some issues with WebAssemblyCallLowering::lowerCall
---
.../GISel/WebAssemblyCallLowering.cpp | 27 +++++++++++++------
1 file changed, 19 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 23a6274e66661..23533c5ad1c75 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -798,10 +798,6 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
auto PtrIntLLT = LLT::scalar(PtrSize);
IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0);
- if (PtrSize > 32) {
- IndirectIdx =
- MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0);
- }
} else if (CalleeType.getAddressSpace() ==
WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) {
Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(),
@@ -833,8 +829,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (Info.Callee.isGlobal()) {
CallInst.addGlobalAddress(Info.Callee.getGlobal());
} else if (Info.Callee.isSymbol()) {
- // TODO: figure out how to trigger/test this
- CallInst.addSym(Info.Callee.getMCSymbol());
+ CallInst.addExternalSymbol(Info.Callee.getSymbolName());
} else {
llvm_unreachable("Trying to lower call with a callee other than reg, "
"global, or a symbol.");
@@ -1078,8 +1073,24 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
for (auto SplitEVT : SplitEVTs) {
Register CurVReg = Info.OrigRet.Regs[i];
ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0};
- setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB);
-
+ if (Info.CB) {
+ setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB);
+ } else {
+ // we don't have a call base, so chances are we're looking at a libcall
+ // (external symbol).
+
+ // TODO: figure out how to get ALL the correct attributes
+ auto &Flags = CurArgInfo.Flags[0];
+ PointerType *PtrTy =
+ dyn_cast<PointerType>(CurArgInfo.Ty->getScalarType());
+ if (PtrTy) {
+ Flags.setPointer();
+ Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace());
+ }
+ Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty);
+ Flags.setMemAlign(MemAlign);
+ Flags.setOrigAlign(MemAlign);
+ }
splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv);
++i;
}
>From 233c7a3bcca753a58a2e1a9cdf4cf2787ab465b3 Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:41:44 -0700
Subject: [PATCH 5/9] Attempt to make CallLowering floating-point aware (use
FPEXT and FPTRUNC instead of integer ANYEXT/TRUNC)
---
.../GISel/WebAssemblyCallLowering.cpp | 29 ++++++++++++++-----
1 file changed, 22 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 23533c5ad1c75..3dd928a825995 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -29,6 +29,7 @@
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineMemOperand.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGenTypes/LowLevelType.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/DataLayout.h"
@@ -108,9 +109,12 @@ mergeVectorRegsToResultRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
/// typed values to the original IR value. \p OrigRegs contains the destination
/// value registers of type \p LLTy, and \p Regs contains the legalized pieces
/// with type \p PartLLT. This is used for incoming values (physregs to vregs).
+
+// Modified to account for floating-point extends/truncations
static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
ArrayRef<Register> Regs, LLT LLTy, LLT PartLLT,
- const ISD::ArgFlagsTy Flags) {
+ const ISD::ArgFlagsTy Flags,
+ bool IsFloatingPoint) {
MachineRegisterInfo &MRI = *B.getMRI();
if (PartLLT == LLTy) {
@@ -153,7 +157,10 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
return;
}
- B.buildTrunc(OrigRegs[0], SrcReg);
+ if (IsFloatingPoint)
+ B.buildFPTrunc(OrigRegs[0], SrcReg);
+ else
+ B.buildTrunc(OrigRegs[0], SrcReg);
return;
}
@@ -166,7 +173,11 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
B.buildMergeValues(OrigRegs[0], Regs);
else {
auto Widened = B.buildMergeLikeInstr(LLT::scalar(SrcSize), Regs);
- B.buildTrunc(OrigRegs[0], Widened);
+
+ if (IsFloatingPoint)
+ B.buildFPTrunc(OrigRegs[0], Widened);
+ else
+ B.buildTrunc(OrigRegs[0], Widened);
}
return;
@@ -496,7 +507,9 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
}
buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT,
- extendOpFromFlags(Arg.Flags[0]));
+ Arg.Ty->isFloatingPointTy()
+ ? TargetOpcode::G_FPEXT
+ : extendOpFromFlags(Arg.Flags[0]));
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
@@ -630,7 +643,7 @@ bool WebAssemblyCallLowering::lowerFormalArguments(
Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
}
buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT,
- Arg.Flags[0]);
+ Arg.Flags[0], Arg.Ty->isFloatingPointTy());
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
@@ -955,7 +968,9 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
}
buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT,
- extendOpFromFlags(Arg.Flags[0]));
+ Arg.Ty->isFloatingPointTy()
+ ? TargetOpcode::G_FPEXT
+ : extendOpFromFlags(Arg.Flags[0]));
}
if (!Arg.Flags[0].isVarArg()) {
@@ -1135,7 +1150,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
}
buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT,
- Ret.Flags[0]);
+ Ret.Flags[0], Ret.Ty->isFloatingPointTy());
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
>From c1e00ae433acdb2567b0e3ba7faf09a643f4457c Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:41:53 -0700
Subject: [PATCH 6/9] Fix lowerCall vararg crash.
---
llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 3dd928a825995..7ee118387bfe4 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -976,8 +976,8 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (!Arg.Flags[0].isVarArg()) {
for (unsigned Part = 0; Part < NumParts; ++Part) {
CallInst.addUse(Arg.Regs[Part]);
- ++NumFixedArgs;
}
+ ++NumFixedArgs;
}
}
>From 57da891e80bcb5154f8f66dd940b12e322e2c1fd Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:42:03 -0700
Subject: [PATCH 7/9] Set up basic legalization (scalar only, limited support
for FP, p0 only)
---
.../GISel/WebAssemblyLegalizerInfo.cpp | 256 ++++++++++++++++++
.../GISel/WebAssemblyLegalizerInfo.h | 2 +
2 files changed, 258 insertions(+)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
index 3acdabb5612cc..c6cd1c5b371e9 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
@@ -11,6 +11,12 @@
//===----------------------------------------------------------------------===//
#include "WebAssemblyLegalizerInfo.h"
+#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "WebAssemblySubtarget.h"
+#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/IR/DerivedTypes.h"
#define DEBUG_TYPE "wasm-legalinfo"
@@ -19,5 +25,255 @@ using namespace LegalizeActions;
WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
const WebAssemblySubtarget &ST) {
+ using namespace TargetOpcode;
+ const LLT s8 = LLT::scalar(8);
+ const LLT s16 = LLT::scalar(16);
+ const LLT s32 = LLT::scalar(32);
+ const LLT s64 = LLT::scalar(64);
+
+ const LLT p0 = LLT::pointer(0, ST.hasAddr64() ? 64 : 32);
+ const LLT p0s = LLT::scalar(ST.hasAddr64() ? 64 : 32);
+
+ getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor({p0});
+
+ getActionDefinitionsBuilder(G_PHI)
+ .legalFor({p0, s32, s64})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64);
+ getActionDefinitionsBuilder(G_BR).alwaysLegal();
+ getActionDefinitionsBuilder(G_BRCOND).legalFor({s32}).clampScalar(0, s32,
+ s32);
+ getActionDefinitionsBuilder(G_BRJT)
+ .legalFor({{p0, s32}})
+ .clampScalar(1, s32, s32);
+
+ getActionDefinitionsBuilder(G_SELECT)
+ .legalFor({{s32, s32}, {s64, s32}, {p0, s32}})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64)
+ .clampScalar(1, s32, s32);
+
+ getActionDefinitionsBuilder(G_JUMP_TABLE).legalFor({p0});
+
+ getActionDefinitionsBuilder(G_ICMP)
+ .legalFor({{s32, s32}, {s32, s64}, {s32, p0}})
+ .widenScalarToNextPow2(1)
+ .clampScalar(1, s32, s64)
+ .clampScalar(0, s32, s32);
+
+ getActionDefinitionsBuilder(G_FCMP)
+ .legalFor({{s32, s32}, {s32, s64}})
+ .clampScalar(0, s32, s32)
+ .libcall();
+
+ getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
+
+ getActionDefinitionsBuilder(G_CONSTANT)
+ .legalFor({s32, s64, p0})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder(G_FCONSTANT)
+ .legalFor({s32, s64})
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder(G_IMPLICIT_DEF)
+ .legalFor({s32, s64, p0})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder(
+ {G_ADD, G_SUB, G_MUL, G_UDIV, G_SDIV, G_UREM, G_SREM})
+ .legalFor({s32, s64})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF,
+ G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_FSHL,
+ G_FSHR})
+ .legalFor({{s32, s32}, {s64, s64}})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64)
+ .minScalarSameAs(1, 0)
+ .maxScalarSameAs(1, 0);
+
+ getActionDefinitionsBuilder({G_SCMP, G_UCMP}).lower();
+
+ getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
+ .legalFor({s32, s64})
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder({G_UMIN, G_UMAX, G_SMIN, G_SMAX}).lower();
+
+ getActionDefinitionsBuilder({G_FADD, G_FSUB, G_FDIV, G_FMUL, G_FNEG, G_FABS,
+ G_FCEIL, G_FFLOOR, G_FSQRT, G_INTRINSIC_TRUNC,
+ G_FNEARBYINT, G_FRINT, G_INTRINSIC_ROUNDEVEN,
+ G_FMINIMUM, G_FMAXIMUM})
+ .legalFor({s32, s64})
+ .minScalar(0, s32);
+
+ // TODO: _IEEE not lowering correctly?
+ getActionDefinitionsBuilder(
+ {G_FMINNUM, G_FMAXNUM, G_FMINNUM_IEEE, G_FMAXNUM_IEEE})
+ .lowerFor({s32, s64})
+ .minScalar(0, s32);
+
+ getActionDefinitionsBuilder({G_FMA, G_FREM})
+ .libcallFor({s32, s64})
+ .minScalar(0, s32);
+
+ getActionDefinitionsBuilder(G_FCOPYSIGN)
+ .legalFor({s32, s64})
+ .minScalar(0, s32)
+ .minScalarSameAs(1, 0)
+ .maxScalarSameAs(1, 0);
+
+ getActionDefinitionsBuilder({G_FPTOUI, G_FPTOUI_SAT, G_FPTOSI, G_FPTOSI_SAT})
+ .legalForCartesianProduct({s32, s64}, {s32, s64})
+ .minScalar(1, s32)
+ .widenScalarToNextPow2(0)
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder({G_UITOFP, G_SITOFP})
+ .legalForCartesianProduct({s32, s64}, {s32, s64})
+ .minScalar(1, s32)
+ .widenScalarToNextPow2(1)
+ .clampScalar(1, s32, s64);
+
+ getActionDefinitionsBuilder(G_PTRTOINT).legalFor({{p0s, p0}});
+ getActionDefinitionsBuilder(G_INTTOPTR).legalFor({{p0, p0s}});
+ getActionDefinitionsBuilder(G_PTR_ADD).legalFor({{p0, p0s}});
+
+ getActionDefinitionsBuilder(G_LOAD)
+ .legalForTypesWithMemDesc(
+ {{s32, p0, s32, 1}, {s64, p0, s64, 1}, {p0, p0, p0, 1}})
+ .legalForTypesWithMemDesc({{s32, p0, s8, 1},
+ {s32, p0, s16, 1},
+
+ {s64, p0, s8, 1},
+ {s64, p0, s16, 1},
+ {s64, p0, s32, 1}})
+ .widenScalarToNextPow2(0)
+ .lowerIfMemSizeNotByteSizePow2()
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder(G_STORE)
+ .legalForTypesWithMemDesc(
+ {{s32, p0, s32, 1}, {s64, p0, s64, 1}, {p0, p0, p0, 1}})
+ .legalForTypesWithMemDesc({{s32, p0, s8, 1},
+ {s32, p0, s16, 1},
+
+ {s64, p0, s8, 1},
+ {s64, p0, s16, 1},
+ {s64, p0, s32, 1}})
+ .widenScalarToNextPow2(0)
+ .lowerIfMemSizeNotByteSizePow2()
+ .clampScalar(0, s32, s64);
+
+ getActionDefinitionsBuilder({G_ZEXTLOAD, G_SEXTLOAD})
+ .legalForTypesWithMemDesc({{s32, p0, s8, 1},
+ {s32, p0, s16, 1},
+
+ {s64, p0, s8, 1},
+ {s64, p0, s16, 1},
+ {s64, p0, s32, 1}})
+ .widenScalarToNextPow2(0)
+ .lowerIfMemSizeNotByteSizePow2()
+ .clampScalar(0, s32, s64)
+ .lower();
+
+ if (ST.hasBulkMemoryOpt()) {
+ getActionDefinitionsBuilder(G_BZERO).unsupported();
+
+ getActionDefinitionsBuilder(G_MEMSET)
+ .legalForCartesianProduct({p0}, {s32}, {p0s})
+ .customForCartesianProduct({p0}, {s8}, {p0s})
+ .immIdx(0);
+
+ getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
+ .legalForCartesianProduct({p0}, {p0}, {p0s})
+ .immIdx(0);
+
+ getActionDefinitionsBuilder(G_MEMCPY_INLINE)
+ .legalForCartesianProduct({p0}, {p0}, {p0s});
+ } else {
+ getActionDefinitionsBuilder({G_BZERO, G_MEMCPY, G_MEMMOVE, G_MEMSET})
+ .libcall();
+ }
+
+ // TODO: figure out how to combine G_ANYEXT of G_ASSERT_{S|Z}EXT (or
+ // appropriate G_AND and G_SEXT_IN_REG?) to a G_{S|Z}EXT + G_ASSERT_{S|Z}EXT
+ // for better optimization (since G_ANYEXT lowers to a ZEXT or SEXT
+ // instruction anyway).
+
+ getActionDefinitionsBuilder(G_ANYEXT)
+ .legalFor({{s64, s32}})
+ .clampScalar(0, s32, s64)
+ .clampScalar(1, s32, s64);
+
+ getActionDefinitionsBuilder({G_SEXT, G_ZEXT})
+ .legalFor({{s64, s32}})
+ .clampScalar(0, s32, s64)
+ .clampScalar(1, s32, s64)
+ .lower();
+
+ if (ST.hasSignExt()) {
+ getActionDefinitionsBuilder(G_SEXT_INREG)
+ .clampScalar(0, s32, s64)
+ .customFor({s32, s64})
+ .lower();
+ } else {
+ getActionDefinitionsBuilder(G_SEXT_INREG).lower();
+ }
+
+ getActionDefinitionsBuilder(G_TRUNC)
+ .legalFor({{s32, s64}})
+ .clampScalar(0, s32, s64)
+ .clampScalar(1, s32, s64)
+ .lower();
+
+ getActionDefinitionsBuilder(G_FPEXT).legalFor({{s64, s32}});
+
+ getActionDefinitionsBuilder(G_FPTRUNC).legalFor({{s32, s64}});
+
+ getActionDefinitionsBuilder(G_VASTART).legalFor({p0});
+ getActionDefinitionsBuilder(G_VAARG)
+ .legalForCartesianProduct({s32, s64}, {p0})
+ .clampScalar(0, s32, s64);
+
getLegacyLegalizerInfo().computeTables();
}
+
+bool WebAssemblyLegalizerInfo::legalizeCustom(
+ LegalizerHelper &Helper, MachineInstr &MI,
+ LostDebugLocObserver &LocObserver) const {
+ switch (MI.getOpcode()) {
+ case TargetOpcode::G_SEXT_INREG: {
+ // Mark only 8/16/32-bit SEXT_INREG as legal
+ auto [DstType, SrcType] = MI.getFirst2LLTs();
+ auto ExtFromWidth = MI.getOperand(2).getImm();
+
+ if (ExtFromWidth == 8 || ExtFromWidth == 16 ||
+ (DstType.getScalarSizeInBits() == 64 && ExtFromWidth == 32)) {
+ return true;
+ }
+ return false;
+ }
+ case TargetOpcode::G_MEMSET: {
+ // Anyext the value being set to 32 bit (only the bottom 8 bits are read by
+ // the instruction).
+ Helper.Observer.changingInstr(MI);
+ auto &Value = MI.getOperand(1);
+
+ Register ExtValueReg =
+ Helper.MIRBuilder.buildAnyExt(LLT::scalar(32), Value).getReg(0);
+ Value.setReg(ExtValueReg);
+ Helper.Observer.changedInstr(MI);
+ return true;
+ }
+ default:
+ break;
+ }
+ return false;
+}
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h
index c02205fc7ae0d..5aca23c9514e1 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h
@@ -24,6 +24,8 @@ class WebAssemblySubtarget;
class WebAssemblyLegalizerInfo : public LegalizerInfo {
public:
WebAssemblyLegalizerInfo(const WebAssemblySubtarget &ST);
+
+ bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override;
};
} // namespace llvm
#endif
>From 332f108979a98e66f34ff5a09475659a94b26478 Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Sun, 28 Sep 2025 22:42:32 -0700
Subject: [PATCH 8/9] start on regbankselect
---
llvm/lib/Target/WebAssembly/CMakeLists.txt | 1 +
.../GISel/WebAssemblyCallLowering.cpp | 299 ++++++++++-------
.../GISel/WebAssemblyRegisterBankInfo.cpp | 302 ++++++++++++++++++
.../GISel/WebAssemblyRegisterBankInfo.h | 40 +++
llvm/lib/Target/WebAssembly/WebAssembly.td | 1 +
.../WebAssembly/WebAssemblyRegisterBanks.td | 20 ++
.../WebAssembly/WebAssemblySubtarget.cpp | 4 +-
7 files changed, 549 insertions(+), 118 deletions(-)
create mode 100644 llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td
diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt
index 371d224efc1c5..e573582509263 100644
--- a/llvm/lib/Target/WebAssembly/CMakeLists.txt
+++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt
@@ -9,6 +9,7 @@ tablegen(LLVM WebAssemblyGenDisassemblerTables.inc -gen-disassembler)
tablegen(LLVM WebAssemblyGenFastISel.inc -gen-fast-isel)
tablegen(LLVM WebAssemblyGenInstrInfo.inc -gen-instr-info)
tablegen(LLVM WebAssemblyGenMCCodeEmitter.inc -gen-emitter)
+tablegen(LLVM WebAssemblyGenRegisterBank.inc -gen-register-bank)
tablegen(LLVM WebAssemblyGenRegisterInfo.inc -gen-register-info)
tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 7ee118387bfe4..7f3c7e62ce02d 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -13,10 +13,12 @@
//===----------------------------------------------------------------------===//
#include "WebAssemblyCallLowering.h"
+#include "GISel/WebAssemblyRegisterBankInfo.h"
#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "Utils/WasmAddressSpaces.h"
#include "WebAssemblyISelLowering.h"
#include "WebAssemblyMachineFunctionInfo.h"
+#include "WebAssemblyRegisterInfo.h"
#include "WebAssemblySubtarget.h"
#include "WebAssemblyUtilities.h"
#include "llvm/Analysis/MemoryLocation.h"
@@ -25,6 +27,7 @@
#include "llvm/CodeGen/FunctionLoweringInfo.h"
#include "llvm/CodeGen/GlobalISel/CallLowering.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/LowLevelTypeUtils.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -435,6 +438,12 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
FunctionLoweringInfo &FLI,
Register SwiftErrorVReg) const {
auto MIB = MIRBuilder.buildInstrNoInsert(WebAssembly::RETURN);
+ MachineFunction &MF = MIRBuilder.getMF();
+ auto &TLI = *getTLI<WebAssemblyTargetLowering>();
+ auto &Subtarget = MF.getSubtarget<WebAssemblySubtarget>();
+ auto &TRI = *Subtarget.getRegisterInfo();
+ auto &TII = *Subtarget.getInstrInfo();
+ auto &RBI = *Subtarget.getRegBankInfo();
assert(((Val && !VRegs.empty()) || (!Val && VRegs.empty())) &&
"Return value without a vreg");
@@ -513,7 +522,11 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
- MIB.addUse(Arg.Regs[Part]);
+ auto NewOutReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part],
+ *TLI.getRegClassFor(NewVT));
+ if (NewOutReg != Arg.Regs[Part])
+ MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]);
+ MIB.addUse(NewOutReg);
}
}
}
@@ -564,6 +577,11 @@ bool WebAssemblyCallLowering::lowerFormalArguments(
WebAssemblyFunctionInfo *MFI = MF.getInfo<WebAssemblyFunctionInfo>();
const DataLayout &DL = F.getDataLayout();
auto &TLI = *getTLI<WebAssemblyTargetLowering>();
+ auto &Subtarget = MF.getSubtarget<WebAssemblySubtarget>();
+ auto &TRI = *Subtarget.getRegisterInfo();
+ auto &TII = *Subtarget.getInstrInfo();
+ auto &RBI = *Subtarget.getRegBankInfo();
+
LLVMContext &Ctx = MIRBuilder.getContext();
const CallingConv::ID CallConv = F.getCallingConv();
@@ -647,9 +665,12 @@ bool WebAssemblyCallLowering::lowerFormalArguments(
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
- MIRBuilder.buildInstr(getWASMArgOpcode(NewVT))
- .addDef(Arg.Regs[Part])
- .addImm(FinalArgIdx);
+ auto ArgInst = MIRBuilder.buildInstr(getWASMArgOpcode(NewVT))
+ .addDef(Arg.Regs[Part])
+ .addImm(FinalArgIdx);
+
+ constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *ArgInst,
+ ArgInst->getDesc(), ArgInst->getOperand(0), 0);
MFI->addParam(NewVT);
++FinalArgIdx;
}
@@ -712,6 +733,9 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
const WebAssemblySubtarget &Subtarget =
MF.getSubtarget<WebAssemblySubtarget>();
+ auto &TRI = *Subtarget.getRegisterInfo();
+ auto &TII = *Subtarget.getInstrInfo();
+ auto &RBI = *Subtarget.getRegBankInfo();
CallingConv::ID CallConv = Info.CallConv;
if (!callingConvSupported(CallConv)) {
@@ -781,21 +805,128 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
}
}
+ if (Info.LoweredTailCall) {
+ MF.getFrameInfo().setHasTailCall();
+ }
+
MachineInstrBuilder CallInst;
bool IsIndirect = false;
Register IndirectIdx;
+ if (Info.Callee.isReg()) {
+ IsIndirect = true;
+ CallInst = MIRBuilder.buildInstr(Info.LoweredTailCall
+ ? WebAssembly::RET_CALL_INDIRECT
+ : WebAssembly::CALL_INDIRECT);
+ } else {
+ CallInst = MIRBuilder.buildInstr(
+ Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL);
+ }
+
+ if (!Info.LoweredTailCall) {
+ if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) {
+ SmallVector<EVT, 4> SplitEVTs;
+ ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs);
+ assert(Info.OrigRet.Regs.size() == SplitEVTs.size() &&
+ "For each split Type there should be exactly one VReg.");
+
+ SmallVector<ArgInfo, 8> SplitReturns;
+
+ unsigned i = 0;
+ for (auto SplitEVT : SplitEVTs) {
+ Register CurVReg = Info.OrigRet.Regs[i];
+ ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0};
+ if (Info.CB) {
+ setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB);
+ } else {
+ // we don't have a call base, so chances are we're looking at a
+ // libcall (external symbol).
+
+ // TODO: figure out how to get ALL the correct attributes
+ auto &Flags = CurArgInfo.Flags[0];
+ PointerType *PtrTy =
+ dyn_cast<PointerType>(CurArgInfo.Ty->getScalarType());
+ if (PtrTy) {
+ Flags.setPointer();
+ Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace());
+ }
+ Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty);
+ Flags.setMemAlign(MemAlign);
+ Flags.setOrigAlign(MemAlign);
+ }
+ splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv);
+ ++i;
+ }
+
+ for (auto &Ret : SplitReturns) {
+ EVT OrigVT = TLI.getValueType(DL, Ret.Ty);
+ MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
+ LLT OrigLLT = getLLTForType(*Ret.Ty, DL);
+ LLT NewLLT = getLLTForMVT(NewVT);
+
+ // If we need to split the type over multiple regs, check it's a
+ // scenario we currently support.
+ unsigned NumParts =
+ TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT);
+
+ ISD::ArgFlagsTy OrigFlags = Ret.Flags[0];
+ Ret.Flags.clear();
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ ISD::ArgFlagsTy Flags = OrigFlags;
+ if (Part == 0) {
+ Flags.setSplit();
+ } else {
+ Flags.setOrigAlign(Align(1));
+ if (Part == NumParts - 1)
+ Flags.setSplitEnd();
+ }
+
+ Ret.Flags.push_back(Flags);
+ }
+
+ Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end());
+ if (NumParts != 1 || OrigVT != NewVT) {
+ // If we can't directly assign the register, we need one or more
+ // intermediate values.
+ Ret.Regs.resize(NumParts);
+
+ // For each split register, create and assign a vreg that will store
+ // the incoming component of the larger value. These will later be
+ // merged to form the final vreg.
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
+ }
+ buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT,
+ Ret.Flags[0], Ret.Ty->isFloatingPointTy());
+ }
+
+ for (unsigned Part = 0; Part < NumParts; ++Part) {
+ // MRI.setRegClass(Ret.Regs[Part], TLI.getRegClassFor(NewVT));
+ auto NewRetReg = constrainRegToClass(MRI, TII, RBI, Ret.Regs[Part],
+ *TLI.getRegClassFor(NewVT));
+ if (Ret.Regs[Part] != NewRetReg)
+ MIRBuilder.buildCopy(NewRetReg, Ret.Regs[Part]);
+
+ CallInst.addDef(Ret.Regs[Part]);
+ }
+ }
+ }
+
+ if (!Info.CanLowerReturn) {
+ insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs,
+ Info.DemoteRegister, Info.DemoteStackIndex);
+ }
+ }
+ auto SavedInsertPt = MIRBuilder.getInsertPt();
+ MIRBuilder.setInstr(*CallInst);
+
if (Info.Callee.isReg()) {
LLT CalleeType = MRI.getType(Info.Callee.getReg());
assert(CalleeType.isPointer() &&
"Trying to lower a call with a Callee other than a pointer???");
- IsIndirect = true;
- CallInst = MIRBuilder.buildInstrNoInsert(
- Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT
- : WebAssembly::CALL_INDIRECT);
-
// Placeholder for the type index.
// This gets replaced with the correct value in WebAssemblyMCInstLower.cpp
CallInst.addImm(0);
@@ -816,11 +947,25 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(),
&Subtarget);
+ Type *PtrTy = PointerType::getUnqual(Ctx);
+ LLT PtrLLT = getLLTForType(*PtrTy, DL);
+ auto PtrIntLLT = LLT::scalar(PtrLLT.getSizeInBits());
+
+ IndirectIdx = MIRBuilder.buildConstant(PtrIntLLT, 0).getReg(0);
+
auto TableSetInstr =
MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF);
TableSetInstr.addSym(Table);
+ TableSetInstr.addUse(IndirectIdx);
TableSetInstr.addUse(Info.Callee.getReg());
- IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0);
+
+ constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *TableSetInstr,
+ TableSetInstr->getDesc(),
+ TableSetInstr->getOperand(1), 1);
+ constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *TableSetInstr,
+ TableSetInstr->getDesc(),
+ TableSetInstr->getOperand(2), 2);
+
} else {
fail(MIRBuilder, "Invalid address space for indirect call");
return false;
@@ -836,9 +981,6 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallInst.addImm(0);
}
} else {
- CallInst = MIRBuilder.buildInstrNoInsert(
- Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL);
-
if (Info.Callee.isGlobal()) {
CallInst.addGlobalAddress(Info.Callee.getGlobal());
} else if (Info.Callee.isSymbol()) {
@@ -884,9 +1026,13 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
/*isSS=*/false);
auto StackAddrSpace = DL.getAllocaAddrSpace();
- auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL);
+ auto PtrLLT =
+ LLT::pointer(StackAddrSpace, DL.getPointerSizeInBits(StackAddrSpace));
+
Register StackObjPtrVreg =
MF.getRegInfo().createGenericVirtualRegister(PtrLLT);
+ MRI.setRegClass(StackObjPtrVreg, TLI.getRepRegClassFor(TLI.getPointerTy(
+ DL, StackAddrSpace)));
MIRBuilder.buildFrameIndex(StackObjPtrVreg, FI);
@@ -975,6 +1121,10 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (!Arg.Flags[0].isVarArg()) {
for (unsigned Part = 0; Part < NumParts; ++Part) {
+ auto NewArgReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part],
+ *TLI.getRegClassFor(NewVT));
+ if (Arg.Regs[Part] != NewArgReg)
+ MIRBuilder.buildCopy(NewArgReg, Arg.Regs[Part]);
CallInst.addUse(Arg.Regs[Part]);
}
++NumFixedArgs;
@@ -984,12 +1134,17 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (CallConv == CallingConv::Swift) {
Type *PtrTy = PointerType::getUnqual(Ctx);
LLT PtrLLT = getLLTForType(*PtrTy, DL);
+ auto &PtrRegClass = *TLI.getRegClassFor(TLI.getSimpleValueType(DL, PtrTy));
if (!HasSwiftSelfArg) {
- CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0));
+ auto NewUndefReg = MIRBuilder.buildUndef(PtrLLT).getReg(0);
+ MRI.setRegClass(NewUndefReg, &PtrRegClass);
+ CallInst.addUse(NewUndefReg);
}
if (!HasSwiftErrorArg) {
- CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0));
+ auto NewUndefReg = MIRBuilder.buildUndef(PtrLLT).getReg(0);
+ MRI.setRegClass(NewUndefReg, &PtrRegClass);
+ CallInst.addUse(NewUndefReg);
}
}
@@ -1019,12 +1174,14 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
unsigned NumBytes = CCInfo.getAlignedCallFrameSize();
auto StackAddrSpace = DL.getAllocaAddrSpace();
- auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL);
- auto SizeLLT = LLT::scalar(PtrLLT.getSizeInBits());
+ auto PtrLLT = LLT::pointer(StackAddrSpace, DL.getPointerSizeInBits(0));
+ auto SizeLLT = LLT::scalar(DL.getPointerSizeInBits(StackAddrSpace));
+ auto *PtrRegClass = TLI.getRegClassFor(TLI.getPointerTy(DL, StackAddrSpace));
if (Info.IsVarArg && NumBytes) {
Register VarArgStackPtr =
MF.getRegInfo().createGenericVirtualRegister(PtrLLT);
+ MRI.setRegClass(VarArgStackPtr, PtrRegClass);
MaybeAlign StackAlign = DL.getStackAlignment();
assert(StackAlign && "data layout string is missing stack alignment");
@@ -1063,110 +1220,20 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallInst.addUse(VarArgStackPtr);
} else if (Info.IsVarArg) {
- CallInst.addUse(MIRBuilder.buildConstant(PtrLLT, 0).getReg(0));
+ auto NewArgReg = MIRBuilder.buildConstant(PtrLLT, 0).getReg(0);
+ MRI.setRegClass(NewArgReg, PtrRegClass);
+ CallInst.addUse(NewArgReg);
}
if (IsIndirect) {
+ auto NewArgReg =
+ constrainRegToClass(MRI, TII, RBI, IndirectIdx, *PtrRegClass);
+ if (IndirectIdx != NewArgReg)
+ MIRBuilder.buildCopy(NewArgReg, IndirectIdx);
CallInst.addUse(IndirectIdx);
}
- MIRBuilder.insertInstr(CallInst);
-
- if (Info.LoweredTailCall) {
- return true;
- }
-
- if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) {
- SmallVector<EVT, 4> SplitEVTs;
- ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs);
- assert(Info.OrigRet.Regs.size() == SplitEVTs.size() &&
- "For each split Type there should be exactly one VReg.");
-
- SmallVector<ArgInfo, 8> SplitReturns;
-
- unsigned i = 0;
- for (auto SplitEVT : SplitEVTs) {
- Register CurVReg = Info.OrigRet.Regs[i];
- ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0};
- if (Info.CB) {
- setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB);
- } else {
- // we don't have a call base, so chances are we're looking at a libcall
- // (external symbol).
-
- // TODO: figure out how to get ALL the correct attributes
- auto &Flags = CurArgInfo.Flags[0];
- PointerType *PtrTy =
- dyn_cast<PointerType>(CurArgInfo.Ty->getScalarType());
- if (PtrTy) {
- Flags.setPointer();
- Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace());
- }
- Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty);
- Flags.setMemAlign(MemAlign);
- Flags.setOrigAlign(MemAlign);
- }
- splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv);
- ++i;
- }
-
- for (auto &Ret : SplitReturns) {
- EVT OrigVT = TLI.getValueType(DL, Ret.Ty);
- MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
- LLT OrigLLT = getLLTForType(*Ret.Ty, DL);
- LLT NewLLT = getLLTForMVT(NewVT);
-
- // If we need to split the type over multiple regs, check it's a scenario
- // we currently support.
- unsigned NumParts =
- TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT);
-
- ISD::ArgFlagsTy OrigFlags = Ret.Flags[0];
- Ret.Flags.clear();
-
- for (unsigned Part = 0; Part < NumParts; ++Part) {
- ISD::ArgFlagsTy Flags = OrigFlags;
- if (Part == 0) {
- Flags.setSplit();
- } else {
- Flags.setOrigAlign(Align(1));
- if (Part == NumParts - 1)
- Flags.setSplitEnd();
- }
-
- Ret.Flags.push_back(Flags);
- }
-
- Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end());
- if (NumParts != 1 || OrigVT != NewVT) {
- // If we can't directly assign the register, we need one or more
- // intermediate values.
- Ret.Regs.resize(NumParts);
-
- // For each split register, create and assign a vreg that will store
- // the incoming component of the larger value. These will later be
- // merged to form the final vreg.
- for (unsigned Part = 0; Part < NumParts; ++Part) {
- Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
- }
- buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT,
- Ret.Flags[0], Ret.Ty->isFloatingPointTy());
- }
-
- for (unsigned Part = 0; Part < NumParts; ++Part) {
- CallInst.addDef(Ret.Regs[Part]);
- }
- }
- }
-
- if (!Info.CanLowerReturn) {
- insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs,
- Info.DemoteRegister, Info.DemoteStackIndex);
-
- for (auto Reg : Info.OrigRet.Regs) {
- CallInst.addDef(Reg);
- }
- }
+ MIRBuilder.setInsertPt(MIRBuilder.getMBB(), SavedInsertPt);
return true;
}
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
index e69de29bb2d1d..e605c46aece85 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
@@ -0,0 +1,302 @@
+#include "WebAssemblyRegisterBankInfo.h"
+#include "WebAssemblySubtarget.h"
+#include "WebAssemblyTargetMachine.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
+
+#define GET_TARGET_REGBANK_IMPL
+
+#include "WebAssemblyGenRegisterBank.inc"
+
+namespace llvm {
+namespace WebAssembly {
+enum PartialMappingIdx {
+ PMI_None = -1,
+ PMI_I32 = 1,
+ PMI_I64,
+ PMI_F32,
+ PMI_F64,
+ PMI_Min = PMI_I32,
+};
+
+enum ValueMappingIdx {
+ InvalidIdx = 0,
+ I32Idx = 1,
+ I64Idx = 4,
+ F32Idx = 7,
+ F64Idx = 10
+};
+
+const RegisterBankInfo::PartialMapping PartMappings[]{{0, 32, I32RegBank},
+ {0, 64, I64RegBank},
+ {0, 32, F32RegBank},
+ {0, 64, F64RegBank}};
+
+const RegisterBankInfo::ValueMapping ValueMappings[] = {
+ // invalid
+ {nullptr, 0},
+ // up to 3 operands as I32
+ {&PartMappings[PMI_I32 - PMI_Min], 1},
+ {&PartMappings[PMI_I32 - PMI_Min], 1},
+ {&PartMappings[PMI_I32 - PMI_Min], 1},
+ // up to 3 operands as I64
+ {&PartMappings[PMI_I64 - PMI_Min], 1},
+ {&PartMappings[PMI_I64 - PMI_Min], 1},
+ {&PartMappings[PMI_I64 - PMI_Min], 1},
+ // up to 3 operands as F32
+ {&PartMappings[PMI_F32 - PMI_Min], 1},
+ {&PartMappings[PMI_F32 - PMI_Min], 1},
+ {&PartMappings[PMI_F32 - PMI_Min], 1},
+ // up to 3 operands as F64
+ {&PartMappings[PMI_F64 - PMI_Min], 1},
+ {&PartMappings[PMI_F64 - PMI_Min], 1},
+ {&PartMappings[PMI_F64 - PMI_Min], 1}};
+
+} // namespace WebAssembly
+} // namespace llvm
+
+using namespace llvm;
+
+WebAssemblyRegisterBankInfo::WebAssemblyRegisterBankInfo(
+ const TargetRegisterInfo &TRI) {}
+
+// Instructions where use operands are floating point registers.
+// Def operands are general purpose.
+static bool isFloatingPointOpcodeUse(unsigned Opc) {
+ switch (Opc) {
+ case TargetOpcode::G_FPTOSI:
+ case TargetOpcode::G_FPTOUI:
+ case TargetOpcode::G_FCMP:
+ return true;
+ default:
+ return isPreISelGenericFloatingPointOpcode(Opc);
+ }
+}
+
+// Instructions where def operands are floating point registers.
+// Use operands are general purpose.
+static bool isFloatingPointOpcodeDef(unsigned Opc) {
+ switch (Opc) {
+ case TargetOpcode::G_SITOFP:
+ case TargetOpcode::G_UITOFP:
+ return true;
+ default:
+ return isPreISelGenericFloatingPointOpcode(Opc);
+ }
+}
+
+static bool isAmbiguous(unsigned Opc) {
+ switch (Opc) {
+ case TargetOpcode::G_LOAD:
+ case TargetOpcode::G_STORE:
+ case TargetOpcode::G_PHI:
+ case TargetOpcode::G_SELECT:
+ case TargetOpcode::G_IMPLICIT_DEF:
+ case TargetOpcode::G_UNMERGE_VALUES:
+ case TargetOpcode::G_MERGE_VALUES:
+ return true;
+ default:
+ return false;
+ }
+}
+
+const RegisterBankInfo::InstructionMapping &
+WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
+
+ unsigned Opc = MI.getOpcode();
+ const MachineFunction &MF = *MI.getParent()->getParent();
+ const MachineRegisterInfo &MRI = MF.getRegInfo();
+ const TargetSubtargetInfo &STI = MF.getSubtarget();
+ const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
+
+ if ((Opc != TargetOpcode::COPY && !isPreISelGenericOpcode(Opc)) ||
+ Opc == TargetOpcode::G_PHI) {
+ const RegisterBankInfo::InstructionMapping &Mapping =
+ getInstrMappingImpl(MI);
+ if (Mapping.isValid())
+ return Mapping;
+ }
+
+ using namespace TargetOpcode;
+
+ unsigned NumOperands = MI.getNumOperands();
+ const ValueMapping *OperandsMapping = nullptr;
+ unsigned MappingID = DefaultMappingID;
+
+ // Check if LLT sizes match sizes of available register banks.
+ for (const MachineOperand &Op : MI.operands()) {
+ if (Op.isReg()) {
+ LLT RegTy = MRI.getType(Op.getReg());
+
+ if (RegTy.isScalar() &&
+ (RegTy.getSizeInBits() != 32 && RegTy.getSizeInBits() != 64))
+ return getInvalidInstructionMapping();
+
+ if (RegTy.isVector() && RegTy.getSizeInBits() != 128)
+ return getInvalidInstructionMapping();
+ }
+ }
+
+ switch (Opc) {
+ case G_BR:
+ return getInstructionMapping(MappingID, /*Cost=*/1,
+ getOperandsMapping({nullptr}), NumOperands);
+ case G_TRAP:
+ return getInstructionMapping(MappingID, /*Cost=*/1, nullptr, 0);
+ }
+
+ const LLT Op0Ty = MRI.getType(MI.getOperand(0).getReg());
+ unsigned Op0Size = Op0Ty.getSizeInBits();
+
+ auto &Op0IntValueMapping =
+ WebAssembly::ValueMappings[Op0Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+ auto &Op0FloatValueMapping =
+ WebAssembly::ValueMappings[Op0Size == 64 ? WebAssembly::F64Idx
+ : WebAssembly::F32Idx];
+ auto &Pointer0ValueMapping =
+ WebAssembly::ValueMappings[MI.getMF()->getDataLayout()
+ .getPointerSizeInBits(0) == 64
+ ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+
+ switch (Opc) {
+ case G_AND:
+ case G_OR:
+ case G_XOR:
+ case G_SHL:
+ case G_ASHR:
+ case G_LSHR:
+ case G_PTR_ADD:
+ case G_INTTOPTR:
+ case G_PTRTOINT:
+ case G_ADD:
+ case G_SUB:
+ case G_MUL:
+ case G_SDIV:
+ case G_SREM:
+ case G_UDIV:
+ case G_UREM:
+ OperandsMapping = &Op0IntValueMapping;
+ break;
+ case G_SEXT_INREG:
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, &Op0IntValueMapping, nullptr});
+ break;
+ case G_FRAME_INDEX:
+ OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
+ break;
+ case G_ZEXT:
+ case G_ANYEXT:
+ case G_SEXT:
+ case G_TRUNC: {
+ const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
+ unsigned Op1Size = Op1Ty.getSizeInBits();
+
+ auto &Op1IntValueMapping =
+ WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, &Op1IntValueMapping});
+ break;
+ }
+ case G_LOAD:
+ case G_STORE:
+ if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0)
+ break;
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, &Pointer0ValueMapping});
+ break;
+ case G_MEMCPY:
+ case G_MEMMOVE: {
+ if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0)
+ break;
+ if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0)
+ break;
+
+ const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg());
+ unsigned Op2Size = Op2Ty.getSizeInBits();
+ auto &Op2IntValueMapping =
+ WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+ OperandsMapping =
+ getOperandsMapping({&Pointer0ValueMapping, &Pointer0ValueMapping,
+ &Op2IntValueMapping, nullptr});
+ break;
+ }
+ case G_MEMSET: {
+ if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0)
+ break;
+ const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
+ unsigned Op1Size = Op1Ty.getSizeInBits();
+ auto &Op1IntValueMapping =
+ WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+
+ const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg());
+ unsigned Op2Size = Op1Ty.getSizeInBits();
+ auto &Op2IntValueMapping =
+ WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+
+ OperandsMapping =
+ getOperandsMapping({&Pointer0ValueMapping, &Op1IntValueMapping,
+ &Op2IntValueMapping, nullptr});
+ break;
+ }
+ case G_GLOBAL_VALUE:
+ case G_CONSTANT:
+ OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
+ break;
+ case G_IMPLICIT_DEF:
+ OperandsMapping = &Op0IntValueMapping;
+ break;
+ case G_ICMP: {
+ const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg());
+ unsigned Op2Size = Op2Ty.getSizeInBits();
+
+ auto &Op2IntValueMapping =
+ WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, nullptr, &Op2IntValueMapping,
+ &Op2IntValueMapping});
+ break;
+ }
+ case G_BRCOND:
+ OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
+ break;
+ case COPY: {
+ Register DstReg = MI.getOperand(0).getReg();
+ Register SrcReg = MI.getOperand(1).getReg();
+ // Check if one of the register is not a generic register.
+ if ((DstReg.isPhysical() || !MRI.getType(DstReg).isValid()) ||
+ (SrcReg.isPhysical() || !MRI.getType(SrcReg).isValid())) {
+ const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI);
+ const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI);
+ if (!DstRB)
+ DstRB = SrcRB;
+ else if (!SrcRB)
+ SrcRB = DstRB;
+ // If both RB are null that means both registers are generic.
+ // We shouldn't be here.
+ assert(DstRB && SrcRB && "Both RegBank were nullptr");
+ TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI);
+ TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI);
+ assert(DstSize == SrcSize &&
+ "Trying to copy between different sized regbanks? Why?");
+
+ return getInstructionMapping(
+ DefaultMappingID, copyCost(*DstRB, *SrcRB, DstSize),
+ getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
+ // We only care about the mapping of the destination.
+ /*NumOperands*/ 1);
+ }
+ }
+ }
+ if (!OperandsMapping)
+ return getInvalidInstructionMapping();
+
+ return getInstructionMapping(MappingID, /*Cost=*/1, OperandsMapping,
+ NumOperands);
+}
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h
index e69de29bb2d1d..f0d95b56ef861 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h
@@ -0,0 +1,40 @@
+//===- WebAssemblyRegisterBankInfo.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file declares the targeting of the RegisterBankInfo class for WASM.
+/// \todo This should be generated by TableGen.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H
+#define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H
+
+#include "llvm/CodeGen/RegisterBankInfo.h"
+
+#define GET_REGBANK_DECLARATIONS
+#include "WebAssemblyGenRegisterBank.inc"
+
+namespace llvm {
+
+class TargetRegisterInfo;
+
+class WebAssemblyGenRegisterBankInfo : public RegisterBankInfo {
+#define GET_TARGET_REGBANK_CLASS
+#include "WebAssemblyGenRegisterBank.inc"
+};
+
+/// This class provides the information for the target register banks.
+class WebAssemblyRegisterBankInfo final
+ : public WebAssemblyGenRegisterBankInfo {
+public:
+ WebAssemblyRegisterBankInfo(const TargetRegisterInfo &TRI);
+
+ const InstructionMapping &
+ getInstrMapping(const MachineInstr &MI) const override;
+};
+} // end namespace llvm
+#endif
diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.td b/llvm/lib/Target/WebAssembly/WebAssembly.td
index 089be5f1dc70e..3705a42fd21c9 100644
--- a/llvm/lib/Target/WebAssembly/WebAssembly.td
+++ b/llvm/lib/Target/WebAssembly/WebAssembly.td
@@ -101,6 +101,7 @@ def FeatureWideArithmetic :
//===----------------------------------------------------------------------===//
include "WebAssemblyRegisterInfo.td"
+include "WebAssemblyRegisterBanks.td"
//===----------------------------------------------------------------------===//
// Instruction Descriptions
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td
new file mode 100644
index 0000000000000..9ebece0e0bf09
--- /dev/null
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td
@@ -0,0 +1,20 @@
+//=- WebAssemblyRegisterBank.td - Describe the WASM Banks ----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+
+def I32RegBank : RegisterBank<"I32RegBank", [I32]>;
+def I64RegBank : RegisterBank<"I64RegBank", [I64]>;
+def F32RegBank : RegisterBank<"F64RegBank", [F32]>;
+def F64RegBank : RegisterBank<"F64RegBank", [F64]>;
+
+def EXTERNREFRegBank : RegisterBank<"EXTERNREFRegBank", [EXTERNREF]>;
+def FUNCREFRegBank : RegisterBank<"FUNCREFRegBank", [FUNCREF]>;
+def EXNREFRegBank : RegisterBank<"EXNREFRegBank", [EXNREF]>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp
index 3ea8b9f85819f..b99c35acabef6 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp
@@ -73,9 +73,9 @@ WebAssemblySubtarget::WebAssemblySubtarget(const Triple &TT,
TLInfo(TM, *this) {
CallLoweringInfo.reset(new WebAssemblyCallLowering(*getTargetLowering()));
Legalizer.reset(new WebAssemblyLegalizerInfo(*this));
- /*auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo());
+ auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo());
RegBankInfo.reset(RBI);
-
+/*
InstSelector.reset(createWebAssemblyInstructionSelector(
*static_cast<const WebAssemblyTargetMachine *>(&TM), *this, *RBI));*/
}
>From aef4448351d858db64f5e203c71f44a3edca30a7 Mon Sep 17 00:00:00 2001
From: Demetrius Kanios <demetrius at kanios.net>
Date: Mon, 29 Sep 2025 11:23:22 -0700
Subject: [PATCH 9/9] Finish initial pass over regbankselect
---
.../GISel/WebAssemblyCallLowering.cpp | 21 +-
.../GISel/WebAssemblyLegalizerInfo.cpp | 55 ++--
.../GISel/WebAssemblyRegisterBankInfo.cpp | 254 +++++++++++++-----
3 files changed, 241 insertions(+), 89 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
index 7f3c7e62ce02d..733d676ac988a 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp
@@ -33,6 +33,7 @@
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGenTypes/LowLevelType.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/DataLayout.h"
@@ -481,6 +482,7 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
LLT OrigLLT = getLLTForType(*Arg.Ty, DL);
LLT NewLLT = getLLTForMVT(NewVT);
+ const TargetRegisterClass &NewRegClass = *TLI.getRegClassFor(NewVT);
// If we need to split the type over multiple regs, check it's a scenario
// we currently support.
@@ -522,10 +524,12 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
- auto NewOutReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part],
- *TLI.getRegClassFor(NewVT));
- if (NewOutReg != Arg.Regs[Part])
+ auto NewOutReg = Arg.Regs[Part];
+ if (!RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI)) {
+ NewOutReg = MRI.createGenericVirtualRegister(NewLLT);
+ assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && "Couldn't constrain brand-new register?");
MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]);
+ }
MIB.addUse(NewOutReg);
}
}
@@ -864,6 +868,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT);
LLT OrigLLT = getLLTForType(*Ret.Ty, DL);
LLT NewLLT = getLLTForMVT(NewVT);
+ const TargetRegisterClass &NewRegClass = *TLI.getRegClassFor(NewVT);
// If we need to split the type over multiple regs, check it's a
// scenario we currently support.
@@ -903,12 +908,12 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
}
for (unsigned Part = 0; Part < NumParts; ++Part) {
- // MRI.setRegClass(Ret.Regs[Part], TLI.getRegClassFor(NewVT));
- auto NewRetReg = constrainRegToClass(MRI, TII, RBI, Ret.Regs[Part],
- *TLI.getRegClassFor(NewVT));
- if (Ret.Regs[Part] != NewRetReg)
+ auto NewRetReg = Ret.Regs[Part];
+ if (!RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI)) {
+ NewRetReg = MRI.createGenericVirtualRegister(NewLLT);
+ assert(RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI) && "Couldn't constrain brand-new register?");
MIRBuilder.buildCopy(NewRetReg, Ret.Regs[Part]);
-
+ }
CallInst.addDef(Ret.Regs[Part]);
}
}
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
index c6cd1c5b371e9..ae2ac0a512427 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp
@@ -89,14 +89,17 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
.clampScalar(0, s32, s64);
getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF,
- G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_FSHL,
- G_FSHR})
+ G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP})
.legalFor({{s32, s32}, {s64, s64}})
.widenScalarToNextPow2(0)
.clampScalar(0, s32, s64)
.minScalarSameAs(1, 0)
.maxScalarSameAs(1, 0);
+ getActionDefinitionsBuilder({G_FSHL, G_FSHR})
+ .legalFor({{s32, s32}, {s64, s64}})
+ .lower();
+
getActionDefinitionsBuilder({G_SCMP, G_UCMP}).lower();
getActionDefinitionsBuilder({G_AND, G_OR, G_XOR})
@@ -123,6 +126,12 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
.libcallFor({s32, s64})
.minScalar(0, s32);
+ getActionDefinitionsBuilder(G_LROUND).libcallForCartesianProduct({s32},
+ {s32, s64});
+
+ getActionDefinitionsBuilder(G_LLROUND).libcallForCartesianProduct({s64},
+ {s32, s64});
+
getActionDefinitionsBuilder(G_FCOPYSIGN)
.legalFor({s32, s64})
.minScalar(0, s32)
@@ -154,9 +163,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
{s64, p0, s8, 1},
{s64, p0, s16, 1},
{s64, p0, s32, 1}})
- .widenScalarToNextPow2(0)
- .lowerIfMemSizeNotByteSizePow2()
- .clampScalar(0, s32, s64);
+ .clampScalar(0, s32, s64)
+ .lowerIfMemSizeNotByteSizePow2();
getActionDefinitionsBuilder(G_STORE)
.legalForTypesWithMemDesc(
@@ -167,9 +175,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
{s64, p0, s8, 1},
{s64, p0, s16, 1},
{s64, p0, s32, 1}})
- .widenScalarToNextPow2(0)
- .lowerIfMemSizeNotByteSizePow2()
- .clampScalar(0, s32, s64);
+ .clampScalar(0, s32, s64)
+ .lowerIfMemSizeNotByteSizePow2();
getActionDefinitionsBuilder({G_ZEXTLOAD, G_SEXTLOAD})
.legalForTypesWithMemDesc({{s32, p0, s8, 1},
@@ -178,10 +185,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
{s64, p0, s8, 1},
{s64, p0, s16, 1},
{s64, p0, s32, 1}})
- .widenScalarToNextPow2(0)
- .lowerIfMemSizeNotByteSizePow2()
.clampScalar(0, s32, s64)
- .lower();
+ .lowerIfMemSizeNotByteSizePow2();
if (ST.hasBulkMemoryOpt()) {
getActionDefinitionsBuilder(G_BZERO).unsupported();
@@ -204,7 +209,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
// TODO: figure out how to combine G_ANYEXT of G_ASSERT_{S|Z}EXT (or
// appropriate G_AND and G_SEXT_IN_REG?) to a G_{S|Z}EXT + G_ASSERT_{S|Z}EXT
- // for better optimization (since G_ANYEXT lowers to a ZEXT or SEXT
+ // for better optimization (since G_ANYEXT will lower to a ZEXT or SEXT
// instruction anyway).
getActionDefinitionsBuilder(G_ANYEXT)
@@ -221,8 +226,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
if (ST.hasSignExt()) {
getActionDefinitionsBuilder(G_SEXT_INREG)
.clampScalar(0, s32, s64)
- .customFor({s32, s64})
- .lower();
+ .customFor({s32, s64});
} else {
getActionDefinitionsBuilder(G_SEXT_INREG).lower();
}
@@ -242,23 +246,42 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo(
.legalForCartesianProduct({s32, s64}, {p0})
.clampScalar(0, s32, s64);
+ getActionDefinitionsBuilder(G_DYN_STACKALLOC).lowerFor({{p0, p0s}});
+
+ getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).lower();
+
getLegacyLegalizerInfo().computeTables();
}
bool WebAssemblyLegalizerInfo::legalizeCustom(
LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const {
+ auto &MRI = *Helper.MIRBuilder.getMRI();
+ auto &MIRBuilder = Helper.MIRBuilder;
+
switch (MI.getOpcode()) {
case TargetOpcode::G_SEXT_INREG: {
+ assert(MI.getOperand(2).isImm() && "Expected immediate");
+
// Mark only 8/16/32-bit SEXT_INREG as legal
- auto [DstType, SrcType] = MI.getFirst2LLTs();
+ auto [DstReg, SrcReg] = MI.getFirst2Regs();
+ auto DstType = MRI.getType(DstReg);
auto ExtFromWidth = MI.getOperand(2).getImm();
if (ExtFromWidth == 8 || ExtFromWidth == 16 ||
(DstType.getScalarSizeInBits() == 64 && ExtFromWidth == 32)) {
return true;
}
- return false;
+
+ Register TmpRes = MRI.createGenericVirtualRegister(DstType);
+
+ auto MIBSz = MIRBuilder.buildConstant(
+ DstType, DstType.getScalarSizeInBits() - ExtFromWidth);
+ MIRBuilder.buildShl(TmpRes, SrcReg, MIBSz->getOperand(0));
+ MIRBuilder.buildAShr(DstReg, TmpRes, MIBSz->getOperand(0));
+ MI.eraseFromParent();
+
+ return true;
}
case TargetOpcode::G_MEMSET: {
// Anyext the value being set to 32 bit (only the bottom 8 bits are read by
diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
index e605c46aece85..fa4103a8b1b31 100644
--- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp
@@ -1,7 +1,9 @@
#include "WebAssemblyRegisterBankInfo.h"
+#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/Support/ErrorHandling.h"
#define GET_TARGET_REGBANK_IMPL
@@ -59,46 +61,6 @@ using namespace llvm;
WebAssemblyRegisterBankInfo::WebAssemblyRegisterBankInfo(
const TargetRegisterInfo &TRI) {}
-// Instructions where use operands are floating point registers.
-// Def operands are general purpose.
-static bool isFloatingPointOpcodeUse(unsigned Opc) {
- switch (Opc) {
- case TargetOpcode::G_FPTOSI:
- case TargetOpcode::G_FPTOUI:
- case TargetOpcode::G_FCMP:
- return true;
- default:
- return isPreISelGenericFloatingPointOpcode(Opc);
- }
-}
-
-// Instructions where def operands are floating point registers.
-// Use operands are general purpose.
-static bool isFloatingPointOpcodeDef(unsigned Opc) {
- switch (Opc) {
- case TargetOpcode::G_SITOFP:
- case TargetOpcode::G_UITOFP:
- return true;
- default:
- return isPreISelGenericFloatingPointOpcode(Opc);
- }
-}
-
-static bool isAmbiguous(unsigned Opc) {
- switch (Opc) {
- case TargetOpcode::G_LOAD:
- case TargetOpcode::G_STORE:
- case TargetOpcode::G_PHI:
- case TargetOpcode::G_SELECT:
- case TargetOpcode::G_IMPLICIT_DEF:
- case TargetOpcode::G_UNMERGE_VALUES:
- case TargetOpcode::G_MERGE_VALUES:
- return true;
- default:
- return false;
- }
-}
-
const RegisterBankInfo::InstructionMapping &
WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
@@ -135,13 +97,35 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
return getInvalidInstructionMapping();
}
}
-
switch (Opc) {
case G_BR:
return getInstructionMapping(MappingID, /*Cost=*/1,
getOperandsMapping({nullptr}), NumOperands);
case G_TRAP:
- return getInstructionMapping(MappingID, /*Cost=*/1, nullptr, 0);
+ case G_DEBUGTRAP:
+ return getInstructionMapping(MappingID, /*Cost=*/1, getOperandsMapping({}),
+ 0);
+ case COPY:
+ Register DstReg = MI.getOperand(0).getReg();
+ if (DstReg.isPhysical()) {
+ if (DstReg.id() == WebAssembly::SP32) {
+ return getInstructionMapping(
+ MappingID, /*Cost=*/1,
+ getOperandsMapping(
+ {&WebAssembly::ValueMappings[WebAssembly::I32Idx]}),
+ 1);
+ } else if (DstReg.id() == WebAssembly::SP64) {
+ return getInstructionMapping(
+ MappingID, /*Cost=*/1,
+ getOperandsMapping(
+ {&WebAssembly::ValueMappings[WebAssembly::I64Idx]}),
+ 1);
+ } else {
+ llvm_unreachable("Trying to copy into WASM physical register other "
+ "than sp32 or sp64?");
+ }
+ }
+ break;
}
const LLT Op0Ty = MRI.getType(MI.getOperand(0).getReg());
@@ -176,8 +160,39 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case G_SREM:
case G_UDIV:
case G_UREM:
+ case G_CTLZ:
+ case G_CTLZ_ZERO_UNDEF:
+ case G_CTTZ:
+ case G_CTTZ_ZERO_UNDEF:
+ case G_CTPOP:
+ case G_FSHL:
+ case G_FSHR:
OperandsMapping = &Op0IntValueMapping;
break;
+ case G_FADD:
+ case G_FSUB:
+ case G_FDIV:
+ case G_FMUL:
+ case G_FNEG:
+ case G_FABS:
+ case G_FCEIL:
+ case G_FFLOOR:
+ case G_FSQRT:
+ case G_INTRINSIC_TRUNC:
+ case G_FNEARBYINT:
+ case G_FRINT:
+ case G_INTRINSIC_ROUNDEVEN:
+ case G_FMINIMUM:
+ case G_FMAXIMUM:
+ case G_FMINNUM:
+ case G_FMAXNUM:
+ case G_FMINNUM_IEEE:
+ case G_FMAXNUM_IEEE:
+ case G_FMA:
+ case G_FREM:
+ case G_FCOPYSIGN:
+ OperandsMapping = &Op0FloatValueMapping;
+ break;
case G_SEXT_INREG:
OperandsMapping =
getOperandsMapping({&Op0IntValueMapping, &Op0IntValueMapping, nullptr});
@@ -185,6 +200,9 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case G_FRAME_INDEX:
OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
break;
+ case G_VASTART:
+ OperandsMapping = &Op0IntValueMapping;
+ break;
case G_ZEXT:
case G_ANYEXT:
case G_SEXT:
@@ -233,7 +251,7 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
: WebAssembly::I32Idx];
const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg());
- unsigned Op2Size = Op1Ty.getSizeInBits();
+ unsigned Op2Size = Op2Ty.getSizeInBits();
auto &Op2IntValueMapping =
WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx
: WebAssembly::I32Idx];
@@ -247,6 +265,9 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
case G_CONSTANT:
OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
break;
+ case G_FCONSTANT:
+ OperandsMapping = getOperandsMapping({&Op0FloatValueMapping, nullptr});
+ break;
case G_IMPLICIT_DEF:
OperandsMapping = &Op0IntValueMapping;
break;
@@ -263,37 +284,140 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
&Op2IntValueMapping});
break;
}
+ case G_FCMP: {
+ const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg());
+ unsigned Op2Size = Op2Ty.getSizeInBits();
+
+ auto &Op2FloatValueMapping =
+ WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::F64Idx
+ : WebAssembly::F32Idx];
+
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, nullptr, &Op2FloatValueMapping,
+ &Op2FloatValueMapping});
+ break;
+ }
case G_BRCOND:
OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
break;
+ case G_JUMP_TABLE:
+ OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr});
+ break;
+ case G_BRJT:
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, nullptr,
+ &WebAssembly::ValueMappings[WebAssembly::I32Idx]});
+ break;
case COPY: {
Register DstReg = MI.getOperand(0).getReg();
Register SrcReg = MI.getOperand(1).getReg();
- // Check if one of the register is not a generic register.
- if ((DstReg.isPhysical() || !MRI.getType(DstReg).isValid()) ||
- (SrcReg.isPhysical() || !MRI.getType(SrcReg).isValid())) {
- const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI);
- const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI);
- if (!DstRB)
- DstRB = SrcRB;
- else if (!SrcRB)
- SrcRB = DstRB;
- // If both RB are null that means both registers are generic.
- // We shouldn't be here.
- assert(DstRB && SrcRB && "Both RegBank were nullptr");
- TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI);
- TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI);
- assert(DstSize == SrcSize &&
- "Trying to copy between different sized regbanks? Why?");
-
- return getInstructionMapping(
- DefaultMappingID, copyCost(*DstRB, *SrcRB, DstSize),
- getCopyMapping(DstRB->getID(), SrcRB->getID(), Size),
- // We only care about the mapping of the destination.
- /*NumOperands*/ 1);
+
+ const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI);
+ const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI);
+
+ if (!DstRB)
+ DstRB = SrcRB;
+ else if (!SrcRB)
+ SrcRB = DstRB;
+
+ assert(DstRB && SrcRB && "Both RegBank were nullptr");
+ TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI);
+ TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI);
+ assert(DstSize == SrcSize &&
+ "Trying to copy between different sized regbanks? Why?");
+
+ WebAssembly::ValueMappingIdx DstValMappingIdx = WebAssembly::InvalidIdx;
+ switch (DstRB->getID()) {
+ case WebAssembly::I32RegBankID:
+ DstValMappingIdx = WebAssembly::I32Idx;
+ break;
+ case WebAssembly::I64RegBankID:
+ DstValMappingIdx = WebAssembly::I64Idx;
+ break;
+ case WebAssembly::F32RegBankID:
+ DstValMappingIdx = WebAssembly::F32Idx;
+ break;
+ case WebAssembly::F64RegBankID:
+ DstValMappingIdx = WebAssembly::F64Idx;
+ break;
+ default:
+ break;
+ }
+
+ WebAssembly::ValueMappingIdx SrcValMappingIdx = WebAssembly::InvalidIdx;
+ switch (SrcRB->getID()) {
+ case WebAssembly::I32RegBankID:
+ SrcValMappingIdx = WebAssembly::I32Idx;
+ break;
+ case WebAssembly::I64RegBankID:
+ SrcValMappingIdx = WebAssembly::I64Idx;
+ break;
+ case WebAssembly::F32RegBankID:
+ SrcValMappingIdx = WebAssembly::F32Idx;
+ break;
+ case WebAssembly::F64RegBankID:
+ SrcValMappingIdx = WebAssembly::F64Idx;
+ break;
+ default:
+ break;
}
+
+ OperandsMapping =
+ getOperandsMapping({&WebAssembly::ValueMappings[DstValMappingIdx],
+ &WebAssembly::ValueMappings[SrcValMappingIdx]});
+ return getInstructionMapping(
+ MappingID, /*Cost=*/copyCost(*DstRB, *SrcRB, DstSize), OperandsMapping,
+ // We only care about the mapping of the destination for COPY.
+ 1);
+ }
+ case G_SELECT:
+ OperandsMapping = getOperandsMapping(
+ {&Op0IntValueMapping, &WebAssembly::ValueMappings[WebAssembly::I32Idx],
+ &Op0IntValueMapping, &Op0IntValueMapping});
+ break;
+ case G_FPTOSI:
+ case G_FPTOSI_SAT:
+ case G_FPTOUI:
+ case G_FPTOUI_SAT: {
+ const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
+ unsigned Op1Size = Op1Ty.getSizeInBits();
+
+ auto &Op1FloatValueMapping =
+ WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::F64Idx
+ : WebAssembly::F32Idx];
+
+ OperandsMapping =
+ getOperandsMapping({&Op0IntValueMapping, &Op1FloatValueMapping});
+ break;
}
+ case G_SITOFP:
+ case G_UITOFP: {
+ const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
+ unsigned Op1Size = Op1Ty.getSizeInBits();
+
+ auto &Op1IntValueMapping =
+ WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx
+ : WebAssembly::I32Idx];
+
+ OperandsMapping =
+ getOperandsMapping({&Op0FloatValueMapping, &Op1IntValueMapping});
+ break;
}
+ case G_FPEXT:
+ case G_FPTRUNC: {
+ const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg());
+ unsigned Op1Size = Op1Ty.getSizeInBits();
+
+ auto &Op1FloatValueMapping =
+ WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::F64Idx
+ : WebAssembly::F32Idx];
+
+ OperandsMapping =
+ getOperandsMapping({&Op0FloatValueMapping, &Op1FloatValueMapping});
+ break;
+ }
+ }
+
if (!OperandsMapping)
return getInvalidInstructionMapping();
More information about the llvm-commits
mailing list