[clang] [llvm] [HLSL] implement elementwise firstbithigh hlsl builtin (PR #111082)

Sarah Spall via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 11:12:36 PDT 2024


================
@@ -2717,82 +2717,82 @@ bool SPIRVInstructionSelector::selectFirstBitHigh64(Register ResVReg,
   Register FBHReg = MRI->createVirtualRegister(GR.getRegClass(postCastT));
   Result &= selectFirstBitHigh32(FBHReg, postCastT, I, bitcastReg, IsSigned);
 
-  // 3. check if result of each top 32 bits is == -1
-  // split result vector into vector of high bits and vector of low bits
-  // get high bits
-  // if ResType is a scalar we need a vector anyways because our code
-  // operates on vectors, even vectors of length one.
-  SPIRVType *VResType = ResType;
+  // 3. split result vector into high bits and low bits
+  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(ResType));
+
+  bool ZeroAsNull = STI.isOpenCLEnv();
   bool isScalarRes = ResType->getOpcode() != SPIRV::OpTypeVector;
-  if (isScalarRes)
-    VResType = GR.getOrCreateSPIRVVectorType(ResType, count, MIRBuilder);
-  // count should be one.
+  if (isScalarRes) {
+    // if scalar do a vector extract
+    Result &= selectNAryOpWithSrcs(
+        HighReg, ResType, I,
+        {FBHReg, GR.getOrCreateConstInt(0, I, ResType, TII, ZeroAsNull)},
+        SPIRV::OpVectorExtractDynamic);
+    Result &= selectNAryOpWithSrcs(
+        LowReg, ResType, I,
+        {FBHReg, GR.getOrCreateConstInt(1, I, ResType, TII, ZeroAsNull)},
+        SPIRV::OpVectorExtractDynamic);
+  } else { // vector case do a shufflevector
+    auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                       TII.get(SPIRV::OpVectorShuffle))
+                   .addDef(HighReg)
+                   .addUse(GR.getSPIRVTypeID(ResType))
+                   .addUse(FBHReg)
+                   .addUse(FBHReg);
+    // ^^ this vector will not be selected from; could be empty
+    unsigned j;
+    for (j = 0; j < count * 2; j += 2) {
+      MIB.addImm(j);
+    }
+    Result &= MIB.constrainAllUses(TII, TRI, RBI);
+
+    // get low bits
+    MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                  TII.get(SPIRV::OpVectorShuffle))
+              .addDef(LowReg)
+              .addUse(GR.getSPIRVTypeID(ResType))
+              .addUse(FBHReg)
+              .addUse(FBHReg);
+    // ^^ this vector will not be selected from; could be empty
+    for (j = 1; j < count * 2; j += 2) {
+      MIB.addImm(j);
+    }
+    Result &= MIB.constrainAllUses(TII, TRI, RBI);
+  }
+
+  // 4. check if result of each top 32 bits is == -1
+  SPIRVType *BoolType = GR.getOrCreateSPIRVBoolType(I, TII);
+  if (!isScalarRes)
+    BoolType = GR.getOrCreateSPIRVVectorType(BoolType, count, MIRBuilder);
 
-  Register HighReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                     TII.get(SPIRV::OpVectorShuffle))
-                 .addDef(HighReg)
-                 .addUse(GR.getSPIRVTypeID(VResType))
-                 .addUse(FBHReg)
-                 .addUse(FBHReg);
-  // ^^ this vector will not be selected from; could be empty
-  unsigned j;
-  for (j = 0; j < count * 2; j += 2) {
-    MIB.addImm(j);
-  }
-  Result &= MIB.constrainAllUses(TII, TRI, RBI);
-
-  // get low bits
-  Register LowReg = MRI->createVirtualRegister(GR.getRegClass(VResType));
-  MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                TII.get(SPIRV::OpVectorShuffle))
-            .addDef(LowReg)
-            .addUse(GR.getSPIRVTypeID(VResType))
-            .addUse(FBHReg)
-            .addUse(FBHReg);
-  // ^^ this vector will not be selected from; could be empty
-  for (j = 1; j < count * 2; j += 2) {
-    MIB.addImm(j);
-  }
-  Result &= MIB.constrainAllUses(TII, TRI, RBI);
-
-  SPIRVType *BoolType = GR.getOrCreateSPIRVVectorType(
-      GR.getOrCreateSPIRVBoolType(I, TII), count, MIRBuilder);
   // check if the high bits are == -1;
-  Register NegOneReg = GR.getOrCreateConstVector(-1, I, VResType, TII);
+  Register NegOneReg =
+      GR.getOrCreateConstScalarOrVector(-1, I, ResType, TII, ZeroAsNull);
----------------
spall wrote:

Sure If I do that I can probably delete the helper function I added in SPIRVGlobalRegistry. I thought this style would be less ugly than a big block.

https://github.com/llvm/llvm-project/pull/111082


More information about the llvm-commits mailing list