[llvm] 05093e2 - [Spirv][HLSL] Add OpAll lowering and float vec support (#87952)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 10 13:27:48 PDT 2024
Author: Farzon Lotfi
Date: 2024-04-10T16:27:44-04:00
New Revision: 05093e243859a371f96ffa1c320a4b51579c3da7
URL: https://github.com/llvm/llvm-project/commit/05093e243859a371f96ffa1c320a4b51579c3da7
DIFF: https://github.com/llvm/llvm-project/commit/05093e243859a371f96ffa1c320a4b51579c3da7.diff
LOG: [Spirv][HLSL] Add OpAll lowering and float vec support (#87952)
The main point of this change was to add support for HLSL's all
intrinsic.
In the process of doing that I found a few issues around creating an
`OpConstantComposite` via `buildZerosVal`.
First the current code didn't support floats so the process of adding
`buildZerosValF` meant I needed a
float version of `getOrCreateIntConstVector`. After doing so I renamed
both versions to `getOrCreateConstVector`. That meant I needed to create
a float type version of `getOrCreateIntCompositeOrNull`. Luckily the
type information was low for this function so was able to split it out
into a helpwe and rename `getOrCreateIntCompositeOrNull` to
`getOrCreateCompositeOrNull` With the exception of type handling
differences of the code and Null vs 0 Constant Op codes these functions
should be identical.
To handle scalar floats I could not use `buildConstantFP` like this PR
did:
https://github.com/llvm/llvm-project/commit/0a2aaab5aba46#diff-733a189c5a8c3211f3a04fd6e719952a3fa231eadd8a7f11e6ecf1e584d57411R1603
because that would create too many superfluous registers (that causes
problems in the validator), I had to create a float version of
`getOrCreateConstInt` which I called `getOrCreateConstFP`.
similar problems with doing it like this:
https://github.com/llvm/llvm-project/blob/main/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp#L1540.
`buildZerosValF` also has a use of a function `getZeroFP`. This is
because half, float, and double scalar values of 0 would collide in
`SPIRVDuplicatesTracker<Constant> CT` if you use `APFloat(0.0f)`.
`getORCreateConstFP` needed its own version of `getOrCreateConstIntReg`
which I called `getOrCreateConstFloatReg` The one difference in this
function is `getOrCreateConstFloatReg` returns a bit width so we don't
have to call `getScalarOrVectorBitWidth` twice ie when it is used again
in `getOrCreateConstFP` for `OpConstantF` `addNumImm`.
`getOrCreateConstFloatReg` needed an `assignFloatTypeToVReg` helper
which called a `getOrCreateSPIRVFloatType` helper. There was no
equivalent IntegerType::get for floats so I handled this with a switch
statement on bit widths to get the right LLVM float type.
Finally, there is the use of `bool ZeroAsNull = STI.isOpenCLEnv();` This
is partly a cosmetic change. When Zeros are treated as nulls, we don't
create `OpConstantComposite` vectors which is something we do in the
DXCs SPIRV backend. The DXC SPIRV backend also does not use
`OpConstantNull`. Finally, I needed a means to test the behavior of the
OpConstantNull and `OpConstantComposite` changes and this was one way I
could do that via the same tests.
Added:
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..70197e948c6582 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -20,7 +20,12 @@
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Type.h"
#include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/Casting.h"
+#include <cassert>
using namespace llvm;
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
@@ -35,6 +40,15 @@ SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
return SpirvType;
}
+SPIRVType *
+SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
+ MachineInstr &I,
+ const SPIRVInstrInfo &TII) {
+ SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
+ assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
+ return SpirvType;
+}
+
SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
const SPIRVInstrInfo &TII) {
@@ -151,6 +165,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
+ // TODO: handle cases where the type is not 32bit wide
+ // TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
@@ -164,9 +180,83 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
return std::make_tuple(Res, CI, NewInstr);
}
+std::tuple<Register, ConstantFP *, bool, unsigned>
+SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
+ MachineIRBuilder *MIRBuilder,
+ MachineInstr *I,
+ const SPIRVInstrInfo *TII) {
+ const Type *LLVMFloatTy;
+ LLVMContext &Ctx = CurMF->getFunction().getContext();
+ unsigned BitWidth = 32;
+ if (SpvType)
+ LLVMFloatTy = getTypeForSPIRVType(SpvType);
+ else {
+ LLVMFloatTy = Type::getFloatTy(Ctx);
+ if (MIRBuilder)
+ SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
+ }
+ bool NewInstr = false;
+ // Find a constant in DT or build a new one.
+ auto *const CI = ConstantFP::get(Ctx, Val);
+ Register Res = DT.find(CI, CurMF);
+ if (!Res.isValid()) {
+ if (SpvType)
+ BitWidth = getScalarOrVectorBitWidth(SpvType);
+ // TODO: handle cases where the type is not 32bit wide
+ // TODO: https://github.com/llvm/llvm-project/issues/88129
+ LLT LLTy = LLT::scalar(32);
+ Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+ CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
+ if (MIRBuilder)
+ assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
+ else
+ assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
+ DT.add(CI, CurMF, Res);
+ NewInstr = true;
+ }
+ return std::make_tuple(Res, CI, NewInstr, BitWidth);
+}
+
+Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ bool ZeroAsNull) {
+ assert(SpvType);
+ ConstantFP *CI;
+ Register Res;
+ bool New;
+ unsigned BitWidth;
+ std::tie(Res, CI, New, BitWidth) =
+ getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
+ // If we have found Res register which is defined by the passed G_CONSTANT
+ // machine instruction, a new constant instruction should be created.
+ if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
+ return Res;
+ MachineInstrBuilder MIB;
+ MachineBasicBlock &BB = *I.getParent();
+ // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
+ if (Val.isPosZero() && ZeroAsNull) {
+ MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ } else {
+ MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ addNumImm(
+ APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
+ MIB);
+ }
+ const auto &ST = CurMF->getSubtarget();
+ constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
+ *ST.getRegisterInfo(), *ST.getRegBankInfo());
+ return Res;
+}
+
Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
- const SPIRVInstrInfo &TII) {
+ const SPIRVInstrInfo &TII,
+ bool ZeroAsNull) {
assert(SpvType);
ConstantInt *CI;
Register Res;
@@ -179,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
return Res;
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
- if (Val) {
+ if (Val || !ZeroAsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
@@ -270,21 +360,46 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
return Res;
}
-Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
- uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
+Register SPIRVGlobalRegistry::getOrCreateBaseRegister(Constant *Val,
+ MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ unsigned BitWidth) {
+ SPIRVType *Type = SpvType;
+ if (SpvType->getOpcode() == SPIRV::OpTypeVector ||
+ SpvType->getOpcode() == SPIRV::OpTypeArray) {
+ auto EleTypeReg = SpvType->getOperand(1).getReg();
+ Type = getSPIRVTypeForVReg(EleTypeReg);
+ }
+ if (Type->getOpcode() == SPIRV::OpTypeFloat) {
+ SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
+ return getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(), I,
+ SpvBaseType, TII);
+ }
+ assert(Type->getOpcode() == SPIRV::OpTypeInt);
+ SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
+ return getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(), I,
+ SpvBaseType, TII);
+}
+
+Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
+ Constant *Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
- unsigned ElemCnt) {
+ unsigned ElemCnt, bool ZeroAsNull) {
// Find a constant vector in DT or build a new one.
Register Res = DT.find(CA, CurMF);
+ // If no values are attached, the composite is null constant.
+ bool IsNull = Val->isNullValue() && ZeroAsNull;
if (!Res.isValid()) {
- SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
// SpvScalConst should be created before SpvVecConst to avoid undefined ID
// error on validation.
// TODO: can moved below once sorting of types/consts/defs is implemented.
Register SpvScalConst;
- if (Val)
- SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
- // TODO: maybe use bitwidth of base type.
+ if (!IsNull)
+ SpvScalConst = getOrCreateBaseRegister(Val, I, SpvType, TII, BitWidth);
+
+ // TODO: handle cases where the type is not 32bit wide
+ // TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
@@ -293,7 +408,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
DT.add(CA, CurMF, SpvVecConst);
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
- if (Val) {
+ if (!IsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
@@ -313,20 +428,42 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
return Res;
}
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType,
- const SPIRVInstrInfo &TII) {
+Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
+ MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ bool ZeroAsNull) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isVectorTy());
const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
Type *LLVMBaseTy = LLVMVecTy->getElementType();
- const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
- auto ConstVec =
- ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
+ assert(LLVMBaseTy->isIntegerTy());
+ auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
+ auto *ConstVec =
+ ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
unsigned BW = getScalarOrVectorBitWidth(SpvType);
- return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
- SpvType->getOperand(2).getImm());
+ return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
+ SpvType->getOperand(2).getImm(),
+ ZeroAsNull);
+}
+
+Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
+ MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ bool ZeroAsNull) {
+ const Type *LLVMTy = getTypeForSPIRVType(SpvType);
+ assert(LLVMTy->isVectorTy());
+ const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
+ Type *LLVMBaseTy = LLVMVecTy->getElementType();
+ assert(LLVMBaseTy->isFloatingPointTy());
+ auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
+ auto *ConstVec =
+ ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
+ unsigned BW = getScalarOrVectorBitWidth(SpvType);
+ return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
+ SpvType->getOperand(2).getImm(),
+ ZeroAsNull);
}
Register
@@ -337,13 +474,13 @@ SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
assert(LLVMTy->isArrayTy());
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
Type *LLVMBaseTy = LLVMArrTy->getElementType();
- const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
- auto ConstArr =
+ auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
+ auto *ConstArr =
ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
- return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
- LLVMArrTy->getNumElements());
+ return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
+ LLVMArrTy->getNumElements());
}
Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
@@ -1093,14 +1230,16 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
return SpirvType;
}
-SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
- unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
- Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
+ MachineInstr &I,
+ const SPIRVInstrInfo &TII,
+ unsigned SPIRVOPcode,
+ Type *LLVMTy) {
Register Reg = DT.find(LLVMTy, CurMF);
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
- auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
+ auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
@@ -1108,6 +1247,31 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
return finishCreatingSPIRVType(LLVMTy, MIB);
}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
+ unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+ Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
+ return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
+}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
+ unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+ LLVMContext &Ctx = CurMF->getFunction().getContext();
+ Type *LLVMTy;
+ switch (BitWidth) {
+ case 16:
+ LLVMTy = Type::getHalfTy(Ctx);
+ break;
+ case 32:
+ LLVMTy = Type::getFloatTy(Ctx);
+ break;
+ case 64:
+ LLVMTy = Type::getDoubleTy(Ctx);
+ break;
+ default:
+ llvm_unreachable("Bit width is of unexpected size.");
+ }
+ return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
+}
+
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 37f575e884ef48..2e3e69456ac260 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -20,6 +20,7 @@
#include "SPIRVDuplicatesTracker.h"
#include "SPIRVInstrInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/IR/Constant.h"
namespace llvm {
using SPIRVType = const MachineInstr;
@@ -234,6 +235,8 @@ class SPIRVGlobalRegistry {
bool EmitIR = true);
SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
MachineInstr &I, const SPIRVInstrInfo &TII);
+ SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
+ MachineInstr &I, const SPIRVInstrInfo &TII);
SPIRVType *assignVectTypeToVReg(SPIRVType *BaseType, unsigned NumElements,
Register VReg, MachineInstr &I,
const SPIRVInstrInfo &TII);
@@ -372,12 +375,20 @@ class SPIRVGlobalRegistry {
std::tuple<Register, ConstantInt *, bool> getOrCreateConstIntReg(
uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
+ std::tuple<Register, ConstantFP *, bool, unsigned> getOrCreateConstFloatReg(
+ APFloat Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
+ MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType);
- Register getOrCreateIntCompositeOrNull(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType,
- const SPIRVInstrInfo &TII,
- Constant *CA, unsigned BitWidth,
- unsigned ElemCnt);
+ Register getOrCreateBaseRegister(Constant *Val, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ unsigned BitWidth);
+ Register getOrCreateCompositeOrNull(Constant *Val, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII, Constant *CA,
+ unsigned BitWidth, unsigned ElemCnt,
+ bool ZeroAsNull = true);
+
Register getOrCreateIntCompositeOrNull(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR,
@@ -388,12 +399,20 @@ class SPIRVGlobalRegistry {
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr, bool EmitIR = true);
Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType, const SPIRVInstrInfo &TII);
+ SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+ bool ZeroAsNull = true);
+ Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII,
+ bool ZeroAsNull = true);
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr);
- Register getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
- SPIRVType *SpvType,
- const SPIRVInstrInfo &TII);
+
+ Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
+ SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+ bool ZeroAsNull = true);
+ Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
+ SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+ bool ZeroAsNull = true);
Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);
@@ -423,6 +442,11 @@ class SPIRVGlobalRegistry {
MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I,
const SPIRVInstrInfo &TII);
+ SPIRVType *getOrCreateSPIRVType(unsigned BitWidth, MachineInstr &I,
+ const SPIRVInstrInfo &TII,
+ unsigned SPIRVOPcode, Type *LLVMTy);
+ SPIRVType *getOrCreateSPIRVFloatType(unsigned BitWidth, MachineInstr &I,
+ const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVBoolType(MachineInstr &I,
const SPIRVInstrInfo &TII);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 45a70da7f86902..c1c0fc4b7dd489 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -28,6 +28,7 @@
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineModuleInfoImpls.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/Support/Debug.h"
@@ -144,6 +145,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectAddrSpaceCast(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectAll(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectBitreverse(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -229,6 +233,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
const SPIRVType *ResType = nullptr) const;
Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
+ Register buildZerosValF(const SPIRVType *ResType, MachineInstr &I) const;
Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1155,6 +1160,65 @@ static unsigned getBoolCmpOpcode(unsigned PredNum) {
}
}
+bool SPIRVInstructionSelector::selectAll(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+ Register InputRegister = I.getOperand(2).getReg();
+ SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+ if (!InputType)
+ report_fatal_error("Input Type could not be determined.");
+
+ bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
+ bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
+ if (IsBoolTy && !IsVectorTy) {
+ assert(ResVReg == I.getOperand(0).getReg());
+ return BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(TargetOpcode::COPY))
+ .addDef(ResVReg)
+ .addUse(InputRegister)
+ .constrainAllUses(TII, TRI, RBI);
+ }
+
+ bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
+ unsigned SpirvNotEqualId =
+ IsFloatTy ? SPIRV::OpFOrdNotEqual : SPIRV::OpINotEqual;
+ SPIRVType *SpvBoolScalarTy = GR.getOrCreateSPIRVBoolType(I, TII);
+ SPIRVType *SpvBoolTy = SpvBoolScalarTy;
+ Register NotEqualReg = ResVReg;
+
+ if (IsVectorTy) {
+ NotEqualReg = IsBoolTy ? InputRegister
+ : MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ const unsigned NumElts = InputType->getOperand(2).getImm();
+ SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
+ }
+
+ if (!IsBoolTy) {
+ Register ConstZeroReg =
+ IsFloatTy ? buildZerosValF(InputType, I) : buildZerosVal(InputType, I);
+
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SpirvNotEqualId))
+ .addDef(NotEqualReg)
+ .addUse(GR.getSPIRVTypeID(SpvBoolTy))
+ .addUse(InputRegister)
+ .addUse(ConstZeroReg)
+ .constrainAllUses(TII, TRI, RBI);
+ }
+
+ if (!IsVectorTy)
+ return true;
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpAll))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(SpvBoolScalarTy))
+ .addUse(NotEqualReg)
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1391,9 +1455,35 @@ bool SPIRVInstructionSelector::selectFCmp(Register ResVReg,
Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
MachineInstr &I) const {
+ // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
+ bool ZeroAsNull = STI.isOpenCLEnv();
+ if (ResType->getOpcode() == SPIRV::OpTypeVector)
+ return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+}
+
+static APFloat getZeroFP(const Type *LLVMFloatTy) {
+ if (!LLVMFloatTy)
+ return APFloat::getZero(APFloat::IEEEsingle());
+ switch (LLVMFloatTy->getScalarType()->getTypeID()) {
+ case Type::HalfTyID:
+ return APFloat::getZero(APFloat::IEEEhalf());
+ default:
+ case Type::FloatTyID:
+ return APFloat::getZero(APFloat::IEEEsingle());
+ case Type::DoubleTyID:
+ return APFloat::getZero(APFloat::IEEEdouble());
+ }
+}
+
+Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
+ MachineInstr &I) const {
+ // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
+ bool ZeroAsNull = STI.isOpenCLEnv();
+ APFloat VZero = getZeroFP(GR.getTypeForSPIRVType(ResType));
if (ResType->getOpcode() == SPIRV::OpTypeVector)
- return GR.getOrCreateConsIntVector(0, I, ResType, TII);
- return GR.getOrCreateConstInt(0, I, ResType, TII);
+ return GR.getOrCreateConstVector(VZero, I, ResType, TII, ZeroAsNull);
+ return GR.getOrCreateConstFP(VZero, I, ResType, TII, ZeroAsNull);
}
Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
@@ -1403,7 +1493,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
APInt One =
AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
if (ResType->getOpcode() == SPIRV::OpTypeVector)
- return GR.getOrCreateConsIntVector(One.getZExtValue(), I, ResType, TII);
+ return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
}
@@ -1785,6 +1875,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
break;
case Intrinsic::spv_thread_id:
return selectSpvThreadId(ResVReg, ResType, I);
+ case Intrinsic::spv_all:
+ return selectAll(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
case Intrinsic::spv_lifetime_end: {
unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
new file mode 100644
index 00000000000000..ef8d463cbd815e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
@@ -0,0 +1,187 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-HLSL
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-OCL
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; Make sure spirv operation function calls for all are generated.
+
+; CHECK-HLSL-DAG: OpMemoryModel Logical GLSL450
+; CHECK-OCL-DAG: OpMemoryModel Physical32 OpenCL
+; CHECK-DAG: OpName %[[#all_bool_arg:]] "a"
+; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec4_bool:]] = OpTypeVector %[[#bool]] 4
+; CHECK-DAG: %[[#vec4_16:]] = OpTypeVector %[[#int_16]] 4
+; CHECK-DAG: %[[#vec4_32:]] = OpTypeVector %[[#int_32]] 4
+; CHECK-DAG: %[[#vec4_64:]] = OpTypeVector %[[#int_64]] 4
+; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec4_float_64:]] = OpTypeVector %[[#float_64]] 4
+
+; CHECK-HLSL-DAG: %[[#const_i64_0:]] = OpConstant %[[#int_64]] 0
+; CHECK-HLSL-DAG: %[[#const_i32_0:]] = OpConstant %[[#int_32]] 0
+; CHECK-HLSL-DAG: %[[#const_i16_0:]] = OpConstant %[[#int_16]] 0
+; CHECK-HLSL-DAG: %[[#const_f64_0:]] = OpConstant %[[#float_64]] 0
+; CHECK-HLSL-DAG: %[[#const_f32_0:]] = OpConstant %[[#float_32:]] 0
+; CHECK-HLSL-DAG: %[[#const_f16_0:]] = OpConstant %[[#float_16:]] 0
+; CHECK-HLSL-DAG: %[[#vec4_const_zeros_i16:]] = OpConstantComposite %[[#vec4_16:]] %[[#const_i16_0:]] %[[#const_i16_0:]] %[[#const_i16_0:]] %[[#const_i16_0:]]
+; CHECK-HLSL-DAG: %[[#vec4_const_zeros_i32:]] = OpConstantComposite %[[#vec4_32:]] %[[#const_i32_0:]] %[[#const_i32_0:]] %[[#const_i32_0:]] %[[#const_i32_0:]]
+; CHECK-HLSL-DAG: %[[#vec4_const_zeros_i64:]] = OpConstantComposite %[[#vec4_64:]] %[[#const_i64_0:]] %[[#const_i64_0:]] %[[#const_i64_0:]] %[[#const_i64_0:]]
+; CHECK-HLSL-DAG: %[[#vec4_const_zeros_f16:]] = OpConstantComposite %[[#vec4_float_16:]] %[[#const_f16_0:]] %[[#const_f16_0:]] %[[#const_f16_0:]] %[[#const_f16_0:]]
+; CHECK-HLSL-DAG: %[[#vec4_const_zeros_f32:]] = OpConstantComposite %[[#vec4_float_32:]] %[[#const_f32_0:]] %[[#const_f32_0:]] %[[#const_f32_0:]] %[[#const_f32_0:]]
+; CHECK-HLSL-DAG: %[[#vec4_const_zeros_f64:]] = OpConstantComposite %[[#vec4_float_64:]] %[[#const_f64_0:]] %[[#const_f64_0:]] %[[#const_f64_0:]] %[[#const_f64_0:]]
+
+; CHECK-OCL-DAG: %[[#const_i64_0:]] = OpConstantNull %[[#int_64]]
+; CHECK-OCL-DAG: %[[#const_i32_0:]] = OpConstantNull %[[#int_32]]
+; CHECK-OCL-DAG: %[[#const_i16_0:]] = OpConstantNull %[[#int_16]]
+; CHECK-OCL-DAG: %[[#const_f64_0:]] = OpConstantNull %[[#float_64]]
+; CHECK-OCL-DAG: %[[#const_f32_0:]] = OpConstantNull %[[#float_32:]]
+; CHECK-OCL-DAG: %[[#const_f16_0:]] = OpConstantNull %[[#float_16:]]
+; CHECK-OCL-DAG: %[[#vec4_const_zeros_i16:]] = OpConstantNull %[[#vec4_16:]]
+; CHECK-OCL-DAG: %[[#vec4_const_zeros_i32:]] = OpConstantNull %[[#vec4_32:]]
+; CHECK-OCL-DAG: %[[#vec4_const_zeros_i64:]] = OpConstantNull %[[#vec4_64:]]
+; CHECK-OCL-DAG: %[[#vec4_const_zeros_f16:]] = OpConstantNull %[[#vec4_float_16:]]
+; CHECK-OCL-DAG: %[[#vec4_const_zeros_f32:]] = OpConstantNull %[[#vec4_float_32:]]
+; CHECK-OCL-DAG: %[[#vec4_const_zeros_f64:]] = OpConstantNull %[[#vec4_float_64:]]
+
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpINotEqual %[[#bool:]] %[[#arg0:]] %[[#const_i64_0:]]
+ %hlsl.all = call i1 @llvm.spv.all.i64(i64 %p0)
+ ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpINotEqual %[[#bool:]] %[[#arg0:]] %[[#const_i32_0:]]
+ %hlsl.all = call i1 @llvm.spv.all.i32(i32 %p0)
+ ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpINotEqual %[[#bool:]] %[[#arg0:]] %[[#const_i16_0:]]
+ %hlsl.all = call i1 @llvm.spv.all.i16(i16 %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpFOrdNotEqual %[[#bool:]] %[[#arg0:]] %[[#const_f64_0:]]
+ %hlsl.all = call i1 @llvm.spv.all.f64(double %p0)
+ ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpFOrdNotEqual %[[#bool:]] %[[#arg0:]] %[[#const_f32_0:]]
+ %hlsl.all = call i1 @llvm.spv.all.f32(float %p0)
+ ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpFOrdNotEqual %[[#bool:]] %[[#arg0:]] %[[#const_f16_0:]]
+ %hlsl.all = call i1 @llvm.spv.all.f16(half %p0)
+ ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#]] = OpAll %[[#vec4_bool:]] %[[#arg0:]]
+ %hlsl.all = call i1 @llvm.spv.all.v4i1(<4 x i1> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_short4(<4 x i16> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#shortVecNotEq:]] = OpINotEqual %[[#vec4_bool:]] %[[#arg0:]] %[[#vec4_const_zeros_i16:]]
+ ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#shortVecNotEq:]]
+ %hlsl.all = call i1 @llvm.spv.all.v4i16(<4 x i16> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_int4(<4 x i32> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#i32VecNotEq:]] = OpINotEqual %[[#vec4_bool:]] %[[#arg0:]] %[[#vec4_const_zeros_i32:]]
+ ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#i32VecNotEq:]]
+ %hlsl.all = call i1 @llvm.spv.all.v4i32(<4 x i32> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_int64_t4(<4 x i64> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#i64VecNotEq:]] = OpINotEqual %[[#vec4_bool:]] %[[#arg0:]] %[[#vec4_const_zeros_i64:]]
+ ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#i64VecNotEq]]
+ %hlsl.all = call i1 @llvm.spv.all.v4i64(<4 x i64> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_half4(<4 x half> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#f16VecNotEq:]] = OpFOrdNotEqual %[[#vec4_bool:]] %[[#arg0:]] %[[#vec4_const_zeros_f16:]]
+ ; CHECK: %[[#]] = OpAll %[[#bool]] %[[#f16VecNotEq:]]
+ %hlsl.all = call i1 @llvm.spv.all.v4f16(<4 x half> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_float4(<4 x float> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#f32VecNotEq:]] = OpFOrdNotEqual %[[#vec4_bool:]] %[[#arg0:]] %[[#vec4_const_zeros_f32:]]
+ ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#f32VecNotEq:]]
+ %hlsl.all = call i1 @llvm.spv.all.v4f32(<4 x float> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_double4(<4 x double> noundef %p0) {
+entry:
+ ; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#]]
+ ; CHECK: %[[#f64VecNotEq:]] = OpFOrdNotEqual %[[#vec4_bool:]] %[[#arg0:]] %[[#vec4_const_zeros_f64:]]
+ ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#f64VecNotEq:]]
+ %hlsl.all = call i1 @llvm.spv.all.v4f64(<4 x double> %p0)
+ ret i1 %hlsl.all
+}
+
+define noundef i1 @all_bool(i1 noundef %a) {
+entry:
+ ; CHECK: %[[#all_bool_arg:]] = OpFunctionParameter %[[#bool:]]
+ ; CHECK: OpReturnValue %[[#all_bool_arg:]]
+ %hlsl.all = call i1 @llvm.spv.all.i1(i1 %a)
+ ret i1 %hlsl.all
+}
+
+declare i1 @llvm.spv.all.v4f16(<4 x half>)
+declare i1 @llvm.spv.all.v4f32(<4 x float>)
+declare i1 @llvm.spv.all.v4f64(<4 x double>)
+declare i1 @llvm.spv.all.v4i1(<4 x i1>)
+declare i1 @llvm.spv.all.v4i16(<4 x i16>)
+declare i1 @llvm.spv.all.v4i32(<4 x i32>)
+declare i1 @llvm.spv.all.v4i64(<4 x i64>)
+declare i1 @llvm.spv.all.i1(i1)
+declare i1 @llvm.spv.all.i16(i16)
+declare i1 @llvm.spv.all.i32(i32)
+declare i1 @llvm.spv.all.i64(i64)
+declare i1 @llvm.spv.all.f16(half)
+declare i1 @llvm.spv.all.f32(float)
+declare i1 @llvm.spv.all.f64(double)
More information about the llvm-commits
mailing list