[llvm] [Spirv][HLSL] Add OpAll lowering and float vec support (PR #87952)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 8 07:16:58 PDT 2024


https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/87952

>From 7c5ec09a0ac0046d637e4ed2cba8f5e180a92739 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 5 Apr 2024 12:38:52 -0400
Subject: [PATCH 1/8] make changes to allow OpConstantComposite to work on
 floats and OpConstantNull to be configurable based on enviornment.

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 159 +++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  33 +++-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  16 +-
 3 files changed, 172 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9592f3e81b4026..a3f4eb8c9582a8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -20,7 +20,9 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/TypedPointerType.h"
+#include <cassert>
 
 using namespace llvm;
 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
@@ -164,9 +166,37 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
   return std::make_tuple(Res, CI, NewInstr);
 }
 
+Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
+                                                 SPIRVType *SpvType,
+                                                 const SPIRVInstrInfo &TII) {
+  assert(SpvType);
+  auto *MF = I.getMF();
+  assert(MF);
+  LLVMContext &Ctx = MF->getFunction().getContext();
+  // Find a constant in DT or build a new one.
+  auto *const ConstFP = ConstantFP::get(Ctx, Val);
+  Register Res = DT.find(ConstFP, MF);
+  if (!Res.isValid()) {
+    Res = MF->getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
+    MF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
+    assignSPIRVTypeToVReg(SpvType, Res, *MF);
+    DT.add(ConstFP, MF, Res);
+
+    MachineInstrBuilder MIB;
+    MachineBasicBlock &BB = *I.getParent();
+    MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
+              .addDef(Res)
+              .addUse(getSPIRVTypeID(SpvType));
+    addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
+  }
+
+  return Res;
+}
+
 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
                                                   SPIRVType *SpvType,
-                                                  const SPIRVInstrInfo &TII) {
+                                                  const SPIRVInstrInfo &TII,
+                                                  bool ZeroAsNull) {
   assert(SpvType);
   ConstantInt *CI;
   Register Res;
@@ -179,7 +209,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
     return Res;
   MachineInstrBuilder MIB;
   MachineBasicBlock &BB = *I.getParent();
-  if (Val) {
+  if (Val || !ZeroAsNull) {
     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
               .addDef(Res)
               .addUse(getSPIRVTypeID(SpvType));
@@ -269,21 +299,69 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
 
   return Res;
 }
+template <>
+Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
+    ConstantFP *Val, MachineInstr &I, SPIRVType *SpvType,
+    const SPIRVInstrInfo &TII, ConstantFP *CA, unsigned BitWidth,
+    unsigned ElemCnt, bool ZeroAsNull) {
+  // Find a constant vector in DT or build a new one.
+  Register Res = DT.find(CA, CurMF);
+  bool IsNull = Val->isNullValue() && ZeroAsNull;
+  if (!Res.isValid()) {
+    SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
+    // SpvScalConst should be created before SpvVecConst to avoid undefined ID
+    // error on validation.
+    // TODO: can moved below once sorting of types/consts/defs is implemented.
+    Register SpvScalConst;
+    if (!IsNull)
+      SpvScalConst = getOrCreateConstFP(Val->getValue(), I, SpvBaseType, TII);
+    // TODO: maybe use bitwidth of base type.
+    LLT LLTy = LLT::scalar(32);
+    Register SpvVecConst =
+        CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
+    assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
+    DT.add(CA, CurMF, SpvVecConst);
+    MachineInstrBuilder MIB;
+    MachineBasicBlock &BB = *I.getParent();
+    if (!IsNull) {
+      MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
+                .addDef(SpvVecConst)
+                .addUse(getSPIRVTypeID(SpvType));
+      for (unsigned i = 0; i < ElemCnt; ++i)
+        MIB.addUse(SpvScalConst);
+    } else {
+      MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
+                .addDef(SpvVecConst)
+                .addUse(getSPIRVTypeID(SpvType));
+    }
+    const auto &Subtarget = CurMF->getSubtarget();
+    constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
+                                     *Subtarget.getRegisterInfo(),
+                                     *Subtarget.getRegBankInfo());
+    return SpvVecConst;
+  }
+  return Res;
+}
 
-Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
-    uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
+template <>
+Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
+    Constant *Val, MachineInstr &I, SPIRVType *SpvType,
     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
-    unsigned ElemCnt) {
+    unsigned ElemCnt, bool ZeroAsNull) {
   // Find a constant vector in DT or build a new one.
   Register Res = DT.find(CA, CurMF);
+  // If no values are attached, the composite is null constant.
+  bool IsNull = Val->isNullValue() && ZeroAsNull;
   if (!Res.isValid()) {
     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
     // error on validation.
     // TODO: can moved below once sorting of types/consts/defs is implemented.
     Register SpvScalConst;
-    if (Val)
-      SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
+    if (!IsNull)
+      SpvScalConst = getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(),
+                                         I, SpvBaseType, TII);
     // TODO: maybe use bitwidth of base type.
     LLT LLTy = LLT::scalar(32);
     Register SpvVecConst =
@@ -293,7 +371,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     DT.add(CA, CurMF, SpvVecConst);
     MachineInstrBuilder MIB;
     MachineBasicBlock &BB = *I.getParent();
-    if (Val) {
+    if (!IsNull) {
       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
                 .addDef(SpvVecConst)
                 .addUse(getSPIRVTypeID(SpvType));
@@ -313,20 +391,42 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
   return Res;
 }
 
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
-                                              SPIRVType *SpvType,
-                                              const SPIRVInstrInfo &TII) {
+Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
+                                                     MachineInstr &I,
+                                                     SPIRVType *SpvType,
+                                                     const SPIRVInstrInfo &TII,
+                                                     bool ZeroAsNull) {
   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
   assert(LLVMTy->isVectorTy());
   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
   Type *LLVMBaseTy = LLVMVecTy->getElementType();
-  const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
-  auto ConstVec =
-      ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
+  assert(LLVMBaseTy->isIntegerTy());
+  auto *ConstVal = ConstantInt::get(LLVMBaseTy, Val);
+  auto *ConstVec =
+      ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
   unsigned BW = getScalarOrVectorBitWidth(SpvType);
-  return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
-                                       SpvType->getOperand(2).getImm());
+  return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
+                                    SpvType->getOperand(2).getImm(),
+                                    ZeroAsNull);
+}
+
+Register SPIRVGlobalRegistry::getOrCreateConstVector(double Val,
+                                                     MachineInstr &I,
+                                                     SPIRVType *SpvType,
+                                                     const SPIRVInstrInfo &TII,
+                                                     bool ZeroAsNull) {
+  const Type *LLVMTy = getTypeForSPIRVType(SpvType);
+  assert(LLVMTy->isVectorTy());
+  const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
+  Type *LLVMBaseTy = LLVMVecTy->getElementType();
+  assert(LLVMBaseTy->isFloatingPointTy());
+  auto *ConstVal = ConstantFP::get(LLVMBaseTy, Val);
+  auto *ConstVec =
+      ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
+  unsigned BW = getScalarOrVectorBitWidth(SpvType);
+  return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
+                                    SpvType->getOperand(2).getImm(),
+                                    ZeroAsNull);
 }
 
 Register
@@ -337,13 +437,13 @@ SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
   assert(LLVMTy->isArrayTy());
   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
   Type *LLVMBaseTy = LLVMArrTy->getElementType();
-  const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
-  auto ConstArr =
+  auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
+  auto *ConstArr =
       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
-  return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
-                                       LLVMArrTy->getNumElements());
+  return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
+                                    LLVMArrTy->getNumElements());
 }
 
 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
@@ -1093,14 +1193,16 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
   return SpirvType;
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
-    unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
+                                                     MachineInstr &I,
+                                                     const SPIRVInstrInfo &TII,
+                                                     unsigned SPIRVOPcode) {
   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
   Register Reg = DT.find(LLVMTy, CurMF);
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
   MachineBasicBlock &BB = *I.getParent();
-  auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
+  auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRVOPcode))
                  .addDef(createTypeVReg(CurMF->getRegInfo()))
                  .addImm(BitWidth)
                  .addImm(0);
@@ -1108,6 +1210,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
   return finishCreatingSPIRVType(LLVMTy, MIB);
 }
 
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
+    unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+  return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt);
+}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
+    unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
+  return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat);
+}
+
 SPIRVType *
 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
   return getOrCreateSPIRVType(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ac799374adce8c..b691cd917d9140 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -368,11 +368,14 @@ class SPIRVGlobalRegistry {
       uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
       MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
   SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType);
-  Register getOrCreateIntCompositeOrNull(uint64_t Val, MachineInstr &I,
-                                         SPIRVType *SpvType,
-                                         const SPIRVInstrInfo &TII,
-                                         Constant *CA, unsigned BitWidth,
-                                         unsigned ElemCnt);
+
+  template <typename T>
+  Register getOrCreateCompositeOrNull(T *Val, MachineInstr &I,
+                                      SPIRVType *SpvType,
+                                      const SPIRVInstrInfo &TII, T *CA,
+                                      unsigned BitWidth, unsigned ElemCnt,
+                                      bool ZeroAsNull = true);
+
   Register getOrCreateIntCompositeOrNull(uint64_t Val,
                                          MachineIRBuilder &MIRBuilder,
                                          SPIRVType *SpvType, bool EmitIR,
@@ -383,12 +386,19 @@ class SPIRVGlobalRegistry {
   Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
                             SPIRVType *SpvType = nullptr, bool EmitIR = true);
   Register getOrCreateConstInt(uint64_t Val, MachineInstr &I,
-                               SPIRVType *SpvType, const SPIRVInstrInfo &TII);
+                               SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+                               bool ZeroAsNull = true);
+  Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
+                              const SPIRVInstrInfo &TII);
   Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
                            SPIRVType *SpvType = nullptr);
