Defer some bind group processing until draw/dispatch

This allows us to save some work when a bind group is replaced,
although WebGPU still requires us to fail a submission if any
resource in any bind group that was ever referenced during
encoding is destroyed.

Fixes #8399
This commit is contained in:
Andy Leiserson 2025-10-20 16:39:29 -07:00 committed by Teodor Tanasoaia
parent e8dfc72d04
commit 3ca6a1c0df
9 changed files with 251 additions and 124 deletions

View File

@ -125,7 +125,13 @@ webgpu:api,validation,render_pass,render_pass_descriptor:color_attachments,depth
// FAIL: webgpu:api,validation,render_pass,render_pass_descriptor:color_attachments,depthSlice,overlaps,diff_miplevel:*
webgpu:api,validation,render_pass,render_pass_descriptor:resolveTarget,*
webgpu:api,validation,render_pass,resolve:resolve_attachment:*
webgpu:api,validation,resource_usages,buffer,in_pass_encoder:*
// FAIL: 8 other cases in resource_usages,texture,in_pass_encoder. https://github.com/gfx-rs/wgpu/issues/3126
webgpu:api,validation,resource_usages,texture,in_pass_encoder:scope,*
webgpu:api,validation,resource_usages,texture,in_pass_encoder:shader_stages_and_visibility,*
webgpu:api,validation,resource_usages,texture,in_pass_encoder:subresources_and_binding_types_combination_for_aspect:*
webgpu:api,validation,resource_usages,texture,in_pass_encoder:subresources_and_binding_types_combination_for_color:compute=false;type0="render-target";type1="render-target"
webgpu:api,validation,resource_usages,texture,in_pass_encoder:unused_bindings_in_pipeline:*
webgpu:api,validation,texture,rg11b10ufloat_renderable:*
webgpu:api,operation,render_pipeline,overrides:*
webgpu:api,operation,rendering,basic:clear:*

View File

