[llvm] [SPIR-V] Fix illegal OpConstantComposite instruction with non-const constituents in SPIR-V Backend (PR #86352)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 22 15:44:13 PDT 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/86352
This PR fixes illegal use of OpConstantComposite with non-constant constituents. The test attached to the PR is able now to satisfy `spirv-val` check. Before the fix SPIR-V Backend produced for the attached test case a pattern like
```
%a = OpVariable %_ptr_CrossWorkgroup_uint CrossWorkgroup %uint_123
%11 = OpConstantComposite %_struct_6 %a %a
```
so that `spirv-val` complained with
```
error: line 25: OpConstantComposite Constituent <id> '10[%a]' is not a constant or undef.
%11 = OpConstantComposite %_struct_6 %a %a
```
>From 78feeefd9726fc1ec0bd16b59973be1d13cd4b08 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 22 Mar 2024 15:35:06 -0700
Subject: [PATCH] fix OpConstantComposite for non-const arguments; add
SpecConstantOp into Duplicate Tracker
---
.../Target/SPIRV/SPIRVDuplicatesTracker.cpp | 1 +
.../lib/Target/SPIRV/SPIRVDuplicatesTracker.h | 9 ++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 1 +
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 8 ++
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 91 +++++++++++++++----
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 1 +
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 1 +
.../SPIRV/pointers/struct-opaque-pointers.ll | 2 +-
8 files changed, 97 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
index d82fb2df4539a3..7c32bb1968ef58 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
@@ -39,6 +39,7 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
prebuildReg2Entry(GT, Reg2Entry);
prebuildReg2Entry(FT, Reg2Entry);
prebuildReg2Entry(AT, Reg2Entry);
+ prebuildReg2Entry(MT, Reg2Entry);
prebuildReg2Entry(ST, Reg2Entry);
for (auto &Op2E : Reg2Entry) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 96cc621791e972..2ec3fb35ca0451 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -262,6 +262,7 @@ class SPIRVGeneralDuplicatesTracker {
SPIRVDuplicatesTracker<GlobalVariable> GT;
SPIRVDuplicatesTracker<Function> FT;
SPIRVDuplicatesTracker<Argument> AT;
+ SPIRVDuplicatesTracker<MachineInstr> MT;
SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
// NOTE: using MOs instead of regs to get rid of MF dependency to be able
@@ -306,6 +307,10 @@ class SPIRVGeneralDuplicatesTracker {
AT.add(Arg, MF, R);
}
+ void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
+ MT.add(MI, MF, R);
+ }
+
void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
Register R) {
ST.add(TD, MF, R);
@@ -337,6 +342,10 @@ class SPIRVGeneralDuplicatesTracker {
return AT.find(const_cast<Argument *>(Arg), MF);
}
+ Register find(const MachineInstr *MI, const MachineFunction *MF) {
+ return MT.find(const_cast<MachineInstr *>(MI), MF);
+ }
+
Register find(const SPIRV::SpecialTypeDescriptor &TD,
const MachineFunction *MF) {
return ST.find(TD, MF);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..0e0ca07fc7f86b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -123,6 +123,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {
auto EleOpc = ElemType->getOpcode();
+ (void)EleOpc;
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
EleOpc == SPIRV::OpTypeBool) &&
"Invalid vector element type");
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index da480b22a525f2..ed0f90ff89ce6e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,6 +94,14 @@ class SPIRVGlobalRegistry {
DT.add(Arg, MF, R);
}
+ void add(const MachineInstr *MI, MachineFunction *MF, Register R) {
+ DT.add(MI, MF, R);
+ }
+
+ Register find(const MachineInstr *MI, MachineFunction *MF) {
+ return DT.find(MI, MF);
+ }
+
Register find(const Constant *C, MachineFunction *MF) {
return DT.find(C, MF);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5bb8f6084f9671..f905ee6de17ba6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -231,6 +231,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
MachineInstr &I) const;
+
+ bool wrapIntoSpecConstantOp(MachineInstr &I,
+ SmallVector<Register> &CompositeArgs) const;
};
} // end anonymous namespace
@@ -1245,6 +1248,24 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
return N;
}
+// Return true if the type represents a constant register
+static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) {
+ if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
+ OpDef->getOperand(1).isReg()) {
+ if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
+ OpDef = RefDef;
+ }
+ return OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
+ OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
+}
+
+// Return true if the virtual register represents a constant
+static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
+ if (SPIRVType *OpDef = MRI->getVRegDef(OpReg))
+ return isConstReg(MRI, OpDef);
+ return false;
+}
+
bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1262,16 +1283,7 @@ bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
// check if we may construct a constant vector
Register OpReg = I.getOperand(OpIdx).getReg();
- bool IsConst = false;
- if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
- if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
- OpDef->getOperand(1).isReg()) {
- if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
- OpDef = RefDef;
- }
- IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
- OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
- }
+ bool IsConst = isConstReg(MRI, OpReg);
if (!IsConst && N < 2)
report_fatal_error(
@@ -1624,6 +1636,49 @@ bool SPIRVInstructionSelector::selectGEP(Register ResVReg,
return Res.constrainAllUses(TII, TRI, RBI);
}
+// Maybe wrap a value into OpSpecConstantOp
+bool SPIRVInstructionSelector::wrapIntoSpecConstantOp(
+ MachineInstr &I, SmallVector<Register> &CompositeArgs) const {
+ bool Result = true;
+ unsigned Lim = I.getNumExplicitOperands();
+ for (unsigned i = I.getNumExplicitDefs() + 1; i < Lim; ++i) {
+ Register OpReg = I.getOperand(i).getReg();
+ SPIRVType *OpDefine = MRI->getVRegDef(OpReg);
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+ if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) ||
+ OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
+ // The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed
+ // by selectAddrSpaceCast()
+ CompositeArgs.push_back(OpReg);
+ continue;
+ }
+ MachineFunction *MF = I.getMF();
+ Register WrapReg = GR.find(OpDefine, MF);
+ if (WrapReg.isValid()) {
+ CompositeArgs.push_back(WrapReg);
+ continue;
+ }
+ // Create a new register for the wrapper
+ WrapReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ GR.add(OpDefine, MF, WrapReg);
+ CompositeArgs.push_back(WrapReg);
+ // Decorate the wrapper register and generate a new instruction
+ MRI->setType(WrapReg, LLT::pointer(0, 32));
+ GR.assignSPIRVTypeToVReg(OpType, WrapReg, *MF);
+ MachineBasicBlock &BB = *I.getParent();
+ Result = BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpSpecConstantOp))
+ .addDef(WrapReg)
+ .addUse(GR.getSPIRVTypeID(OpType))
+ .addImm(static_cast<uint32_t>(SPIRV::Opcode::Bitcast))
+ .addUse(OpReg)
+ .constrainAllUses(TII, TRI, RBI);
+ if (!Result)
+ break;
+ }
+ return Result;
+}
+
bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1662,17 +1717,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_const_composite: {
// If no values are attached, the composite is null constant.
bool IsNull = I.getNumExplicitDefs() + 1 == I.getNumExplicitOperands();
- unsigned Opcode =
- IsNull ? SPIRV::OpConstantNull : SPIRV::OpConstantComposite;
+ // Select a proper instruction.
+ unsigned Opcode = SPIRV::OpConstantNull;
+ SmallVector<Register> CompositeArgs;
+ if (!IsNull) {
+ Opcode = SPIRV::OpConstantComposite;
+ if (!wrapIntoSpecConstantOp(I, CompositeArgs))
+ return false;
+ }
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType));
// skip type MD node we already used when generated assign.type for this
if (!IsNull) {
- for (unsigned i = I.getNumExplicitDefs() + 1;
- i < I.getNumExplicitOperands(); ++i) {
- MIB.addUse(I.getOperand(i).getReg());
- }
+ for (Register OpReg : CompositeArgs)
+ MIB.addUse(OpReg);
}
return MIB.constrainAllUses(TII, TRI, RBI);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index d547f91ba4a565..1f0d8d8cd43a8f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -543,6 +543,7 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
Register Dst = ICMP->getOperand(0).getReg();
MachineOperand &PredOp = ICMP->getOperand(1);
const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
+ (void)CC;
assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 8dbbd9049844c8..ff102e318469f4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -1611,3 +1611,4 @@ multiclass OpcodeOperand<bits<32> value> {
// TODO: implement other mnemonics.
defm InBoundsPtrAccessChain : OpcodeOperand<70>;
defm PtrCastToGeneric : OpcodeOperand<121>;
+defm Bitcast : OpcodeOperand<124>;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
index d426fc4dfd4eec..ce3ab8895a5948 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
@@ -1,5 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
-; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0
; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]]
More information about the llvm-commits
mailing list