-  Register getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
-                                    SPIRVType *SpvType,
-                                    const SPIRVInstrInfo &TII);
+
+  Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
+                                  SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+                                  bool ZeroAsNull = true);
+  Register getOrCreateConstVector(double Val, MachineInstr &I,
+                                  SPIRVType *SpvType, const SPIRVInstrInfo &TII,
+                                  bool ZeroAsNull = true);
   Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
                                    SPIRVType *SpvType,
                                    const SPIRVInstrInfo &TII);
@@ -418,6 +428,11 @@ class SPIRVGlobalRegistry {
                                          MachineIRBuilder &MIRBuilder);
   SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I,
                                          const SPIRVInstrInfo &TII);
+  SPIRVType *getOrCreateSPIRVType(unsigned BitWidth, MachineInstr &I,
+                                  const SPIRVInstrInfo &TII,
+                                  unsigned SPIRVOPcode);
+  SPIRVType *getOrCreateSPIRVFloatType(unsigned BitWidth, MachineInstr &I,
+                                       const SPIRVInstrInfo &TII);
   SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);
   SPIRVType *getOrCreateSPIRVBoolType(MachineInstr &I,
                                       const SPIRVInstrInfo &TII);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 49749b56345306..cedded1c3f884d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -229,6 +229,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
                             const SPIRVType *ResType = nullptr) const;
 
   Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
+  Register buildZerosValF(const SPIRVType *ResType, MachineInstr &I) const;
   Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
                         MachineInstr &I) const;
 
@@ -1391,9 +1392,18 @@ bool SPIRVInstructionSelector::selectFCmp(Register ResVReg,
 
 Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
                                                  MachineInstr &I) const {
+  // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
+  bool ZeroAsNull = STI.isOpenCLEnv();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
-    return GR.getOrCreateConsIntVector(0, I, ResType, TII);
-  return GR.getOrCreateConstInt(0, I, ResType, TII);
+    return GR.getOrCreateConstVector(0UL, I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
+}
+
+Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
+                                                  MachineInstr &I) const {
+  if (ResType->getOpcode() == SPIRV::OpTypeVector)
+    return GR.getOrCreateConstVector(0.0, I, ResType, TII);
+  return GR.getOrCreateConstFP(APFloat(0.0), I, ResType, TII);
 }
 
 Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
@@ -1403,7 +1413,7 @@ Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
   APInt One =
       AllOnes ? APInt::getAllOnes(BitWidth) : APInt::getOneBitSet(BitWidth, 0);
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
-    return GR.getOrCreateConsIntVector(One.getZExtValue(), I, ResType, TII);
+    return GR.getOrCreateConstVector(One.getZExtValue(), I, ResType, TII);
   return GR.getOrCreateConstInt(One.getZExtValue(), I, ResType, TII);
 }
 

>From cb62fee15c1e7a207bb194674fe31dde114f01e1 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 5 Apr 2024 20:32:30 -0400
Subject: [PATCH 2/8] add all lowering, one issue with validator remains

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 109 +++++++----
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  27 ++-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  61 +++++-
 .../test/CodeGen/SPIRV/hlsl-intrinsics/all.ll | 173 ++++++++++++++++++
 .../CodeGen/SPIRV/hlsl-intrinsics/all_next.ll |  17 ++
 5 files changed, 347 insertions(+), 40 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a3f4eb8c9582a8..6360ce8b004de5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -22,6 +22,7 @@
 #include "SPIRVUtils.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/Casting.h"
 #include <cassert>
 
 using namespace llvm;
@@ -37,6 +38,15 @@ SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
   return SpirvType;
 }
 
+SPIRVType *
+SPIRVGlobalRegistry::assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
+                                           MachineInstr &I,
+                                           const SPIRVInstrInfo &TII) {
+  SPIRVType *SpirvType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
+  assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
+  return SpirvType;
+}
+
 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
     const SPIRVInstrInfo &TII) {
@@ -166,28 +176,65 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
   return std::make_tuple(Res, CI, NewInstr);
 }
 
-Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
-                                                 SPIRVType *SpvType,
-                                                 const SPIRVInstrInfo &TII) {
-  assert(SpvType);
-  auto *MF = I.getMF();
-  assert(MF);
-  LLVMContext &Ctx = MF->getFunction().getContext();
+std::tuple<Register, ConstantFP *, bool>
+SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
+                                              MachineIRBuilder *MIRBuilder,
+                                              MachineInstr *I,
+                                              const SPIRVInstrInfo *TII) {
+  const Type *LLVMFloatTy;
+  LLVMContext &Ctx = CurMF->getFunction().getContext();
+  if (SpvType)
+    LLVMFloatTy = getTypeForSPIRVType(SpvType);
+  else {
+    LLVMFloatTy = Type::getFloatTy(Ctx);
+    if (MIRBuilder)
+      SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
+  }
+  bool NewInstr = false;
   // Find a constant in DT or build a new one.
-  auto *const ConstFP = ConstantFP::get(Ctx, Val);
-  Register Res = DT.find(ConstFP, MF);
+  auto *const CI = ConstantFP::get(Ctx, Val);
+  Register Res = DT.find(CI, CurMF);
   if (!Res.isValid()) {
-    Res = MF->getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
-    MF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
-    assignSPIRVTypeToVReg(SpvType, Res, *MF);
-    DT.add(ConstFP, MF, Res);
+    unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
+    LLT LLTy = LLT::scalar(32);
+    Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+    CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
+    if (MIRBuilder)
+      assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
+    else
+      assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
+    DT.add(CI, CurMF, Res);
+    NewInstr = true;
+  }
+  return std::make_tuple(Res, CI, NewInstr);
+}
 
-    MachineInstrBuilder MIB;
-    MachineBasicBlock &BB = *I.getParent();
+Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
+                                                 SPIRVType *SpvType,
+                                                 const SPIRVInstrInfo &TII,
+                                                 bool ZeroAsNull) {
+  assert(SpvType);
+  ConstantFP *CI;
+  Register Res;
+  bool New;
+  std::tie(Res, CI, New) =
+      getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
+  // If we have found Res register which is defined by the passed G_CONSTANT
+  // machine instruction, a new constant instruction should be created.
+  if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
+    return Res;
+  MachineInstrBuilder MIB;
+  MachineBasicBlock &BB = *I.getParent();
+  // In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
+  if (Val.isPosZero() && ZeroAsNull) {
+    MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
+              .addDef(Res)
+              .addUse(getSPIRVTypeID(SpvType));
+  } else {
     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
               .addDef(Res)
               .addUse(getSPIRVTypeID(SpvType));
-    addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
+    addNumImm(CI->getValueAPF().bitcastToAPInt(), MIB);
   }
 
   return Res;
@@ -299,10 +346,10 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
 
   return Res;
 }
-template <>
-Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
-    ConstantFP *Val, MachineInstr &I, SPIRVType *SpvType,
-    const SPIRVInstrInfo &TII, ConstantFP *CA, unsigned BitWidth,
+
+Register SPIRVGlobalRegistry::getOrCreateFloatCompositeOrNull(
+    Constant *Val, MachineInstr &I, SPIRVType *SpvType,
+    const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
     unsigned ElemCnt, bool ZeroAsNull) {
   // Find a constant vector in DT or build a new one.
   Register Res = DT.find(CA, CurMF);
@@ -314,7 +361,8 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
     // TODO: can moved below once sorting of types/consts/defs is implemented.
     Register SpvScalConst;
     if (!IsNull)
-      SpvScalConst = getOrCreateConstFP(Val->getValue(), I, SpvBaseType, TII);
+      SpvScalConst = getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(),
+                                        I, SpvBaseType, TII, ZeroAsNull);
     // TODO: maybe use bitwidth of base type.
     LLT LLTy = LLT::scalar(32);
     Register SpvVecConst =
@@ -344,8 +392,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
   return Res;
 }
 
-template <>
-Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
+Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
     unsigned ElemCnt, bool ZeroAsNull) {
@@ -405,9 +452,9 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
   auto *ConstVec =
       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
   unsigned BW = getScalarOrVectorBitWidth(SpvType);
-  return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
-                                    SpvType->getOperand(2).getImm(),
-                                    ZeroAsNull);
+  return getOrCreateIntCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
+                                       SpvType->getOperand(2).getImm(),
+                                       ZeroAsNull);
 }
 
 Register SPIRVGlobalRegistry::getOrCreateConstVector(double Val,
@@ -424,9 +471,9 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(double Val,
   auto *ConstVec =
       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstVal);
   unsigned BW = getScalarOrVectorBitWidth(SpvType);
-  return getOrCreateCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec, BW,
-                                    SpvType->getOperand(2).getImm(),
-                                    ZeroAsNull);
+  return getOrCreateFloatCompositeOrNull(ConstVal, I, SpvType, TII, ConstVec,
+                                         BW, SpvType->getOperand(2).getImm(),
+                                         ZeroAsNull);
 }
 
 Register
@@ -442,8 +489,8 @@ SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
-  return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
-                                    LLVMArrTy->getNumElements());
+  return getOrCreateIntCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
+                                       LLVMArrTy->getNumElements());
 }
 
 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index b691cd917d9140..4a9670b80ed42f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -20,6 +20,7 @@
 #include "SPIRVDuplicatesTracker.h"
 #include "SPIRVInstrInfo.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/IR/Constant.h"
 
 namespace llvm {
 using SPIRVType = const MachineInstr;
@@ -234,6 +235,8 @@ class SPIRVGlobalRegistry {
                               bool EmitIR = true);
   SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
                                  MachineInstr &I, const SPIRVInstrInfo &TII);
