Unify rebinding resources.

This commit is contained in:
Vecvec 2025-06-22 18:51:15 +12:00 committed by Teodor Tanasoaia
parent 4d36a02691
commit acd8cb18bb
3 changed files with 97 additions and 113 deletions

View File

@ -13,9 +13,7 @@ use crate::{
bind::{Binder, BinderError},
compute_command::ArcComputeCommand,
end_pipeline_statistics_query,
memory_init::{
fixup_discarded_surfaces, SurfacesInDiscardState,
},
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr,
PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
@ -723,7 +721,8 @@ fn set_pipeline(
.general
.tracker
.compute_pipelines
.insert_single(pipeline);
.insert_single(pipeline)
.clone();
unsafe {
state
@ -733,67 +732,29 @@ fn set_pipeline(
}
// Rebind resources
if state.general.binder.pipeline_layout.is_none()
|| !state
.general
.binder
.pipeline_layout
.as_ref()
.unwrap()
.is_equal(&pipeline.layout)
{
let (start_index, entries) = state
.general
.binder
.change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.general.snatch_guard)?;
unsafe {
state.general.raw_encoder.set_bind_group(
pipeline.layout.raw(),
start_index as u32 + i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
pass::rebind_resources::<ComputePassErrorInner, _>(
&mut state.general,
&pipeline.layout,
wgt::ShaderStages::COMPUTE,
&pipeline.late_sized_buffer_groups,
|| {
// This only needs to be here for compute pipelines because they use push constants for
// validating indirect draws.
state.push_constants.clear();
// Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
if let Some(push_constant_range) =
pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
pcr.stages
.contains(wgt::ShaderStages::COMPUTE)
.then_some(pcr.range.clone())
})
{
// Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
state.push_constants.extend(core::iter::repeat_n(0, len));
}
}
// TODO: integrate this in the code below once we simplify push constants
state.push_constants.clear();
// Note that can only be one range for each stage. See the `MoreThanOnePushConstantRangePerStage` error.
if let Some(push_constant_range) =
pipeline.layout.push_constant_ranges.iter().find_map(|pcr| {
pcr.stages
.contains(wgt::ShaderStages::COMPUTE)
.then_some(pcr.range.clone())
})
{
// Note that non-0 range start doesn't work anyway https://github.com/gfx-rs/wgpu/issues/4502
let len = push_constant_range.len() / wgt::PUSH_CONSTANT_ALIGNMENT as usize;
state.push_constants.extend(core::iter::repeat_n(0, len));
}
// Clear push constant ranges
let non_overlapping =
super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
for range in non_overlapping {
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
state.general.raw_encoder.set_push_constants(
pipeline.layout.raw(),
wgt::ShaderStages::COMPUTE,
clear_offset,
clear_data,
);
});
}
}
Ok(())
},
)
}
fn set_push_constant(
@ -1047,8 +1008,7 @@ fn dispatch_indirect(
.merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
use crate::resource::Trackable;
state
.flush_states(Some(buffer.tracker_index()))?;
state.flush_states(Some(buffer.tracker_index()))?;
let buf_raw = buffer.try_raw(state.general.snatch_guard)?;
unsafe {

View File

@ -1,16 +1,17 @@
//! Generic pass functions that both compute and render passes need.
use crate::api_log;
use crate::binding_model::{BindError, BindGroup};
use crate::command::bind::Binder;
use crate::command::memory_init::{CommandBufferTextureMemoryActions, SurfacesInDiscardState};
use crate::command::CommandBuffer;
use crate::device::{Device, DeviceError};
use crate::init_tracker::BufferInitTrackerAction;
use crate::pipeline::LateSizedBufferGroup;
use crate::ray_tracing::AsAction;
use crate::resource::{DestroyedResourceError, Labeled, ParentDevice};
use crate::snatch::SnatchGuard;
use crate::track::{ResourceUsageCompatibilityError, Tracker, UsageScope};
use crate::{api_log, binding_model};
use alloc::sync::Arc;
use alloc::vec::Vec;
use thiserror::Error;
@ -157,3 +158,63 @@ where
}
Ok(())
}
/// After a pipeline has been changed, resources must be rebound
pub(crate) fn rebind_resources<E, F: FnOnce()>(
state: &mut BaseState,
pipeline_layout: &Arc<binding_model::PipelineLayout>,
stages: wgt::ShaderStages,
late_sized_buffer_groups: &[LateSizedBufferGroup],
f: F,
) -> Result<(), E>
where
E: From<DestroyedResourceError>,
{
if state.binder.pipeline_layout.is_none()
|| !state
.binder
.pipeline_layout
.as_ref()
.unwrap()
.is_equal(pipeline_layout)
{
let (start_index, entries) = state
.binder
.change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.snatch_guard)?;
unsafe {
state.raw_encoder.set_bind_group(
pipeline_layout.raw(),
start_index as u32 + i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
}
}
f();
let non_overlapping =
super::bind::compute_nonoverlapping_ranges(&pipeline_layout.push_constant_ranges);
// Clear push constant ranges
for range in non_overlapping {
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
state.raw_encoder.set_push_constants(
pipeline_layout.raw(),
stages,
clear_offset,
clear_data,
);
});
}
}
Ok(())
}

View File

@ -2194,7 +2194,8 @@ fn set_pipeline(
.general
.tracker
.render_pipelines
.insert_single(pipeline);
.insert_single(pipeline)
.clone();
pipeline.same_device_as(cmd_buf.as_ref())?;
@ -2234,51 +2235,13 @@ fn set_pipeline(
}
// Rebind resource
if state.general.binder.pipeline_layout.is_none()
|| !state
.general
.binder
.pipeline_layout
.as_ref()
.unwrap()
.is_equal(&pipeline.layout)
{
let (start_index, entries) = state
.general
.binder
.change_pipeline_layout(&pipeline.layout, &pipeline.late_sized_buffer_groups);
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.general.snatch_guard)?;
unsafe {
state.general.raw_encoder.set_bind_group(
pipeline.layout.raw(),
start_index as u32 + i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
}
}
// Clear push constant ranges
let non_overlapping =
super::bind::compute_nonoverlapping_ranges(&pipeline.layout.push_constant_ranges);
for range in non_overlapping {
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
state.general.raw_encoder.set_push_constants(
pipeline.layout.raw(),
range.stages,
clear_offset,
clear_data,
);
});
}
}
pass::rebind_resources::<RenderPassErrorInner, _>(
&mut state.general,
&pipeline.layout,
ShaderStages::VERTEX_FRAGMENT,
&pipeline.late_sized_buffer_groups,
|| {},
)?;
// Update vertex buffer limits.
state.vertex.update_limits(&pipeline.vertex_steps);