[llvm] [SPIR-V] Rework usage of virtual registers' types and classes (PR #104104)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 14 10:59:21 PDT 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/104104
This PR continues https://github.com/llvm/llvm-project/pull/101732 changes in virtual register processing aimed to improve correctness of emitted MIR between passes from the perspective of MachineVerifier.
>From 6e959d3767f60f929f647555b485ce73e38995a8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 14 Aug 2024 10:58:03 -0700
Subject: [PATCH] Rework usage of virtual registers' types and classes
---
.../SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp | 3 +-
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 60 ++++++------
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 97 ++++++++-----------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +-
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 10 +-
llvm/lib/Target/SPIRV/SPIRVISelLowering.h | 2 +-
llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp | 9 +-
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 64 ++++++------
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 50 +++++-----
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 14 +--
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 27 +++---
llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td | 46 +++------
llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp | 2 +-
14 files changed, 171 insertions(+), 221 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
index 6dd0df2a104c0f..42567f695395ef 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
@@ -67,7 +67,8 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
// Check if we define an ID, and take a type as operand 1.
auto &DefOpInfo = MCDesc.operands()[0];
auto &FirstArgOpInfo = MCDesc.operands()[1];
- return DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
+ return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
+ DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
}
return false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8fa5106cef32e9..e98067ed408201 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -443,7 +443,7 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
if (!DestinationReg.isValid()) {
DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
- MRI->setType(DestinationReg, LLT::scalar(32));
+ MRI->setType(DestinationReg, LLT::scalar(64));
GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
}
// TODO: consider using correct address space and alignment (p0 is canonical
@@ -526,11 +526,11 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
report_fatal_error("Unknown CL memory scope");
}
-static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
- SPIRVGlobalRegistry *GR,
- unsigned BitWidth = 32) {
- SPIRVType *IntType = GR->getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
- return GR->buildConstantInt(Val, MIRBuilder, IntType);
+static Register buildConstantIntReg32(uint64_t Val,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ return GR->buildConstantInt(Val, MIRBuilder,
+ GR->getOrCreateSPIRVIntegerType(32, MIRBuilder));
}
static Register buildScopeReg(Register CLScopeRegister,
@@ -548,7 +548,7 @@ static Register buildScopeReg(Register CLScopeRegister,
return CLScopeRegister;
}
}
- return buildConstantIntReg(Scope, MIRBuilder, GR);
+ return buildConstantIntReg32(Scope, MIRBuilder, GR);
}
static Register buildMemSemanticsReg(Register SemanticsRegister,
@@ -568,7 +568,7 @@ static Register buildMemSemanticsReg(Register SemanticsRegister,
return SemanticsRegister;
}
}
- return buildConstantIntReg(Semantics, MIRBuilder, GR);
+ return buildConstantIntReg32(Semantics, MIRBuilder, GR);
}
static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
@@ -625,7 +625,7 @@ static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
ScopeRegister = Call->Arguments[1];
MIRBuilder.getMRI()->setRegClass(ScopeRegister, &SPIRV::iIDRegClass);
} else
- ScopeRegister = buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
+ ScopeRegister = buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
Register MemSemanticsReg;
if (Call->Arguments.size() > 2) {
@@ -636,7 +636,7 @@ static bool buildAtomicLoadInst(const SPIRV::IncomingCall *Call,
int Semantics =
SPIRV::MemorySemantics::SequentiallyConsistent |
getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
- MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
+ MemSemanticsReg = buildConstantIntReg32(Semantics, MIRBuilder, GR);
}
MIRBuilder.buildInstr(SPIRV::OpAtomicLoad)
@@ -656,13 +656,13 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
Register ScopeRegister =
- buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
+ buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
Register PtrRegister = Call->Arguments[0];
MIRBuilder.getMRI()->setRegClass(PtrRegister, &SPIRV::iIDRegClass);
int Semantics =
SPIRV::MemorySemantics::SequentiallyConsistent |
getMemSemanticsForStorageClass(GR->getPointerStorageClass(PtrRegister));
- Register MemSemanticsReg = buildConstantIntReg(Semantics, MIRBuilder, GR);
+ Register MemSemanticsReg = buildConstantIntReg32(Semantics, MIRBuilder, GR);
MIRBuilder.getMRI()->setRegClass(Call->Arguments[1], &SPIRV::iIDRegClass);
MIRBuilder.buildInstr(SPIRV::OpAtomicStore)
.addUse(PtrRegister)
@@ -733,9 +733,9 @@ static bool buildAtomicCompareExchangeInst(
MRI->setRegClass(Call->Arguments[4], &SPIRV::iIDRegClass);
}
if (!MemSemEqualReg.isValid())
- MemSemEqualReg = buildConstantIntReg(MemSemEqual, MIRBuilder, GR);
+ MemSemEqualReg = buildConstantIntReg32(MemSemEqual, MIRBuilder, GR);
if (!MemSemUnequalReg.isValid())
- MemSemUnequalReg = buildConstantIntReg(MemSemUnequal, MIRBuilder, GR);
+ MemSemUnequalReg = buildConstantIntReg32(MemSemUnequal, MIRBuilder, GR);
Register ScopeReg;
auto Scope = IsCmpxchg ? SPIRV::Scope::Workgroup : SPIRV::Scope::Device;
@@ -750,12 +750,12 @@ static bool buildAtomicCompareExchangeInst(
MRI->setRegClass(Call->Arguments[5], &SPIRV::iIDRegClass);
}
if (!ScopeReg.isValid())
- ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
+ ScopeReg = buildConstantIntReg32(Scope, MIRBuilder, GR);
Register Expected = IsCmpxchg
? ExpectedArg
: buildLoadInst(SpvDesiredTy, ExpectedArg, MIRBuilder,
- GR, LLT::scalar(32));
+ GR, LLT::scalar(64));
MRI->setType(Expected, DesiredLLT);
Register Tmp = !IsCmpxchg ? MRI->createGenericVirtualRegister(DesiredLLT)
: Call->ReturnRegister;
@@ -941,7 +941,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MemSemanticsReg = Call->Arguments[0];
MRI->setRegClass(MemSemanticsReg, &SPIRV::iIDRegClass);
} else
- MemSemanticsReg = buildConstantIntReg(MemSemantics, MIRBuilder, GR);
+ MemSemanticsReg = buildConstantIntReg32(MemSemantics, MIRBuilder, GR);
Register ScopeReg;
SPIRV::Scope::Scope Scope = SPIRV::Scope::Workgroup;
@@ -967,11 +967,11 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
}
if (!ScopeReg.isValid())
- ScopeReg = buildConstantIntReg(Scope, MIRBuilder, GR);
+ ScopeReg = buildConstantIntReg32(Scope, MIRBuilder, GR);
auto MIB = MIRBuilder.buildInstr(Opcode).addUse(ScopeReg);
if (Opcode != SPIRV::OpMemoryBarrier)
- MIB.addUse(buildConstantIntReg(MemScope, MIRBuilder, GR));
+ MIB.addUse(buildConstantIntReg32(MemScope, MIRBuilder, GR));
MIB.addUse(MemSemanticsReg);
return true;
}
@@ -1133,7 +1133,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
auto Scope = Builtin->Name.starts_with("sub_group") ? SPIRV::Scope::Subgroup
: SPIRV::Scope::Workgroup;
- Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
+ Register ScopeRegister = buildConstantIntReg32(Scope, MIRBuilder, GR);
// Build work/sub group instruction.
auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
@@ -1302,7 +1302,7 @@ static bool generateKernelClockInst(const SPIRV::IncomingCall *Call,
.EndsWith("device", SPIRV::Scope::Scope::Device)
.EndsWith("work_group", SPIRV::Scope::Scope::Workgroup)
.EndsWith("sub_group", SPIRV::Scope::Scope::Subgroup);
- Register ScopeReg = buildConstantIntReg(ScopeArg, MIRBuilder, GR);
+ Register ScopeReg = buildConstantIntReg32(ScopeArg, MIRBuilder, GR);
MIRBuilder.buildInstr(SPIRV::OpReadClockKHR)
.addDef(ResultReg)
@@ -1617,7 +1617,7 @@ static bool generateImageSizeQueryInst(const SPIRV::IncomingCall *Call,
.addUse(GR->getSPIRVTypeID(QueryResultType))
.addUse(Call->Arguments[0]);
if (!IsDimBuf)
- MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Lod id.
+ MIB.addUse(buildConstantIntReg32(0, MIRBuilder, GR)); // Lod id.
if (NumExpectedRetComponents == NumActualRetComponents)
return true;
if (NumExpectedRetComponents == 1) {
@@ -2110,8 +2110,8 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
GEPInst
.addImm(GepMI->getOperand(2).getImm()) // In bound.
.addUse(ArrayMI->getOperand(0).getReg()) // Alloca.
- .addUse(buildConstantIntReg(0, MIRBuilder, GR)) // Indices.
- .addUse(buildConstantIntReg(I, MIRBuilder, GR));
+ .addUse(buildConstantIntReg32(0, MIRBuilder, GR)) // Indices.
+ .addUse(buildConstantIntReg32(I, MIRBuilder, GR));
LocalSizes.push_back(Reg);
}
}
@@ -2128,7 +2128,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
// If there are no event arguments in the original call, add dummy ones.
if (!HasEvents) {
- MIB.addUse(buildConstantIntReg(0, MIRBuilder, GR)); // Dummy num events.
+ MIB.addUse(buildConstantIntReg32(0, MIRBuilder, GR)); // Dummy num events.
Register NullPtr = GR->getOrCreateConstNullPtr(
MIRBuilder, getOrCreateSPIRVDeviceEventPointer(MIRBuilder, GR));
MIB.addUse(NullPtr); // Dummy wait events.
@@ -2147,10 +2147,10 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
// TODO: these numbers should be obtained from block literal structure.
// Param Size: Size of block literal structure.
- MIB.addUse(buildConstantIntReg(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+ MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
// Param Aligment: Aligment of block literal structure.
MIB.addUse(
- buildConstantIntReg(DL.getPrefTypeAlign(PType).value(), MIRBuilder, GR));
+ buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(), MIRBuilder, GR));
for (unsigned i = 0; i < LocalSizes.size(); i++)
MIB.addUse(LocalSizes[i]);
@@ -2218,7 +2218,7 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
return buildOpFromWrapper(MIRBuilder, Opcode, Call,
IsSet ? TypeReg : Register(0));
- auto Scope = buildConstantIntReg(SPIRV::Scope::Workgroup, MIRBuilder, GR);
+ auto Scope = buildConstantIntReg32(SPIRV::Scope::Workgroup, MIRBuilder, GR);
switch (Opcode) {
case SPIRV::OpGroupAsyncCopy: {
@@ -2238,7 +2238,7 @@ static bool generateAsyncCopy(const SPIRV::IncomingCall *Call,
.addUse(Call->Arguments[2])
.addUse(Call->Arguments.size() > 4
? Call->Arguments[3]
- : buildConstantIntReg(1, MIRBuilder, GR))
+ : buildConstantIntReg32(1, MIRBuilder, GR))
.addUse(EventReg);
if (NewType != nullptr)
insertAssignInstr(Call->ReturnRegister, nullptr, NewType, GR, MIRBuilder,
@@ -2513,7 +2513,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
MIRBuilder.getMRI()->setRegClass(ReturnRegister, &SPIRV::iIDRegClass);
} else if (OrigRetTy && OrigRetTy->isVoidTy()) {
ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
- MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(32));
+ MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 316abe866a163c..26d84c8efb10ba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -371,7 +371,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
}
auto MRI = MIRBuilder.getMRI();
- Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
@@ -557,7 +557,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
for (const Argument &Arg : CF->args()) {
if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
continue; // Don't handle zero sized types.
- Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(Reg, &SPIRV::iIDRegClass);
ToInsert.push_back({Reg});
VRegArgs.push_back(ToInsert.back());
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6702a0efc638ae..621b6bdf71cd06 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -72,17 +72,14 @@ void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
VRegToTypeMap[&MF][VReg] = SpirvType;
}
-static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
- auto &MRI = MIRBuilder.getMF().getRegInfo();
- auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
+static Register createTypeVReg(MachineRegisterInfo &MRI) {
+ auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
}
-static Register createTypeVReg(MachineRegisterInfo &MRI) {
- auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
- MRI.setRegClass(Res, &SPIRV::TYPERegClass);
- return Res;
+inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
+ return createTypeVReg(MIRBuilder.getMF().getRegInfo());
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
@@ -157,26 +154,24 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
return MIB;
}
-std::tuple<Register, ConstantInt *, bool>
+std::tuple<Register, ConstantInt *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
- const IntegerType *LLVMIntTy;
- if (SpvType)
- LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
- else
- LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
+ assert(SpvType);
+ const IntegerType *LLVMIntTy =
+ cast<IntegerType>(getTypeForSPIRVType(SpvType));
+ unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
- unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
- LLT LLTy = LLT::scalar(32);
- Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+ Res =
+ CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
@@ -185,7 +180,7 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
DT.add(CI, CurMF, Res);
NewInstr = true;
}
- return std::make_tuple(Res, CI, NewInstr);
+ return std::make_tuple(Res, CI, NewInstr, BitWidth);
}
std::tuple<Register, ConstantFP *, bool, unsigned>
@@ -193,27 +188,19 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
- const Type *LLVMFloatTy;
+ assert(SpvType);
LLVMContext &Ctx = CurMF->getFunction().getContext();
- unsigned BitWidth = 32;
- if (SpvType)
- LLVMFloatTy = getTypeForSPIRVType(SpvType);
- else {
- LLVMFloatTy = Type::getFloatTy(Ctx);
- if (MIRBuilder)
- SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
- }
+ const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType);
+ unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
auto *const CI = ConstantFP::get(Ctx, Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
- if (SpvType)
- BitWidth = getScalarOrVectorBitWidth(SpvType);
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
- LLT LLTy = LLT::scalar(32);
- Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+ Res =
+ CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
@@ -269,7 +256,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
ConstantInt *CI;
Register Res;
bool New;
- std::tie(Res, CI, New) =
+ unsigned BitWidth;
+ std::tie(Res, CI, New, BitWidth) =
getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
// If we have found Res register which is defined by the passed G_CONSTANT
// machine instruction, a new constant instruction should be created.
@@ -281,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
- addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
+ addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(Res)
@@ -297,19 +285,16 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType,
bool EmitIR) {
+ assert(SpvType);
auto &MF = MIRBuilder.getMF();
- const IntegerType *LLVMIntTy;
- if (SpvType)
- LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
- else
- LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
+ const IntegerType *LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
// Find a constant in DT or build a new one.
const auto ConstInt =
ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(ConstInt, &MF);
if (!Res.isValid()) {
- unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
- LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
+ unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
+ LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); // lev
Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
@@ -318,18 +303,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
if (EmitIR) {
MIRBuilder.buildConstant(Res, *ConstInt);
} else {
- if (!SpvType)
- SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
+ Register SpvTypeReg = getSPIRVTypeID(SpvType);
MachineInstrBuilder MIB;
if (Val) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
- .addUse(getSPIRVTypeID(SpvType));
+ .addUse(SpvTypeReg);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
- .addUse(getSPIRVTypeID(SpvType));
+ .addUse(SpvTypeReg);
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
@@ -353,7 +337,8 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
const auto ConstFP = ConstantFP::get(Ctx, Val);
Register Res = DT.find(ConstFP, &MF);
if (!Res.isValid()) {
- Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
+ Res = MF.getRegInfo().createGenericVirtualRegister(
+ LLT::scalar(getScalarOrVectorBitWidth(SpvType)));
MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);
@@ -407,7 +392,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
- LLT LLTy = LLT::scalar(32);
+ LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
@@ -509,7 +494,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
}
- LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
+ LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
@@ -650,7 +635,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
// Set to Reg the same type as ResVReg has.
auto MRI = MIRBuilder.getMRI();
- assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
+// assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
if (Reg != ResVReg) {
LLT RegLLTy =
LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
@@ -706,8 +691,9 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
bool EmitIR) {
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
+ SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
Register NumElementsVReg =
- buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
+ buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
@@ -1188,14 +1174,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);
+ SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
SPIRVType *SpirvTy =
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ElemType))
- .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
- .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
- .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
- .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
+ .addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true))
+ .addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true))
+ .addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true))
+ .addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true));
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}
@@ -1386,8 +1373,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
- SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
- Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
+ SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
+ Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
@@ -1436,7 +1423,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
Register Res = DT.find(UV, CurMF);
if (Res.isValid())
return Res;
- LLT LLTy = LLT::scalar(32);
+ LLT LLTy = LLT::scalar(64);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 821c1218fcb7f0..290c7c25831ddb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -430,7 +430,7 @@ class SPIRVGlobalRegistry {
getOrCreateSpecialType(const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual);
- std::tuple<Register, ConstantInt *, bool> getOrCreateConstIntReg(
+ std::tuple<Register, ConstantInt *, bool, unsigned> getOrCreateConstIntReg(
uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
std::tuple<Register, ConstantFP *, bool, unsigned> getOrCreateConstFloatReg(
@@ -455,7 +455,7 @@ class SPIRVGlobalRegistry {
public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
- SPIRVType *SpvType = nullptr, bool EmitIR = true);
+ SPIRVType *SpvType, bool EmitIR = true);
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 76419faa38e090..8db9808bb87e1d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -91,13 +91,9 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0u, RC);
if (VT.isFloatingPoint())
- RC = VT.isVector() ? &SPIRV::vfIDRegClass
- : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass
- : &SPIRV::fIDRegClass);
+ RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
else if (VT.isInteger())
- RC = VT.isVector() ? &SPIRV::vIDRegClass
- : (VT.getScalarSizeInBits() > 32 ? &SPIRV::iID64RegClass
- : &SPIRV::iIDRegClass);
+ RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
else
RC = &SPIRV::iIDRegClass;
@@ -115,7 +111,7 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
SPIRVGlobalRegistry &GR, MachineInstr &I,
Register OpReg, unsigned OpIdx,
SPIRVType *NewPtrType) {
- Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MachineIRBuilder MIB(I);
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
.addDef(NewReg)
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index 77356b7512a739..37a2f4aa99c034 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -44,7 +44,7 @@ class SPIRVTargetLowering : public TargetLowering {
// This is to prevent sexts of non-i64 vector indices which are generated
// within general IRTranslator hence type generation for it is omitted.
MVT getVectorIdxTy(const DataLayout &DL) const override {
- return MVT::getIntegerVT(32);
+ return MVT::getIntegerVT(32); // lev
}
unsigned getNumRegistersForCallingConv(LLVMContext &Context,
CallingConv::ID CC,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index 12cf7613a45cf3..dac7640cdddd69 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -256,12 +256,9 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
}
bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
- if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_ID64 ||
- MI.getOpcode() == SPIRV::GET_fID || MI.getOpcode() == SPIRV::GET_fID64 ||
- MI.getOpcode() == SPIRV::GET_pID32 ||
- MI.getOpcode() == SPIRV::GET_pID64 || MI.getOpcode() == SPIRV::GET_vfID ||
- MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID32 ||
- MI.getOpcode() == SPIRV::GET_vpID64) {
+ if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
+ MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
+ MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
auto &MRI = MI.getMF()->getRegInfo();
MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
MI.eraseFromParent();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index c4b09dd6bfe430..2298011d0d656a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -15,18 +15,14 @@ include "SPIRVSymbolicOperands.td"
// Codegen only metadata instructions
let isCodeGenOnly=1 in {
- def ASSIGN_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>;
- def DECL_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>;
- def GET_ID: Pseudo<(outs iID:$dst_id), (ins ANYID:$src)>;
- def GET_ID64: Pseudo<(outs iID64:$dst_id), (ins ANYID:$src)>;
- def GET_fID: Pseudo<(outs fID:$dst_id), (ins ANYID:$src)>;
- def GET_fID64: Pseudo<(outs fID64:$dst_id), (ins ANYID:$src)>;
- def GET_pID32: Pseudo<(outs pID32:$dst_id), (ins ANYID:$src)>;
- def GET_pID64: Pseudo<(outs pID64:$dst_id), (ins ANYID:$src)>;
- def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>;
- def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>;
- def GET_vpID32: Pseudo<(outs vpID32:$dst_id), (ins ANYID:$src)>;
- def GET_vpID64: Pseudo<(outs vpID64:$dst_id), (ins ANYID:$src)>;
+ def ASSIGN_TYPE: Pseudo<(outs ID:$dst_id), (ins ID:$src_id, TYPE:$src_ty)>;
+ def DECL_TYPE: Pseudo<(outs ID:$dst_id), (ins ID:$src_id, TYPE:$src_ty)>;
+ def GET_ID: Pseudo<(outs iID:$dst_id), (ins ID:$src)>;
+ def GET_fID: Pseudo<(outs fID:$dst_id), (ins ID:$src)>;
+ def GET_pID: Pseudo<(outs pID:$dst_id), (ins ID:$src)>;
+ def GET_vID: Pseudo<(outs vID:$dst_id), (ins ID:$src)>;
+ def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ID:$src)>;
+ def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins ID:$src)>;
}
def SPVTypeBin : SDTypeProfile<1, 2, []>;
@@ -36,16 +32,18 @@ def assigntype : SDNode<"SPIRVISD::AssignType", SPVTypeBin>;
def : GINodeEquiv<ASSIGN_TYPE, assigntype>;
class BinOp<string name, bits<16> opCode, list<dag> pattern=[]>
- : Op<opCode, (outs ANYID:$dst), (ins TYPE:$src_ty, ANYID:$src, ANYID:$src2),
+ : Op<opCode, (outs ID:$dst), (ins TYPE:$src_ty, ID:$src, ID:$src2),
"$dst = "#name#" $src_ty $src $src2", pattern>;
class BinOpTyped<string name, bits<16> opCode, RegisterClass CID, SDNode node>
: Op<opCode, (outs iID:$dst), (ins TYPE:$src_ty, CID:$src, CID:$src2),
- "$dst = "#name#" $src_ty $src $src2", [(set iID:$dst, (assigntype (node CID:$src, CID:$src2), TYPE:$src_ty))]>;
+ "$dst = "#name#" $src_ty $src $src2",
+ [(set iID:$dst, (assigntype (node CID:$src, CID:$src2), TYPE:$src_ty))]>;
class TernOpTyped<string name, bits<16> opCode, RegisterClass CCond, RegisterClass CID, SDNode node>
: Op<opCode, (outs iID:$dst), (ins TYPE:$src_ty, CCond:$cond, CID:$src1, CID:$src2),
- "$dst = "#name#" $src_ty $cond $src1 $src2", [(set iID:$dst, (assigntype (node CCond:$cond, CID:$src1, CID:$src2), TYPE:$src_ty))]>;
+ "$dst = "#name#" $src_ty $cond $src1 $src2",
+ [(set iID:$dst, (assigntype (node CCond:$cond, CID:$src1, CID:$src2), TYPE:$src_ty))]>;
multiclass BinOpTypedGen<string name, bits<16> opCode, SDNode node, bit genF = 0, bit genV = 0> {
if genF then
@@ -70,10 +68,8 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genP =
def SIVCond: TernOpTyped<name, opCode, vID, iID, node>;
}
if genP then {
- def SPSCond32: TernOpTyped<name, opCode, iID, pID32, node>;
- def SPVCond32: TernOpTyped<name, opCode, vID, pID32, node>;
- def SPSCond64: TernOpTyped<name, opCode, iID, pID64, node>;
- def SPVCond64: TernOpTyped<name, opCode, vID, pID64, node>;
+ def SPSCond: TernOpTyped<name, opCode, iID, pID, node>;
+ def SPVCond: TernOpTyped<name, opCode, vID, pID, node>;
}
if genV then {
if genF then {
@@ -85,16 +81,14 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genP =
def VIVCond: TernOpTyped<name, opCode, vID, vID, node>;
}
if genP then {
- def VPSCond32: TernOpTyped<name, opCode, iID, vpID32, node>;
- def VPVCond32: TernOpTyped<name, opCode, vID, vpID32, node>;
- def VPSCond64: TernOpTyped<name, opCode, iID, vpID64, node>;
- def VPVCond64: TernOpTyped<name, opCode, vID, vpID64, node>;
+ def VPSCond: TernOpTyped<name, opCode, iID, vpID, node>;
+ def VPVCond: TernOpTyped<name, opCode, vID, vpID, node>;
}
}
}
class UnOp<string name, bits<16> opCode, list<dag> pattern=[]>
- : Op<opCode, (outs ANYID:$dst), (ins TYPE:$type, ANYID:$src),
+ : Op<opCode, (outs ID:$dst), (ins TYPE:$type, ID:$src),
"$dst = "#name#" $type $src", pattern>;
class UnOpTyped<string name, bits<16> opCode, RegisterClass CID, SDNode node>
: Op<opCode, (outs iID:$dst), (ins TYPE:$src_ty, CID:$src),
@@ -222,21 +216,21 @@ return CurDAG->getTargetConstant(
N->getValueAP().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32);
}]>;
-def fimm_to_i32 : SDNodeXForm<imm, [{
+def fimm_to_i64 : SDNodeXForm<imm, [{
return CurDAG->getTargetConstant(
- N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i32);
+ N->getValueAPF().bitcastToAPInt().getZExtValue(), SDLoc(N), MVT::i64);
}]>;
-def gi_bitcast_fimm_to_i32 : GICustomOperandRenderer<"renderFImm32">,
- GISDNodeXFormEquiv<fimm_to_i32>;
+def gi_bitcast_fimm_to_i64 : GICustomOperandRenderer<"renderFImm64">,
+ GISDNodeXFormEquiv<fimm_to_i64>;
def gi_bitcast_imm_to_i32 : GICustomOperandRenderer<"renderImm32">,
GISDNodeXFormEquiv<imm_to_i32>;
-def PseudoConstI: IntImmLeaf<i32, [{ return Imm.getBitWidth() <= 32; }], imm_to_i32>;
-def PseudoConstF: FPImmLeaf<f32, [{ return true; }], fimm_to_i32>;
-def ConstPseudoTrue: IntImmLeaf<i32, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
-def ConstPseudoFalse: IntImmLeaf<i32, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
+def PseudoConstI: IntImmLeaf<i64, [{ return Imm.getBitWidth() <= 32; }], imm_to_i32>;
+def PseudoConstF: FPImmLeaf<f64, [{ return true; }], fimm_to_i64>;
+def ConstPseudoTrue: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 1; }]>;
+def ConstPseudoFalse: IntImmLeaf<i64, [{ return Imm.getBitWidth() == 1 && Imm.getZExtValue() == 0; }]>;
def ConstPseudoNull: IntImmLeaf<i64, [{ return Imm.isZero(); }]>;
multiclass IntFPImm<bits<16> opCode, string name> {
@@ -634,7 +628,7 @@ let isTerminator=1 in {
let isReturn = 1, hasDelaySlot=0, isBarrier = 0, isTerminator=1, isNotDuplicable = 1 in {
def OpKill: SimpleOp<"OpKill", 252>;
def OpReturn: SimpleOp<"OpReturn", 253>;
- def OpReturnValue: Op<254, (outs), (ins ANYID:$ret), "OpReturnValue $ret">;
+ def OpReturnValue: Op<254, (outs), (ins ID:$ret), "OpReturnValue $ret">;
def OpUnreachable: SimpleOp<"OpUnreachable", 255>;
}
def OpLifetimeStart: Op<256, (outs), (ins ID:$ptr, i32imm:$sz), "OpLifetimeStart $ptr, $sz">;
@@ -862,9 +856,9 @@ def OpGroupLogicalXorKHR: Op<6408, (outs ID:$res), (ins TYPE:$type, ID:$scope, i
"$res = OpGroupLogicalXorKHR $type $scope $groupOp $value">;
// Inline Assembly Instructions
-def OpAsmTargetINTEL: Op<5609, (outs ID:$res), (ins StringImm:$str), "$res = OpAsmTargetINTEL $str">;
+def OpAsmTargetINTEL: Op<5609, (outs ID:$res), (ins StringImm:$str, variable_ops), "$res = OpAsmTargetINTEL $str">;
def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$target,
- StringImm:$asm, StringImm:$constraints),
+ StringImm:$asm, StringImm:$constraints, variable_ops),
"$res = OpAsmINTEL $type $asm_type $target $asm">;
def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops),
"$res = OpAsmCallINTEL $type $asm">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c55235a04a607f..ee8bbbdf203615 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -183,7 +183,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
- void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
+ void renderFImm64(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
bool selectConst(Register ResVReg, const SPIRVType *ResType, const APInt &Imm,
@@ -295,14 +295,8 @@ void SPIRVInstructionSelector::resetVRegsType(MachineFunction &MF) {
HasVRegsReset = &MF;
MachineRegisterInfo &MRI = MF.getRegInfo();
- for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
- Register Reg = Register::index2VirtReg(I);
- LLT Ty = MRI.getType(Reg);
- if (Ty.isScalar())
- MRI.setType(Reg, LLT::scalar(32));
- else if (Ty.isVector() && !Ty.isPointer())
- MRI.setType(Reg, LLT::scalar(32));
- }
+ for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I)
+ MRI.setType(Register::index2VirtReg(I), LLT::scalar(64));
for (const auto &MBB : MF) {
for (const auto &MI : MBB) {
if (MI.getOpcode() != SPIRV::ASSIGN_TYPE)
@@ -341,9 +335,13 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
Register SrcReg = I.getOperand(1).getReg();
auto *Def = MRI->getVRegDef(SrcReg);
if (isTypeFoldingSupported(Def->getOpcode())) {
- if (MRI->getType(DstReg).isPointer())
- MRI->setType(DstReg, LLT::scalar(32));
bool Res = selectImpl(I, *CoverageInfo);
+ LLVM_DEBUG({
+ if (!Res && Def->getOpcode() != TargetOpcode::G_CONSTANT) {
+ dbgs() << "Unexpected pattern in ASSIGN_TYPE.\nInstruction: ";
+ I.print(dbgs());
+ }
+ });
assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
if (Res)
return Res;
@@ -353,8 +351,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
I.removeFromParent();
return true;
} else if (I.getNumDefs() == 1) {
- // Make all vregs 32 bits (for SPIR-V IDs).
- MRI->setType(I.getOperand(0).getReg(), LLT::scalar(32));
+ // Make all vregs 64 bits (for SPIR-V IDs).
+ MRI->setType(I.getOperand(0).getReg(), LLT::scalar(64));
}
return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
}
@@ -371,9 +369,9 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
SPIRVType *ResType = HasDefs ? GR.getSPIRVTypeForVReg(ResVReg) : nullptr;
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
if (spvSelect(ResVReg, ResType, I)) {
- if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs).
+ if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs).
for (unsigned i = 0; i < I.getNumDefs(); ++i)
- MRI->setType(I.getOperand(i).getReg(), LLT::scalar(32));
+ MRI->setType(I.getOperand(i).getReg(), LLT::scalar(64));
I.removeFromParent();
return true;
}
@@ -899,7 +897,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
true, GlobalValue::InternalLinkage,
Constant::getNullValue(LLVMArrTy));
- Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
GR.add(GV, GR.CurMF, VarReg);
buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
@@ -911,7 +909,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
ValTy, I, TII, SPIRV::StorageClass::UniformConstant);
- SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
selectUnOpWithSrc(SrcReg, SourceTy, I, VarReg, SPIRV::OpBitcast);
}
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemorySized))
@@ -1617,7 +1615,7 @@ bool SPIRVInstructionSelector::selectICmp(Register ResVReg,
return selectCmp(ResVReg, ResType, CmpOpc, I);
}
-void SPIRVInstructionSelector::renderFImm32(MachineInstrBuilder &MIB,
+void SPIRVInstructionSelector::renderFImm64(MachineInstrBuilder &MIB,
const MachineInstr &I,
int OpIdx) const {
assert(I.getOpcode() == TargetOpcode::G_FCONSTANT && OpIdx == -1 &&
@@ -1637,14 +1635,14 @@ void SPIRVInstructionSelector::renderImm32(MachineInstrBuilder &MIB,
Register
SPIRVInstructionSelector::buildI32Constant(uint32_t Val, MachineInstr &I,
const SPIRVType *ResType) const {
- Type *LLVMTy = IntegerType::get(GR.CurMF->getFunction().getContext(), 32);
+ Type *LLVMTy = IntegerType::get(GR.CurMF->getFunction().getContext(), 32); // lev
const SPIRVType *SpvI32Ty =
ResType ? ResType : GR.getOrCreateSPIRVIntegerType(32, I, TII);
// Find a constant in DT or build a new one.
auto ConstInt = ConstantInt::get(LLVMTy, Val);
Register NewReg = GR.find(ConstInt, GR.CurMF);
if (!NewReg.isValid()) {
- NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ NewReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
GR.add(ConstInt, GR.CurMF, NewReg);
MachineInstr *MI;
MachineBasicBlock &BB = *I.getParent();
@@ -1844,7 +1842,7 @@ bool SPIRVInstructionSelector::selectConst(Register ResVReg,
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType));
// <=32-bit integers should be caught by the sdag pattern.
- assert(Imm.getBitWidth() > 32);
+ assert(Imm.getBitWidth() > 32); // lev
addNumImm(Imm, MIB);
return MIB.constrainAllUses(TII, TRI, RBI);
}
@@ -1994,7 +1992,7 @@ bool SPIRVInstructionSelector::wrapIntoSpecConstantOp(
GR.add(OpDefine, MF, WrapReg);
CompositeArgs.push_back(WrapReg);
// Decorate the wrapper register and generate a new instruction
- MRI->setType(WrapReg, LLT::pointer(0, 32));
+ MRI->setType(WrapReg, LLT::pointer(0, 64));
GR.assignSPIRVTypeToVReg(OpType, WrapReg, *MF);
MachineBasicBlock &BB = *I.getParent();
Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp))
@@ -2301,7 +2299,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
// registers without a definition. We will resolve it later, during
// module analysis stage.
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
- Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
MachineInstrBuilder MB =
BuildMI(BB, I, I.getDebugLoc(),
@@ -2412,7 +2410,7 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
// 93 ThreadId reads the thread ID
MachineIRBuilder MIRBuilder(I);
- const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder);
+ const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder); // lev
const SPIRVType *Vec3Ty =
GR.getOrCreateSPIRVVectorType(U32Type, 3, MIRBuilder);
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType(
@@ -2421,7 +2419,7 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
// Create new register for GlobalInvocationID builtin variable.
Register NewRegister =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
- MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 32));
+ MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64));
GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
// Build GlobalInvocationID global variable with the necessary decorations.
@@ -2434,7 +2432,7 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
// Create new register for loading value.
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register LoadedRegister = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
- MIRBuilder.getMRI()->setType(LoadedRegister, LLT::pointer(0, 32));
+ MIRBuilder.getMRI()->setType(LoadedRegister, LLT::pointer(0, 64));
GR.assignSPIRVTypeToVReg(Vec3Ty, LoadedRegister, MIRBuilder.getMF());
// Load v3uint value from the global variable.
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index 44685be3d68ad4..7516d61a67724d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -54,11 +54,9 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
} // namespace llvm
static bool isMetaInstrGET(unsigned Opcode) {
- return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_ID64 ||
- Opcode == SPIRV::GET_fID || Opcode == SPIRV::GET_fID64 ||
- Opcode == SPIRV::GET_pID32 || Opcode == SPIRV::GET_pID64 ||
- Opcode == SPIRV::GET_vID || Opcode == SPIRV::GET_vfID ||
- Opcode == SPIRV::GET_vpID32 || Opcode == SPIRV::GET_vpID64;
+ return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
+ Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
+ Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID;
}
static bool mayBeInserted(unsigned Opcode) {
@@ -138,10 +136,8 @@ static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
continue;
// Restore usual instructions pattern for the newly inserted
// instruction
- MRI.setRegClass(ResVReg, MRI.getType(ResVReg).isVector()
- ? &SPIRV::iIDRegClass
- : &SPIRV::ANYIDRegClass);
- MRI.setType(ResVReg, LLT::scalar(32));
+ MRI.setRegClass(ResVReg, &SPIRV::iIDRegClass);
+ MRI.setType(ResVReg, LLT::scalar(64));
insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
processInstr(I, MIB, MRI, GR);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 6838f4bf9410f0..bcfcf0b21c1606 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -323,8 +323,7 @@ static const TargetRegisterClass *getRegClass(SPIRVType *SpvType,
case SPIRV::OpTypeFloat:
return &SPIRV::fIDRegClass;
case SPIRV::OpTypePointer:
- return GR.getPointerSize() == 64 ? &SPIRV::pID64RegClass
- : &SPIRV::pID32RegClass;
+ return &SPIRV::pIDRegClass;
case SPIRV::OpTypeVector: {
SPIRVType *ElemType =
GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
@@ -332,8 +331,7 @@ static const TargetRegisterClass *getRegClass(SPIRVType *SpvType,
if (ElemOpcode == SPIRV::OpTypeFloat)
return &SPIRV::vfIDRegClass;
if (ElemOpcode == SPIRV::OpTypePointer)
- return GR.getPointerSize() == 64 ? &SPIRV::vpID64RegClass
- : &SPIRV::vpID32RegClass;
+ return &SPIRV::vpIDRegClass;
return &SPIRV::vIDRegClass;
}
}
@@ -356,10 +354,7 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
bool IsVec = SrcLLT.isVector();
if (IsVec)
NewT = LLT::fixed_vector(2, NewT);
- if (PtrSz == 64)
- GetIdOp = IsVec ? SPIRV::GET_vpID64 : SPIRV::GET_pID64;
- else
- GetIdOp = IsVec ? SPIRV::GET_vpID32 : SPIRV::GET_pID32;
+ GetIdOp = IsVec ? SPIRV::GET_vpID : SPIRV::GET_pID;
} else if (SrcLLT.isVector()) {
NewT = LLT::scalar(GR.getScalarOrVectorBitWidth(SpvType));
NewT = LLT::fixed_vector(2, NewT);
@@ -501,6 +496,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
if (isSpvIntrinsic(UseMI, Intrinsic::spv_assign_type) ||
isSpvIntrinsic(UseMI, Intrinsic::spv_assign_name))
continue;
+ if (UseMI.getOpcode() == SPIRV::ASSIGN_TYPE)
+ NeedAssignType = false;
}
Type *Ty = nullptr;
if (MIOp == TargetOpcode::G_CONSTANT) {
@@ -611,8 +608,8 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
continue;
}
- if (MRI.getType(DstReg).isPointer())
- MRI.setType(DstReg, LLT::pointer(0, GR->getPointerSize()));
+// if (MRI.getType(DstReg).isPointer())
+// MRI.setType(DstReg, LLT::pointer(0, GR->getPointerSize()));
}
}
}
@@ -626,7 +623,7 @@ insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
for (unsigned i = 0, Sz = ToProcess.size(); i + 1 < Sz; i += 2) {
MachineInstr *I1 = ToProcess[i], *I2 = ToProcess[i + 1];
assert(isSpvIntrinsic(*I1, Intrinsic::spv_inline_asm) && I2->isInlineAsm());
- MIRBuilder.setInsertPt(*I1->getParent(), *I1);
+ MIRBuilder.setInsertPt(*I2->getParent(), *I2);
if (!AsmTargetReg.isValid()) {
// define vendor specific assembly target or dialect
@@ -706,10 +703,10 @@ insertInlineAsmProcess(MachineFunction &MF, SPIRVGlobalRegistry *GR,
unsigned IntrIdx = 2;
for (unsigned Idx : Ops) {
++IntrIdx;
- const MachineOperand &MO = I2->getOperand(Idx);
- if (MO.isReg())
- AsmCall.addUse(MO.getReg());
- else
+ //const MachineOperand &MO = I2->getOperand(Idx);
+ //if (MO.isReg())
+ // AsmCall.addUse(MO.getReg());
+ //else
AsmCall.addUse(I1->getOperand(IntrIdx).getReg());
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
index 936ad8e684b3e2..1ef42b79f1a8ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
@@ -12,7 +12,6 @@
let Namespace = "SPIRV" in {
// Pointer types for patterns with the GlobalISelEmitter
- def p32 : PtrValueType <i32, 0>;
def p64 : PtrValueType <i64, 0>;
class VTPtrVec<int nelem, PtrValueType ptr>
@@ -20,50 +19,35 @@ let Namespace = "SPIRV" in {
int isPointer = true;
}
- def v2p32 : VTPtrVec<2, p32>;
def v2p64 : VTPtrVec<2, p64>;
// Class for type registers
def TYPE0 : Register<"TYPE0">;
- def TYPE : RegisterClass<"SPIRV", [i32], 32, (add TYPE0)>;
+ def TYPE : RegisterClass<"SPIRV", [i64], 64, (add TYPE0)>;
// Class for non-type registers
def ID0 : Register<"ID0">;
- def ID640 : Register<"ID640">;
def fID0 : Register<"fID0">;
- def fID640 : Register<"fID640">;
- def pID320 : Register<"pID320">;
- def pID640 : Register<"pID640">;
+ def pID0 : Register<"pID0">;
def vID0 : Register<"vID0">;
def vfID0 : Register<"vfID0">;
- def vpID320 : Register<"vpID320">;
- def vpID640 : Register<"vpID640">;
-
- def iID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>;
- def iID64 : RegisterClass<"SPIRV", [i64], 32, (add ID640)>;
- def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>;
- def fID64 : RegisterClass<"SPIRV", [f64], 32, (add fID640)>;
- def pID32 : RegisterClass<"SPIRV", [p32], 32, (add pID320)>;
- def pID64 : RegisterClass<"SPIRV", [p64], 32, (add pID640)>;
- def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>;
- def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>;
- def vpID32 : RegisterClass<"SPIRV", [v2p32], 32, (add vpID320)>;
- def vpID64 : RegisterClass<"SPIRV", [v2p64], 32, (add vpID640)>;
-
+ def vpID0 : Register<"vpID0">;
+
+ def iID : RegisterClass<"SPIRV", [i64], 64, (add ID0)>;
+ def fID : RegisterClass<"SPIRV", [f64], 64, (add fID0)>;
+ def pID : RegisterClass<"SPIRV", [p64], 64, (add pID0)>;
+ def vID : RegisterClass<"SPIRV", [v2i64], 64, (add vID0)>;
+ def vfID : RegisterClass<"SPIRV", [v2f64], 64, (add vfID0)>;
+ def vpID : RegisterClass<"SPIRV", [v2p64], 64, (add vpID0)>;
+
def ID : RegisterClass<
"SPIRV",
- [i32, i64, f32, f64, p32, p64, v2i32, v2f32, v2p32, v2p64],
- 32,
- (add iID, iID64, fID, fID64, pID32, pID64, vID, vfID, vpID32, vpID64)>;
-
- def ANYID : RegisterClass<
- "SPIRV",
- [i32, i64, f32, f64, p32, p64, v2i32, v2f32, v2p32, v2p64],
- 32,
- (add ID0, ID640, fID0, fID640, pID320, pID640, vID0, vfID0, vpID320, vpID640)>;
+ [i64, f64, p64, v2i64, v2f64, v2p64],
+ 64,
+ (add iID, fID, pID, vID, vfID, vpID)>;
// A few instructions like OpName can take ids from both type and non-type
// instructions, so we need a super-class to allow for both to count as valid
// arguments for these instructions.
- def ANY : RegisterClass<"SPIRV", [i32], 32, (add TYPE, ID)>;
+ def ANY : RegisterClass<"SPIRV", [i64], 64, (add TYPE, ID)>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp b/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp
index 322e051a87db1a..003e0ee2f240a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp
@@ -232,7 +232,7 @@ void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
// %10 = OpCompositeInsert %v2uint %uint_5 %8 0
// %11 = OpVectorShuffle %v2uint %10 %8 0 0
// %call = OpExtInst %v2uint %1 s_min %14 %11
- auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
+ auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0); // lev
PoisonValue *PVal = PoisonValue::get(Arg0Ty);
Instruction *Inst =
InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);
More information about the llvm-commits
mailing list