+  SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
+                                   MachineInstr &I, const SPIRVInstrInfo &TII);
   SPIRVType *assignVectTypeToVReg(SPIRVType *BaseType, unsigned NumElements,
                                   Register VReg, MachineInstr &I,
                                   const SPIRVInstrInfo &TII);
@@ -367,14 +370,24 @@ class SPIRVGlobalRegistry {
   std::tuple<Register, ConstantInt *, bool> getOrCreateConstIntReg(
       uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
       MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
+  std::tuple<Register, ConstantFP *, bool> getOrCreateConstFloatReg(
+      APFloat Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
+      MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
   SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType);
 
-  template <typename T>
-  Register getOrCreateCompositeOrNull(T *Val, MachineInstr &I,
-                                      SPIRVType *SpvType,
-                                      const SPIRVInstrInfo &TII, T *CA,
-                                      unsigned BitWidth, unsigned ElemCnt,
-                                      bool ZeroAsNull = true);
+  Register getOrCreateIntCompositeOrNull(Constant *Val, MachineInstr &I,
+                                         SPIRVType *SpvType,
+                                         const SPIRVInstrInfo &TII,
+                                         Constant *CA, unsigned BitWidth,
+                                         unsigned ElemCnt,
+                                         bool ZeroAsNull = true);
+
+  Register getOrCreateFloatCompositeOrNull(Constant *Val, MachineInstr &I,
+                                           SPIRVType *SpvType,
+                                           const SPIRVInstrInfo &TII,
+                                           Constant *CA, unsigned BitWidth,
+                                           unsigned ElemCnt,
+                                           bool ZeroAsNull = true);
 
   Register getOrCreateIntCompositeOrNull(uint64_t Val,
                                          MachineIRBuilder &MIRBuilder,
@@ -389,7 +402,7 @@ class SPIRVGlobalRegistry {
                                SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                bool ZeroAsNull = true);
   Register getOrCreateConstFP(APFloat Val, MachineInstr &I, SPIRVType *SpvType,
-                              const SPIRVInstrInfo &TII);
+                              const SPIRVInstrInfo &TII, bool ZeroAsNull);
   Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
                            SPIRVType *SpvType = nullptr);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index cedded1c3f884d..b7eea2bdba0bab 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -144,6 +144,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectAddrSpaceCast(Register ResVReg, const SPIRVType *ResType,
                            MachineInstr &I) const;
 
+  bool selectAll(Register ResVReg, const SPIRVType *ResType,
+                 MachineInstr &I) const;
+
   bool selectBitreverse(Register ResVReg, const SPIRVType *ResType,
                         MachineInstr &I) const;
 
@@ -1156,6 +1159,56 @@ static unsigned getBoolCmpOpcode(unsigned PredNum) {
   }
 }
 
+bool SPIRVInstructionSelector::selectAll(Register ResVReg,
+                                         const SPIRVType *ResType,
+                                         MachineInstr &I) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+  Register InputRegister = I.getOperand(2).getReg();
+  SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+  if (InputType->getOpcode() == SPIRV::OpTypeBool) {
+    Register LoadReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(InputType))
+        .addUse(InputRegister);
+
+    return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpStore))
+        .addUse(LoadReg)
+        .addUse(InputRegister);
+  }
+
+  bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
+  unsigned SpirvNotEqualId =
+      IsFloatTy ? SPIRV::OpFOrdNotEqual : SPIRV::OpINotEqual;
+
+  bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
+  Register ConstCompositeZeroReg =
+      IsFloatTy ? buildZerosValF(InputType, I) : buildZerosVal(InputType, I);
+  Register NotEqualReg =
+      IsVectorTy ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
+  BuildMI(BB, I, I.getDebugLoc(), TII.get(SpirvNotEqualId))
+      .addDef(NotEqualReg)
+      .addUse(GR.getSPIRVTypeID(InputType))
+      .addUse(I.getOperand(2).getReg())
+      .addUse(ConstCompositeZeroReg)
+      .constrainAllUses(TII, TRI, RBI);
+
+  if (!IsVectorTy)
+    return true;
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpAll))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(InputType))
+      .addUse(NotEqualReg)
+      .constrainAllUses(TII, TRI, RBI);
+
+  // bool IsSigned = GR.isScalarOrVectorSigned(InputType);
+  // return selectSelect(ResVReg, ResType, OpAllReg, I, IsSigned);
+}
+
 bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
