Switch to pipeline stream desc for dx12 pipeline creation (#8377)

Co-authored-by: Inner Daemons <magnus.larsson.mn@gmail.com>
Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
Inner Daemons 2025-11-14 16:35:28 -05:00 committed by GitHub
parent 71820eef20
commit 92fa99af1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 426 additions and 198 deletions

View File

@ -26,8 +26,9 @@ use crate::{
dxgi::{name::ObjectExt as _, result::HResult as _},
},
dx12::{
borrow_optional_interface_temporarily, shader_compilation, suballocation, DCompLib,
DynamicStorageBufferOffsets, Event, ShaderCacheKey, ShaderCacheValue,
borrow_optional_interface_temporarily, pipeline_desc::RenderPipelineStateStreamDesc,
shader_compilation, suballocation, DCompLib, DynamicStorageBufferOffsets, Event,
ShaderCacheKey, ShaderCacheValue,
},
AccelerationStructureEntries, TlasInstance,
};
@ -1866,8 +1867,6 @@ impl crate::Device for super::Device {
>,
) -> Result<super::RenderPipeline, crate::PipelineError> {
let mut shader_stages = wgt::ShaderStages::empty();
let root_signature =
unsafe { borrow_optional_interface_temporarily(&desc.layout.shared.signature) };
let (topology_class, topology) = conv::map_topology(desc.primitive.topology);
let mut rtv_formats = [Dxgi::Common::DXGI_FORMAT_UNKNOWN;
Direct3D12::D3D12_SIMULTANEOUS_RENDER_TARGET_COUNT as usize];
@ -1907,6 +1906,7 @@ impl crate::Device for super::Device {
Direct3D12::D3D12_CONSERVATIVE_RASTERIZATION_MODE_OFF
},
};
let blob_fs = match desc.fragment_stage {
Some(ref stage) => {
shader_stages |= wgt::ShaderStages::FRAGMENT;
@ -1918,7 +1918,6 @@ impl crate::Device for super::Device {
Some(shader) => shader.create_native_shader(),
None => Direct3D12::D3D12_SHADER_BYTECODE::default(),
};
let mut vertex_strides = [None; crate::MAX_VERTEX_BUFFERS];
let stream_output = Direct3D12::D3D12_STREAM_OUTPUT_DESC {
pSODeclaration: ptr::null(),
NumEntries: 0,
@ -1953,20 +1952,51 @@ impl crate::Device for super::Device {
};
let flags = Direct3D12::D3D12_PIPELINE_STATE_FLAG_NONE;
let raw: Direct3D12::ID3D12PipelineState = match &desc.vertex_processor {
let mut stream_desc = RenderPipelineStateStreamDesc {
// Shared by vertex and mesh pipelines
root_signature: desc.layout.shared.signature.as_ref(),
pixel_shader,
blend_state,
sample_mask: desc.multisample.mask as u32,
rasterizer_state,
depth_stencil_state,
primitive_topology_type: topology_class,
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY {
RTFormats: rtv_formats,
NumRenderTargets: desc.color_targets.len() as u32,
},
dsv_format,
sample_desc,
node_mask: 0,
cached_pso,
flags,
// Optional data that depends on the pipeline type (vertex vs mesh).
vertex_shader: Default::default(),
input_layout: Default::default(),
index_buffer_strip_cut_value: Default::default(),
stream_output,
task_shader: Default::default(),
mesh_shader: Default::default(),
};
let mut input_element_descs = Vec::new();
let blob_vs;
let blob_ts;
let blob_ms;
let mut vertex_strides = [None; crate::MAX_VERTEX_BUFFERS];
match &desc.vertex_processor {
&crate::VertexProcessor::Standard {
vertex_buffers,
ref vertex_stage,
} => {
shader_stages |= wgt::ShaderStages::VERTEX;
let blob_vs = self.load_shader(
blob_vs = Some(self.load_shader(
vertex_stage,
desc.layout,
naga::ShaderStage::Vertex,
desc.fragment_stage.as_ref(),
)?;
)?);
let mut input_element_descs = Vec::new();
for (i, (stride, vbuf)) in vertex_strides.iter_mut().zip(vertex_buffers).enumerate()
{
*stride = Some(vbuf.array_stride as u32);
@ -1990,54 +2020,37 @@ impl crate::Device for super::Device {
});
}
}
let raw_desc = Direct3D12::D3D12_GRAPHICS_PIPELINE_STATE_DESC {
pRootSignature: root_signature,
VS: blob_vs.create_native_shader(),
PS: pixel_shader,
GS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
DS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
HS: Direct3D12::D3D12_SHADER_BYTECODE::default(),
StreamOutput: stream_output,
BlendState: blend_state,
SampleMask: desc.multisample.mask as u32,
RasterizerState: rasterizer_state,
DepthStencilState: depth_stencil_state,
InputLayout: Direct3D12::D3D12_INPUT_LAYOUT_DESC {
pInputElementDescs: if input_element_descs.is_empty() {
ptr::null()
} else {
input_element_descs.as_ptr()
},
NumElements: input_element_descs.len() as u32,
stream_desc.vertex_shader = blob_vs.as_ref().unwrap().create_native_shader();
stream_desc.input_layout = Direct3D12::D3D12_INPUT_LAYOUT_DESC {
pInputElementDescs: if input_element_descs.is_empty() {
ptr::null()
} else {
input_element_descs.as_ptr()
},
IBStripCutValue: match desc.primitive.strip_index_format {
Some(wgt::IndexFormat::Uint16) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
}
Some(wgt::IndexFormat::Uint32) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
}
None => Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
},
PrimitiveTopologyType: topology_class,
NumRenderTargets: desc.color_targets.len() as u32,
RTVFormats: rtv_formats,
DSVFormat: dsv_format,
SampleDesc: sample_desc,
NodeMask: 0,
CachedPSO: cached_pso,
Flags: flags,
NumElements: input_element_descs.len() as u32,
};
stream_desc.index_buffer_strip_cut_value = match desc.primitive.strip_index_format {
Some(wgt::IndexFormat::Uint16) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFF
}
Some(wgt::IndexFormat::Uint32) => {
Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_0xFFFFFFFF
}
None => Direct3D12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
};
stream_desc.stream_output = Direct3D12::D3D12_STREAM_OUTPUT_DESC {
pSODeclaration: ptr::null(),
NumEntries: 0,
pBufferStrides: ptr::null(),
NumStrides: 0,
RasterizedStream: 0,
};
unsafe {
profiling::scope!("ID3D12Device::CreateGraphicsPipelineState");
self.raw.CreateGraphicsPipelineState(&raw_desc)
}
}
crate::VertexProcessor::Mesh {
task_stage,
mesh_stage,
} => {
let blob_ts = if let Some(ts) = task_stage {
blob_ts = if let Some(ts) = task_stage {
shader_stages |= wgt::ShaderStages::TASK;
Some(self.load_shader(
ts,
@ -2054,48 +2067,36 @@ impl crate::Device for super::Device {
Default::default()
};
shader_stages |= wgt::ShaderStages::MESH;
let blob_ms = self.load_shader(
blob_ms = Some(self.load_shader(
mesh_stage,
desc.layout,
naga::ShaderStage::Mesh,
desc.fragment_stage.as_ref(),
)?;
let desc = super::MeshShaderPipelineStateStream {
root_signature: root_signature
.as_ref()
.map(|a| a.as_raw().cast())
.unwrap_or(ptr::null_mut()),
task_shader,
pixel_shader,
mesh_shader: blob_ms.create_native_shader(),
blend_state,
sample_mask: desc.multisample.mask as u32,
rasterizer_state,
depth_stencil_state,
primitive_topology_type: topology_class,
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY {
RTFormats: rtv_formats,
NumRenderTargets: desc.color_targets.len() as u32,
},
dsv_format,
sample_desc,
node_mask: 0,
cached_pso,
flags,
};
let mut raw_desc = unsafe { desc.to_bytes() };
let stream_desc = Direct3D12::D3D12_PIPELINE_STATE_STREAM_DESC {
SizeInBytes: raw_desc.len(),
pPipelineStateSubobjectStream: raw_desc.as_mut_ptr().cast(),
};
let device: Direct3D12::ID3D12Device2 = self.raw.cast().unwrap();
)?);
stream_desc.task_shader = task_shader;
stream_desc.mesh_shader = blob_ms.as_ref().unwrap().create_native_shader();
}
};
let raw: Direct3D12::ID3D12PipelineState =
// If stream descriptors are available, use them as they are more flexible.
if let Ok(device) = self.raw.cast::<Direct3D12::ID3D12Device2>() {
// Prefer stream descs where possible
let mut stream = stream_desc.to_stream();
unsafe {
profiling::scope!("ID3D12Device2::CreatePipelineState");
device.CreatePipelineState(&stream_desc)
stream.create_pipeline_state(&device).map_err(|err| {
crate::PipelineError::Linkage(shader_stages, err.to_string())
})?
}
}
}
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;
} else {
unsafe {
// Safety: `stream_desc` entirely outlives the `desc`.
let desc = stream_desc.to_graphics_pipeline_descriptor();
self.raw.CreateGraphicsPipelineState(&desc).map_err(|err| {
crate::PipelineError::Linkage(shader_stages, err.to_string())
})?
}
};
if let Some(label) = desc.label {
raw.set_name(label)?;

View File

@ -79,6 +79,7 @@ mod dcomp;
mod descriptor;
mod device;
mod instance;
mod pipeline_desc;
mod sampler;
mod shader_compilation;
mod suballocation;
@ -1615,116 +1616,3 @@ pub enum ShaderModuleSource {
DxilPassthrough(DxilPassthroughShader),
HlslPassthrough(HlslPassthroughShader),
}
#[repr(C)]
#[derive(Debug)]
struct MeshShaderPipelineStateStream {
root_signature: *mut Direct3D12::ID3D12RootSignature,
task_shader: Direct3D12::D3D12_SHADER_BYTECODE,
mesh_shader: Direct3D12::D3D12_SHADER_BYTECODE,
pixel_shader: Direct3D12::D3D12_SHADER_BYTECODE,
blend_state: Direct3D12::D3D12_BLEND_DESC,
sample_mask: u32,
rasterizer_state: Direct3D12::D3D12_RASTERIZER_DESC,
depth_stencil_state: Direct3D12::D3D12_DEPTH_STENCIL_DESC,
primitive_topology_type: Direct3D12::D3D12_PRIMITIVE_TOPOLOGY_TYPE,
rtv_formats: Direct3D12::D3D12_RT_FORMAT_ARRAY,
dsv_format: Dxgi::Common::DXGI_FORMAT,
sample_desc: Dxgi::Common::DXGI_SAMPLE_DESC,
node_mask: u32,
cached_pso: Direct3D12::D3D12_CACHED_PIPELINE_STATE,
flags: Direct3D12::D3D12_PIPELINE_STATE_FLAGS,
}
impl MeshShaderPipelineStateStream {
/// # Safety
///
/// Returned bytes contain pointers into this struct, for them to be valid,
/// this struct may be at the same location. As if `as_bytes<'a>(&'a self) -> Vec<u8> + 'a`
pub unsafe fn to_bytes(&self) -> Vec<u8> {
use Direct3D12::*;
let mut bytes = Vec::new();
macro_rules! push_subobject {
($subobject_type:expr, $data:expr) => {{
// Ensure 8-byte alignment for the subobject start
let alignment = 8;
let aligned_length = bytes.len().next_multiple_of(alignment);
bytes.resize(aligned_length, 0);
// Append the type tag (u32)
let tag: u32 = $subobject_type.0 as u32;
bytes.extend_from_slice(&tag.to_ne_bytes());
// Align the data
let obj_align = align_of_val(&$data);
let data_start = bytes.len().next_multiple_of(obj_align);
bytes.resize(data_start, 0);
// Append the data itself
#[allow(clippy::ptr_as_ptr, trivial_casts)]
let data_ptr = &$data as *const _ as *const u8;
let data_size = size_of_val(&$data);
let slice = unsafe { core::slice::from_raw_parts(data_ptr, data_size) };
bytes.extend_from_slice(slice);
}};
}
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE,
self.root_signature
);
if !self.task_shader.pShaderBytecode.is_null() {
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_AS, self.task_shader);
}
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MS, self.mesh_shader);
if !self.pixel_shader.pShaderBytecode.is_null() {
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS, self.pixel_shader);
}
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND, self.blend_state);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK,
self.sample_mask
);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER,
self.rasterizer_state
);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL,
self.depth_stencil_state
);
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY,
self.primitive_topology_type
);
if self.rtv_formats.NumRenderTargets != 0 {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS,
self.rtv_formats
);
}
if self.dsv_format != Dxgi::Common::DXGI_FORMAT_UNKNOWN {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT,
self.dsv_format
);
}
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC,
self.sample_desc
);
if self.node_mask != 0 {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK,
self.node_mask
);
}
if !self.cached_pso.pCachedBlob.is_null() {
push_subobject!(
D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO,
self.cached_pso
);
}
push_subobject!(D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS, self.flags);
bytes
}
}

