Deferred error reporting for compute passes

This commit is contained in:
Andy Leiserson 2025-06-06 15:06:22 -07:00
parent 09c08037c9
commit 460f073a67
4 changed files with 546 additions and 337 deletions

View File

@ -19,6 +19,7 @@ use wgpu_core::command::CommandEncoderError;
use wgpu_core::command::ComputePassError;
use wgpu_core::command::CreateRenderBundleError;
use wgpu_core::command::EncoderStateError;
use wgpu_core::command::PassStateError;
use wgpu_core::command::QueryError;
use wgpu_core::command::RenderBundleError;
use wgpu_core::command::RenderPassError;
@ -204,6 +205,12 @@ impl From<EncoderStateError> for GPUError {
}
}
impl From<PassStateError> for GPUError {
fn from(err: PassStateError) -> Self {
GPUError::Validation(fmt_err(&err))
}
}
impl From<CreateBufferError> for GPUError {
fn from(err: CreateBufferError) -> Self {
match err {

View File

@ -4,7 +4,7 @@ use wgt::{BufferAddress, DynamicOffset};
use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
use core::{fmt, str};
use crate::command::EncoderStateError;
use crate::command::{EncoderStateError, PassStateError, TimestampWritesError};
use crate::ray_tracing::AsAction;
use crate::{
binding_model::{
@ -17,9 +17,9 @@ use crate::{
memory_init::{
fixup_discarded_surfaces, CommandBufferTextureMemoryActions, SurfacesInDiscardState,
},
validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr, PassErrorScope,
PassTimestampWrites, QueryUseError, StateChange,
pass_base, pass_try, validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites,
BasePass, BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr,
PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
},
device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
global::Global,
@ -35,16 +35,24 @@ use crate::{
Label,
};
pub type ComputeBasePass = BasePass<ArcComputeCommand, ComputePassError>;
/// A pass's [encoder state](https://www.w3.org/TR/webgpu/#encoder-state) and
/// its validity are two distinct conditions, i.e., the full matrix of
/// (open, ended) x (valid, invalid) is possible.
///
/// The presence or absence of the `parent` `Option` indicates the pass's state.
/// The presence or absence of an error in `base.error` indicates the pass's
/// validity.
pub struct ComputePass {
/// All pass data & records is stored here.
///
/// If this is `None`, the pass is in the 'ended' state and can no longer be used.
/// Any attempt to record more commands will result in a validation error.
base: Option<BasePass<ArcComputeCommand, Infallible>>,
base: ComputeBasePass,
/// Parent command buffer that this pass records commands into.
///
/// If it is none, this pass is invalid and any operation on it will return an error.
/// If this is `Some`, then the pass is in WebGPU's "open" state. If it is
/// `None`, then the pass is in the "ended" state.
/// See <https://www.w3.org/TR/webgpu/#encoder-state>
parent: Option<Arc<CommandBuffer>>,
timestamp_writes: Option<ArcPassTimestampWrites>,
@ -56,15 +64,15 @@ pub struct ComputePass {
impl ComputePass {
/// If the parent command buffer is invalid, the returned pass will be invalid.
fn new(parent: Option<Arc<CommandBuffer>>, desc: ArcComputePassDescriptor) -> Self {
fn new(parent: Arc<CommandBuffer>, desc: ArcComputePassDescriptor) -> Self {
let ArcComputePassDescriptor {
label,
timestamp_writes,
} = desc;
Self {
base: Some(BasePass::new(&label)),
parent,
base: BasePass::new(&label),
parent: Some(parent),
timestamp_writes,
current_bind_groups: BindGroupStateChange::new(),
@ -72,19 +80,19 @@ impl ComputePass {
}
}
#[inline]
pub fn label(&self) -> Option<&str> {
self.base.as_ref().and_then(|base| base.label.as_deref())
fn new_invalid(parent: Arc<CommandBuffer>, label: &Label, err: ComputePassError) -> Self {
Self {
base: BasePass::new_invalid(label, err),
parent: Some(parent),
timestamp_writes: None,
current_bind_groups: BindGroupStateChange::new(),
current_pipeline: StateChange::new(),
}
}
fn base_mut<'a>(
&'a mut self,
scope: PassErrorScope,
) -> Result<&'a mut BasePass<ArcComputeCommand, Infallible>, ComputePassError> {
self.base
.as_mut()
.ok_or(ComputePassErrorInner::PassEnded)
.map_pass_err(scope)
#[inline]
pub fn label(&self) -> Option<&str> {
self.base.label.as_deref()
}
}
@ -171,9 +179,12 @@ pub enum ComputePassErrorInner {
PassEnded,
#[error(transparent)]
InvalidResource(#[from] InvalidResourceError),
#[error(transparent)]
TimestampWrites(#[from] TimestampWritesError),
}
/// Error encountered when performing a compute pass.
/// Error encountered when performing a compute pass, stored for later reporting
/// when encoding ends.
#[derive(Clone, Debug, Error)]
#[error("{scope}")]
pub struct ComputePassError {
@ -278,8 +289,10 @@ impl<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder>
impl Global {
/// Creates a compute pass.
///
/// If creation fails, an invalid pass is returned.
/// Any operation on an invalid pass will return an error.
/// If creation fails, an invalid pass is returned. Attempting to record
/// commands into an invalid pass is permitted, but a validation error will
/// ultimately be generated when the parent encoder is finished, and it is
/// not possible to run any commands from the invalid pass.
///
/// If successful, puts the encoder into the [`Locked`] state.
///
@ -289,35 +302,80 @@ impl Global {
encoder_id: id::CommandEncoderId,
desc: &ComputePassDescriptor<'_>,
) -> (ComputePass, Option<CommandEncoderError>) {
use EncoderStateError as SErr;
let scope = PassErrorScope::Pass;
let hub = &self.hub;
let mut arc_desc = ArcComputePassDescriptor {
label: desc.label.as_deref().map(Cow::Borrowed),
timestamp_writes: None, // Handle only once we resolved the encoder.
};
let make_err = |e, arc_desc| (ComputePass::new(None, arc_desc), Some(e));
let label = desc.label.as_deref().map(Cow::Borrowed);
let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock();
match cmd_buf.data.lock().lock_encoder() {
Ok(_) => {}
Err(e) => return make_err(e.into(), arc_desc),
};
arc_desc.timestamp_writes = match desc
.timestamp_writes
.as_ref()
.map(|tw| {
Self::validate_pass_timestamp_writes(&cmd_buf.device, &hub.query_sets.read(), tw)
})
.transpose()
{
Ok(ok) => ok,
Err(e) => return make_err(e, arc_desc),
};
(ComputePass::new(Some(cmd_buf), arc_desc), None)
match cmd_buf_data.lock_encoder() {
Ok(()) => {
drop(cmd_buf_data);
match desc
.timestamp_writes
.as_ref()
.map(|tw| {
Self::validate_pass_timestamp_writes::<ComputePassErrorInner>(
&cmd_buf.device,
&hub.query_sets.read(),
tw,
)
})
.transpose()
{
Ok(timestamp_writes) => {
let arc_desc = ArcComputePassDescriptor {
label,
timestamp_writes,
};
(ComputePass::new(cmd_buf, arc_desc), None)
}
Err(err) => (
ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
None,
),
}
}
Err(err @ SErr::Locked) => {
// Attempting to open a new pass while the encoder is locked
// invalidates the encoder, but does not generate a validation
// error.
cmd_buf_data.invalidate(err.clone());
drop(cmd_buf_data);
(
ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
None,
)
}
Err(err @ (SErr::Ended | SErr::Submitted)) => {
// Attempting to open a new pass after the encode has ended
// generates an immediate validation error.
drop(cmd_buf_data);
(
ComputePass::new_invalid(cmd_buf, &label, err.clone().map_pass_err(scope)),
Some(err.into()),
)
}
Err(err @ SErr::Invalid) => {
// Passes can be opened even on an invalid encoder. Such passes
// are even valid, but since there's no visible side-effect of
// the pass being valid and there's no point in storing recorded
// commands that will ultimately be discarded, we open an
// invalid pass to save that work.
drop(cmd_buf_data);
(
ComputePass::new_invalid(cmd_buf, &label, err.map_pass_err(scope)),
None,
)
}
Err(SErr::Unlocked) => {
unreachable!("lock_encoder cannot fail due to the encoder being unlocked")
}
}
}
/// Note that this differs from [`Self::compute_pass_end`], it will
@ -377,7 +435,7 @@ impl Global {
panic!("{:?}", err);
};
compute_pass.base = Some(BasePass {
compute_pass.base = BasePass {
label,
error: None,
commands: super::ComputeCommand::resolve_compute_command_ids(&self.hub, &commands)
@ -385,246 +443,259 @@ impl Global {
dynamic_offsets,
string_data,
push_constant_data,
});
};
self.compute_pass_end(&mut compute_pass).unwrap();
}
pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), ComputePassError> {
pub fn compute_pass_end(&self, pass: &mut ComputePass) -> Result<(), EncoderStateError> {
profiling::scope!("CommandEncoder::run_compute_pass");
let pass_scope = PassErrorScope::Pass;
let cmd_buf = pass
.parent
.as_ref()
.ok_or(ComputePassErrorInner::InvalidParentEncoder)
.map_pass_err(pass_scope)?;
let base = pass
.base
.take()
.ok_or(ComputePassErrorInner::PassEnded)
.map_pass_err(pass_scope)?;
let device = &cmd_buf.device;
device.check_is_valid().map_pass_err(pass_scope)?;
let cmd_buf = pass.parent.take().ok_or(EncoderStateError::Ended)?;
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
let cmd_buf_data = &mut *cmd_buf_data_guard;
let encoder = &mut cmd_buf_data.encoder;
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close_if_open().map_pass_err(pass_scope)?;
let raw_encoder = encoder
.open_pass(base.label.as_deref())
.map_pass_err(pass_scope)?;
let mut state = State {
binder: Binder::new(),
pipeline: None,
scope: device.new_usage_scope(),
debug_scope_depth: 0,
snatch_guard: device.snatchable_lock.read(),
device,
raw_encoder,
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
as_actions: &mut cmd_buf_data.as_actions,
temp_offsets: Vec::new(),
dynamic_offset_count: 0,
string_offset: 0,
active_query: None,
push_constants: Vec::new(),
intermediate_trackers: Tracker::new(),
pending_discard_init_fixups: SurfacesInDiscardState::new(),
};
let indices = &state.device.tracker_indices;
state.tracker.buffers.set_size(indices.buffers.size());
state.tracker.textures.set_size(indices.textures.size());
let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
if let Some(tw) = pass.timestamp_writes.take() {
tw.query_set
.same_device_as(cmd_buf.as_ref())
.map_pass_err(pass_scope)?;
let query_set = state.tracker.query_sets.insert_single(tw.query_set);
// Unlike in render passes we can't delay resetting the query sets since
// there is no auxiliary pass.
let range = if let (Some(index_a), Some(index_b)) =
(tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
{
Some(index_a.min(index_b)..index_a.max(index_b) + 1)
} else {
tw.beginning_of_pass_write_index
.or(tw.end_of_pass_write_index)
.map(|i| i..i + 1)
};
// Range should always be Some, both values being None should lead to a validation error.
// But no point in erroring over that nuance here!
if let Some(range) = range {
unsafe {
state.raw_encoder.reset_queries(query_set.raw(), range);
}
if let Some(err) = pass.base.error.take() {
if matches!(
err,
ComputePassError {
inner: ComputePassErrorInner::EncoderState(EncoderStateError::Ended),
scope: _,
}
Some(hal::PassTimestampWrites {
query_set: query_set.raw(),
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
) {
// If the encoder was already finished at time of pass creation,
// then it was not put in the locked state, so we need to
// generate a validation error here due to the encoder not being
// locked. The encoder already has a copy of the error.
return Err(EncoderStateError::Ended);
} else {
None
};
let hal_desc = hal::ComputePassDescriptor {
label: hal_label(base.label.as_deref(), device.instance_flags),
timestamp_writes,
};
unsafe {
state.raw_encoder.begin_compute_pass(&hal_desc);
}
for command in base.commands {
match command {
ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group,
} => {
let scope = PassErrorScope::SetBindGroup;
set_bind_group(
&mut state,
cmd_buf,
&base.dynamic_offsets,
index,
num_dynamic_offsets,
bind_group,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::SetPipeline(pipeline) => {
let scope = PassErrorScope::SetPipelineCompute;
set_pipeline(&mut state, cmd_buf, pipeline).map_pass_err(scope)?;
}
ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
} => {
let scope = PassErrorScope::SetPushConstant;
set_push_constant(
&mut state,
&base.push_constant_data,
offset,
size_bytes,
values_offset,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch { indirect: false };
dispatch(&mut state, groups).map_pass_err(scope)?;
}
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
let scope = PassErrorScope::Dispatch { indirect: true };
dispatch_indirect(&mut state, cmd_buf, buffer, offset).map_pass_err(scope)?;
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
push_debug_group(&mut state, &base.string_data, len);
}
ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
pop_debug_group(&mut state).map_pass_err(scope)?;
}
ArcComputeCommand::InsertDebugMarker { color: _, len } => {
insert_debug_marker(&mut state, &base.string_data, len);
}
ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
} => {
let scope = PassErrorScope::WriteTimestamp;
write_timestamp(&mut state, cmd_buf, query_set, query_index)
.map_pass_err(scope)?;
}
ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
} => {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
validate_and_begin_pipeline_statistics_query(
query_set,
state.raw_encoder,
&mut state.tracker.query_sets,
cmd_buf,
query_index,
None,
&mut state.active_query,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
.map_pass_err(scope)?;
}
// If the pass is invalid, invalidate the parent encoder and return.
// Since we do not track the state of an invalid encoder, it is not
// necessary to unlock it.
cmd_buf_data.invalidate(err);
return Ok(());
}
}
unsafe {
state.raw_encoder.end_compute_pass();
}
cmd_buf_data.unlock_and_record(|cmd_buf_data| -> Result<(), ComputePassError> {
let device = &cmd_buf.device;
device.check_is_valid().map_pass_err(pass_scope)?;
let State {
snatch_guard,
tracker,
intermediate_trackers,
pending_discard_init_fixups,
..
} = state;
let base = &mut pass.base;
// Stop the current command buffer.
encoder.close().map_pass_err(pass_scope)?;
let encoder = &mut cmd_buf_data.encoder;
// Create a new command buffer, which we will insert _before_ the body of the compute pass.
//
// Use that buffer to insert barriers and clear discarded images.
let transit = encoder
.open_pass(Some("(wgpu internal) Pre Pass"))
.map_pass_err(pass_scope)?;
fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
transit,
&mut tracker.textures,
device,
&snatch_guard,
);
CommandBuffer::insert_barriers_from_tracker(
transit,
tracker,
&intermediate_trackers,
&snatch_guard,
);
// Close the command buffer, and swap it with the previous.
encoder.close_and_swap().map_pass_err(pass_scope)?;
cmd_buf_data_guard.mark_successful();
// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close_if_open().map_pass_err(pass_scope)?;
let raw_encoder = encoder
.open_pass(base.label.as_deref())
.map_pass_err(pass_scope)?;
Ok(())
let mut state = State {
binder: Binder::new(),
pipeline: None,
scope: device.new_usage_scope(),
debug_scope_depth: 0,
snatch_guard: device.snatchable_lock.read(),
device,
raw_encoder,
tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
as_actions: &mut cmd_buf_data.as_actions,
temp_offsets: Vec::new(),
dynamic_offset_count: 0,
string_offset: 0,
active_query: None,
push_constants: Vec::new(),
intermediate_trackers: Tracker::new(),
pending_discard_init_fixups: SurfacesInDiscardState::new(),
};
let indices = &state.device.tracker_indices;
state.tracker.buffers.set_size(indices.buffers.size());
state.tracker.textures.set_size(indices.textures.size());
let timestamp_writes: Option<hal::PassTimestampWrites<'_, dyn hal::DynQuerySet>> =
if let Some(tw) = pass.timestamp_writes.take() {
tw.query_set
.same_device_as(cmd_buf.as_ref())
.map_pass_err(pass_scope)?;
let query_set = state.tracker.query_sets.insert_single(tw.query_set);
// Unlike in render passes we can't delay resetting the query sets since
// there is no auxiliary pass.
let range = if let (Some(index_a), Some(index_b)) =
(tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
{
Some(index_a.min(index_b)..index_a.max(index_b) + 1)
} else {
tw.beginning_of_pass_write_index
.or(tw.end_of_pass_write_index)
.map(|i| i..i + 1)
};
// Range should always be Some, both values being None should lead to a validation error.
// But no point in erroring over that nuance here!
if let Some(range) = range {
unsafe {
state.raw_encoder.reset_queries(query_set.raw(), range);
}
}
Some(hal::PassTimestampWrites {
query_set: query_set.raw(),
beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
end_of_pass_write_index: tw.end_of_pass_write_index,
})
} else {
None
};
let hal_desc = hal::ComputePassDescriptor {
label: hal_label(base.label.as_deref(), device.instance_flags),
timestamp_writes,
};
unsafe {
state.raw_encoder.begin_compute_pass(&hal_desc);
}
for command in base.commands.drain(..) {
match command {
ArcComputeCommand::SetBindGroup {
index,
num_dynamic_offsets,
bind_group,
} => {
let scope = PassErrorScope::SetBindGroup;
set_bind_group(
&mut state,
cmd_buf.as_ref(),
&base.dynamic_offsets,
index,
num_dynamic_offsets,
bind_group,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::SetPipeline(pipeline) => {
let scope = PassErrorScope::SetPipelineCompute;
set_pipeline(&mut state, cmd_buf.as_ref(), pipeline).map_pass_err(scope)?;
}
ArcComputeCommand::SetPushConstant {
offset,
size_bytes,
values_offset,
} => {
let scope = PassErrorScope::SetPushConstant;
set_push_constant(
&mut state,
&base.push_constant_data,
offset,
size_bytes,
values_offset,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::Dispatch(groups) => {
let scope = PassErrorScope::Dispatch { indirect: false };
dispatch(&mut state, groups).map_pass_err(scope)?;
}
ArcComputeCommand::DispatchIndirect { buffer, offset } => {
let scope = PassErrorScope::Dispatch { indirect: true };
dispatch_indirect(&mut state, cmd_buf.as_ref(), buffer, offset)
.map_pass_err(scope)?;
}
ArcComputeCommand::PushDebugGroup { color: _, len } => {
push_debug_group(&mut state, &base.string_data, len);
}
ArcComputeCommand::PopDebugGroup => {
let scope = PassErrorScope::PopDebugGroup;
pop_debug_group(&mut state).map_pass_err(scope)?;
}
ArcComputeCommand::InsertDebugMarker { color: _, len } => {
insert_debug_marker(&mut state, &base.string_data, len);
}
ArcComputeCommand::WriteTimestamp {
query_set,
query_index,
} => {
let scope = PassErrorScope::WriteTimestamp;
write_timestamp(&mut state, cmd_buf.as_ref(), query_set, query_index)
.map_pass_err(scope)?;
}
ArcComputeCommand::BeginPipelineStatisticsQuery {
query_set,
query_index,
} => {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
validate_and_begin_pipeline_statistics_query(
query_set,
state.raw_encoder,
&mut state.tracker.query_sets,
cmd_buf.as_ref(),
query_index,
None,
&mut state.active_query,
)
.map_pass_err(scope)?;
}
ArcComputeCommand::EndPipelineStatisticsQuery => {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
end_pipeline_statistics_query(state.raw_encoder, &mut state.active_query)
.map_pass_err(scope)?;
}
}
}
unsafe {
state.raw_encoder.end_compute_pass();
}
let State {
snatch_guard,
tracker,
intermediate_trackers,
pending_discard_init_fixups,
..
} = state;
// Stop the current command buffer.
encoder.close().map_pass_err(pass_scope)?;
// Create a new command buffer, which we will insert _before_ the body of the compute pass.
//
// Use that buffer to insert barriers and clear discarded images.
let transit = encoder
.open_pass(Some("(wgpu internal) Pre Pass"))
.map_pass_err(pass_scope)?;
fixup_discarded_surfaces(
pending_discard_init_fixups.into_iter(),
transit,
&mut tracker.textures,
device,
&snatch_guard,
);
CommandBuffer::insert_barriers_from_tracker(
transit,
tracker,
&intermediate_trackers,
&snatch_guard,
);
// Close the command buffer, and swap it with the previous.
encoder.close_and_swap().map_pass_err(pass_scope)?;
Ok(())
})
}
}
@ -1095,21 +1166,20 @@ impl Global {
index: u32,
bind_group_id: Option<id::BindGroupId>,
offsets: &[DynamicOffset],
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::SetBindGroup;
let base = pass
.base
.as_mut()
.ok_or(ComputePassErrorInner::PassEnded)
.map_pass_err(scope)?; // Can't use base_mut() utility here because of borrow checker.
let redundant = pass.current_bind_groups.set_and_check_redundant(
bind_group_id,
index,
&mut base.dynamic_offsets,
&mut pass.base.dynamic_offsets,
offsets,
);
// This statement will return an error if the pass is ended.
// Its important the error check comes before the early-out for `redundant`.
let base = pass_base!(pass, scope);
if redundant {
return Ok(());
}
@ -1119,12 +1189,11 @@ impl Global {
let bind_group_id = bind_group_id.unwrap();
let hub = &self.hub;
let bg = hub
.bind_groups
.get(bind_group_id)
.get()
.map_pass_err(scope)?;
bind_group = Some(bg);
bind_group = Some(pass_try!(
base,
scope,
hub.bind_groups.get(bind_group_id).get()
));
}
base.commands.push(ArcComputeCommand::SetBindGroup {
@ -1140,23 +1209,21 @@ impl Global {
&self,
pass: &mut ComputePass,
pipeline_id: id::ComputePipelineId,
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id);
let scope = PassErrorScope::SetPipelineCompute;
let base = pass.base_mut(scope)?;
// This statement will return an error if the pass is ended.
// Its important the error check comes before the early-out for `redundant`.
let base = pass_base!(pass, scope);
if redundant {
// Do redundant early-out **after** checking whether the pass is ended or not.
return Ok(());
}
let hub = &self.hub;
let pipeline = hub
.compute_pipelines
.get(pipeline_id)
.get()
.map_pass_err(scope)?;
let pipeline = pass_try!(base, scope, hub.compute_pipelines.get(pipeline_id).get());
base.commands.push(ArcComputeCommand::SetPipeline(pipeline));
@ -1168,23 +1235,33 @@ impl Global {
pass: &mut ComputePass,
offset: u32,
data: &[u8],
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::SetPushConstant;
let base = pass.base_mut(scope)?;
let base = pass_base!(pass, scope);
if offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
return Err(ComputePassErrorInner::PushConstantOffsetAlignment).map_pass_err(scope);
pass_try!(
base,
scope,
Err(ComputePassErrorInner::PushConstantOffsetAlignment),
);
}
if data.len() as u32 & (wgt::PUSH_CONSTANT_ALIGNMENT - 1) != 0 {
return Err(ComputePassErrorInner::PushConstantSizeAlignment).map_pass_err(scope);
pass_try!(
base,
scope,
Err(ComputePassErrorInner::PushConstantSizeAlignment),
)
}
let value_offset = base
.push_constant_data
.len()
.try_into()
.map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
.map_pass_err(scope)?;
let value_offset = pass_try!(
base,
scope,
base.push_constant_data
.len()
.try_into()
.map_err(|_| ComputePassErrorInner::PushConstantOutOfMemory)
);
base.push_constant_data.extend(
data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
@ -1206,11 +1283,11 @@ impl Global {
groups_x: u32,
groups_y: u32,
groups_z: u32,
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::Dispatch { indirect: false };
let base = pass.base_mut(scope)?;
base.commands
pass_base!(pass, scope)
.commands
.push(ArcComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
Ok(())
@ -1221,12 +1298,12 @@ impl Global {
pass: &mut ComputePass,
buffer_id: id::BufferId,
offset: BufferAddress,
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let hub = &self.hub;
let scope = PassErrorScope::Dispatch { indirect: true };
let base = pass.base_mut(scope)?;
let base = pass_base!(pass, scope);
let buffer = hub.buffers.get(buffer_id).get().map_pass_err(scope)?;
let buffer = pass_try!(base, scope, hub.buffers.get(buffer_id).get());
base.commands
.push(ArcComputeCommand::DispatchIndirect { buffer, offset });
@ -1239,8 +1316,8 @@ impl Global {
pass: &mut ComputePass,
label: &str,
color: u32,
) -> Result<(), ComputePassError> {
let base = pass.base_mut(PassErrorScope::PushDebugGroup)?;
) -> Result<(), PassStateError> {
let base = pass_base!(pass, PassErrorScope::PushDebugGroup);
let bytes = label.as_bytes();
base.string_data.extend_from_slice(bytes);
@ -1256,8 +1333,8 @@ impl Global {
pub fn compute_pass_pop_debug_group(
&self,
pass: &mut ComputePass,
) -> Result<(), ComputePassError> {
let base = pass.base_mut(PassErrorScope::PopDebugGroup)?;
) -> Result<(), PassStateError> {
let base = pass_base!(pass, PassErrorScope::PopDebugGroup);
base.commands.push(ArcComputeCommand::PopDebugGroup);
@ -1269,8 +1346,8 @@ impl Global {
pass: &mut ComputePass,
label: &str,
color: u32,
) -> Result<(), ComputePassError> {
let base = pass.base_mut(PassErrorScope::InsertDebugMarker)?;
) -> Result<(), PassStateError> {
let base = pass_base!(pass, PassErrorScope::InsertDebugMarker);
let bytes = label.as_bytes();
base.string_data.extend_from_slice(bytes);
@ -1288,12 +1365,12 @@ impl Global {
pass: &mut ComputePass,
query_set_id: id::QuerySetId,
query_index: u32,
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::WriteTimestamp;
let base = pass.base_mut(scope)?;
let base = pass_base!(pass, scope);
let hub = &self.hub;
let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
base.commands.push(ArcComputeCommand::WriteTimestamp {
query_set,
@ -1308,12 +1385,12 @@ impl Global {
pass: &mut ComputePass,
query_set_id: id::QuerySetId,
query_index: u32,
) -> Result<(), ComputePassError> {
) -> Result<(), PassStateError> {
let scope = PassErrorScope::BeginPipelineStatisticsQuery;
let base = pass.base_mut(scope)?;
let base = pass_base!(pass, scope);
let hub = &self.hub;
let query_set = hub.query_sets.get(query_set_id).get().map_pass_err(scope)?;
let query_set = pass_try!(base, scope, hub.query_sets.get(query_set_id).get());
base.commands
.push(ArcComputeCommand::BeginPipelineStatisticsQuery {
@ -1327,10 +1404,9 @@ impl Global {
pub fn compute_pass_end_pipeline_statistics_query(
&self,
pass: &mut ComputePass,
) -> Result<(), ComputePassError> {
let scope = PassErrorScope::EndPipelineStatisticsQuery;
let base = pass.base_mut(scope)?;
base.commands
) -> Result<(), PassStateError> {
pass_base!(pass, PassErrorScope::EndPipelineStatisticsQuery)
.commands
.push(ArcComputeCommand::EndPipelineStatisticsQuery);
Ok(())

View File

@ -189,8 +189,9 @@ impl CommandEncoderStatus {
/// Locks the encoder by putting it in the [`Self::Locked`] state.
///
/// Call [`Self::unlock_encoder`] to put the [`CommandBuffer`] back into the
/// [`Self::Recording`] state.
/// Render or compute passes call this on start. At the end of the pass,
/// they call [`Self::unlock_and_record`] to put the [`CommandBuffer`] back
/// into the [`Self::Recording`] state.
fn lock_encoder(&mut self) -> Result<(), EncoderStateError> {
match mem::replace(self, Self::Transitioning) {
Self::Recording(inner) => {
@ -240,6 +241,52 @@ impl CommandEncoderStatus {
}
}
/// Unlocks the [`CommandBuffer`] and puts it back into the
/// [`Self::Recording`] state, then records commands using the supplied
/// closure.
///
/// This function is the unlocking counterpart to [`Self::lock_encoder`]. It
/// is only valid to call this function if the encoder is in the
/// [`Self::Locked`] state.
///
/// If the closure returns an error, stores that error in the encoder for
/// later reporting when `finish()` is called. Returns `Ok(())` even if the
/// closure returned an error.
///
/// If the encoder is not in the [`Self::Locked`] state, the closure will
/// not be called and nothing will be recorded. If a validation error should
/// be raised immediately, returns it in `Err`, otherwise, returns `Ok(())`.
fn unlock_and_record<
F: FnOnce(&mut CommandBufferMutable) -> Result<(), E>,
E: Clone + Into<CommandEncoderError>,
>(
&mut self,
f: F,
) -> Result<(), EncoderStateError> {
match mem::replace(self, Self::Transitioning) {
Self::Locked(inner) => {
*self = Self::Recording(inner);
RecordingGuard { inner: self }.record(f);
Ok(())
}
st @ Self::Finished(_) => {
*self = st;
Err(EncoderStateError::Ended)
}
Self::Recording(_) => {
*self = Self::Error(EncoderStateError::Unlocked.into());
Err(EncoderStateError::Unlocked)
}
st @ Self::Error(_) => {
// Encoder is invalid. Do not record anything, but do not
// return an immediate validation error.
*self = st;
Ok(())
}
Self::Transitioning => unreachable!(),
}
}
fn finish(&mut self) -> Result<(), CommandEncoderError> {
match mem::replace(self, Self::Transitioning) {
Self::Recording(mut inner) => {
@ -863,8 +910,53 @@ impl<C: Clone, E: Clone> BasePass<C, E> {
push_constant_data: Vec::new(),
}
}
fn new_invalid(label: &Label, err: E) -> Self {
Self {
label: label.as_deref().map(str::to_owned),
error: Some(err),
commands: Vec::new(),
dynamic_offsets: Vec::new(),
string_data: Vec::new(),
push_constant_data: Vec::new(),
}
}
}
macro_rules! pass_base {
($pass:expr, $scope:expr $(,)?) => {
match (&$pass.parent, &$pass.base.error) {
// Attempting to record a command on a finished encoder raises a
// validation error.
(&None, _) => return Err(EncoderStateError::Ended).map_pass_err($scope),
// Attempting to record a command on an open but invalid pass (i.e.
// a pass with a stored error) fails silently. (The stored error
// will be transferred to the parent encoder when the pass is ended,
// and then raised as a validation error when `finish()` is called
// for the parent).
(&Some(_), &Some(_)) => return Ok(()),
// Happy path
(&Some(_), &None) => &mut $pass.base,
}
};
}
pub(crate) use pass_base;
macro_rules! pass_try {
($base:expr, $scope:expr, $res:expr $(,)?) => {
match $res.map_pass_err($scope) {
Ok(val) => val,
Err(err) => {
$base.error.get_or_insert(err);
return Ok(());
}
}
};
}
pub(crate) use pass_try;
/// Errors related to the state of a command or pass encoder.
///
/// The exact behavior of these errors may change based on the resolution of
@ -936,14 +1028,26 @@ pub enum CommandEncoderError {
BuildAccelerationStructure(#[from] BuildAccelerationStructureError),
#[error(transparent)]
TransitionResources(#[from] TransitionResourcesError),
#[error(transparent)]
ComputePass(#[from] ComputePassError),
// TODO: The following are temporary and can be removed once error handling
// is updated for render passes. (They will report via RenderPassError
// instead.)
#[error(transparent)]
QueryUse(#[from] QueryUseError),
#[error(transparent)]
TimestampWrites(#[from] TimestampWritesError),
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum TimestampWritesError {
#[error(
"begin and end indices of pass timestamp writes are both set to {idx}, which is not allowed"
)]
TimestampWriteIndicesEqual { idx: u32 },
#[error(transparent)]
TimestampWritesInvalid(#[from] QueryUseError),
IndicesEqual { idx: u32 },
#[error("no begin or end indices were specified for pass timestamp writes, expected at least one to be set")]
TimestampWriteIndicesMissing,
IndicesMissing,
}
impl Global {
@ -1064,11 +1168,18 @@ impl Global {
})
}
fn validate_pass_timestamp_writes(
fn validate_pass_timestamp_writes<E>(
device: &Device,
query_sets: &Storage<Fallible<QuerySet>>,
timestamp_writes: &PassTimestampWrites,
) -> Result<ArcPassTimestampWrites, CommandEncoderError> {
) -> Result<ArcPassTimestampWrites, E>
where
E: From<TimestampWritesError>
+ From<QueryUseError>
+ From<DeviceError>
+ From<MissingFeatures>
+ From<InvalidResourceError>,
{
let &PassTimestampWrites {
query_set,
beginning_of_pass_write_index,
@ -1090,7 +1201,7 @@ impl Global {
if let Some((begin, end)) = beginning_of_pass_write_index.zip(end_of_pass_write_index) {
if begin == end {
return Err(CommandEncoderError::TimestampWriteIndicesEqual { idx: begin });
return Err(TimestampWritesError::IndicesEqual { idx: begin }.into());
}
}
@ -1098,7 +1209,7 @@ impl Global {
.or(end_of_pass_write_index)
.is_none()
{
return Err(CommandEncoderError::TimestampWriteIndicesMissing);
return Err(TimestampWritesError::IndicesMissing.into());
}
Ok(ArcPassTimestampWrites {
@ -1218,6 +1329,12 @@ where
}
}
impl MapPassErr<PassStateError> for EncoderStateError {
fn map_pass_err(self, scope: PassErrorScope) -> PassStateError {
PassStateError { scope, inner: self }
}
}
#[derive(Clone, Copy, Debug)]
pub enum DrawKind {
Draw,
@ -1277,3 +1394,12 @@ pub enum PassErrorScope {
#[error("In a insert_debug_marker command")]
InsertDebugMarker,
}
/// Variant of `EncoderStateError` that includes the pass scope.
#[derive(Clone, Debug, Error)]
#[error("{scope}")]
pub struct PassStateError {
pub scope: PassErrorScope,
#[source]
pub(super) inner: EncoderStateError,
}

View File

@ -1577,7 +1577,7 @@ impl Global {
arc_desc.timestamp_writes = desc
.timestamp_writes
.map(|tw| Global::validate_pass_timestamp_writes(device, &query_sets, tw))
.map(|tw| Global::validate_pass_timestamp_writes::<CommandEncoderError>(device, &query_sets, tw))
.transpose()?;
arc_desc.occlusion_query_set =