@ -1,3 +1,5 @@
use core::{iter::zip, ops::Range};
use alloc::{boxed::Box, sync::Arc, vec::Vec};
use arrayvec::ArrayVec;
@ -192,14 +194,16 @@ mod compat {
}
#[derive(Debug, Default)]
pub(crate) struct BoundBindGroupLayouts {
pub(super) struct BoundBindGroupLayouts {
entries: ArrayVec<Entry, { hal::MAX_BIND_GROUPS }>,
rebind_start: usize,
}
impl BoundBindGroupLayouts {
pub fn new() -> Self {
Self {
entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
rebind_start: 0,
}
}
@ -211,15 +215,19 @@ mod compat {
.unwrap_or(self.entries.len())
}
fn make_range(&self, start_index: usize) -> Range<usize> {
/// Get the range of entries that needs to be rebound, and clears it.
pub fn take_rebind_range(&mut self) -> Range<usize> {
let end = self.num_valid_entries();
start_index..end.max(start_index)
let start = self.rebind_start;
self.rebind_start = end;
start..end.max(start)
}
pub fn update_expectations(
&mut self,
expectations: &[Arc<BindGroupLayout>],
) -> Range<usize> {
pub fn update_start_index(&mut self, start_index: usize) {
self.rebind_start = self.rebind_start.min(start_index);
}
pub fn update_expectations(&mut self, expectations: &[Arc<BindGroupLayout>]) {
let start_index = self
.entries
.iter()
@ -237,12 +245,12 @@ mod compat {
for e in self.entries[expectations.len()..].iter_mut() {
e.expected = None;
}
self.make_range(start_index)
self.update_start_index(start_index);
}
pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) -> Range<usize> {
pub fn assign(&mut self, index: usize, value: Arc<BindGroupLayout>) {
self.entries[index].assigned = Some(value);
self.make_range(index)
self.update_start_index(index);
}
pub fn list_active(&self) -> impl Iterator<Item = usize> + '_ {
@ -333,10 +341,10 @@ impl Binder {
&'a mut self,
new: &Arc<PipelineLayout>,
late_sized_buffer_groups: &[LateSizedBufferGroup],
) -> (usize, &'a [EntryPayload]) {
) {
let old_id_opt = self.pipeline_layout.replace(new.clone());
let mut bind_range = self.manager.update_expectations(&new.bind_group_layouts);
self.manager.update_expectations(&new.bind_group_layouts);
// Update the buffer binding sizes that are required by shaders.
for (payload, late_group) in self.payloads.iter_mut().zip(late_sized_buffer_groups) {
@ -363,11 +371,9 @@ impl Binder {
if let Some(old) = old_id_opt {
// root constants are the base compatibility property
if old.push_constant_ranges != new.push_constant_ranges {
bind_range.start = 0;
self.manager.update_start_index(0);
}
}
(bind_range.start, &self.payloads[bind_range])
}
pub(super) fn assign_group<'a>(
@ -375,7 +381,7 @@ impl Binder {
index: usize,
bind_group: &Arc<BindGroup>,
offsets: &[wgt::DynamicOffset],
) -> &'a [EntryPayload] {
) {
let payload = &mut self.payloads[index];
payload.group = Some(bind_group.clone());
payload.dynamic_offsets.clear();
@ -401,8 +407,20 @@ impl Binder {
}
}
let bind_range = self.manager.assign(index, bind_group.layout.clone());
&self.payloads[bind_range]
self.manager.assign(index, bind_group.layout.clone());
}
/// Get the range of entries that needs to be rebound, and clears it.
pub(super) fn take_rebind_range(&mut self) -> Range<usize> {
self.manager.take_rebind_range()
}
pub(super) fn entries(
&self,
range: Range<usize>,
) -> impl ExactSizeIterator<Item = (usize, &'_ EntryPayload)> + '_ {
let payloads = &self.payloads[range.clone()];
zip(range, payloads)
}
pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup>> + 'a {

View File

@ -7,7 +7,10 @@ use wgt::{
use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
use core::{convert::Infallible, fmt, str};
use crate::{api_log, binding_model::BindError, resource::RawResourceAccess};
use crate::{
api_log, binding_model::BindError, command::pass::flush_bindings_helper,
resource::RawResourceAccess,
};
use crate::{
binding_model::{LateMinBufferBindingSizeMismatch, PushConstantUploadError},
command::{
@ -280,27 +283,68 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
}
}
// `extra_buffer` is there to represent the indirect buffer that is also
// part of the usage scope.
fn flush_states(
/// Flush binding state in preparation for a dispatch.
///
/// # Differences between render and compute passes
///
/// There are differences between the `flush_bindings` implementations for
/// render and compute passes, because render passes have a single usage
/// scope for the entire pass, and compute passes have a separate usage
/// scope for each dispatch.
///
/// For compute passes, bind groups are merged into a fresh usage scope
/// here, not into the pass usage scope within calls to `set_bind_group`. As
/// specified by WebGPU, for compute passes, we merge only the bind groups
/// that are actually used by the pipeline, unlike render passes, which
/// merge every bind group that is ever set, even if it is not ultimately
/// used by the pipeline.
///
/// For compute passes, we call `drain_barriers` here, because barriers may
/// be needed before each dispatch if a previous dispatch had a conflicting
/// usage. For render passes, barriers are emitted once at the start of the
/// render pass.
///
/// # Indirect buffer handling
///
/// For indirect dispatches without validation, pass both `indirect_buffer`
/// and `indirect_buffer_index_if_not_validating`. The indirect buffer will
/// be added to the usage scope and the tracker.
///
/// For indirect dispatches with validation, pass only `indirect_buffer`.
/// The indirect buffer will be added to the usage scope to detect usage
/// conflicts. The indirect buffer does not need to be added to the tracker;
/// the indirect validation code handles transitions manually.
fn flush_bindings(
&mut self,
indirect_buffer: Option<TrackerIndex>,
) -> Result<(), ResourceUsageCompatibilityError> {
for bind_group in self.pass.binder.list_active() {
unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
// Note: stateless trackers are not merged: the lifetime reference
// is held to the bind group itself.
}
indirect_buffer: Option<&Arc<Buffer>>,
indirect_buffer_index_if_not_validating: Option<TrackerIndex>,
) -> Result<(), ComputePassErrorInner> {
let mut scope = self.pass.base.device.new_usage_scope();
for bind_group in self.pass.binder.list_active() {
self.intermediate_trackers
.set_from_bind_group(&mut self.pass.scope, &bind_group.used);
unsafe { scope.merge_bind_group(&bind_group.used)? };
}
// Add the state of the indirect buffer if it hasn't been hit before.
// When indirect validation is turned on, our actual use of the buffer
// is `STORAGE_READ_ONLY`, but for usage scope validation, we still want
// to treat it as indirect so we can detect the conflicts prescribed by
// WebGPU. The usage scope we construct here never leaves this function
// (and is not used to populate a tracker), so it's fine to do this.
if let Some(buffer) = indirect_buffer {
scope
.buffers
.merge_single(buffer, wgt::BufferUses::INDIRECT)?;
}
// Add the state of the indirect buffer, if needed (see above).
self.intermediate_trackers
.buffers
.set_multiple(&mut self.pass.scope.buffers, indirect_buffer);
.set_multiple(&mut scope.buffers, indirect_buffer_index_if_not_validating);
flush_bindings_helper(&mut self.pass, |bind_group| {
self.intermediate_trackers
.set_from_bind_group(&mut scope, &bind_group.used)
})?;
CommandEncoder::drain_barriers(
self.pass.base.raw_encoder,
@ -821,7 +865,7 @@ fn set_pipeline(
}
// Rebind resources
pass::rebind_resources::<ComputePassErrorInner, _>(
pass::change_pipeline_layout::<ComputePassErrorInner, _>(
&mut state.pass,
&pipeline.layout,
&pipeline.late_sized_buffer_groups,
@ -850,7 +894,7 @@ fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorI
state.is_ready()?;
state.flush_states(None)?;
state.flush_bindings(None, None)?;
let groups_size_limit = state
.pass
@ -1051,7 +1095,7 @@ fn dispatch_indirect(
}]);
}
state.flush_states(None)?;
state.flush_bindings(Some(&buffer), None)?;
unsafe {
state
.pass
@ -1060,14 +1104,8 @@ fn dispatch_indirect(
.dispatch_indirect(params.dst_buffer, 0);
}
} else {
state
.pass
.scope
.buffers
.merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
use crate::resource::Trackable;
state.flush_states(Some(buffer.tracker_index()))?;
state.flush_bindings(Some(&buffer), Some(buffer.tracker_index()))?;
let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
unsafe {

View File

@ -7,7 +7,6 @@ use crate::command::memory_init::SurfacesInDiscardState;
use crate::command::{DebugGroupError, QueryResetMap, QueryUseError};
use crate::device::{Device, DeviceError, MissingFeatures};
use crate::pipeline::LateSizedBufferGroup;
use crate::ray_tracing::AsAction;
use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
use crate::track::{ResourceUsageCompatibilityError, UsageScope};
use crate::{api_log, binding_model};
@ -100,12 +99,15 @@ where
);
state.dynamic_offset_count += num_dynamic_offsets;
if bind_group.is_none() {
let Some(bind_group) = bind_group else {
// TODO: Handle bind_group None.
return Ok(());
}
};
let bind_group = bind_group.unwrap();
// Add the bind group to the tracker. This is done for both compute and
// render passes, and is used to fail submission of the command buffer if
// any resource in any of the bind groups has been destroyed, whether or
// not the bind group is actually used by the pipeline.
let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
bind_group.same_device(device)?;
@ -113,7 +115,9 @@ where
bind_group.validate_dynamic_bindings(index, &state.temp_offsets)?;
if merge_bind_groups {
// merge the resource tracker in
// Merge the bind group's resources into the tracker. We only do this
// for render passes. For compute passes it is done per dispatch in
// [`flush_bindings`].
unsafe {
state.scope.merge_bind_group(&bind_group.used)?;
}
@ -122,57 +126,76 @@ where
// is held to the bind group itself.
state
.base
.buffer_memory_init_actions
.extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
action
.buffer
.initialization_status
.read()
.check_action(action)
}));
for action in bind_group.used_texture_ranges.iter() {
state.pending_discard_init_fixups.extend(
state
.base
.texture_memory_actions
.register_init_action(action),
);
}
let used_resource = bind_group
.used
.acceleration_structures
.into_iter()
.map(|tlas| AsAction::UseTlas(tlas.clone()));
state.base.as_actions.extend(used_resource);
let pipeline_layout = state.binder.pipeline_layout.clone();
let entries = state
.binder
.assign_group(index as usize, bind_group, &state.temp_offsets);
if !entries.is_empty() && pipeline_layout.is_some() {
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.base.snatch_guard)?;
unsafe {
state.base.raw_encoder.set_bind_group(
pipeline_layout,
index + i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
}
}
Ok(())
}
/// After a pipeline has been changed, resources must be rebound
pub(crate) fn rebind_resources<E, F: FnOnce()>(
/// Helper for `flush_bindings` implementing the portions that are the same for
/// compute and render passes.
pub(super) fn flush_bindings_helper<F>(
state: &mut PassState,
mut f: F,
) -> Result<(), DestroyedResourceError>
where
F: FnMut(&Arc<BindGroup>),
{
for bind_group in state.binder.list_active() {
f(bind_group);
state.base.buffer_memory_init_actions.extend(
bind_group.used_buffer_ranges.iter().filter_map(|action| {
action
.buffer
.initialization_status
.read()
.check_action(action)
}),
);
for action in bind_group.used_texture_ranges.iter() {
state.pending_discard_init_fixups.extend(
state
.base
.texture_memory_actions
.register_init_action(action),
);
}
let used_resource = bind_group
.used
.acceleration_structures
.into_iter()
.map(|tlas| crate::ray_tracing::AsAction::UseTlas(tlas.clone()));
state.base.as_actions.extend(used_resource);
}
let range = state.binder.take_rebind_range();
let entries = state.binder.entries(range);
match state.binder.pipeline_layout.as_ref() {
Some(pipeline_layout) if entries.len() != 0 => {
for (i, e) in entries {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.base.snatch_guard)?;
unsafe {
state.base.raw_encoder.set_bind_group(
pipeline_layout.raw(),
i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
}
}
_ => {}
}
Ok(())
}
pub(super) fn change_pipeline_layout<E, F: FnOnce()>(
state: &mut PassState,
pipeline_layout: &Arc<binding_model::PipelineLayout>,
late_sized_buffer_groups: &[LateSizedBufferGroup],
@ -189,24 +212,9 @@ where
.unwrap()
.is_equal(pipeline_layout)
{
let (start_index, entries) = state
state
.binder
.change_pipeline_layout(pipeline_layout, late_sized_buffer_groups);
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.base.snatch_guard)?;
unsafe {
state.base.raw_encoder.set_bind_group(
pipeline_layout.raw(),
start_index as u32 + i as u32,
Some(raw_bg),
&e.dynamic_offsets,
);
}
}
}
}
f();

View File

@ -11,7 +11,9 @@ use wgt::{
};
use crate::command::{
encoder::EncodingState, pass, pass_base, pass_try, validate_and_begin_occlusion_query,
encoder::EncodingState,
pass::{self, flush_bindings_helper},
pass_base, pass_try, validate_and_begin_occlusion_query,
validate_and_begin_pipeline_statistics_query, ArcCommand, DebugGroupError, EncoderStateError,
InnerCommandEncoder, PassStateError, TimestampWritesError,
};
@ -571,6 +573,15 @@ impl<'scope, 'snatch_guard, 'cmd_enc> State<'scope, 'snatch_guard, 'cmd_enc> {
}
}
/// Flush binding state in preparation for a draw call.
///
/// See the compute pass version for an explanation of some ways that
/// `flush_bindings` differs between the two types of passes.
fn flush_bindings(&mut self) -> Result<(), RenderPassErrorInner> {
flush_bindings_helper(&mut self.pass, |_| {})?;
Ok(())
}
/// Reset the `RenderBundle`-related states.
fn reset_bundle(&mut self) {
self.pass.binder.reset();
@ -2376,7 +2387,7 @@ fn set_pipeline(
}
// Rebind resource
pass::rebind_resources::<RenderPassErrorInner, _>(
pass::change_pipeline_layout::<RenderPassErrorInner, _>(
&mut state.pass,
&pipeline.layout,
&pipeline.late_sized_buffer_groups,
@ -2610,10 +2621,11 @@ fn draw(
instance_count: u32,
first_vertex: u32,
first_instance: u32,
) -> Result<(), DrawError> {
) -> Result<(), RenderPassErrorInner> {
api_log!("RenderPass::draw {vertex_count} {instance_count} {first_vertex} {first_instance}");
state.is_ready(DrawCommandFamily::Draw)?;
state.flush_bindings()?;
state
.vertex
@ -2644,10 +2656,11 @@ fn draw_indexed(
first_index: u32,
base_vertex: i32,
first_instance: u32,
) -> Result<(), DrawError> {
) -> Result<(), RenderPassErrorInner> {
api_log!("RenderPass::draw_indexed {index_count} {instance_count} {first_index} {base_vertex} {first_instance}");
state.is_ready(DrawCommandFamily::DrawIndexed)?;
state.flush_bindings()?;
let last_index = first_index as u64 + index_count as u64;
let index_limit = state.index.limit;
@ -2655,7 +2668,8 @@ fn draw_indexed(
return Err(DrawError::IndexBeyondLimit {
last_index,
index_limit,
});
}
.into());
}
state
.vertex
@ -2681,10 +2695,11 @@ fn draw_mesh_tasks(
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) -> Result<(), DrawError> {
) -> Result<(), RenderPassErrorInner> {
api_log!("RenderPass::draw_mesh_tasks {group_count_x} {group_count_y} {group_count_z}");
state.is_ready(DrawCommandFamily::DrawMeshTasks)?;
state.flush_bindings()?;
let groups_size_limit = state
.pass
@ -2702,7 +2717,8 @@ fn draw_mesh_tasks(
current: [group_count_x, group_count_y, group_count_z],
limit: groups_size_limit,
max_total: max_groups,
});
}
.into());
}
unsafe {
@ -2732,6 +2748,7 @@ fn multi_draw_indirect(
);
state.is_ready(family)?;
state.flush_bindings()?;
state
.pass
@ -2913,6 +2930,7 @@ fn multi_draw_indirect_count(
);
state.is_ready(family)?;
state.flush_bindings()?;
let stride = get_stride_of_indirect_args(family);

View File

@ -1799,6 +1799,12 @@ fn validate_command_buffer(
}
}
}
{
profiling::scope!("bind groups");
for bind_group in &cmd_buf_data.trackers.bind_groups {
bind_group.try_raw(snatch_guard)?;
}
}
if let Err(e) =
cmd_buf_data.validate_acceleration_structure_actions(snatch_guard, command_index_guard)

View File

@ -303,7 +303,7 @@ impl BufferTracker {
}
/// Extend the vectors to let the given index be valid.
fn allow_index(&mut self, index: usize) {
pub fn allow_index(&mut self, index: usize) {
if index >= self.start.len() {
self.set_size(index + 1);
}
@ -442,6 +442,11 @@ impl BufferTracker {
/// over all elements in the usage scope. We use each the
/// a given iterator of ids as a source of which IDs to look at.
/// All the IDs must have first been added to the usage scope.
///
/// # Panics
///
/// If a resource identified by `index_source` is not found in the usage
/// scope.
pub fn set_multiple(
&mut self,
scope: &mut BufferUsageScope,
@ -456,9 +461,8 @@ impl BufferTracker {
let index = index.as_usize();
scope.tracker_assert_in_bounds(index);
if unsafe { !scope.metadata.contains_unchecked(index) } {
continue;
unsafe {
assert!(scope.metadata.contains_unchecked(index));
}
// SAFETY: we checked that the index is in bounds for the scope, and

View File

@ -612,12 +612,32 @@ impl DeviceTracker {
/// A full double sided tracker used by CommandBuffers.
pub(crate) struct Tracker {
/// Buffers used within this command buffer.
///
/// For compute passes, this only includes buffers actually used by the
/// pipeline (contrast with the `bind_groups` member).
pub buffers: BufferTracker,
/// Textures used within this command buffer.
///
/// For compute passes, this only includes textures actually used by the
/// pipeline (contrast with the `bind_groups` member).
pub textures: TextureTracker,
pub blas_s: BlasTracker,
pub tlas_s: StatelessTracker<resource::Tlas>,
pub views: StatelessTracker<resource::TextureView>,
/// Contains all bind groups that were passed in any call to
/// `set_bind_group` on the encoder.
///
/// WebGPU requires that submission fails if any resource in any of these
/// bind groups is destroyed, even if the resource is not actually used by
/// the pipeline (e.g. because the pipeline does not use the bound slot, or
/// because the bind group was replaced by a subsequent call to
/// `set_bind_group`).
pub bind_groups: StatelessTracker<binding_model::BindGroup>,
pub compute_pipelines: StatelessTracker<pipeline::ComputePipeline>,
pub render_pipelines: StatelessTracker<pipeline::RenderPipeline>,
pub bundles: StatelessTracker<command::RenderBundle>,
@ -654,6 +674,10 @@ impl Tracker {
///
/// Only stateful things are merged in here, all other resources are owned
/// indirectly by the bind group.
///
/// # Panics
///
/// If a resource in the `bind_group` is not found in the usage scope.
pub fn set_from_bind_group(&mut self, scope: &mut UsageScope, bind_group: &BindGroupStates) {
self.buffers.set_multiple(
&mut scope.buffers,

View File

@ -611,6 +611,10 @@ impl TextureTracker {
/// over all elements in the usage scope. We use each the
/// bind group as a source of which IDs to look at. The bind groups
/// must have first been added to the usage scope.
///
/// # Panics
///
/// If a resource in `bind_group_state` is not found in the usage scope.
pub fn set_multiple(
&mut self,
scope: &mut TextureUsageScope,
@ -623,11 +627,12 @@ impl TextureTracker {
for (view, _) in bind_group_state.views.iter() {
let index = view.parent.tracker_index().as_usize();
scope.tracker_assert_in_bounds(index);
if unsafe { !scope.metadata.contains_unchecked(index) } {
continue;
scope.tracker_assert_in_bounds(index);
unsafe {
assert!(scope.metadata.contains_unchecked(index));
}
let texture_selector = &view.parent.full_range;
// SAFETY: we checked that the index is in bounds for the scope, and
// called `set_size` to ensure it is valid for `self`.