View File

@ -0,0 +1,339 @@
//! We try to use pipeline stream descriptors where possible, but this isn't allowed
//! on some older windows 10 versions. Therefore, we also must have some logic to
//! convert such descriptors to the "traditional" equivalent,
//! `D3D12_GRAPHICS_PIPELINE_STATE_DESC`.
//!
//! Stream descriptors allow extending the pipeline, enabling more advanced features,
//! including mesh shaders and multiview/view instancing. Using a stream descriptor
//! is like using a vulkan descriptor with a `pNext` chain. It doesn't have direct
//! benefits to all use cases, but allows new use cases.
//!
//! The code for pipeline stream descriptors is very complicated, and can have bad
//! consequences if it is written incorrectly. It has been isolated to this file for
//! that reason.
use core::{ffi::c_void, mem::ManuallyDrop, ptr::NonNull};
use alloc::vec::Vec;
use windows::Win32::Graphics::Direct3D12::*;
use windows::Win32::Graphics::Dxgi::Common::*;
use windows_core::Interface;
use crate::dx12::borrow_interface_temporarily;
// Wrapper newtypes for various pipeline subobjects which
// use complicated or non-unique representations.
#[repr(transparent)]
#[derive(Copy, Clone)]
// Option<NonNull<c_void>> is guaranteed to have the same representation as a raw pointer.
struct RootSignature(Option<NonNull<c_void>>);
#[repr(transparent)]
#[derive(Copy, Clone)]
struct VertexShader(D3D12_SHADER_BYTECODE);
#[repr(transparent)]
#[derive(Copy, Clone)]
struct PixelShader(D3D12_SHADER_BYTECODE);
#[repr(transparent)]
#[derive(Copy, Clone)]
struct MeshShader(D3D12_SHADER_BYTECODE);
#[repr(transparent)]
#[derive(Copy, Clone)]
struct TaskShader(D3D12_SHADER_BYTECODE);
#[repr(transparent)]
#[derive(Copy, Clone)]
struct SampleMask(u32);
#[repr(transparent)]
#[derive(Copy, Clone)]
struct NodeMask(u32);
/// Trait for types that can be used as subobjects in a pipeline state stream.
///
/// Safety:
/// - The type must be the correct alignment and size for the subobject it represents.
/// - The type must map to exactly one `D3D12_PIPELINE_STATE_SUBOBJECT_TYPE` variant.
/// - The variant must correctly represent the type's role in the pipeline state stream.
/// - The type must be `Copy` to ensure safe duplication in the stream.
/// - The type must be valid to memcpy into the pipeline state stream.
unsafe trait RenderPipelineStreamObject: Copy {
const SUBOBJECT_TYPE: D3D12_PIPELINE_STATE_SUBOBJECT_TYPE;
}
macro_rules! implement_stream_object {
(unsafe $ty:ty => $variant:expr) => {
unsafe impl RenderPipelineStreamObject for $ty {
const SUBOBJECT_TYPE: D3D12_PIPELINE_STATE_SUBOBJECT_TYPE = $variant;
}
};
}
implement_stream_object! { unsafe RootSignature => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_ROOT_SIGNATURE }
implement_stream_object! { unsafe VertexShader => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_VS }
implement_stream_object! { unsafe PixelShader => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PS }
implement_stream_object! { unsafe MeshShader => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_MS }
implement_stream_object! { unsafe TaskShader => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_AS }
implement_stream_object! { unsafe D3D12_BLEND_DESC => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_BLEND }
implement_stream_object! { unsafe SampleMask => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_MASK }
implement_stream_object! { unsafe D3D12_RASTERIZER_DESC => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RASTERIZER }
implement_stream_object! { unsafe D3D12_DEPTH_STENCIL_DESC => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL }
implement_stream_object! { unsafe D3D12_PRIMITIVE_TOPOLOGY_TYPE => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_PRIMITIVE_TOPOLOGY }
implement_stream_object! { unsafe D3D12_RT_FORMAT_ARRAY => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_RENDER_TARGET_FORMATS }
implement_stream_object! { unsafe DXGI_FORMAT => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_DEPTH_STENCIL_FORMAT }
implement_stream_object! { unsafe DXGI_SAMPLE_DESC => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_SAMPLE_DESC }
implement_stream_object! { unsafe NodeMask => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_NODE_MASK }
implement_stream_object! { unsafe D3D12_CACHED_PIPELINE_STATE => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_CACHED_PSO }
implement_stream_object! { unsafe D3D12_PIPELINE_STATE_FLAGS => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_FLAGS }
implement_stream_object! { unsafe D3D12_INPUT_LAYOUT_DESC => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_INPUT_LAYOUT }
implement_stream_object! { unsafe D3D12_INDEX_BUFFER_STRIP_CUT_VALUE => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_IB_STRIP_CUT_VALUE }
implement_stream_object! { unsafe D3D12_STREAM_OUTPUT_DESC => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE_STREAM_OUTPUT }
/// Implementaation of a pipeline state stream, which is a sequence of subobjects put into
/// a byte array according to some basic alignment rules.
///
/// Each subobject must start on an 8 byte boundary. Each subobject contains a 32 bit
/// type identifier, followed by the actual subobject data, aligned as required by the
/// subobject's structure.
///
/// See <https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ns-d3d12-d3d12_pipeline_state_stream_desc>
/// for more information.
pub(super) struct RenderPipelineStateStream<'a> {
bytes: Vec<u8>,
_marker: core::marker::PhantomData<&'a ()>,
}
impl<'a> RenderPipelineStateStream<'a> {
fn new() -> Self {
// Dynamic allocation is used here because the resulting stream can become very large.
// We pre-allocate the size based on an estimate of the size of the struct plus some extra space
// per member for tags and alignment padding. In practice this will always be too big, as not
// all members will be used.
let size_of_stream_desc = size_of::<RenderPipelineStateStreamDesc>();
let members = 20; // Approximate number of members we might push
let capacity = size_of_stream_desc + members * 8; // Extra space for tags and alignment
Self {
bytes: Vec::with_capacity(capacity),
_marker: core::marker::PhantomData,
}
}
/// Align the internal byte buffer to the given alignment,
/// padding with zeros as necessary.
fn align_to(&mut self, alignment: usize) {
let aligned_length = self.bytes.len().next_multiple_of(alignment);
self.bytes.resize(aligned_length, 0);
}
/// Adds a subobject to the pipeline state stream.
fn add_object<T: RenderPipelineStreamObject>(&mut self, object: T) {
// Ensure 8-byte alignment for the subobject start.
self.align_to(8);
// Append the type tag (u32)
let tag: u32 = T::SUBOBJECT_TYPE.0 as u32;
self.bytes.extend_from_slice(&tag.to_ne_bytes());
// Align the data to its natural alignment.
self.align_to(align_of_val::<T>(&object));
// Append the data itself, as raw bytes
let data_ptr: *const T = &object;
let data_u8_ptr: *const u8 = data_ptr.cast::<u8>();
let data_size = size_of_val::<T>(&object);
let slice = unsafe { core::slice::from_raw_parts::<u8>(data_u8_ptr, data_size) };
self.bytes.extend_from_slice(slice);
}
/// Creates a pipeline state object from the stream.
///
/// Safety:
/// - All unsafety invariants required by [`ID3D12Device2::CreatePipelineState`] must be upheld by the caller.
pub unsafe fn create_pipeline_state(
&mut self,
device: &ID3D12Device2,
) -> windows::core::Result<ID3D12PipelineState> {
let stream_desc = D3D12_PIPELINE_STATE_STREAM_DESC {
SizeInBytes: self.bytes.len(),
pPipelineStateSubobjectStream: self.bytes.as_mut_ptr().cast(),
};
// Safety: lifetime on Self preserved the contents
// of the stream. Other unsafety invariants are upheld by the caller.
unsafe { device.CreatePipelineState(&stream_desc) }
}
}
#[repr(C)]
#[derive(Debug)]
pub struct RenderPipelineStateStreamDesc<'a> {
pub root_signature: Option<&'a ID3D12RootSignature>,
pub pixel_shader: D3D12_SHADER_BYTECODE,
pub blend_state: D3D12_BLEND_DESC,
pub sample_mask: u32,
pub rasterizer_state: D3D12_RASTERIZER_DESC,
pub depth_stencil_state: D3D12_DEPTH_STENCIL_DESC,
pub primitive_topology_type: D3D12_PRIMITIVE_TOPOLOGY_TYPE,
pub rtv_formats: D3D12_RT_FORMAT_ARRAY,
pub dsv_format: DXGI_FORMAT,
pub sample_desc: DXGI_SAMPLE_DESC,
pub node_mask: u32,
pub cached_pso: D3D12_CACHED_PIPELINE_STATE,
pub flags: D3D12_PIPELINE_STATE_FLAGS,
// Vertex pipeline specific
pub vertex_shader: D3D12_SHADER_BYTECODE,
pub input_layout: D3D12_INPUT_LAYOUT_DESC,
pub index_buffer_strip_cut_value: D3D12_INDEX_BUFFER_STRIP_CUT_VALUE,
pub stream_output: D3D12_STREAM_OUTPUT_DESC,
// Mesh pipeline specific
pub task_shader: D3D12_SHADER_BYTECODE,
pub mesh_shader: D3D12_SHADER_BYTECODE,
}
impl RenderPipelineStateStreamDesc<'_> {
pub fn to_stream(&self) -> RenderPipelineStateStream<'_> {
let mut stream = RenderPipelineStateStream::new();
// Importantly here, the ID3D12RootSignature _itself_ is the pointer we're
// trying to serialize into the stream, not a pointer to the pointer.
//
// This is correct because as_raw() returns turns that smart object into the raw
// pointer that _is_ the com object handle.
let root_sig_pointer = self
.root_signature
.map(|a| NonNull::new(a.as_raw()).unwrap());
// Because the stream object borrows from self for its entire lifetime,
// it is safe to store the pointer into it.
stream.add_object(RootSignature(root_sig_pointer));
stream.add_object(self.blend_state);
stream.add_object(SampleMask(self.sample_mask));
stream.add_object(self.rasterizer_state);
stream.add_object(self.depth_stencil_state);
stream.add_object(self.primitive_topology_type);
if self.rtv_formats.NumRenderTargets != 0 {
stream.add_object(self.rtv_formats);
}
if self.dsv_format != DXGI_FORMAT_UNKNOWN {
stream.add_object(self.dsv_format);
}
stream.add_object(self.sample_desc);
if self.node_mask != 0 {
stream.add_object(NodeMask(self.node_mask));
}
if !self.cached_pso.pCachedBlob.is_null() {
stream.add_object(self.cached_pso);
}
stream.add_object(self.flags);
if !self.pixel_shader.pShaderBytecode.is_null() {
stream.add_object(PixelShader(self.pixel_shader));
}
if !self.vertex_shader.pShaderBytecode.is_null() {
stream.add_object(VertexShader(self.vertex_shader));
stream.add_object(self.input_layout);
stream.add_object(self.index_buffer_strip_cut_value);
stream.add_object(self.stream_output);
}
if !self.task_shader.pShaderBytecode.is_null() {
stream.add_object(TaskShader(self.task_shader));
}
if !self.mesh_shader.pShaderBytecode.is_null() {
stream.add_object(MeshShader(self.mesh_shader));
}
stream
}
/// Returns a traditional D3D12_GRAPHICS_PIPELINE_STATE_DESC.
///
/// Safety:
/// - This returned struct must not outlive self.
pub unsafe fn to_graphics_pipeline_descriptor(&self) -> D3D12_GRAPHICS_PIPELINE_STATE_DESC {
D3D12_GRAPHICS_PIPELINE_STATE_DESC {
pRootSignature: if let Some(rsig) = self.root_signature {
unsafe { borrow_interface_temporarily(rsig) }
} else {
ManuallyDrop::new(None)
},
VS: self.vertex_shader,
PS: self.pixel_shader,
DS: D3D12_SHADER_BYTECODE::default(),
HS: D3D12_SHADER_BYTECODE::default(),
GS: D3D12_SHADER_BYTECODE::default(),
StreamOutput: self.stream_output,
BlendState: self.blend_state,
SampleMask: self.sample_mask,
RasterizerState: self.rasterizer_state,
DepthStencilState: self.depth_stencil_state,
InputLayout: self.input_layout,
IBStripCutValue: self.index_buffer_strip_cut_value,
PrimitiveTopologyType: self.primitive_topology_type,
NumRenderTargets: self.rtv_formats.NumRenderTargets,
RTVFormats: self.rtv_formats.RTFormats,
DSVFormat: self.dsv_format,
SampleDesc: self.sample_desc,
NodeMask: self.node_mask,
CachedPSO: self.cached_pso,
Flags: self.flags,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wrappers() {
assert_eq!(size_of::<RootSignature>(), size_of::<ID3D12RootSignature>());
assert_eq!(
align_of::<RootSignature>(),
align_of::<ID3D12RootSignature>()
)
}
implement_stream_object!(unsafe u16 => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE(1));
implement_stream_object!(unsafe u32 => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE(2));
implement_stream_object!(unsafe u64 => D3D12_PIPELINE_STATE_SUBOBJECT_TYPE(3));
#[test]
fn stream() {
let mut stream = RenderPipelineStateStream::new();
stream.add_object(42u16);
stream.add_object(84u32);
stream.add_object(168u64);
assert_eq!(stream.bytes.len(), 32);
// Object 1: u16
// Tag at the beginning
assert_eq!(&stream.bytes[0..4], &1u32.to_ne_bytes());
// Data tucked in, aligned to the natural alignment of u16
assert_eq!(&stream.bytes[4..6], &42u16.to_ne_bytes());
// Padding to align the next subobject to an 8 byte boundary.
assert_eq!(&stream.bytes[6..8], &[0, 0]);
// Object 2: u32
// Tag at the beginning
assert_eq!(&stream.bytes[8..12], &2u32.to_ne_bytes());
// Data tucked in, aligned to the natural alignment of u32
assert_eq!(&stream.bytes[12..16], &84u32.to_ne_bytes());
// Object 3: u64
// Tag at the beginning
assert_eq!(&stream.bytes[16..20], &3u32.to_ne_bytes());
// Padding to align the u64 to an 8 byte boundary.
assert_eq!(&stream.bytes[20..24], &[0, 0, 0, 0]);
// Data tucked in, aligned to the natural alignment of u64
assert_eq!(&stream.bytes[24..32], &168u64.to_ne_bytes());
}
}