[llvm] [DX] Support pipeline state masks (PR #66425)

Chris B via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 14 16:44:56 PDT 2023


================
@@ -321,6 +321,68 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
     Current += PSize;
   }
 
+  ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts();
+  uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount();
+  uint8_t InputVectorCount = getInputVectorCount();
+
+  auto maskDwordSize = [](uint8_t Vector) {
+    return (static_cast<uint32_t>(Vector) + 7) >> 3;
+  };
+
+  auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) {
+    return maskDwordSize(Y) * X * 4;
+  };
+
+  if (usesViewID()) {
+    for (uint32_t I = 0; I < 4; ++I) {
+      // The vector mask is one bit per component and 4 components per vector.
+      // We can compute the number of dwords required by rounding up to the next
+      // multiple of 8.
+      uint32_t NumDwords =
+          maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I]));
+      size_t NumBytes = NumDwords * sizeof(uint32_t);
+      OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes);
+      Current += NumBytes;
+    }
+
+    if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) {
+      uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount);
+      size_t NumBytes = NumDwords * sizeof(uint32_t);
+      PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes);
+      Current += NumBytes;
+    }
+  }
+
+  // Input/Output mapping table
+  for (uint32_t I = 0; I < 4; ++I) {
+    if (InputVectorCount == 0 || OutputVectorCounts[I] == 0)
+      continue;
+    uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]);
+    size_t NumBytes = NumDwords * sizeof(uint32_t);
+    InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes);
+    Current += NumBytes;
+  }
+
+  // Hull shader: Input/Patch mapping table
+  if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 &&
+      InputVectorCount > 0) {
+    uint32_t NumDwords =
+        mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount);
+    size_t NumBytes = NumDwords * sizeof(uint32_t);
+    InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes);
+    Current += NumBytes;
+  }
+
+  // Domain Shader: Patch/Output mapping table
+  if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&
+      OutputVectorCounts[0] > 0) {
----------------
llvm-beanz wrote:

I like the refactoring idea here. I could also make an accessor that does the check and returns 0 for the incorrect stages. It can't be an error if the value is non-zero because the value is in a union that stores other data in other stages.

I would very much like to reconsider how all of this is encoded in the future because it goes to great lengths to save a few bits here and there, but is _really_ complex as a result.

https://github.com/llvm/llvm-project/pull/66425


More information about the llvm-commits mailing list