[llvm] Add vector reduction instructions (PR #82786)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 23 08:27:39 PST 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/82786
This PR is to add vector reduction instructions according to https://llvm.org/docs/GlobalISel/GenericOpcode.html#vector-reduction-operations and widen in such a way a range of successful supported conversions, covering new cases of vector reduction instructions which IRTranslator is unable to resolve.
By legalizing vector reduction instructions we introduce a new instruction patterns that should be addressed, including patterns that are delegated to pre-legalize step. To address this problem, a new pass is added that is to bring newly generated instructions after legalization to an aspect required by instruction selection.
>From df1312ce79e3b44bb76f628fb55a05ca522c7032 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 23 Feb 2024 08:20:24 -0800
Subject: [PATCH] add vector reduction instructions
---
llvm/lib/Target/SPIRV/CMakeLists.txt | 1 +
llvm/lib/Target/SPIRV/SPIRV.h | 2 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 49 +++++-
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 25 +++
llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp | 146 ++++++++++++++++++
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 98 ++++++------
llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 1 +
7 files changed, 269 insertions(+), 53 deletions(-)
create mode 100644 llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index d1ada45d17a5bc..afc26dda4c68bd 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -29,6 +29,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVMetadata.cpp
SPIRVModuleAnalysis.cpp
SPIRVPreLegalizer.cpp
+ SPIRVPostLegalizer.cpp
SPIRVPrepareFunctions.cpp
SPIRVRegisterBankInfo.cpp
SPIRVRegisterInfo.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 9460b0808cae89..6979107349d968 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -23,6 +23,7 @@ ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
FunctionPass *createSPIRVRegularizerPass();
FunctionPass *createSPIRVPreLegalizerPass();
+FunctionPass *createSPIRVPostLegalizerPass();
FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -32,6 +33,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
void initializeSPIRVModuleAnalysisPass(PassRegistry &);
void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
void initializeSPIRVPreLegalizerPass(PassRegistry &);
+void initializeSPIRVPostLegalizerPass(PassRegistry &);
void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 7258d3b4d88ed3..6987d54e2b176d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -183,6 +183,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectUnmergeValues(MachineInstr &I) const;
+
Register buildI32Constant(uint32_t Val, MachineInstr &I,
const SPIRVType *ResType = nullptr) const;
@@ -235,7 +237,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more.
auto *Def = MRI->getVRegDef(I.getOperand(1).getReg());
if (isTypeFoldingSupported(Def->getOpcode())) {
- auto Res = selectImpl(I, *CoverageInfo);
+ bool Res = selectImpl(I, *CoverageInfo);
assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
if (Res)
return Res;
@@ -263,7 +265,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
if (spvSelect(ResVReg, ResType, I)) {
if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs).
- MRI->setType(ResVReg, LLT::scalar(32));
+ for (unsigned i = 0; i < I.getNumDefs(); ++i)
+ MRI->setType(I.getOperand(i).getReg(), LLT::scalar(32));
I.removeFromParent();
return true;
}
@@ -273,9 +276,9 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
- assert(!isTypeFoldingSupported(I.getOpcode()) ||
- I.getOpcode() == TargetOpcode::G_CONSTANT);
const unsigned Opcode = I.getOpcode();
+ if (isTypeFoldingSupported(Opcode) && Opcode != TargetOpcode::G_CONSTANT)
+ return selectImpl(I, *CoverageInfo);
switch (Opcode) {
case TargetOpcode::G_CONSTANT:
return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(),
@@ -504,6 +507,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FENCE:
return selectFence(I);
+ case TargetOpcode::G_UNMERGE_VALUES:
+ return selectUnmergeValues(I);
+
default:
return false;
}
@@ -733,6 +739,41 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
return Result;
}
+bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
+ unsigned ArgI = I.getNumOperands() - 1;
+ Register SrcReg =
+ I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
+ SPIRVType *DefType =
+ SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
+ if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+ report_fatal_error(
+ "cannot select G_UNMERGE_VALUES with a non-vector argument");
+
+ SPIRVType *ScalarType =
+ GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ MachineBasicBlock &BB = *I.getParent();
+ bool Res = false;
+ for (unsigned i = 0; i < I.getNumDefs(); ++i) {
+ Register ResVReg = I.getOperand(i).getReg();
+ SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
+ if (!ResType) {
+ // There was no "assign type" actions, let's fix this now
+ ResType = ScalarType;
+ MRI->setRegClass(ResVReg, &SPIRV::IDRegClass);
+ MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
+ GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
+ }
+ auto MIB =
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(SrcReg)
+ .addImm(static_cast<int64_t>(i));
+ Res |= MIB.constrainAllUses(TII, TRI, RBI);
+ }
+ return Res;
+}
+
bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm());
uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 4f2e7a240fc2cc..b6fd4fd2d8b800 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -113,6 +113,13 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
+ auto allVectors = {
+ v2s1, v2s8, v2s16, v2s32, v2s64,
+ v3s1, v3s8, v3s16, v3s32, v3s64,
+ v4s1, v4s8, v4s16, v4s32, v4s64,
+ v8s1, v8s8, v8s16, v8s32, v8s64,
+ v16s1, v16s8, v16s16, v16s32, v16s64};
+
auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -146,6 +153,24 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
// TODO: add proper rules for vectors legalization.
getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
+ // Vector Reduction Operations
+ getActionDefinitionsBuilder(
+ {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
+ G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
+ G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
+ G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
+ .legalFor(allVectors)
+ .scalarize(1)
+ .lower();
+
+ getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
+ .scalarize(2)
+ .lower();
+
+ // Merge/Unmerge
+ // TODO: add proper legalization rules.
+ getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
+
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
.legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
new file mode 100644
index 00000000000000..186dc3441327f0
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -0,0 +1,146 @@
+//===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
+//
+// which may appear after the legalizer pass
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The pass partially apply pre-legalization logic to new instructions inserted
+// as a result of legalization:
+// - assigns SPIR-V types to registers for new instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/Target/TargetIntrinsicInfo.h"
+
+#define DEBUG_TYPE "spirv-postlegalizer"
+
+using namespace llvm;
+
+namespace {
+class SPIRVPostLegalizer : public MachineFunctionPass {
+public:
+ static char ID;
+ SPIRVPostLegalizer() : MachineFunctionPass(ID) {
+ initializeSPIRVPostLegalizerPass(*PassRegistry::getPassRegistry());
+ }
+ bool runOnMachineFunction(MachineFunction &MF) override;
+};
+} // namespace
+
+// Defined in SPIRVLegalizerInfo.cpp.
+extern bool isTypeFoldingSupported(unsigned Opcode);
+
+namespace llvm {
+// Defined in SPIRVPreLegalizer.cpp.
+extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI);
+extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
+} // namespace llvm
+
+static bool isMetaInstrGET(unsigned Opcode) {
+ return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
+ Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
+ Opcode == SPIRV::GET_vfID;
+}
+
+static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+ MachineIRBuilder MIB) {
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+
+ for (MachineBasicBlock &MBB : MF) {
+ for (MachineInstr &I : MBB) {
+ const unsigned Opcode = I.getOpcode();
+ if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
+ unsigned ArgI = I.getNumOperands() - 1;
+ Register SrcReg = I.getOperand(ArgI).isReg()
+ ? I.getOperand(ArgI).getReg()
+ : Register(0);
+ SPIRVType *DefType =
+ SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
+ if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
+ report_fatal_error(
+ "cannot select G_UNMERGE_VALUES with a non-vector argument");
+ SPIRVType *ScalarType =
+ GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
+ for (unsigned i = 0; i < I.getNumDefs(); ++i) {
+ Register ResVReg = I.getOperand(i).getReg();
+ SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
+ if (!ResType) {
+ // There was no "assign type" actions, let's fix this now
+ ResType = ScalarType;
+ MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
+ MRI.setType(ResVReg,
+ LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
+ GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
+ }
+ }
+ } else if (isTypeFoldingSupported(Opcode) && I.getNumDefs() == 1 &&
+ I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
+ Register ResVReg = I.getOperand(0).getReg();
+ SPIRVType *ResVType = GR->getSPIRVTypeForVReg(ResVReg);
+ // Check if the register defined by the instruction is newly generated
+ // or already processed
+ if (!ResVType) {
+ // Set type of the defined register
+ ResVType = GR->getSPIRVTypeForVReg(I.getOperand(1).getReg());
+ // Check if we have type defined for operands of the new instruction
+ if (!ResVType)
+ continue;
+ // Set type & class
+ MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
+ MRI.setType(ResVReg,
+ LLT::scalar(GR->getScalarOrVectorBitWidth(ResVType)));
+ GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
+ }
+ // Check if the instruction newly generated or already processed
+ MachineInstr *NextMI = I.getNextNode();
+ if (NextMI && isMetaInstrGET(NextMI->getOpcode()))
+ continue;
+ // Restore usual instructions pattern for the newly inserted instruction
+ MRI.setRegClass(ResVReg, MRI.getType(ResVReg).isVector()
+ ? &SPIRV::IDRegClass
+ : &SPIRV::ANYIDRegClass);
+ MRI.setType(ResVReg, LLT::scalar(32));
+ insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
+ processInstr(I, MIB, MRI, GR);
+ }
+ }
+ }
+}
+
+bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
+ // Initialize the type registry.
+ const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
+ SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+ GR->setCurrentFunc(MF);
+ MachineIRBuilder MIB(MF);
+
+ processNewInstrs(MF, GR, MIB);
+
+ return true;
+}
+
+INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
+ false)
+
+char SPIRVPostLegalizer::ID = 0;
+
+FunctionPass *llvm::createSPIRVPostLegalizerPass() {
+ return new SPIRVPostLegalizer();
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 144216896eb68c..1e92e5ce264f04 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -212,6 +212,34 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
return SpirvTy;
}
+static std::pair<Register, unsigned>
+createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
+ const SPIRVGlobalRegistry &GR) {
+ LLT NewT = LLT::scalar(32);
+ SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
+ assert(SpvType && "VReg is expected to have SPIRV type");
+ bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
+ bool IsVectorFloat =
+ SpvType->getOpcode() == SPIRV::OpTypeVector &&
+ GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
+ SPIRV::OpTypeFloat;
+ IsFloat |= IsVectorFloat;
+ auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
+ auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
+ if (MRI.getType(ValReg).isPointer()) {
+ NewT = LLT::pointer(0, 32);
+ GetIdOp = SPIRV::GET_pID;
+ DstClass = &SPIRV::pIDRegClass;
+ } else if (MRI.getType(ValReg).isVector()) {
+ NewT = LLT::fixed_vector(2, NewT);
+ GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
+ DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
+ }
+ Register IdReg = MRI.createGenericVirtualRegister(NewT);
+ MRI.setRegClass(IdReg, DstClass);
+ return {IdReg, GetIdOp};
+}
+
// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
// a dst of the definition, assign SPIRVType to both registers. If SpirvTy is
// provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -249,6 +277,27 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
Def->getOperand(0).setReg(NewReg);
return NewReg;
}
+
+void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
+ MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
+ unsigned Opc = MI.getOpcode();
+ assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
+ MachineInstr &AssignTypeInst =
+ *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
+ auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
+ AssignTypeInst.getOperand(1).setReg(NewReg);
+ MI.getOperand(0).setReg(NewReg);
+ MIB.setInsertPt(*MI.getParent(),
+ (MI.getNextNode() ? MI.getNextNode()->getIterator()
+ : MI.getParent()->end()));
+ for (auto &Op : MI.operands()) {
+ if (!Op.isReg() || Op.isDef())
+ continue;
+ auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
+ MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
+ Op.setReg(IdOpInfo.first);
+ }
+}
} // namespace llvm
static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
@@ -345,55 +394,6 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MI->eraseFromParent();
}
-static std::pair<Register, unsigned>
-createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
- const SPIRVGlobalRegistry &GR) {
- LLT NewT = LLT::scalar(32);
- SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
- assert(SpvType && "VReg is expected to have SPIRV type");
- bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
- bool IsVectorFloat =
- SpvType->getOpcode() == SPIRV::OpTypeVector &&
- GR.getSPIRVTypeForVReg(SpvType->getOperand(1).getReg())->getOpcode() ==
- SPIRV::OpTypeFloat;
- IsFloat |= IsVectorFloat;
- auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
- auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
- if (MRI.getType(ValReg).isPointer()) {
- NewT = LLT::pointer(0, 32);
- GetIdOp = SPIRV::GET_pID;
- DstClass = &SPIRV::pIDRegClass;
- } else if (MRI.getType(ValReg).isVector()) {
- NewT = LLT::fixed_vector(2, NewT);
- GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
- DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
- }
- Register IdReg = MRI.createGenericVirtualRegister(NewT);
- MRI.setRegClass(IdReg, DstClass);
- return {IdReg, GetIdOp};
-}
-
-static void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
- MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
- unsigned Opc = MI.getOpcode();
- assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
- MachineInstr &AssignTypeInst =
- *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
- auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
- AssignTypeInst.getOperand(1).setReg(NewReg);
- MI.getOperand(0).setReg(NewReg);
- MIB.setInsertPt(*MI.getParent(),
- (MI.getNextNode() ? MI.getNextNode()->getIterator()
- : MI.getParent()->end()));
- for (auto &Op : MI.operands()) {
- if (!Op.isReg() || Op.isDef())
- continue;
- auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
- MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
- Op.setReg(IdOpInfo.first);
- }
-}
-
// Defined in SPIRVLegalizerInfo.cpp.
extern bool isTypeFoldingSupported(unsigned Opcode);
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index e1b7bdd3140dbe..fbf64f2b1dfb13 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -189,6 +189,7 @@ void SPIRVPassConfig::addPreLegalizeMachineIR() {
// Use the default legalizer.
bool SPIRVPassConfig::addLegalizeMachineIR() {
addPass(new Legalizer());
+ addPass(createSPIRVPostLegalizerPass());
return false;
}
More information about the llvm-commits
mailing list