@@ -1401,9 +1454,11 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
 
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
                                                   MachineInstr &I) const {
+  // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
+  bool ZeroAsNull = STI.isOpenCLEnv();
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
-    return GR.getOrCreateConstVector(0.0, I, ResType, TII);
-  return GR.getOrCreateConstFP(APFloat(0.0), I, ResType, TII);
+    return GR.getOrCreateConstVector(0.0, I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstFP(APFloat(0.0), I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
@@ -1795,6 +1850,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     break;
   case Intrinsic::spv_thread_id:
     return selectSpvThreadId(ResVReg, ResType, I);
+  case Intrinsic::spv_all:
+    return selectAll(ResVReg, ResType, I);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {
     unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
new file mode 100644
index 00000000000000..b85fafa95d9b21
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
@@ -0,0 +1,173 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Make sure spirv operation function calls for all are generated.
+
+; CHECK: OpMemoryModel Logical GLSL450
+; CHECK: %[[#int_64:]] = OpTypeInt 64 0
+; CHECK: %[[#bool:]] = OpTypeBool
+; CHECK: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK: %[[#int_16:]] = OpTypeInt 16 0
+; CHECK: %[[#float_64:]] = OpTypeFloat 64
+; CHECK: %[[#float_32:]] = OpTypeFloat 32
+; CHECK: %[[#float_16:]] = OpTypeFloat 16
+; CHECK: %[[#vec_1_4:]] = OpTypeVector %[[#bool]] 4
+; CHECK: %[[#vec_16_4:]] = OpTypeVector %[[#int_16]] 4
+; CHECK: %[[#vec_32_4:]] = OpTypeVector %[[#int_32]] 4
+; CHECK: %[[#vec_64_4:]] = OpTypeVector %[[#int_64]] 4
+; CHECK: %[[#vec_float_16_4:]] = OpTypeVector %[[#float_16]] 4
+; CHECK: %[[#vec_float_32_4:]] = OpTypeVector %[[#float_32]] 4
+; CHECK: %[[#vec_float_64_4:]] = OpTypeVector %[[#float_64]] 4
+
+; CHECK: %[[#const_int_64:]] = OpConstant %[[#int_64]] 0
+; CHECK: %[[#const_int_32:]] = OpConstant %[[#int_32]] 0
+; CHECK: %[[#const_int_16:]] = OpConstant %[[#int_16]] 0
+; CHECK: %[[#const_float_64:]] = OpConstant %[[#float_64]] 0
+; CHECK: %[[#const_bool:]] = OpConstantNull %[[#bool]]
+
+; CHECK: %[[#vec_zero_const_i1_4:]] = OpConstantComposite %[[#vec_1_4:]] %[[#const_bool:]] %[[#const_bool:]] %[[#const_bool:]] %[[#const_bool:]]
+; CHECK: %[[#vec_zero_const_i16_4:]] = OpConstantComposite %[[#vec_16_4:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]]
+; CHECK: %[[#vec_zero_const_i32_4:]] = OpConstantComposite %[[#vec_32_4:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]]
+; CHECK: %[[#vec_zero_const_i64_4:]] = OpConstantComposite %[[#vec_64_4:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]]
+
+; CHECK: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
+; CHECK: %[[#vec_zero_const_f16_4:]] = OpConstantComposite %[[#vec_float_16_4:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]]
+; CHECK: %[[#const_float_32:]] =  OpConstant %[[#float_32:]] 0
+; CHECK: %[[#vec_zero_const_f32_4:]] = OpConstantComposite %[[#vec_float_32_4:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]]
+; CHECK: %[[#vec_zero_const_f64_4:]] = OpConstantComposite %[[#vec_float_64_4:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]]
+
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  ; CHECK: %[[#]] = OpINotEqual %[[#int_64:]] %[[#]] %[[#const_int_64:]]
+  %hlsl.all = call i1 @llvm.spv.all.i64(i64 %0)
+  ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  ; CHECK: %[[#]] = OpINotEqual %[[#int_32:]] %[[#]] %[[#const_int_32:]]
+  %hlsl.all = call i1 @llvm.spv.all.i32(i32 %0)
+  ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  ; CHECK: %[[#]] = OpINotEqual %[[#int_16:]] %[[#]] %[[#const_int_16:]]
+  %hlsl.all = call i1 @llvm.spv.all.i16(i16 %0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#float_64:]] %[[#]] %[[#const_float_64:]]
+  %hlsl.all = call i1 @llvm.spv.all.f64(double %0)
+  ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#float_32:]] %[[#]] %[[#const_float_32:]]
+  %hlsl.all = call i1 @llvm.spv.all.f32(float %0)
+  ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#float_16:]] %[[#]] %[[#const_float_16:]]
+  %hlsl.all = call i1 @llvm.spv.all.f16(half %0)
+  ret i1 %hlsl.all
+}
+
+
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+  ; CHECK: %[[#boolVecNotEq:]] = OpINotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_i1_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_1_4:]] %[[#boolVecNotEq:]]
+  %hlsl.all = call i1 @llvm.spv.all.v4i1(<4 x i1> %p0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_short4(<4 x i16> noundef %p0) {
+entry:
+  ; CHECK: %[[#shortVecNotEq:]] = OpINotEqual %[[#vec_16_4:]] %[[#]] %[[#vec_zero_const_i16_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_16_4:]] %[[#shortVecNotEq:]]
+  %hlsl.all = call i1 @llvm.spv.all.v4i16(<4 x i16> %p0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_int4(<4 x i32> noundef %p0) {
+entry:
+  ; CHECK: %[[#i32VecNotEq:]] = OpINotEqual %[[#vec_32_4:]] %[[#]] %[[#vec_zero_const_i32_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_32_4:]] %[[#i32VecNotEq:]]
+  %hlsl.all = call i1 @llvm.spv.all.v4i32(<4 x i32> %p0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_int64_t4(<4 x i64> noundef %p0) {
+entry:
+  ; CHECK: %[[#i64VecNotEq:]] = OpINotEqual %[[#vec_64_4:]] %[[#]] %[[#vec_zero_const_i64_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_64_4:]] %[[#i64VecNotEq]]
+  %hlsl.all = call i1 @llvm.spv.all.v4i64(<4 x i64> %p0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_half4(<4 x half> noundef %p0) {
+entry:
+  ; CHECK: %[[#f16VecNotEq:]] = OpFOrdNotEqual %[[#vec_float_16_4:]] %[[#]] %[[#vec_zero_const_f16_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_float_16_4]] %[[#f16VecNotEq:]]
+  %hlsl.all = call i1 @llvm.spv.all.v4f16(<4 x half> %p0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_float4(<4 x float> noundef %p0) {
+entry:
+  ; CHECK: %[[#f32VecNotEq:]] = OpFOrdNotEqual %[[#vec_float_32_4:]] %[[#]] %[[#vec_zero_const_f32_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_float_32_4:]] %[[#f32VecNotEq:]]
+  %hlsl.all = call i1 @llvm.spv.all.v4f32(<4 x float> %p0)
+  ret i1 %hlsl.all
+}
+
+define noundef i1 @all_double4(<4 x double> noundef %p0) {
+entry:
+  ; CHECK: %[[#f64VecNotEq:]] = OpFOrdNotEqual %[[#vec_float_64_4:]] %[[#]] %[[#vec_zero_const_f64_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#vec_float_64_4:]] %[[#f64VecNotEq:]]
+  %hlsl.all = call i1 @llvm.spv.all.v4f64(<4 x double> %p0)
+  ret i1 %hlsl.all
+}
+
+declare i1 @llvm.spv.all.v4f16(<4 x half>)
+declare i1 @llvm.spv.all.v4f32(<4 x float>)
+declare i1 @llvm.spv.all.v4f64(<4 x double>)
+declare i1 @llvm.spv.all.v4i1(<4 x i1>)
+declare i1 @llvm.spv.all.v4i16(<4 x i16>)
+declare i1 @llvm.spv.all.v4i32(<4 x i32>)
+declare i1 @llvm.spv.all.v4i64(<4 x i64>)
+declare i1 @llvm.spv.all.i1(i1)
+declare i1 @llvm.spv.all.i16(i16)
+declare i1 @llvm.spv.all.i32(i32)
+declare i1 @llvm.spv.all.i64(i64)
+declare i1 @llvm.spv.all.f16(half)
+declare i1 @llvm.spv.all.f32(float)
+declare i1 @llvm.spv.all.f64(double)
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll
new file mode 100644
index 00000000000000..b8c39c7aa5f31f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll
@@ -0,0 +1,17 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpMemoryModel Logical GLSL450
+; CHECK: OpName %[[#all_bool_arg:]] "a"
+; CHECK: OpName %[[#all_bool_ret:]] "hlsl.all"
+; CHECK: %[[#bool:]] = OpTypeBool
+
+define noundef i1 @all_bool(i1 noundef %a) {
+entry:
+  ; CHECK: %[[#all_bool_arg:]] = OpFunctionParameter %[[#bool:]]
+  ; CHECK: OpLoad %[[#bool:]] %[[#all_bool_arg:]]
+  ; CHECK: OpStore %[[#all_bool_ret:]] %[[#all_bool_arg:]]
+  ; CHECK: OpReturnValue %[[#all_bool_ret:]]
+  %hlsl.all = call i1 @llvm.spv.all.i1(i1 %a)
+  ret i1 %hlsl.all
+}
\ No newline at end of file

>From bbee0ba2eaaa6bd1d7203f105da25d6ac2f19fb7 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Sat, 6 Apr 2024 12:58:31 -0400
Subject: [PATCH 3/8] fix bug that was making float32\16 registers into
 int32\int16

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 20 +++++++++++--------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  2 +-
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6360ce8b004de5..ed1474d25c65b6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -355,14 +355,15 @@ Register SPIRVGlobalRegistry::getOrCreateFloatCompositeOrNull(
   Register Res = DT.find(CA, CurMF);
   bool IsNull = Val->isNullValue() && ZeroAsNull;
   if (!Res.isValid()) {
-    SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
     // error on validation.
     // TODO: can moved below once sorting of types/consts/defs is implemented.
     Register SpvScalConst;
-    if (!IsNull)
+    if (!IsNull) {
+      SPIRVType *SpvBaseType = getOrCreateSPIRVFloatType(BitWidth, I, TII);
       SpvScalConst = getOrCreateConstFP(dyn_cast<ConstantFP>(Val)->getValue(),
                                         I, SpvBaseType, TII, ZeroAsNull);
+    }
     // TODO: maybe use bitwidth of base type.
     LLT LLTy = LLT::scalar(32);
     Register SpvVecConst =
@@ -401,14 +402,15 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
   // If no values are attached, the composite is null constant.
   bool IsNull = Val->isNullValue() && ZeroAsNull;
   if (!Res.isValid()) {
-    SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
     // error on validation.
     // TODO: can moved below once sorting of types/consts/defs is implemented.
     Register SpvScalConst;
-    if (!IsNull)
+    if (!IsNull) {
+      SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
       SpvScalConst = getOrCreateConstInt(Val->getUniqueInteger().getSExtValue(),
                                          I, SpvBaseType, TII);
+    }
     // TODO: maybe use bitwidth of base type.
     LLT LLTy = LLT::scalar(32);
     Register SpvVecConst =
@@ -1243,8 +1245,7 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
                                                      MachineInstr &I,
                                                      const SPIRVInstrInfo &TII,
-                                                     unsigned SPIRVOPcode) {
-  Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
+                                                     unsigned SPIRVOPcode, Type *LLVMTy) {
   Register Reg = DT.find(LLVMTy, CurMF);
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
@@ -1259,11 +1260,14 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
-  return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt);
+  Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
+  return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeInt, LLVMTy);
 }
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
-  return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat);
+  LLVMContext &Ctx = CurMF->getFunction().getContext();
+  Type *LLVMTy = Type::getFloatTy(Ctx);
+  return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
 }
 
 SPIRVType *
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 4a9670b80ed42f..2f181c30537f80 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -443,7 +443,7 @@ class SPIRVGlobalRegistry {
                                          const SPIRVInstrInfo &TII);
   SPIRVType *getOrCreateSPIRVType(unsigned BitWidth, MachineInstr &I,
                                   const SPIRVInstrInfo &TII,
-                                  unsigned SPIRVOPcode);
+                                  unsigned SPIRVOPcode, Type *LLVMTy);
   SPIRVType *getOrCreateSPIRVFloatType(unsigned BitWidth, MachineInstr &I,
                                        const SPIRVInstrInfo &TII);
   SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);

>From 5d3f7b5d9655a08f7e4ab5d88bf3c9fafab2ea6e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Sun, 7 Apr 2024 00:51:12 -0400
Subject: [PATCH 4/8] first version that passes the validator

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 42 ++++++++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  4 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 59 ++++++++++++++-----
 .../test/CodeGen/SPIRV/hlsl-intrinsics/all.ll | 44 +++++++-------
 4 files changed, 98 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ed1474d25c65b6..15d087bf3dd36b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -20,7 +20,9 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/Type.h"
 #include "llvm/IR/TypedPointerType.h"
 #include "llvm/Support/Casting.h"
 #include <cassert>
@@ -176,13 +178,14 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
   return std::make_tuple(Res, CI, NewInstr);
 }
 
-std::tuple<Register, ConstantFP *, bool>
+std::tuple<Register, ConstantFP *, bool, unsigned>
 SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
                                               MachineIRBuilder *MIRBuilder,
                                               MachineInstr *I,
                                               const SPIRVInstrInfo *TII) {
   const Type *LLVMFloatTy;
   LLVMContext &Ctx = CurMF->getFunction().getContext();
+  unsigned BitWidth = 32;
   if (SpvType)
     LLVMFloatTy = getTypeForSPIRVType(SpvType);
   else {
@@ -195,7 +198,9 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
   auto *const CI = ConstantFP::get(Ctx, Val);
   Register Res = DT.find(CI, CurMF);
   if (!Res.isValid()) {
-    unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
+    if (SpvType)
+      BitWidth = getScalarOrVectorBitWidth(SpvType);
+
     LLT LLTy = LLT::scalar(32);
     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
@@ -206,7 +211,7 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
     DT.add(CI, CurMF, Res);
     NewInstr = true;
   }
-  return std::make_tuple(Res, CI, NewInstr);
+  return std::make_tuple(Res, CI, NewInstr, BitWidth);
 }
 
 Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
@@ -217,7 +222,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
   ConstantFP *CI;
   Register Res;
   bool New;
-  std::tie(Res, CI, New) =
+  unsigned BitWidth;
+  std::tie(Res, CI, New, BitWidth) =
       getOrCreateConstFloatReg(Val, SpvType, nullptr, &I, &TII);
   // If we have found Res register which is defined by the passed G_CONSTANT
   // machine instruction, a new constant instruction should be created.
@@ -234,9 +240,13 @@ Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
               .addDef(Res)
               .addUse(getSPIRVTypeID(SpvType));
-    addNumImm(CI->getValueAPF().bitcastToAPInt(), MIB);
+    addNumImm(
+        APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
+        MIB);
   }
-
+  const auto &ST = CurMF->getSubtarget();
+  constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
+                                   *ST.getRegisterInfo(), *ST.getRegBankInfo());
   return Res;
 }
 
@@ -459,7 +469,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(uint64_t Val,
                                        ZeroAsNull);
 }
 
-Register SPIRVGlobalRegistry::getOrCreateConstVector(double Val,
+Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
                                                      MachineInstr &I,
                                                      SPIRVType *SpvType,
                                                      const SPIRVInstrInfo &TII,
@@ -1245,7 +1255,8 @@ SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
                                                      MachineInstr &I,
                                                      const SPIRVInstrInfo &TII,
-                                                     unsigned SPIRVOPcode, Type *LLVMTy) {
+                                                     unsigned SPIRVOPcode,
+                                                     Type *LLVMTy) {
   Register Reg = DT.find(LLVMTy, CurMF);
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
@@ -1266,7 +1277,20 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVFloatType(
     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
   LLVMContext &Ctx = CurMF->getFunction().getContext();
-  Type *LLVMTy = Type::getFloatTy(Ctx);
+  Type *LLVMTy;
+  switch (BitWidth) {
+  case 16:
+    LLVMTy = Type::getHalfTy(Ctx);
+    break;
+  case 32:
+    LLVMTy = Type::getFloatTy(Ctx);
+    break;
+  case 64:
+    LLVMTy = Type::getDoubleTy(Ctx);
+    break;
+  default:
+    llvm_unreachable("Bit width is of unexpected size.");
+  }
   return getOrCreateSPIRVType(BitWidth, I, TII, SPIRV::OpTypeFloat, LLVMTy);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 2f181c30537f80..63feda2338d6b2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -370,7 +370,7 @@ class SPIRVGlobalRegistry {
   std::tuple<Register, ConstantInt *, bool> getOrCreateConstIntReg(
       uint64_t Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
       MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
-  std::tuple<Register, ConstantFP *, bool> getOrCreateConstFloatReg(
+  std::tuple<Register, ConstantFP *, bool, unsigned> getOrCreateConstFloatReg(
       APFloat Val, SPIRVType *SpvType, MachineIRBuilder *MIRBuilder,
       MachineInstr *I = nullptr, const SPIRVInstrInfo *TII = nullptr);
   SPIRVType *finishCreatingSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType);
@@ -409,7 +409,7 @@ class SPIRVGlobalRegistry {
   Register getOrCreateConstVector(uint64_t Val, MachineInstr &I,
                                   SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                   bool ZeroAsNull = true);
-  Register getOrCreateConstVector(double Val, MachineInstr &I,
+  Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
                                   SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                   bool ZeroAsNull = true);
   Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b7eea2bdba0bab..72c9a6c4fcd936 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1167,8 +1167,9 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   MachineBasicBlock &BB = *I.getParent();
   Register InputRegister = I.getOperand(2).getReg();
   SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
-
-  if (InputType->getOpcode() == SPIRV::OpTypeBool) {
+  bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
+  bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
+  if (IsBoolTy && !IsVectorTy) {
     Register LoadReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
     BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
         .addDef(ResVReg)
@@ -1183,25 +1184,36 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
   unsigned SpirvNotEqualId =
       IsFloatTy ? SPIRV::OpFOrdNotEqual : SPIRV::OpINotEqual;
+  SPIRVType *SpvBoolScalarTy = GR.getOrCreateSPIRVBoolType(I, TII);
+  SPIRVType *SpvBoolTy = SpvBoolScalarTy;
+  Register NotEqualReg = ResVReg;
+
+  if (IsVectorTy) {
+    NotEqualReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    const unsigned NumElts = InputType->getOperand(2).getImm();
+    SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
+  }
 
-  bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
-  Register ConstCompositeZeroReg =
-      IsFloatTy ? buildZerosValF(InputType, I) : buildZerosVal(InputType, I);
-  Register NotEqualReg =
-      IsVectorTy ? MRI->createVirtualRegister(&SPIRV::IDRegClass) : ResVReg;
-  BuildMI(BB, I, I.getDebugLoc(), TII.get(SpirvNotEqualId))
-      .addDef(NotEqualReg)
-      .addUse(GR.getSPIRVTypeID(InputType))
-      .addUse(I.getOperand(2).getReg())
-      .addUse(ConstCompositeZeroReg)
-      .constrainAllUses(TII, TRI, RBI);
+  if (!IsBoolTy) {
+    Register ConstCompositeZeroReg =
+        IsFloatTy ? buildZerosValF(InputType, I) : buildZerosVal(InputType, I);
+
+    BuildMI(BB, I, I.getDebugLoc(), TII.get(SpirvNotEqualId))
+        .addDef(NotEqualReg)
+        .addUse(GR.getSPIRVTypeID(SpvBoolTy))
+        .addUse(InputRegister)
+        .addUse(ConstCompositeZeroReg)
+        .constrainAllUses(TII, TRI, RBI);
+  } else {
+    NotEqualReg = InputRegister;
+  }
 
   if (!IsVectorTy)
     return true;
 
   return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpAll))
       .addDef(ResVReg)
-      .addUse(GR.getSPIRVTypeID(InputType))
+      .addUse(GR.getSPIRVTypeID(SpvBoolScalarTy))
       .addUse(NotEqualReg)
       .constrainAllUses(TII, TRI, RBI);
 
@@ -1452,13 +1464,28 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
 }
 
+APFloat getZeroFP(const Type *LLVMFloatTy) {
+  if (!LLVMFloatTy)
+    return APFloat::getZero(APFloat::IEEEsingle());
+  switch (LLVMFloatTy->getScalarType()->getTypeID()) {
+  case Type::HalfTyID:
+    return APFloat::getZero(APFloat::IEEEhalf());
+  default:
+  case Type::FloatTyID:
+    return APFloat::getZero(APFloat::IEEEsingle());
+  case Type::DoubleTyID:
+    return APFloat::getZero(APFloat::IEEEdouble());
+  }
+}
+
 Register SPIRVInstructionSelector::buildZerosValF(const SPIRVType *ResType,
                                                   MachineInstr &I) const {
   // OpenCL uses nulls for Zero. In HLSL we don't use null constants.
   bool ZeroAsNull = STI.isOpenCLEnv();
+  APFloat VZero = getZeroFP(GR.getTypeForSPIRVType(ResType));
   if (ResType->getOpcode() == SPIRV::OpTypeVector)
-    return GR.getOrCreateConstVector(0.0, I, ResType, TII, ZeroAsNull);
-  return GR.getOrCreateConstFP(APFloat(0.0), I, ResType, TII, ZeroAsNull);
+    return GR.getOrCreateConstVector(VZero, I, ResType, TII, ZeroAsNull);
+  return GR.getOrCreateConstFP(VZero, I, ResType, TII, ZeroAsNull);
 }
 
 Register SPIRVInstructionSelector::buildOnesVal(bool AllOnes,
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
index b85fafa95d9b21..9b79ce8c6aaca1 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
@@ -23,16 +23,13 @@
 ; CHECK: %[[#const_int_32:]] = OpConstant %[[#int_32]] 0
 ; CHECK: %[[#const_int_16:]] = OpConstant %[[#int_16]] 0
 ; CHECK: %[[#const_float_64:]] = OpConstant %[[#float_64]] 0
-; CHECK: %[[#const_bool:]] = OpConstantNull %[[#bool]]
+; CHECK: %[[#const_float_32:]] =  OpConstant %[[#float_32:]] 0
+; CHECK: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
 
-; CHECK: %[[#vec_zero_const_i1_4:]] = OpConstantComposite %[[#vec_1_4:]] %[[#const_bool:]] %[[#const_bool:]] %[[#const_bool:]] %[[#const_bool:]]
 ; CHECK: %[[#vec_zero_const_i16_4:]] = OpConstantComposite %[[#vec_16_4:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]]
 ; CHECK: %[[#vec_zero_const_i32_4:]] = OpConstantComposite %[[#vec_32_4:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]]
 ; CHECK: %[[#vec_zero_const_i64_4:]] = OpConstantComposite %[[#vec_64_4:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]]
-
-; CHECK: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
 ; CHECK: %[[#vec_zero_const_f16_4:]] = OpConstantComposite %[[#vec_float_16_4:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]]
-; CHECK: %[[#const_float_32:]] =  OpConstant %[[#float_32:]] 0
 ; CHECK: %[[#vec_zero_const_f32_4:]] = OpConstantComposite %[[#vec_float_32_4:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]]
 ; CHECK: %[[#vec_zero_const_f64_4:]] = OpConstantComposite %[[#vec_float_64_4:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]]
 
@@ -41,7 +38,7 @@ entry:
   %p0.addr = alloca i64, align 8
   store i64 %p0, ptr %p0.addr, align 8
   %0 = load i64, ptr %p0.addr, align 8
-  ; CHECK: %[[#]] = OpINotEqual %[[#int_64:]] %[[#]] %[[#const_int_64:]]
+  ; CHECK: %[[#]] = OpINotEqual %[[#bool:]] %[[#]] %[[#const_int_64:]]
   %hlsl.all = call i1 @llvm.spv.all.i64(i64 %0)
   ret i1 %hlsl.all
 }
@@ -52,7 +49,7 @@ entry:
   %p0.addr = alloca i32, align 4
   store i32 %p0, ptr %p0.addr, align 4
   %0 = load i32, ptr %p0.addr, align 4
-  ; CHECK: %[[#]] = OpINotEqual %[[#int_32:]] %[[#]] %[[#const_int_32:]]
+  ; CHECK: %[[#]] = OpINotEqual %[[#bool:]] %[[#]] %[[#const_int_32:]]
   %hlsl.all = call i1 @llvm.spv.all.i32(i32 %0)
   ret i1 %hlsl.all
 }
@@ -63,7 +60,7 @@ entry:
   %p0.addr = alloca i16, align 2
   store i16 %p0, ptr %p0.addr, align 2
   %0 = load i16, ptr %p0.addr, align 2
-  ; CHECK: %[[#]] = OpINotEqual %[[#int_16:]] %[[#]] %[[#const_int_16:]]
+  ; CHECK: %[[#]] = OpINotEqual %[[#bool:]] %[[#]] %[[#const_int_16:]]
   %hlsl.all = call i1 @llvm.spv.all.i16(i16 %0)
   ret i1 %hlsl.all
 }
@@ -73,7 +70,7 @@ entry:
   %p0.addr = alloca double, align 8
   store double %p0, ptr %p0.addr, align 8
   %0 = load double, ptr %p0.addr, align 8
-  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#float_64:]] %[[#]] %[[#const_float_64:]]
+  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#bool:]] %[[#]] %[[#const_float_64:]]
   %hlsl.all = call i1 @llvm.spv.all.f64(double %0)
   ret i1 %hlsl.all
 }
@@ -84,7 +81,7 @@ entry:
   %p0.addr = alloca float, align 4
   store float %p0, ptr %p0.addr, align 4
   %0 = load float, ptr %p0.addr, align 4
-  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#float_32:]] %[[#]] %[[#const_float_32:]]
+  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#bool:]] %[[#]] %[[#const_float_32:]]
   %hlsl.all = call i1 @llvm.spv.all.f32(float %0)
   ret i1 %hlsl.all
 }
@@ -95,7 +92,7 @@ entry:
   %p0.addr = alloca half, align 2
   store half %p0, ptr %p0.addr, align 2
   %0 = load half, ptr %p0.addr, align 2
-  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#float_16:]] %[[#]] %[[#const_float_16:]]
+  ; CHECK: %[[#]] = OpFOrdNotEqual %[[#bool:]] %[[#]] %[[#const_float_16:]]
   %hlsl.all = call i1 @llvm.spv.all.f16(half %0)
   ret i1 %hlsl.all
 }
@@ -103,7 +100,6 @@ entry:
 
 define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
 entry:
-  ; CHECK: %[[#boolVecNotEq:]] = OpINotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_i1_4:]]
   ; CHECK: %[[#]] = OpAll %[[#vec_1_4:]] %[[#boolVecNotEq:]]
   %hlsl.all = call i1 @llvm.spv.all.v4i1(<4 x i1> %p0)
   ret i1 %hlsl.all
@@ -111,48 +107,48 @@ entry:
 
 define noundef i1 @all_short4(<4 x i16> noundef %p0) {
 entry:
-  ; CHECK: %[[#shortVecNotEq:]] = OpINotEqual %[[#vec_16_4:]] %[[#]] %[[#vec_zero_const_i16_4:]]
-  ; CHECK: %[[#]] = OpAll %[[#vec_16_4:]] %[[#shortVecNotEq:]]
+  ; CHECK: %[[#shortVecNotEq:]] = OpINotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_i16_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#shortVecNotEq:]]
   %hlsl.all = call i1 @llvm.spv.all.v4i16(<4 x i16> %p0)
   ret i1 %hlsl.all
 }
 
 define noundef i1 @all_int4(<4 x i32> noundef %p0) {
 entry:
-  ; CHECK: %[[#i32VecNotEq:]] = OpINotEqual %[[#vec_32_4:]] %[[#]] %[[#vec_zero_const_i32_4:]]
-  ; CHECK: %[[#]] = OpAll %[[#vec_32_4:]] %[[#i32VecNotEq:]]
+  ; CHECK: %[[#i32VecNotEq:]] = OpINotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_i32_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#i32VecNotEq:]]
   %hlsl.all = call i1 @llvm.spv.all.v4i32(<4 x i32> %p0)
   ret i1 %hlsl.all
 }
 
 define noundef i1 @all_int64_t4(<4 x i64> noundef %p0) {
 entry:
-  ; CHECK: %[[#i64VecNotEq:]] = OpINotEqual %[[#vec_64_4:]] %[[#]] %[[#vec_zero_const_i64_4:]]
-  ; CHECK: %[[#]] = OpAll %[[#vec_64_4:]] %[[#i64VecNotEq]]
+  ; CHECK: %[[#i64VecNotEq:]] = OpINotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_i64_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#i64VecNotEq]]
   %hlsl.all = call i1 @llvm.spv.all.v4i64(<4 x i64> %p0)
   ret i1 %hlsl.all
 }
 
 define noundef i1 @all_half4(<4 x half> noundef %p0) {
 entry:
-  ; CHECK: %[[#f16VecNotEq:]] = OpFOrdNotEqual %[[#vec_float_16_4:]] %[[#]] %[[#vec_zero_const_f16_4:]]
-  ; CHECK: %[[#]] = OpAll %[[#vec_float_16_4]] %[[#f16VecNotEq:]]
+  ; CHECK: %[[#f16VecNotEq:]] = OpFOrdNotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_f16_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#bool]] %[[#f16VecNotEq:]]
   %hlsl.all = call i1 @llvm.spv.all.v4f16(<4 x half> %p0)
   ret i1 %hlsl.all
 }
 
 define noundef i1 @all_float4(<4 x float> noundef %p0) {
 entry:
-  ; CHECK: %[[#f32VecNotEq:]] = OpFOrdNotEqual %[[#vec_float_32_4:]] %[[#]] %[[#vec_zero_const_f32_4:]]
-  ; CHECK: %[[#]] = OpAll %[[#vec_float_32_4:]] %[[#f32VecNotEq:]]
+  ; CHECK: %[[#f32VecNotEq:]] = OpFOrdNotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_f32_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#f32VecNotEq:]]
   %hlsl.all = call i1 @llvm.spv.all.v4f32(<4 x float> %p0)
   ret i1 %hlsl.all
 }
 
 define noundef i1 @all_double4(<4 x double> noundef %p0) {
 entry:
-  ; CHECK: %[[#f64VecNotEq:]] = OpFOrdNotEqual %[[#vec_float_64_4:]] %[[#]] %[[#vec_zero_const_f64_4:]]
-  ; CHECK: %[[#]] = OpAll %[[#vec_float_64_4:]] %[[#f64VecNotEq:]]
+  ; CHECK: %[[#f64VecNotEq:]] = OpFOrdNotEqual %[[#vec_1_4:]] %[[#]] %[[#vec_zero_const_f64_4:]]
+  ; CHECK: %[[#]] = OpAll %[[#bool:]] %[[#f64VecNotEq:]]
   %hlsl.all = call i1 @llvm.spv.all.v4f64(<4 x double> %p0)
   ret i1 %hlsl.all
 }

>From 40d9a06ca67c752de5ff7698127d6f3dcebaee73 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Sun, 7 Apr 2024 18:37:03 -0400
Subject: [PATCH 5/8] address the bool scalar case with a TargetOpcode::COPY
 add opencl tests to make sure both paths are valid spirv

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 14 ++---
 .../test/CodeGen/SPIRV/hlsl-intrinsics/all.ll | 57 +++++++++++++------
 .../CodeGen/SPIRV/hlsl-intrinsics/all_next.ll | 17 ------
 3 files changed, 46 insertions(+), 42 deletions(-)
 delete mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 72c9a6c4fcd936..68f339b305b4fe 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineModuleInfoImpls.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/Debug.h"
 
@@ -1170,15 +1171,12 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
   bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
   if (IsBoolTy && !IsVectorTy) {
-    Register LoadReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
-    BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))
+    assert(ResVReg == I.getOperand(0).getReg());
+    return BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                   TII.get(TargetOpcode::COPY))
         .addDef(ResVReg)
-        .addUse(GR.getSPIRVTypeID(InputType))
-        .addUse(InputRegister);
-
-    return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpStore))
-        .addUse(LoadReg)
-        .addUse(InputRegister);
+        .addUse(InputRegister)
+        .constrainAllUses(TII, TRI, RBI);
   }
 
   bool IsFloatTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeFloat);
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
index 9b79ce8c6aaca1..c6d02c29709930 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
@@ -1,9 +1,12 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-HLSL
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-OCL
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 ; Make sure spirv operation function calls for all are generated.
 
-; CHECK: OpMemoryModel Logical GLSL450
+; CHECK-HLSL: OpMemoryModel Logical GLSL450
+; CHECK-OCL: OpMemoryModel Physical32 OpenCL
+; CHECK: OpName %[[#all_bool_arg:]] "a"
 ; CHECK: %[[#int_64:]] = OpTypeInt 64 0
 ; CHECK: %[[#bool:]] = OpTypeBool
 ; CHECK: %[[#int_32:]] = OpTypeInt 32 0
@@ -19,19 +22,31 @@
 ; CHECK: %[[#vec_float_32_4:]] = OpTypeVector %[[#float_32]] 4
 ; CHECK: %[[#vec_float_64_4:]] = OpTypeVector %[[#float_64]] 4
 
-; CHECK: %[[#const_int_64:]] = OpConstant %[[#int_64]] 0
-; CHECK: %[[#const_int_32:]] = OpConstant %[[#int_32]] 0
-; CHECK: %[[#const_int_16:]] = OpConstant %[[#int_16]] 0
-; CHECK: %[[#const_float_64:]] = OpConstant %[[#float_64]] 0
-; CHECK: %[[#const_float_32:]] =  OpConstant %[[#float_32:]] 0
-; CHECK: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
-
-; CHECK: %[[#vec_zero_const_i16_4:]] = OpConstantComposite %[[#vec_16_4:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]]
-; CHECK: %[[#vec_zero_const_i32_4:]] = OpConstantComposite %[[#vec_32_4:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]]
-; CHECK: %[[#vec_zero_const_i64_4:]] = OpConstantComposite %[[#vec_64_4:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]]
-; CHECK: %[[#vec_zero_const_f16_4:]] = OpConstantComposite %[[#vec_float_16_4:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]]
-; CHECK: %[[#vec_zero_const_f32_4:]] = OpConstantComposite %[[#vec_float_32_4:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]]
-; CHECK: %[[#vec_zero_const_f64_4:]] = OpConstantComposite %[[#vec_float_64_4:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]]
+; CHECK-HLSL: %[[#const_int_64:]] = OpConstant %[[#int_64]] 0
+; CHECK-HLSL: %[[#const_int_32:]] = OpConstant %[[#int_32]] 0
+; CHECK-HLSL: %[[#const_int_16:]] = OpConstant %[[#int_16]] 0
+; CHECK-HLSL: %[[#const_float_64:]] = OpConstant %[[#float_64]] 0
+; CHECK-HLSL: %[[#const_float_32:]] = OpConstant %[[#float_32:]] 0
+; CHECK-HLSL: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
+; CHECK-HLSL: %[[#vec_zero_const_i16_4:]] = OpConstantComposite %[[#vec_16_4:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]]
+; CHECK-HLSL: %[[#vec_zero_const_i32_4:]] = OpConstantComposite %[[#vec_32_4:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]]
+; CHECK-HLSL: %[[#vec_zero_const_i64_4:]] = OpConstantComposite %[[#vec_64_4:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]]
+; CHECK-HLSL: %[[#vec_zero_const_f16_4:]] = OpConstantComposite %[[#vec_float_16_4:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]]
+; CHECK-HLSL: %[[#vec_zero_const_f32_4:]] = OpConstantComposite %[[#vec_float_32_4:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]]
+; CHECK-HLSL: %[[#vec_zero_const_f64_4:]] = OpConstantComposite %[[#vec_float_64_4:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]]
+
+; CHECK-OCL: %[[#const_int_64:]] = OpConstantNull %[[#int_64]]
+; CHECK-OCL: %[[#const_int_32:]] = OpConstantNull %[[#int_32]]
+; CHECK-OCL: %[[#const_int_16:]] = OpConstantNull %[[#int_16]]
+; CHECK-OCL: %[[#const_float_64:]] = OpConstantNull %[[#float_64]] 
+; CHECK-OCL: %[[#const_float_32:]] = OpConstantNull %[[#float_32:]]
+; CHECK-OCL: %[[#const_float_16:]] = OpConstantNull %[[#float_16:]]
+; CHECK-OCL: %[[#vec_zero_const_i16_4:]] = OpConstantNull %[[#vec_16_4:]]
+; CHECK-OCL: %[[#vec_zero_const_i32_4:]] = OpConstantNull %[[#vec_32_4:]]
+; CHECK-OCL: %[[#vec_zero_const_i64_4:]] = OpConstantNull %[[#vec_64_4:]]
+; CHECK-OCL: %[[#vec_zero_const_f16_4:]] = OpConstantNull %[[#vec_float_16_4:]]
+; CHECK-OCL: %[[#vec_zero_const_f32_4:]] = OpConstantNull %[[#vec_float_32_4:]]
+; CHECK-OCL: %[[#vec_zero_const_f64_4:]] = OpConstantNull %[[#vec_float_64_4:]]
 
 define noundef i1 @all_int64_t(i64 noundef %p0) {
 entry:
@@ -153,6 +168,14 @@ entry:
   ret i1 %hlsl.all
 }
 
+define noundef i1 @all_bool(i1 noundef %a) {
+entry:
+  ; CHECK: %[[#all_bool_arg:]] = OpFunctionParameter %[[#bool:]]
+  ; CHECK: OpReturnValue %[[#all_bool_arg:]]
+  %hlsl.all = call i1 @llvm.spv.all.i1(i1 %a)
+  ret i1 %hlsl.all
+}
+
 declare i1 @llvm.spv.all.v4f16(<4 x half>)
 declare i1 @llvm.spv.all.v4f32(<4 x float>)
 declare i1 @llvm.spv.all.v4f64(<4 x double>)
@@ -166,4 +189,4 @@ declare i1 @llvm.spv.all.i32(i32)
 declare i1 @llvm.spv.all.i64(i64)
 declare i1 @llvm.spv.all.f16(half)
 declare i1 @llvm.spv.all.f32(float)
-declare i1 @llvm.spv.all.f64(double)
\ No newline at end of file
+declare i1 @llvm.spv.all.f64(double)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll
deleted file mode 100644
index b8c39c7aa5f31f..00000000000000
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all_next.ll
+++ /dev/null
@@ -1,17 +0,0 @@
-; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
-; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-
-; CHECK: OpMemoryModel Logical GLSL450
-; CHECK: OpName %[[#all_bool_arg:]] "a"
-; CHECK: OpName %[[#all_bool_ret:]] "hlsl.all"
-; CHECK: %[[#bool:]] = OpTypeBool
-
-define noundef i1 @all_bool(i1 noundef %a) {
-entry:
-  ; CHECK: %[[#all_bool_arg:]] = OpFunctionParameter %[[#bool:]]
-  ; CHECK: OpLoad %[[#bool:]] %[[#all_bool_arg:]]
-  ; CHECK: OpStore %[[#all_bool_ret:]] %[[#all_bool_arg:]]
-  ; CHECK: OpReturnValue %[[#all_bool_ret:]]
-  %hlsl.all = call i1 @llvm.spv.all.i1(i1 %a)
-  ret i1 %hlsl.all
-}
\ No newline at end of file

>From 0ee52263fad8b1ce6edbd79387d47991851aa24e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 8 Apr 2024 01:26:41 -0400
Subject: [PATCH 6/8] remove commented out code

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 68f339b305b4fe..548c154b1677c8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1214,9 +1214,6 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
       .addUse(GR.getSPIRVTypeID(SpvBoolScalarTy))
       .addUse(NotEqualReg)
       .constrainAllUses(TII, TRI, RBI);
-
-  // bool IsSigned = GR.isScalarOrVectorSigned(InputType);
-  // return selectSelect(ResVReg, ResType, OpAllReg, I, IsSigned);
 }
 
 bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,

>From acb203560357241df14dd9fe1a1b8211c4a2ee5f Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 8 Apr 2024 10:13:08 -0400
Subject: [PATCH 7/8] address pr comments

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 15 ++--
 .../test/CodeGen/SPIRV/hlsl-intrinsics/all.ll | 86 +++++++++----------
 2 files changed, 52 insertions(+), 49 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 548c154b1677c8..930fa896ac8e82 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -234,6 +234,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
   Register buildZerosValF(const SPIRVType *ResType, MachineInstr &I) const;
+  static APFloat getZeroFP(const Type *LLVMFloatTy);
   Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
                         MachineInstr &I) const;
 
@@ -1168,6 +1169,10 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   MachineBasicBlock &BB = *I.getParent();
   Register InputRegister = I.getOperand(2).getReg();
   SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+  
+  if(!InputType)
+    report_fatal_error("Input Type could not be determined.");
+
   bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
   bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
   if (IsBoolTy && !IsVectorTy) {
@@ -1187,23 +1192,21 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   Register NotEqualReg = ResVReg;
 
   if (IsVectorTy) {
-    NotEqualReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    NotEqualReg = IsBoolTy ? InputRegister : MRI->createVirtualRegister(&SPIRV::IDRegClass);
     const unsigned NumElts = InputType->getOperand(2).getImm();
     SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
   }
 
   if (!IsBoolTy) {
-    Register ConstCompositeZeroReg =
+    Register ConstZeroReg =
         IsFloatTy ? buildZerosValF(InputType, I) : buildZerosVal(InputType, I);
 
     BuildMI(BB, I, I.getDebugLoc(), TII.get(SpirvNotEqualId))
         .addDef(NotEqualReg)
         .addUse(GR.getSPIRVTypeID(SpvBoolTy))
         .addUse(InputRegister)
-        .addUse(ConstCompositeZeroReg)
+        .addUse(ConstZeroReg)
         .constrainAllUses(TII, TRI, RBI);
-  } else {
-    NotEqualReg = InputRegister;
   }
 
   if (!IsVectorTy)
@@ -1459,7 +1462,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
 }
 
-APFloat getZeroFP(const Type *LLVMFloatTy) {
+APFloat SPIRVInstructionSelector::getZeroFP(const Type *LLVMFloatTy) {
   if (!LLVMFloatTy)
     return APFloat::getZero(APFloat::IEEEsingle());
   switch (LLVMFloatTy->getScalarType()->getTypeID()) {
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
index c6d02c29709930..fc3e0e87a941bf 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll
@@ -4,49 +4,49 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 ; Make sure spirv operation function calls for all are generated.
 
-; CHECK-HLSL: OpMemoryModel Logical GLSL450
-; CHECK-OCL: OpMemoryModel Physical32 OpenCL
-; CHECK: OpName %[[#all_bool_arg:]] "a"
-; CHECK: %[[#int_64:]] = OpTypeInt 64 0
-; CHECK: %[[#bool:]] = OpTypeBool
-; CHECK: %[[#int_32:]] = OpTypeInt 32 0
-; CHECK: %[[#int_16:]] = OpTypeInt 16 0
-; CHECK: %[[#float_64:]] = OpTypeFloat 64
-; CHECK: %[[#float_32:]] = OpTypeFloat 32
-; CHECK: %[[#float_16:]] = OpTypeFloat 16
-; CHECK: %[[#vec_1_4:]] = OpTypeVector %[[#bool]] 4
-; CHECK: %[[#vec_16_4:]] = OpTypeVector %[[#int_16]] 4
-; CHECK: %[[#vec_32_4:]] = OpTypeVector %[[#int_32]] 4
-; CHECK: %[[#vec_64_4:]] = OpTypeVector %[[#int_64]] 4
-; CHECK: %[[#vec_float_16_4:]] = OpTypeVector %[[#float_16]] 4
-; CHECK: %[[#vec_float_32_4:]] = OpTypeVector %[[#float_32]] 4
-; CHECK: %[[#vec_float_64_4:]] = OpTypeVector %[[#float_64]] 4
-
-; CHECK-HLSL: %[[#const_int_64:]] = OpConstant %[[#int_64]] 0
-; CHECK-HLSL: %[[#const_int_32:]] = OpConstant %[[#int_32]] 0
-; CHECK-HLSL: %[[#const_int_16:]] = OpConstant %[[#int_16]] 0
-; CHECK-HLSL: %[[#const_float_64:]] = OpConstant %[[#float_64]] 0
-; CHECK-HLSL: %[[#const_float_32:]] = OpConstant %[[#float_32:]] 0
-; CHECK-HLSL: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
-; CHECK-HLSL: %[[#vec_zero_const_i16_4:]] = OpConstantComposite %[[#vec_16_4:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]]
-; CHECK-HLSL: %[[#vec_zero_const_i32_4:]] = OpConstantComposite %[[#vec_32_4:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]]
-; CHECK-HLSL: %[[#vec_zero_const_i64_4:]] = OpConstantComposite %[[#vec_64_4:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]]
-; CHECK-HLSL: %[[#vec_zero_const_f16_4:]] = OpConstantComposite %[[#vec_float_16_4:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]]
-; CHECK-HLSL: %[[#vec_zero_const_f32_4:]] = OpConstantComposite %[[#vec_float_32_4:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]]
-; CHECK-HLSL: %[[#vec_zero_const_f64_4:]] = OpConstantComposite %[[#vec_float_64_4:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]]
-
-; CHECK-OCL: %[[#const_int_64:]] = OpConstantNull %[[#int_64]]
-; CHECK-OCL: %[[#const_int_32:]] = OpConstantNull %[[#int_32]]
-; CHECK-OCL: %[[#const_int_16:]] = OpConstantNull %[[#int_16]]
-; CHECK-OCL: %[[#const_float_64:]] = OpConstantNull %[[#float_64]] 
-; CHECK-OCL: %[[#const_float_32:]] = OpConstantNull %[[#float_32:]]
-; CHECK-OCL: %[[#const_float_16:]] = OpConstantNull %[[#float_16:]]
-; CHECK-OCL: %[[#vec_zero_const_i16_4:]] = OpConstantNull %[[#vec_16_4:]]
-; CHECK-OCL: %[[#vec_zero_const_i32_4:]] = OpConstantNull %[[#vec_32_4:]]
-; CHECK-OCL: %[[#vec_zero_const_i64_4:]] = OpConstantNull %[[#vec_64_4:]]
-; CHECK-OCL: %[[#vec_zero_const_f16_4:]] = OpConstantNull %[[#vec_float_16_4:]]
-; CHECK-OCL: %[[#vec_zero_const_f32_4:]] = OpConstantNull %[[#vec_float_32_4:]]
-; CHECK-OCL: %[[#vec_zero_const_f64_4:]] = OpConstantNull %[[#vec_float_64_4:]]
+; CHECK-HLSL-DAG: OpMemoryModel Logical GLSL450
+; CHECK-OCL-DAG: OpMemoryModel Physical32 OpenCL
+; CHECK-DAG: OpName %[[#all_bool_arg:]] "a"
+; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#float_64:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#vec_1_4:]] = OpTypeVector %[[#bool]] 4
+; CHECK-DAG: %[[#vec_16_4:]] = OpTypeVector %[[#int_16]] 4
+; CHECK-DAG: %[[#vec_32_4:]] = OpTypeVector %[[#int_32]] 4
+; CHECK-DAG: %[[#vec_64_4:]] = OpTypeVector %[[#int_64]] 4
+; CHECK-DAG: %[[#vec_float_16_4:]] = OpTypeVector %[[#float_16]] 4
+; CHECK-DAG: %[[#vec_float_32_4:]] = OpTypeVector %[[#float_32]] 4
+; CHECK-DAG: %[[#vec_float_64_4:]] = OpTypeVector %[[#float_64]] 4
+
+; CHECK-HLSL-DAG: %[[#const_int_64:]] = OpConstant %[[#int_64]] 0
+; CHECK-HLSL-DAG: %[[#const_int_32:]] = OpConstant %[[#int_32]] 0
+; CHECK-HLSL-DAG: %[[#const_int_16:]] = OpConstant %[[#int_16]] 0
+; CHECK-HLSL-DAG: %[[#const_float_64:]] = OpConstant %[[#float_64]] 0
+; CHECK-HLSL-DAG: %[[#const_float_32:]] = OpConstant %[[#float_32:]] 0
+; CHECK-HLSL-DAG: %[[#const_float_16:]] = OpConstant %[[#float_16:]] 0
+; CHECK-HLSL-DAG: %[[#vec_zero_const_i16_4:]] = OpConstantComposite %[[#vec_16_4:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]] %[[#const_int_16:]]
+; CHECK-HLSL-DAG: %[[#vec_zero_const_i32_4:]] = OpConstantComposite %[[#vec_32_4:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]] %[[#const_int_32:]]
+; CHECK-HLSL-DAG: %[[#vec_zero_const_i64_4:]] = OpConstantComposite %[[#vec_64_4:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]] %[[#const_int_64:]]
+; CHECK-HLSL-DAG: %[[#vec_zero_const_f16_4:]] = OpConstantComposite %[[#vec_float_16_4:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]] %[[#const_float_16:]]
+; CHECK-HLSL-DAG: %[[#vec_zero_const_f32_4:]] = OpConstantComposite %[[#vec_float_32_4:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]] %[[#const_float_32:]]
+; CHECK-HLSL-DAG: %[[#vec_zero_const_f64_4:]] = OpConstantComposite %[[#vec_float_64_4:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]] %[[#const_float_64:]]
+
+; CHECK-OCL-DAG: %[[#const_int_64:]] = OpConstantNull %[[#int_64]]
+; CHECK-OCL-DAG: %[[#const_int_32:]] = OpConstantNull %[[#int_32]]
+; CHECK-OCL-DAG: %[[#const_int_16:]] = OpConstantNull %[[#int_16]]
+; CHECK-OCL-DAG: %[[#const_float_64:]] = OpConstantNull %[[#float_64]] 
+; CHECK-OCL-DAG: %[[#const_float_32:]] = OpConstantNull %[[#float_32:]]
+; CHECK-OCL-DAG: %[[#const_float_16:]] = OpConstantNull %[[#float_16:]]
+; CHECK-OCL-DAG: %[[#vec_zero_const_i16_4:]] = OpConstantNull %[[#vec_16_4:]]
+; CHECK-OCL-DAG: %[[#vec_zero_const_i32_4:]] = OpConstantNull %[[#vec_32_4:]]
+; CHECK-OCL-DAG: %[[#vec_zero_const_i64_4:]] = OpConstantNull %[[#vec_64_4:]]
+; CHECK-OCL-DAG: %[[#vec_zero_const_f16_4:]] = OpConstantNull %[[#vec_float_16_4:]]
+; CHECK-OCL-DAG: %[[#vec_zero_const_f32_4:]] = OpConstantNull %[[#vec_float_32_4:]]
+; CHECK-OCL-DAG: %[[#vec_zero_const_f64_4:]] = OpConstantNull %[[#vec_float_64_4:]]
 
 define noundef i1 @all_int64_t(i64 noundef %p0) {
 entry:

>From 52c45570d07aecb6a6eb6346c83e1c461e53fe2f Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 8 Apr 2024 10:16:10 -0400
Subject: [PATCH 8/8] change getZeroFP from static class func to static func

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 930fa896ac8e82..24882732e941e6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -234,7 +234,6 @@ class SPIRVInstructionSelector : public InstructionSelector {
 
   Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
   Register buildZerosValF(const SPIRVType *ResType, MachineInstr &I) const;
-  static APFloat getZeroFP(const Type *LLVMFloatTy);
   Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
                         MachineInstr &I) const;
 
@@ -1169,8 +1168,8 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   MachineBasicBlock &BB = *I.getParent();
   Register InputRegister = I.getOperand(2).getReg();
   SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
-  
-  if(!InputType)
+
+  if (!InputType)
     report_fatal_error("Input Type could not be determined.");
 
   bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
@@ -1192,7 +1191,8 @@ bool SPIRVInstructionSelector::selectAll(Register ResVReg,
   Register NotEqualReg = ResVReg;
 
   if (IsVectorTy) {
-    NotEqualReg = IsBoolTy ? InputRegister : MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    NotEqualReg = IsBoolTy ? InputRegister
+                           : MRI->createVirtualRegister(&SPIRV::IDRegClass);
     const unsigned NumElts = InputType->getOperand(2).getImm();
     SpvBoolTy = GR.getOrCreateSPIRVVectorType(SpvBoolTy, NumElts, I, TII);
   }
@@ -1462,7 +1462,7 @@ Register SPIRVInstructionSelector::buildZerosVal(const SPIRVType *ResType,
   return GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull);
 }
 
-APFloat SPIRVInstructionSelector::getZeroFP(const Type *LLVMFloatTy) {
+static APFloat getZeroFP(const Type *LLVMFloatTy) {
   if (!LLVMFloatTy)
     return APFloat::getZero(APFloat::IEEEsingle());
   switch (LLVMFloatTy->getScalarType()->getTypeID()) {



More information about the llvm-commits mailing list