[llvm] [DX] Support pipeline state masks (PR #66425)
Chris B via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 14 16:44:56 PDT 2023
================
@@ -321,6 +321,68 @@ Error DirectX::PSVRuntimeInfo::parse(uint16_t ShaderKind) {
Current += PSize;
}
+ ArrayRef<uint8_t> OutputVectorCounts = getOutputVectorCounts();
+ uint8_t PatchConstOrPrimVectorCount = getPatchConstOrPrimVectorCount();
+ uint8_t InputVectorCount = getInputVectorCount();
+
+ auto maskDwordSize = [](uint8_t Vector) {
+ return (static_cast<uint32_t>(Vector) + 7) >> 3;
+ };
+
+ auto mapTableSize = [maskDwordSize](uint8_t X, uint8_t Y) {
+ return maskDwordSize(Y) * X * 4;
+ };
+
+ if (usesViewID()) {
+ for (uint32_t I = 0; I < 4; ++I) {
+ // The vector mask is one bit per component and 4 components per vector.
+ // We can compute the number of dwords required by rounding up to the next
+ // multiple of 8.
+ uint32_t NumDwords =
+ maskDwordSize(static_cast<uint32_t>(OutputVectorCounts[I]));
+ size_t NumBytes = NumDwords * sizeof(uint32_t);
+ OutputVectorMasks[I].Data = Data.substr(Current - Data.begin(), NumBytes);
+ Current += NumBytes;
+ }
+
+ if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0) {
+ uint32_t NumDwords = maskDwordSize(PatchConstOrPrimVectorCount);
+ size_t NumBytes = NumDwords * sizeof(uint32_t);
+ PatchOrPrimMasks.Data = Data.substr(Current - Data.begin(), NumBytes);
+ Current += NumBytes;
+ }
+ }
+
+ // Input/Output mapping table
+ for (uint32_t I = 0; I < 4; ++I) {
+ if (InputVectorCount == 0 || OutputVectorCounts[I] == 0)
+ continue;
+ uint32_t NumDwords = mapTableSize(InputVectorCount, OutputVectorCounts[I]);
+ size_t NumBytes = NumDwords * sizeof(uint32_t);
+ InputOutputMap[I].Data = Data.substr(Current - Data.begin(), NumBytes);
+ Current += NumBytes;
+ }
+
+ // Hull shader: Input/Patch mapping table
+ if (ShaderStage == Triple::Hull && PatchConstOrPrimVectorCount > 0 &&
+ InputVectorCount > 0) {
+ uint32_t NumDwords =
+ mapTableSize(InputVectorCount, PatchConstOrPrimVectorCount);
+ size_t NumBytes = NumDwords * sizeof(uint32_t);
+ InputPatchMap.Data = Data.substr(Current - Data.begin(), NumBytes);
+ Current += NumBytes;
+ }
+
+ // Domain Shader: Patch/Output mapping table
+ if (ShaderStage == Triple::Domain && PatchConstOrPrimVectorCount > 0 &&
+ OutputVectorCounts[0] > 0) {
----------------
llvm-beanz wrote:
I like the refactoring idea here. I could also make an accessor that does the check and returns 0 for the incorrect stages. It can't be an error if the value is non-zero because the value is in a union that stores other data in other stages.
I would very much like to reconsider how all of this is encoded in the future because it goes to great lengths to save a few bits here and there, but is _really_ complex as a result.
https://github.com/llvm/llvm-project/pull/66425
More information about the llvm-commits
mailing list