Extract the parts of PassState that apply to the base encoder

And rename a few things.
This commit is contained in:
Andy Leiserson 2025-08-29 17:02:07 -07:00
parent f4ea8642a5
commit 885845087f
5 changed files with 288 additions and 216 deletions

View File

@ -8,7 +8,8 @@ use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
use core::{fmt, str};
use crate::command::{
pass, CommandEncoder, DebugGroupError, EncoderStateError, PassStateError, TimestampWritesError,
encoder::EncodingState, pass, CommandEncoder, DebugGroupError, EncoderStateError,
PassStateError, TimestampWritesError,
};
use crate::resource::DestroyedResourceError;
use crate::{binding_model::BindError, resource::RawResourceAccess};
@ -256,7 +257,7 @@ impl WebGpuError for ComputePassError {
struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
pipeline: Option<Arc<ComputePipeline>>,
general: pass::BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
active_query: Option<(Arc<resource::QuerySet>, u32)>,
@ -270,8 +271,8 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
{
fn is_ready(&self) -> Result<(), DispatchError> {
if let Some(pipeline) = self.pipeline.as_ref() {
self.general.binder.check_compatibility(pipeline.as_ref())?;
self.general.binder.check_late_buffer_bindings()?;
self.pass.binder.check_compatibility(pipeline.as_ref())?;
self.pass.binder.check_late_buffer_bindings()?;
Ok(())
} else {
Err(DispatchError::MissingPipeline(pass::MissingPipeline))
@ -284,19 +285,16 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
&mut self,
indirect_buffer: Option<TrackerIndex>,
) -> Result<(), ResourceUsageCompatibilityError> {
for bind_group in self.general.binder.list_active() {
unsafe { self.general.scope.merge_bind_group(&bind_group.used)? };
for bind_group in self.pass.binder.list_active() {
unsafe { self.pass.scope.merge_bind_group(&bind_group.used)? };
// Note: stateless trackers are not merged: the lifetime reference
// is held to the bind group itself.
}
for bind_group in self.general.binder.list_active() {
for bind_group in self.pass.binder.list_active() {
unsafe {
self.intermediate_trackers
.set_and_remove_from_usage_scope_sparse(
&mut self.general.scope,
&bind_group.used,
)
.set_and_remove_from_usage_scope_sparse(&mut self.pass.scope, &bind_group.used)
}
}
@ -305,15 +303,15 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
self.intermediate_trackers
.buffers
.set_and_remove_from_usage_scope_sparse(
&mut self.general.scope.buffers,
&mut self.pass.scope.buffers,
indirect_buffer,
);
}
CommandEncoder::drain_barriers(
self.general.raw_encoder,
self.pass.base.raw_encoder,
&mut self.intermediate_trackers,
self.general.snatch_guard,
self.pass.base.snatch_guard,
);
Ok(())
}
@ -536,27 +534,32 @@ impl Global {
.map_pass_err(pass_scope)?;
let snatch_guard = device.snatchable_lock.read();
let mut debug_scope_depth = 0;
let mut state = State {
pipeline: None,
general: pass::BaseState {
device,
raw_encoder,
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
as_actions: &mut cmd_buf_data.as_actions,
pass: pass::PassState {
base: EncodingState {
device,
raw_encoder,
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
as_actions: &mut cmd_buf_data.as_actions,
indirect_draw_validation_resources: &mut cmd_buf_data
.indirect_draw_validation_resources,
snatch_guard: &snatch_guard,
debug_scope_depth: &mut debug_scope_depth,
},
binder: Binder::new(),
temp_offsets: Vec::new(),
dynamic_offset_count: 0,
pending_discard_init_fixups: SurfacesInDiscardState::new(),
snatch_guard: &snatch_guard,
scope: device.new_usage_scope(),
debug_scope_depth: 0,
string_offset: 0,
},
active_query: None,
@ -566,14 +569,16 @@ impl Global {
intermediate_trackers: Tracker::new(),
};
let indices = &state.general.device.tracker_indices;
let indices = &state.pass.base.device.tracker_indices;
state
.general
.pass
.base
.tracker
.buffers
.set_size(indices.buffers.size());
state
.general
.pass
.base
.tracker
.textures
.set_size(indices.textures.size());
@ -584,7 +589,12 @@ impl Global {
.same_device_as(cmd_enc.as_ref())
.map_pass_err(pass_scope)?;
let query_set = state.general.tracker.query_sets.insert_single(tw.query_set);
let query_set = state
.pass
.base
.tracker
.query_sets
.insert_single(tw.query_set);
// Unlike in render passes we can't delay resetting the query sets since
// there is no auxiliary pass.
@ -602,7 +612,8 @@ impl Global {
if let Some(range) = range {
unsafe {
state
.general
.pass
.base
.raw_encoder
.reset_queries(query_set.raw(), range);
}
@ -623,7 +634,7 @@ impl Global {
};
unsafe {
state.general.raw_encoder.begin_compute_pass(&hal_desc);
state.pass.base.raw_encoder.begin_compute_pass(&hal_desc);
}
for command in base.commands.drain(..) {
@ -635,7 +646,7 @@ impl Global {
} => {
let scope = PassErrorScope::SetBindGroup;
pass::set_bind_group::<ComputePassErrorInner>(
&mut state.general,
&mut state.pass,
cmd_enc.as_ref(),
&base.dynamic_offsets,
index,
@ -656,7 +667,7 @@ impl Global {
} => {
let scope = PassErrorScope::SetPushConstant;
pass::set_push_constant::<ComputePassErrorInner, _>(
&mut state.general,
&mut state.pass,
&base.push_constant_data,
wgt::ShaderStages::COMPUTE,
offset,
@ -683,15 +694,15 @@ impl Global {
.map_pass_err(scope)?;
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
pass::push_debug_group(&mut state.general, &base.string_data, len);
pass::push_debug_group(&mut state.pass, &base.string_data, len);
}
ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
pass::pop_debug_group::<ComputePassErrorInner>(&mut state.general)
pass::pop_debug_group::<ComputePassErrorInner>(&mut state.pass)
.map_pass_err(scope)?;
}
ArcComputeCommand::InsertDebugMarker { color: _, len } => {
pass::insert_debug_marker(&mut state.general, &base.string_data, len);
pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
}
ArcComputeCommand::WriteTimestamp {
query_set,
@ -699,7 +710,7 @@ impl Global {
} => {
let scope = PassErrorScope::WriteTimestamp;
pass::write_timestamp::<ComputePassErrorInner>(
&mut state.general,
&mut state.pass,
cmd_enc.as_ref(),
None,
query_set,
@ -714,8 +725,8 @@ impl Global {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
validate_and_begin_pipeline_statistics_query(
query_set,
state.general.raw_encoder,
&mut state.general.tracker.query_sets,
state.pass.base.raw_encoder,
&mut state.pass.base.tracker.query_sets,
cmd_enc.as_ref(),
query_index,
None,
@ -726,7 +737,7 @@ impl Global {
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(
state.general.raw_encoder,
state.pass.base.raw_encoder,
&mut state.active_query,
)
.map_pass_err(scope)?;
@ -734,7 +745,7 @@ impl Global {
}
}
if state.general.debug_scope_depth > 0 {
if *state.pass.base.debug_scope_depth > 0 {
Err(
ComputePassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
.map_pass_err(pass_scope),
@ -742,13 +753,13 @@ impl Global {
}
unsafe {
state.general.raw_encoder.end_compute_pass();
state.pass.base.raw_encoder.end_compute_pass();
}
let State {
general:
pass::BaseState {
tracker,
pass:
pass::PassState {
base: EncodingState { tracker, .. },
pending_discard_init_fixups,
..
},
@ -799,7 +810,8 @@ fn set_pipeline(
state.pipeline = Some(pipeline.clone());
let pipeline = state
.general
.pass
.base
.tracker
.compute_pipelines
.insert_single(pipeline)
@ -807,14 +819,15 @@ fn set_pipeline(
unsafe {
state
.general
.pass
.base
.raw_encoder
.set_compute_pipeline(pipeline.raw());
}
// Rebind resources
pass::rebind_resources::<ComputePassErrorInner, _>(
&mut state.general,
&mut state.pass,
&pipeline.layout,
&pipeline.late_sized_buffer_groups,
|| {
@ -843,7 +856,8 @@ fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorI
state.flush_states(None)?;
let groups_size_limit = state
.general
.pass
.base
.device
.limits
.max_compute_workgroups_per_dimension;
@ -861,7 +875,7 @@ fn dispatch(state: &mut State, groups: [u32; 3]) -> Result<(), ComputePassErrorI
}
unsafe {
state.general.raw_encoder.dispatch(groups);
state.pass.base.raw_encoder.dispatch(groups);
}
Ok(())
}
@ -877,12 +891,13 @@ fn dispatch_indirect(
state.is_ready()?;
state
.general
.pass
.base
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
buffer.check_usage(wgt::BufferUsages::INDIRECT)?;
buffer.check_destroyed(state.general.snatch_guard)?;
buffer.check_destroyed(state.pass.base.snatch_guard)?;
if offset % 4 != 0 {
return Err(ComputePassErrorInner::UnalignedIndirectBufferOffset(offset));
@ -898,7 +913,7 @@ fn dispatch_indirect(
}
let stride = 3 * 4; // 3 integers, x/y/z group size
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
buffer.initialization_status.read().create_action(
&buffer,
offset..(offset + stride),
@ -906,21 +921,23 @@ fn dispatch_indirect(
),
);
if let Some(ref indirect_validation) = state.general.device.indirect_validation {
let params =
indirect_validation
.dispatch
.params(&state.general.device.limits, offset, buffer.size);
if let Some(ref indirect_validation) = state.pass.base.device.indirect_validation {
let params = indirect_validation.dispatch.params(
&state.pass.base.device.limits,
offset,
buffer.size,
);
unsafe {
state
.general
.pass
.base
.raw_encoder
.set_compute_pipeline(params.pipeline);
}
unsafe {
state.general.raw_encoder.set_push_constants(
state.pass.base.raw_encoder.set_push_constants(
params.pipeline_layout,
wgt::ShaderStages::COMPUTE,
0,
@ -929,7 +946,7 @@ fn dispatch_indirect(
}
unsafe {
state.general.raw_encoder.set_bind_group(
state.pass.base.raw_encoder.set_bind_group(
params.pipeline_layout,
0,
Some(params.dst_bind_group),
@ -937,13 +954,13 @@ fn dispatch_indirect(
);
}
unsafe {
state.general.raw_encoder.set_bind_group(
state.pass.base.raw_encoder.set_bind_group(
params.pipeline_layout,
1,
Some(
buffer
.indirect_validation_bind_groups
.get(state.general.snatch_guard)
.get(state.pass.base.snatch_guard)
.unwrap()
.dispatch
.as_ref(),
@ -957,17 +974,19 @@ fn dispatch_indirect(
.buffers
.set_single(&buffer, wgt::BufferUses::STORAGE_READ_ONLY);
let src_barrier = src_transition
.map(|transition| transition.into_hal(&buffer, state.general.snatch_guard));
.map(|transition| transition.into_hal(&buffer, state.pass.base.snatch_guard));
unsafe {
state
.general
.pass
.base
.raw_encoder
.transition_buffers(src_barrier.as_slice());
}
unsafe {
state
.general
.pass
.base
.raw_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
@ -979,7 +998,7 @@ fn dispatch_indirect(
}
unsafe {
state.general.raw_encoder.dispatch([1, 1, 1]);
state.pass.base.raw_encoder.dispatch([1, 1, 1]);
}
// reset state
@ -988,14 +1007,15 @@ fn dispatch_indirect(
unsafe {
state
.general
.pass
.base
.raw_encoder
.set_compute_pipeline(pipeline.raw());
}
if !state.push_constants.is_empty() {
unsafe {
state.general.raw_encoder.set_push_constants(
state.pass.base.raw_encoder.set_push_constants(
pipeline.layout.raw(),
wgt::ShaderStages::COMPUTE,
0,
@ -1004,11 +1024,11 @@ fn dispatch_indirect(
}
}
for (i, e) in state.general.binder.list_valid() {
for (i, e) in state.pass.binder.list_valid() {
let group = e.group.as_ref().unwrap();
let raw_bg = group.try_raw(state.general.snatch_guard)?;
let raw_bg = group.try_raw(state.pass.base.snatch_guard)?;
unsafe {
state.general.raw_encoder.set_bind_group(
state.pass.base.raw_encoder.set_bind_group(
pipeline.layout.raw(),
i as u32,
Some(raw_bg),
@ -1020,7 +1040,8 @@ fn dispatch_indirect(
unsafe {
state
.general
.pass
.base
.raw_encoder
.transition_buffers(&[hal::BufferBarrier {
buffer: params.dst_buffer,
@ -1034,13 +1055,14 @@ fn dispatch_indirect(
state.flush_states(None)?;
unsafe {
state
.general
.pass
.base
.raw_encoder
.dispatch_indirect(params.dst_buffer, 0);
}
} else {
state
.general
.pass
.scope
.buffers
.merge_single(&buffer, wgt::BufferUses::INDIRECT)?;
@ -1048,9 +1070,13 @@ fn dispatch_indirect(
use crate::resource::Trackable;
state.flush_states(Some(buffer.tracker_index()))?;
let buf_raw = buffer.try_raw(state.general.snatch_guard)?;
let buf_raw = buffer.try_raw(state.pass.base.snatch_guard)?;
unsafe {
state.general.raw_encoder.dispatch_indirect(buf_raw, offset);
state
.pass
.base
.raw_encoder
.dispatch_indirect(buf_raw, offset);
}
}

View File

@ -0,0 +1,30 @@
use alloc::{sync::Arc, vec::Vec};
use crate::{
command::memory_init::CommandBufferTextureMemoryActions, device::Device,
init_tracker::BufferInitTrackerAction, ray_tracing::AsAction, snatch::SnatchGuard,
track::Tracker,
};
/// State applicable when encoding commands onto a compute pass, or onto a
/// render pass, or directly with a command encoder.
pub(crate) struct EncodingState<'snatch_guard, 'cmd_enc, 'raw_encoder> {
pub(crate) device: &'cmd_enc Arc<Device>,
pub(crate) raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
pub(crate) tracker: &'cmd_enc mut Tracker,
pub(crate) buffer_memory_init_actions: &'cmd_enc mut Vec<BufferInitTrackerAction>,
pub(crate) texture_memory_actions: &'cmd_enc mut CommandBufferTextureMemoryActions,
pub(crate) as_actions: &'cmd_enc mut Vec<AsAction>,
pub(crate) indirect_draw_validation_resources:
&'cmd_enc mut crate::indirect_validation::DrawResources,
pub(crate) snatch_guard: &'snatch_guard SnatchGuard<'snatch_guard>,
/// Current debug scope nesting depth.
///
/// When encoding a compute or render pass, this is the depth of debug
/// scopes in the pass, not the depth of debug scopes in the parent encoder.
pub(crate) debug_scope_depth: &'cmd_enc mut u32,
}

View File

@ -5,6 +5,7 @@ mod clear;
mod compute;
mod compute_command;
mod draw;
mod encoder;
mod memory_init;
mod pass;
mod query;

View File

@ -2,15 +2,14 @@
use crate::binding_model::{BindError, BindGroup, PushConstantUploadError};
use crate::command::bind::Binder;
use crate::command::memory_init::{CommandBufferTextureMemoryActions, SurfacesInDiscardState};
use crate::command::encoder::EncodingState;
use crate::command::memory_init::SurfacesInDiscardState;
use crate::command::{CommandEncoder, DebugGroupError, QueryResetMap, QueryUseError};
use crate::device::{Device, DeviceError, MissingFeatures};
use crate::init_tracker::BufferInitTrackerAction;
use crate::device::{DeviceError, MissingFeatures};
use crate::pipeline::LateSizedBufferGroup;
use crate::ray_tracing::AsAction;
use crate::resource::{DestroyedResourceError, Labeled, ParentDevice, QuerySet};
use crate::snatch::SnatchGuard;
use crate::track::{ResourceUsageCompatibilityError, Tracker, UsageScope};
use crate::track::{ResourceUsageCompatibilityError, UsageScope};
use crate::{api_log, binding_model};
use alloc::sync::Arc;
use alloc::vec::Vec;
@ -42,15 +41,8 @@ impl WebGpuError for InvalidValuesOffset {
}
}
pub(crate) struct BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
pub(crate) device: &'cmd_enc Arc<Device>,
pub(crate) raw_encoder: &'raw_encoder mut dyn hal::DynCommandEncoder,
pub(crate) tracker: &'cmd_enc mut Tracker,
pub(crate) buffer_memory_init_actions: &'cmd_enc mut Vec<BufferInitTrackerAction>,
pub(crate) texture_memory_actions: &'cmd_enc mut CommandBufferTextureMemoryActions,
pub(crate) as_actions: &'cmd_enc mut Vec<AsAction>,
pub(crate) struct PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
pub(crate) base: EncodingState<'snatch_guard, 'cmd_enc, 'raw_encoder>,
/// Immediate texture inits required because of prior discards. Need to
/// be inserted before texture reads.
@ -64,14 +56,11 @@ pub(crate) struct BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
pub(crate) dynamic_offset_count: usize,
pub(crate) snatch_guard: &'snatch_guard SnatchGuard<'snatch_guard>,
pub(crate) debug_scope_depth: u32,
pub(crate) string_offset: usize,
}
pub(crate) fn set_bind_group<E>(
state: &mut BaseState,
state: &mut PassState,
cmd_enc: &CommandEncoder,
dynamic_offsets: &[DynamicOffset],
index: u32,
@ -95,7 +84,7 @@ where
);
}
let max_bind_groups = state.device.limits.max_bind_groups;
let max_bind_groups = state.base.device.limits.max_bind_groups;
if index >= max_bind_groups {
return Err(BindGroupIndexOutOfRange {
index,
@ -117,7 +106,7 @@ where
}
let bind_group = bind_group.unwrap();
let bind_group = state.tracker.bind_groups.insert_single(bind_group);
let bind_group = state.base.tracker.bind_groups.insert_single(bind_group);
bind_group.same_device_as(cmd_enc)?;
@ -133,6 +122,7 @@ where
// is held to the bind group itself.
state
.base
.buffer_memory_init_actions
.extend(bind_group.used_buffer_ranges.iter().filter_map(|action| {
action
@ -142,9 +132,12 @@ where
.check_action(action)
}));
for action in bind_group.used_texture_ranges.iter() {
state
.pending_discard_init_fixups
.extend(state.texture_memory_actions.register_init_action(action));
state.pending_discard_init_fixups.extend(
state
.base
.texture_memory_actions
.register_init_action(action),
);
}
let used_resource = bind_group
@ -153,7 +146,7 @@ where
.into_iter()
.map(|tlas| AsAction::UseTlas(tlas.clone()));
state.as_actions.extend(used_resource);
state.base.as_actions.extend(used_resource);
let pipeline_layout = state.binder.pipeline_layout.clone();
let entries = state
@ -163,9 +156,9 @@ where
let pipeline_layout = pipeline_layout.as_ref().unwrap().raw();
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.snatch_guard)?;
let raw_bg = group.try_raw(state.base.snatch_guard)?;
unsafe {
state.raw_encoder.set_bind_group(
state.base.raw_encoder.set_bind_group(
pipeline_layout,
index + i as u32,
Some(raw_bg),
@ -180,7 +173,7 @@ where
/// After a pipeline has been changed, resources must be rebound
pub(crate) fn rebind_resources<E, F: FnOnce()>(
state: &mut BaseState,
state: &mut PassState,
pipeline_layout: &Arc<binding_model::PipelineLayout>,
late_sized_buffer_groups: &[LateSizedBufferGroup],
f: F,
@ -202,9 +195,9 @@ where
if !entries.is_empty() {
for (i, e) in entries.iter().enumerate() {
if let Some(group) = e.group.as_ref() {
let raw_bg = group.try_raw(state.snatch_guard)?;
let raw_bg = group.try_raw(state.base.snatch_guard)?;
unsafe {
state.raw_encoder.set_bind_group(
state.base.raw_encoder.set_bind_group(
pipeline_layout.raw(),
start_index as u32 + i as u32,
Some(raw_bg),
@ -225,7 +218,7 @@ where
let offset = range.range.start;
let size_bytes = range.range.end - offset;
super::push_constant_clear(offset, size_bytes, |clear_offset, clear_data| unsafe {
state.raw_encoder.set_push_constants(
state.base.raw_encoder.set_push_constants(
pipeline_layout.raw(),
range.stages,
clear_offset,
@ -238,7 +231,7 @@ where
}
pub(crate) fn set_push_constant<E, F: FnOnce(&[u32])>(
state: &mut BaseState,
state: &mut PassState,
push_constant_data: &[u32],
stages: wgt::ShaderStages,
offset: u32,
@ -269,6 +262,7 @@ where
unsafe {
state
.base
.raw_encoder
.set_push_constants(pipeline_layout.raw(), stages, offset, data_slice)
}
@ -276,7 +270,7 @@ where
}
pub(crate) fn write_timestamp<E>(
state: &mut BaseState,
state: &mut PassState,
cmd_enc: &CommandEncoder,
pending_query_resets: Option<&mut QueryResetMap>,
query_set: Arc<QuerySet>,
@ -293,18 +287,24 @@ where
query_set.same_device_as(cmd_enc)?;
state
.base
.device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)?;
let query_set = state.tracker.query_sets.insert_single(query_set);
let query_set = state.base.tracker.query_sets.insert_single(query_set);
query_set.validate_and_write_timestamp(state.raw_encoder, query_index, pending_query_resets)?;
query_set.validate_and_write_timestamp(
state.base.raw_encoder,
query_index,
pending_query_resets,
)?;
Ok(())
}
pub(crate) fn push_debug_group(state: &mut BaseState, string_data: &[u8], len: usize) {
state.debug_scope_depth += 1;
pub(crate) fn push_debug_group(state: &mut PassState, string_data: &[u8], len: usize) {
*state.base.debug_scope_depth += 1;
if !state
.base
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
@ -314,36 +314,38 @@ pub(crate) fn push_debug_group(state: &mut BaseState, string_data: &[u8], len: u
api_log!("Pass::push_debug_group {label:?}");
unsafe {
state.raw_encoder.begin_debug_marker(label);
state.base.raw_encoder.begin_debug_marker(label);
}
}
state.string_offset += len;
}
pub(crate) fn pop_debug_group<E>(state: &mut BaseState) -> Result<(), E>
pub(crate) fn pop_debug_group<E>(state: &mut PassState) -> Result<(), E>
where
E: From<DebugGroupError>,
{
api_log!("Pass::pop_debug_group");
if state.debug_scope_depth == 0 {
if *state.base.debug_scope_depth == 0 {
return Err(DebugGroupError::InvalidPop.into());
}
state.debug_scope_depth -= 1;
*state.base.debug_scope_depth -= 1;
if !state
.base
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
{
unsafe {
state.raw_encoder.end_debug_marker();
state.base.raw_encoder.end_debug_marker();
}
}
Ok(())
}
pub(crate) fn insert_debug_marker(state: &mut BaseState, string_data: &[u8], len: usize) {
pub(crate) fn insert_debug_marker(state: &mut PassState, string_data: &[u8], len: usize) {
if !state
.base
.device
.instance_flags
.contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
@ -352,7 +354,7 @@ pub(crate) fn insert_debug_marker(state: &mut BaseState, string_data: &[u8], len
str::from_utf8(&string_data[state.string_offset..state.string_offset + len]).unwrap();
api_log!("Pass::insert_debug_marker {label:?}");
unsafe {
state.raw_encoder.insert_debug_marker(label);
state.base.raw_encoder.insert_debug_marker(label);
}
}
state.string_offset += len;

View File

@ -10,7 +10,7 @@ use wgt::{
};
use crate::command::{
pass, pass_base, pass_try, validate_and_begin_occlusion_query,
encoder::EncodingState, pass, pass_base, pass_try, validate_and_begin_occlusion_query,
validate_and_begin_pipeline_statistics_query, DebugGroupError, EncoderStateError,
InnerCommandEncoder, PassStateError, TimestampWritesError,
};
@ -504,7 +504,7 @@ struct State<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder> {
info: RenderPassInfo,
general: pass::BaseState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
pass: pass::PassState<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>,
active_occlusion_query: Option<(Arc<QuerySet>, u32)>,
active_pipeline_statistics_query: Option<(Arc<QuerySet>, u32)>,
@ -515,8 +515,8 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
{
fn is_ready(&self, family: DrawCommandFamily) -> Result<(), DrawError> {
if let Some(pipeline) = self.pipeline.as_ref() {
self.general.binder.check_compatibility(pipeline.as_ref())?;
self.general.binder.check_late_buffer_bindings()?;
self.pass.binder.check_compatibility(pipeline.as_ref())?;
self.pass.binder.check_late_buffer_bindings()?;
if self.blend_constant == OptionalState::Required {
return Err(DrawError::MissingBlendConstant);
@ -569,7 +569,7 @@ impl<'scope, 'snatch_guard, 'cmd_enc, 'raw_encoder>
/// Reset the `RenderBundle`-related states.
fn reset_bundle(&mut self) {
self.general.binder.reset();
self.pass.binder.reset();
self.pipeline = None;
self.index.reset();
self.vertex = Default::default();
@ -1853,8 +1853,6 @@ impl Global {
let tracker = &mut cmd_buf_data.trackers;
let buffer_memory_init_actions = &mut cmd_buf_data.buffer_memory_init_actions;
let texture_memory_actions = &mut cmd_buf_data.texture_memory_actions;
let indirect_draw_validation_resources =
&mut cmd_buf_data.indirect_draw_validation_resources;
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
@ -1889,6 +1887,8 @@ impl Global {
tracker.buffers.set_size(indices.buffers.size());
tracker.textures.set_size(indices.textures.size());
let mut debug_scope_depth = 0;
let mut state = State {
pipeline_flags: PipelineFlags::empty(),
blend_constant: OptionalState::Unused,
@ -1899,23 +1899,26 @@ impl Global {
info,
general: pass::BaseState {
device,
raw_encoder: encoder.raw.as_mut(),
tracker,
buffer_memory_init_actions,
texture_memory_actions,
as_actions: &mut cmd_buf_data.as_actions,
pass: pass::PassState {
base: EncodingState {
device,
raw_encoder: encoder.raw.as_mut(),
tracker,
buffer_memory_init_actions,
texture_memory_actions,
as_actions: &mut cmd_buf_data.as_actions,
indirect_draw_validation_resources: &mut cmd_buf_data
.indirect_draw_validation_resources,
snatch_guard,
debug_scope_depth: &mut debug_scope_depth,
},
pending_discard_init_fixups,
scope: device.new_usage_scope(),
binder: Binder::new(),
snatch_guard,
temp_offsets: Vec::new(),
dynamic_offset_count: 0,
debug_scope_depth: 0,
string_offset: 0,
},
@ -1932,7 +1935,7 @@ impl Global {
} => {
let scope = PassErrorScope::SetBindGroup;
pass::set_bind_group::<RenderPassErrorInner>(
&mut state.general,
&mut state.pass,
cmd_enc.as_ref(),
&base.dynamic_offsets,
index,
@ -1996,7 +1999,7 @@ impl Global {
} => {
let scope = PassErrorScope::SetPushConstant;
pass::set_push_constant::<RenderPassErrorInner, _>(
&mut state.general,
&mut state.pass,
&base.push_constant_data,
stages,
offset,
@ -2086,7 +2089,6 @@ impl Global {
};
multi_draw_indirect(
&mut state,
indirect_draw_validation_resources,
&mut indirect_draw_validation_batcher,
&cmd_enc,
buffer,
@ -2121,15 +2123,15 @@ impl Global {
.map_pass_err(scope)?;
}
ArcRenderCommand::PushDebugGroup { color: _, len } => {
pass::push_debug_group(&mut state.general, &base.string_data, len);
pass::push_debug_group(&mut state.pass, &base.string_data, len);
}
ArcRenderCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
pass::pop_debug_group::<RenderPassErrorInner>(&mut state.general)
pass::pop_debug_group::<RenderPassErrorInner>(&mut state.pass)
.map_pass_err(scope)?;
}
ArcRenderCommand::InsertDebugMarker { color: _, len } => {
pass::insert_debug_marker(&mut state.general, &base.string_data, len);
pass::insert_debug_marker(&mut state.pass, &base.string_data, len);
}
ArcRenderCommand::WriteTimestamp {
query_set,
@ -2137,7 +2139,7 @@ impl Global {
} => {
let scope = PassErrorScope::WriteTimestamp;
pass::write_timestamp::<RenderPassErrorInner>(
&mut state.general,
&mut state.pass,
cmd_enc.as_ref(),
Some(&mut pending_query_resets),
query_set,
@ -2157,8 +2159,8 @@ impl Global {
validate_and_begin_occlusion_query(
query_set,
state.general.raw_encoder,
&mut state.general.tracker.query_sets,
state.pass.base.raw_encoder,
&mut state.pass.base.tracker.query_sets,
query_index,
Some(&mut pending_query_resets),
&mut state.active_occlusion_query,
@ -2170,7 +2172,7 @@ impl Global {
let scope = PassErrorScope::EndOcclusionQuery;
end_occlusion_query(
state.general.raw_encoder,
state.pass.base.raw_encoder,
&mut state.active_occlusion_query,
)
.map_pass_err(scope)?;
@ -2187,8 +2189,8 @@ impl Global {
validate_and_begin_pipeline_statistics_query(
query_set,
state.general.raw_encoder,
&mut state.general.tracker.query_sets,
state.pass.base.raw_encoder,
&mut state.pass.base.tracker.query_sets,
cmd_enc.as_ref(),
query_index,
Some(&mut pending_query_resets),
@ -2201,7 +2203,7 @@ impl Global {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(
state.general.raw_encoder,
state.pass.base.raw_encoder,
&mut state.active_pipeline_statistics_query,
)
.map_pass_err(scope)?;
@ -2210,7 +2212,6 @@ impl Global {
let scope = PassErrorScope::ExecuteBundle;
execute_bundle(
&mut state,
indirect_draw_validation_resources,
&mut indirect_draw_validation_batcher,
&cmd_enc,
bundle,
@ -2220,7 +2221,7 @@ impl Global {
}
}
if state.general.debug_scope_depth > 0 {
if *state.pass.base.debug_scope_depth > 0 {
Err(
RenderPassErrorInner::DebugGroupError(DebugGroupError::MissingPop)
.map_pass_err(pass_scope),
@ -2231,16 +2232,16 @@ impl Global {
.info
.finish(
device,
state.general.raw_encoder,
state.general.snatch_guard,
&mut state.general.scope,
state.pass.base.raw_encoder,
state.pass.base.snatch_guard,
&mut state.pass.scope,
self.instance.flags,
)
.map_pass_err(pass_scope)?;
let trackers = state.general.scope;
let trackers = state.pass.scope;
let pending_discard_init_fixups = state.general.pending_discard_init_fixups;
let pending_discard_init_fixups = state.pass.pending_discard_init_fixups;
encoder.close().map_pass_err(pass_scope)?;
(trackers, pending_discard_init_fixups, pending_query_resets)
@ -2301,7 +2302,8 @@ fn set_pipeline(
state.pipeline = Some(pipeline.clone());
let pipeline = state
.general
.pass
.base
.tracker
.render_pipelines
.insert_single(pipeline)
@ -2330,7 +2332,8 @@ fn set_pipeline(
unsafe {
state
.general
.pass
.base
.raw_encoder
.set_render_pipeline(pipeline.raw());
}
@ -2338,7 +2341,8 @@ fn set_pipeline(
if pipeline.flags.contains(PipelineFlags::STENCIL_REFERENCE) {
unsafe {
state
.general
.pass
.base
.raw_encoder
.set_stencil_reference(state.stencil_reference);
}
@ -2346,7 +2350,7 @@ fn set_pipeline(
// Rebind resource
pass::rebind_resources::<RenderPassErrorInner, _>(
&mut state.general,
&mut state.pass,
&pipeline.layout,
&pipeline.late_sized_buffer_groups,
|| {},
@ -2369,7 +2373,7 @@ fn set_index_buffer(
api_log!("RenderPass::set_index_buffer {}", buffer.error_ident());
state
.general
.pass
.scope
.buffers
.merge_single(&buffer, wgt::BufferUses::INDEX)?;
@ -2386,12 +2390,12 @@ fn set_index_buffer(
.into());
}
let (binding, resolved_size) = buffer
.binding(offset, size, state.general.snatch_guard)
.binding(offset, size, state.pass.base.snatch_guard)
.map_err(RenderCommandError::from)?;
let end = offset + resolved_size;
state.index.update_buffer(offset..end, index_format);
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
buffer.initialization_status.read().create_action(
&buffer,
offset..end,
@ -2400,7 +2404,11 @@ fn set_index_buffer(
);
unsafe {
hal::DynCommandEncoder::set_index_buffer(state.general.raw_encoder, binding, index_format);
hal::DynCommandEncoder::set_index_buffer(
state.pass.base.raw_encoder,
binding,
index_format,
);
}
Ok(())
}
@ -2420,14 +2428,14 @@ fn set_vertex_buffer(
);
state
.general
.pass
.scope
.buffers
.merge_single(&buffer, wgt::BufferUses::VERTEX)?;
buffer.same_device_as(cmd_enc.as_ref())?;
let max_vertex_buffers = state.general.device.limits.max_vertex_buffers;
let max_vertex_buffers = state.pass.base.device.limits.max_vertex_buffers;
if slot >= max_vertex_buffers {
return Err(RenderCommandError::VertexBufferIndexOutOfRange {
index: slot,
@ -2442,11 +2450,11 @@ fn set_vertex_buffer(
return Err(RenderCommandError::UnalignedVertexBuffer { slot, offset }.into());
}
let (binding, buffer_size) = buffer
.binding(offset, size, state.general.snatch_guard)
.binding(offset, size, state.pass.base.snatch_guard)
.map_err(RenderCommandError::from)?;
state.vertex.buffer_sizes[slot as usize] = Some(buffer_size);
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
buffer.initialization_status.read().create_action(
&buffer,
offset..(offset + buffer_size),
@ -2455,7 +2463,7 @@ fn set_vertex_buffer(
);
unsafe {
hal::DynCommandEncoder::set_vertex_buffer(state.general.raw_encoder, slot, binding);
hal::DynCommandEncoder::set_vertex_buffer(state.pass.base.raw_encoder, slot, binding);
}
if let Some(pipeline) = state.pipeline.as_ref() {
state.vertex.update_limits(&pipeline.vertex_steps);
@ -2474,7 +2482,7 @@ fn set_blend_constant(state: &mut State, color: &Color) {
color.a as f32,
];
unsafe {
state.general.raw_encoder.set_blend_constants(&array);
state.pass.base.raw_encoder.set_blend_constants(&array);
}
}
@ -2487,7 +2495,7 @@ fn set_stencil_reference(state: &mut State, value: u32) {
.contains(PipelineFlags::STENCIL_REFERENCE)
{
unsafe {
state.general.raw_encoder.set_stencil_reference(value);
state.pass.base.raw_encoder.set_stencil_reference(value);
}
}
}
@ -2502,18 +2510,18 @@ fn set_viewport(
if rect.w < 0.0
|| rect.h < 0.0
|| rect.w > state.general.device.limits.max_texture_dimension_2d as f32
|| rect.h > state.general.device.limits.max_texture_dimension_2d as f32
|| rect.w > state.pass.base.device.limits.max_texture_dimension_2d as f32
|| rect.h > state.pass.base.device.limits.max_texture_dimension_2d as f32
{
return Err(RenderCommandError::InvalidViewportRectSize {
w: rect.w,
h: rect.h,
max: state.general.device.limits.max_texture_dimension_2d,
max: state.pass.base.device.limits.max_texture_dimension_2d,
}
.into());
}
let max_viewport_range = state.general.device.limits.max_texture_dimension_2d as f32 * 2.0;
let max_viewport_range = state.pass.base.device.limits.max_texture_dimension_2d as f32 * 2.0;
if rect.x < -max_viewport_range
|| rect.y < -max_viewport_range
@ -2541,7 +2549,8 @@ fn set_viewport(
};
unsafe {
state
.general
.pass
.base
.raw_encoder
.set_viewport(&r, depth_min..depth_max);
}
@ -2563,7 +2572,7 @@ fn set_scissor(state: &mut State, rect: Rect<u32>) -> Result<(), RenderPassError
h: rect.h,
};
unsafe {
state.general.raw_encoder.set_scissor_rect(&r);
state.pass.base.raw_encoder.set_scissor_rect(&r);
}
Ok(())
}
@ -2590,7 +2599,7 @@ fn draw(
unsafe {
if instance_count > 0 && vertex_count > 0 {
state.general.raw_encoder.draw(
state.pass.base.raw_encoder.draw(
first_vertex,
vertex_count,
first_instance,
@ -2628,7 +2637,7 @@ fn draw_indexed(
unsafe {
if instance_count > 0 && index_count > 0 {
state.general.raw_encoder.draw_indexed(
state.pass.base.raw_encoder.draw_indexed(
first_index,
index_count,
base_vertex,
@ -2651,11 +2660,12 @@ fn draw_mesh_tasks(
state.is_ready(DrawCommandFamily::DrawMeshTasks)?;
let groups_size_limit = state
.general
.pass
.base
.device
.limits
.max_task_workgroups_per_dimension;
let max_groups = state.general.device.limits.max_task_workgroup_total_count;
let max_groups = state.pass.base.device.limits.max_task_workgroup_total_count;
if group_count_x > groups_size_limit
|| group_count_y > groups_size_limit
|| group_count_z > groups_size_limit
@ -2670,10 +2680,11 @@ fn draw_mesh_tasks(
unsafe {
if group_count_x > 0 && group_count_y > 0 && group_count_z > 0 {
state
.general
.raw_encoder
.draw_mesh_tasks(group_count_x, group_count_y, group_count_z);
state.pass.base.raw_encoder.draw_mesh_tasks(
group_count_x,
group_count_y,
group_count_z,
);
}
}
Ok(())
@ -2681,7 +2692,6 @@ fn draw_mesh_tasks(
fn multi_draw_indirect(
state: &mut State,
indirect_draw_validation_resources: &mut crate::indirect_validation::DrawResources,
indirect_draw_validation_batcher: &mut crate::indirect_validation::DrawBatcher,
cmd_enc: &Arc<CommandEncoder>,
indirect_buffer: Arc<crate::resource::Buffer>,
@ -2697,13 +2707,14 @@ fn multi_draw_indirect(
state.is_ready(family)?;
state
.general
.pass
.base
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
indirect_buffer.same_device_as(cmd_enc.as_ref())?;
indirect_buffer.check_usage(BufferUsages::INDIRECT)?;
indirect_buffer.check_destroyed(state.general.snatch_guard)?;
indirect_buffer.check_destroyed(state.pass.base.snatch_guard)?;
if offset % 4 != 0 {
return Err(RenderPassErrorInner::UnalignedIndirectBufferOffset(offset));
@ -2721,7 +2732,7 @@ fn multi_draw_indirect(
});
}
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
indirect_buffer.initialization_status.read().create_action(
&indirect_buffer,
offset..end_offset,
@ -2749,9 +2760,9 @@ fn multi_draw_indirect(
}
}
if state.general.device.indirect_validation.is_some() {
if state.pass.base.device.indirect_validation.is_some() {
state
.general
.pass
.scope
.buffers
.merge_single(&indirect_buffer, wgt::BufferUses::STORAGE_READ_ONLY)?;
@ -2807,9 +2818,9 @@ fn multi_draw_indirect(
}
let mut draw_ctx = DrawContext {
raw_encoder: state.general.raw_encoder,
device: state.general.device,
indirect_draw_validation_resources,
raw_encoder: state.pass.base.raw_encoder,
device: state.pass.base.device,
indirect_draw_validation_resources: state.pass.base.indirect_draw_validation_resources,
indirect_draw_validation_batcher,
indirect_buffer,
family,
@ -2841,15 +2852,15 @@ fn multi_draw_indirect(
draw_ctx.draw(current_draw_data);
} else {
state
.general
.pass
.scope
.buffers
.merge_single(&indirect_buffer, wgt::BufferUses::INDIRECT)?;
draw(
state.general.raw_encoder,
state.pass.base.raw_encoder,
family,
indirect_buffer.try_raw(state.general.snatch_guard)?,
indirect_buffer.try_raw(state.pass.base.snatch_guard)?,
offset,
count,
);
@ -2879,11 +2890,13 @@ fn multi_draw_indirect_count(
let stride = get_stride_of_indirect_args(family);
state
.general
.pass
.base
.device
.require_features(wgt::Features::MULTI_DRAW_INDIRECT_COUNT)?;
state
.general
.pass
.base
.device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)?;
@ -2891,22 +2904,22 @@ fn multi_draw_indirect_count(
count_buffer.same_device_as(cmd_enc.as_ref())?;
state
.general
.pass
.scope
.buffers
.merge_single(&indirect_buffer, wgt::BufferUses::INDIRECT)?;
indirect_buffer.check_usage(BufferUsages::INDIRECT)?;
let indirect_raw = indirect_buffer.try_raw(state.general.snatch_guard)?;
let indirect_raw = indirect_buffer.try_raw(state.pass.base.snatch_guard)?;
state
.general
.pass
.scope
.buffers
.merge_single(&count_buffer, wgt::BufferUses::INDIRECT)?;
count_buffer.check_usage(BufferUsages::INDIRECT)?;
let count_raw = count_buffer.try_raw(state.general.snatch_guard)?;
let count_raw = count_buffer.try_raw(state.pass.base.snatch_guard)?;
if offset % 4 != 0 {
return Err(RenderPassErrorInner::UnalignedIndirectBufferOffset(offset));
@ -2921,7 +2934,7 @@ fn multi_draw_indirect_count(
buffer_size: indirect_buffer.size,
});
}
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
indirect_buffer.initialization_status.read().create_action(
&indirect_buffer,
offset..end_offset,
@ -2938,7 +2951,7 @@ fn multi_draw_indirect_count(
count_buffer_size: count_buffer.size,
});
}
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
count_buffer.initialization_status.read().create_action(
&count_buffer,
count_buffer_offset..end_count_offset,
@ -2948,7 +2961,7 @@ fn multi_draw_indirect_count(
match family {
DrawCommandFamily::Draw => unsafe {
state.general.raw_encoder.draw_indirect_count(
state.pass.base.raw_encoder.draw_indirect_count(
indirect_raw,
offset,
count_raw,
@ -2957,7 +2970,7 @@ fn multi_draw_indirect_count(
);
},
DrawCommandFamily::DrawIndexed => unsafe {
state.general.raw_encoder.draw_indexed_indirect_count(
state.pass.base.raw_encoder.draw_indexed_indirect_count(
indirect_raw,
offset,
count_raw,
@ -2966,7 +2979,7 @@ fn multi_draw_indirect_count(
);
},
DrawCommandFamily::DrawMeshTasks => unsafe {
state.general.raw_encoder.draw_mesh_tasks_indirect_count(
state.pass.base.raw_encoder.draw_mesh_tasks_indirect_count(
indirect_raw,
offset,
count_raw,
@ -2980,14 +2993,13 @@ fn multi_draw_indirect_count(
fn execute_bundle(
state: &mut State,
indirect_draw_validation_resources: &mut crate::indirect_validation::DrawResources,
indirect_draw_validation_batcher: &mut crate::indirect_validation::DrawBatcher,
cmd_enc: &Arc<CommandEncoder>,
bundle: Arc<super::RenderBundle>,
) -> Result<(), RenderPassErrorInner> {
api_log!("RenderPass::execute_bundle {}", bundle.error_ident());
let bundle = state.general.tracker.bundles.insert_single(bundle);
let bundle = state.pass.base.tracker.bundles.insert_single(bundle);
bundle.same_device_as(cmd_enc.as_ref())?;
@ -3010,7 +3022,7 @@ fn execute_bundle(
);
}
state.general.buffer_memory_init_actions.extend(
state.pass.base.buffer_memory_init_actions.extend(
bundle
.buffer_memory_init_actions
.iter()
@ -3023,9 +3035,10 @@ fn execute_bundle(
}),
);
for action in bundle.texture_memory_init_actions.iter() {
state.general.pending_discard_init_fixups.extend(
state.pass.pending_discard_init_fixups.extend(
state
.general
.pass
.base
.texture_memory_actions
.register_init_action(action),
);
@ -3033,10 +3046,10 @@ fn execute_bundle(
unsafe {
bundle.execute(
state.general.raw_encoder,
indirect_draw_validation_resources,
state.pass.base.raw_encoder,
state.pass.base.indirect_draw_validation_resources,
indirect_draw_validation_batcher,
state.general.snatch_guard,
state.pass.base.snatch_guard,
)
}
.map_err(|e| match e {
@ -3050,7 +3063,7 @@ fn execute_bundle(
})?;
unsafe {
state.general.scope.merge_render_bundle(&bundle.used)?;
state.pass.scope.merge_render_bundle(&bundle.used)?;
};
state.reset_bundle();
Ok(())