Deferred error reporting for other command encoder operations

* clear commands
* query set functions
* command_encoder_as_hal_mut
* ray_tracing
This commit is contained in:
Andy Leiserson 2025-06-06 12:31:14 -07:00 committed by Jim Blandy
parent e702d1c116
commit 3a5d0f2747
10 changed files with 706 additions and 708 deletions

View File

@ -27,6 +27,13 @@ pub use run::{execute_test, TestingContext};
pub use wgpu_macros::gpu_test; pub use wgpu_macros::gpu_test;
/// Run some code in an error scope and assert that validation fails. /// Run some code in an error scope and assert that validation fails.
///
/// Note that errors related to commands for the GPU (i.e. raised by methods on
/// GPUCommandEncoder, GPURenderPassEncoder, GPUComputePassEncoder,
/// GPURenderBundleEncoder) are usually not raised immediately. They are raised
/// only when `finish()` is called on the command encoder. Tests of such error
/// cases should call `fail` with a closure that calls `finish()`, not with a
/// closure that encodes the actual command.
pub fn fail<T>( pub fn fail<T>(
device: &wgpu::Device, device: &wgpu::Device,
callback: impl FnOnce() -> T, callback: impl FnOnce() -> T,

View File

@ -344,13 +344,12 @@ static CLEAR_OFFSET_OUTSIDE_RESOURCE_BOUNDS: GpuTestConfiguration = GpuTestConfi
let out_of_bounds = size.checked_add(wgpu::COPY_BUFFER_ALIGNMENT).unwrap(); let out_of_bounds = size.checked_add(wgpu::COPY_BUFFER_ALIGNMENT).unwrap();
let mut encoder = ctx.device.create_command_encoder(&Default::default());
encoder.clear_buffer(&buffer, out_of_bounds, None);
wgpu_test::fail( wgpu_test::fail(
&ctx.device, &ctx.device,
|| { || encoder.finish(),
ctx.device
.create_command_encoder(&Default::default())
.clear_buffer(&buffer, out_of_bounds, None)
},
Some("Clear of 20..20 would end up overrunning the bounds of the buffer of size 16"), Some("Clear of 20..20 would end up overrunning the bounds of the buffer of size 16"),
); );
}); });
@ -370,17 +369,16 @@ static CLEAR_OFFSET_PLUS_SIZE_OUTSIDE_U64_BOUNDS: GpuTestConfiguration =
let max_valid_offset = u64::MAX - (u64::MAX % wgpu::COPY_BUFFER_ALIGNMENT); let max_valid_offset = u64::MAX - (u64::MAX % wgpu::COPY_BUFFER_ALIGNMENT);
let smallest_aligned_invalid_size = wgpu::COPY_BUFFER_ALIGNMENT; let smallest_aligned_invalid_size = wgpu::COPY_BUFFER_ALIGNMENT;
let mut encoder = ctx.device.create_command_encoder(&Default::default());
encoder.clear_buffer(
&buffer,
max_valid_offset,
Some(smallest_aligned_invalid_size),
);
wgpu_test::fail( wgpu_test::fail(
&ctx.device, &ctx.device,
|| { || encoder.finish(),
ctx.device
.create_command_encoder(&Default::default())
.clear_buffer(
&buffer,
max_valid_offset,
Some(smallest_aligned_invalid_size),
)
},
Some(concat!( Some(concat!(
"Clear starts at offset 18446744073709551612 with size of 4, ", "Clear starts at offset 18446744073709551612 with size of 4, ",
"but these added together exceed `u64::MAX`" "but these added together exceed `u64::MAX`"

View File

@ -330,20 +330,19 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne
); );
// Texture clear should fail. // Texture clear should fail.
encoder_for_clear.clear_texture(
&texture_for_write,
&wgpu::ImageSubresourceRange {
aspect: wgpu::TextureAspect::All,
base_mip_level: 0,
mip_level_count: None,
base_array_layer: 0,
array_layer_count: None,
},
);
fail( fail(
&ctx.device, &ctx.device,
|| { || encoder_for_clear.finish(),
encoder_for_clear.clear_texture(
&texture_for_write,
&wgpu::ImageSubresourceRange {
aspect: wgpu::TextureAspect::All,
base_mip_level: 0,
mip_level_count: None,
base_array_layer: 0,
array_layer_count: None,
},
);
},
Some("device with '' label is invalid"), Some("device with '' label is invalid"),
); );

View File

@ -188,11 +188,8 @@ fn blas_compaction(ctx: TestingContext) {
let mut build_entry = as_ctx.blas_build_entry(); let mut build_entry = as_ctx.blas_build_entry();
build_entry.blas = &compacted; build_entry.blas = &compacted;
fail( fail_encoder.build_acceleration_structures([&build_entry], []);
&ctx.device, fail(&ctx.device, || fail_encoder.finish(), None);
|| fail_encoder.build_acceleration_structures([&build_entry], []),
None,
);
} }
#[gpu_test] #[gpu_test]
@ -733,13 +730,8 @@ fn only_tlas_vertex_return(ctx: TestingContext) {
label: Some("TLAS 1"), label: Some("TLAS 1"),
}); });
fail( encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas]);
&ctx.device, fail(&ctx.device, || encoder_tlas.finish(), None);
|| {
encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas]);
},
None,
);
} }
#[gpu_test] #[gpu_test]
@ -817,30 +809,29 @@ fn test_as_build_format_stride(
.create_command_encoder(&CommandEncoderDescriptor { .create_command_encoder(&CommandEncoderDescriptor {
label: Some("BLAS_1"), label: Some("BLAS_1"),
}); });
fail_if( command_encoder.build_acceleration_structures(
&[BlasBuildEntry {
blas: &blas,
geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry {
size: &blas_size,
vertex_buffer: &vertices,
first_vertex: 0,
vertex_stride: stride,
index_buffer: None,
first_index: None,
transform_buffer: None,
transform_buffer_offset: None,
}]),
}],
&[],
);
let command_buffer = fail_if(
&ctx.device, &ctx.device,
invalid_combination, invalid_combination,
|| { || command_encoder.finish(),
command_encoder.build_acceleration_structures(
&[BlasBuildEntry {
blas: &blas,
geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry {
size: &blas_size,
vertex_buffer: &vertices,
first_vertex: 0,
vertex_stride: stride,
index_buffer: None,
first_index: None,
transform_buffer: None,
transform_buffer_offset: None,
}]),
}],
&[],
)
},
None, None,
); );
if !invalid_combination { if !invalid_combination {
ctx.queue.submit([command_encoder.finish()]); ctx.queue.submit([command_buffer]);
} }
} }

View File

@ -86,7 +86,7 @@ impl Global {
dst: BufferId, dst: BufferId,
offset: BufferAddress, offset: BufferAddress,
size: Option<BufferAddress>, size: Option<BufferAddress>,
) -> Result<(), ClearError> { ) -> Result<(), EncoderStateError> {
profiling::scope!("CommandEncoder::clear_buffer"); profiling::scope!("CommandEncoder::clear_buffer");
api_log!("CommandEncoder::clear_buffer {dst:?}"); api_log!("CommandEncoder::clear_buffer {dst:?}");
@ -96,77 +96,74 @@ impl Global {
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(|cmd_buf_data| -> Result<(), ClearError> {
let cmd_buf_data = &mut *cmd_buf_data_guard; #[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::ClearBuffer { dst, offset, size });
}
#[cfg(feature = "trace")] let dst_buffer = hub.buffers.get(dst).get()?;
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::ClearBuffer { dst, offset, size });
}
let dst_buffer = hub.buffers.get(dst).get()?; dst_buffer.same_device_as(cmd_buf.as_ref())?;
dst_buffer.same_device_as(cmd_buf.as_ref())?; let dst_pending = cmd_buf_data
.trackers
.buffers
.set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
let dst_pending = cmd_buf_data let snatch_guard = dst_buffer.device.snatchable_lock.read();
.trackers let dst_raw = dst_buffer.try_raw(&snatch_guard)?;
.buffers dst_buffer.check_usage(BufferUsages::COPY_DST)?;
.set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
let snatch_guard = dst_buffer.device.snatchable_lock.read(); // Check if offset & size are valid.
let dst_raw = dst_buffer.try_raw(&snatch_guard)?; if offset % wgt::COPY_BUFFER_ALIGNMENT != 0 {
dst_buffer.check_usage(BufferUsages::COPY_DST)?; return Err(ClearError::UnalignedBufferOffset(offset));
}
// Check if offset & size are valid. let size = size.unwrap_or(dst_buffer.size.saturating_sub(offset));
if offset % wgt::COPY_BUFFER_ALIGNMENT != 0 { if size % wgt::COPY_BUFFER_ALIGNMENT != 0 {
return Err(ClearError::UnalignedBufferOffset(offset)); return Err(ClearError::UnalignedFillSize(size));
} }
let end_offset =
let size = size.unwrap_or(dst_buffer.size.saturating_sub(offset)); offset
if size % wgt::COPY_BUFFER_ALIGNMENT != 0 { .checked_add(size)
return Err(ClearError::UnalignedFillSize(size)); .ok_or(ClearError::OffsetPlusSizeExceeds64BitBounds {
} start_offset: offset,
let end_offset = requested_size: size,
offset })?;
.checked_add(size) if end_offset > dst_buffer.size {
.ok_or(ClearError::OffsetPlusSizeExceeds64BitBounds { return Err(ClearError::BufferOverrun {
start_offset: offset, start_offset: offset,
requested_size: size, end_offset,
})?; buffer_size: dst_buffer.size,
if end_offset > dst_buffer.size { });
return Err(ClearError::BufferOverrun { }
start_offset: offset,
end_offset,
buffer_size: dst_buffer.size,
});
}
if offset == end_offset { if offset == end_offset {
log::trace!("Ignoring fill_buffer of size 0"); log::trace!("Ignoring fill_buffer of size 0");
return Ok(());
}
cmd_buf_data_guard.mark_successful(); // Mark dest as initialized.
return Ok(()); cmd_buf_data.buffer_memory_init_actions.extend(
} dst_buffer.initialization_status.read().create_action(
&dst_buffer,
offset..end_offset,
MemoryInitKind::ImplicitlyInitialized,
),
);
// Mark dest as initialized. // actual hal barrier & operation
cmd_buf_data.buffer_memory_init_actions.extend( let dst_barrier =
dst_buffer.initialization_status.read().create_action( dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard));
&dst_buffer, let cmd_buf_raw = cmd_buf_data.encoder.open()?;
offset..end_offset, unsafe {
MemoryInitKind::ImplicitlyInitialized, cmd_buf_raw.transition_buffers(dst_barrier.as_slice());
), cmd_buf_raw.clear_buffer(dst_raw, offset..end_offset);
); }
// actual hal barrier & operation Ok(())
let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard)); })
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
unsafe {
cmd_buf_raw.transition_buffers(dst_barrier.as_slice());
cmd_buf_raw.clear_buffer(dst_raw, offset..end_offset);
}
cmd_buf_data_guard.mark_successful();
Ok(())
} }
pub fn command_encoder_clear_texture( pub fn command_encoder_clear_texture(
@ -174,7 +171,7 @@ impl Global {
command_encoder_id: CommandEncoderId, command_encoder_id: CommandEncoderId,
dst: TextureId, dst: TextureId,
subresource_range: &ImageSubresourceRange, subresource_range: &ImageSubresourceRange,
) -> Result<(), ClearError> { ) -> Result<(), EncoderStateError> {
profiling::scope!("CommandEncoder::clear_texture"); profiling::scope!("CommandEncoder::clear_texture");
api_log!("CommandEncoder::clear_texture {dst:?}"); api_log!("CommandEncoder::clear_texture {dst:?}");
@ -184,79 +181,78 @@ impl Global {
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(|cmd_buf_data| -> Result<(), ClearError> {
let cmd_buf_data = &mut *cmd_buf_data_guard; #[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::ClearTexture {
dst,
subresource_range: *subresource_range,
});
}
#[cfg(feature = "trace")] if !cmd_buf.support_clear_texture {
if let Some(ref mut list) = cmd_buf_data.commands { return Err(ClearError::MissingClearTextureFeature);
list.push(TraceCommand::ClearTexture { }
dst,
subresource_range: *subresource_range,
});
}
if !cmd_buf.support_clear_texture { let dst_texture = hub.textures.get(dst).get()?;
return Err(ClearError::MissingClearTextureFeature);
}
let dst_texture = hub.textures.get(dst).get()?; dst_texture.same_device_as(cmd_buf.as_ref())?;
dst_texture.same_device_as(cmd_buf.as_ref())?; // Check if subresource aspects are valid.
let clear_aspects =
hal::FormatAspects::new(dst_texture.desc.format, subresource_range.aspect);
if clear_aspects.is_empty() {
return Err(ClearError::MissingTextureAspect {
texture_format: dst_texture.desc.format,
subresource_range_aspects: subresource_range.aspect,
});
};
// Check if subresource aspects are valid. // Check if subresource level range is valid
let clear_aspects = let subresource_mip_range =
hal::FormatAspects::new(dst_texture.desc.format, subresource_range.aspect); subresource_range.mip_range(dst_texture.full_range.mips.end);
if clear_aspects.is_empty() { if dst_texture.full_range.mips.start > subresource_mip_range.start
return Err(ClearError::MissingTextureAspect { || dst_texture.full_range.mips.end < subresource_mip_range.end
texture_format: dst_texture.desc.format, {
subresource_range_aspects: subresource_range.aspect, return Err(ClearError::InvalidTextureLevelRange {
}); texture_level_range: dst_texture.full_range.mips.clone(),
}; subresource_base_mip_level: subresource_range.base_mip_level,
subresource_mip_level_count: subresource_range.mip_level_count,
});
}
// Check if subresource layer range is valid
let subresource_layer_range =
subresource_range.layer_range(dst_texture.full_range.layers.end);
if dst_texture.full_range.layers.start > subresource_layer_range.start
|| dst_texture.full_range.layers.end < subresource_layer_range.end
{
return Err(ClearError::InvalidTextureLayerRange {
texture_layer_range: dst_texture.full_range.layers.clone(),
subresource_base_array_layer: subresource_range.base_array_layer,
subresource_array_layer_count: subresource_range.array_layer_count,
});
}
// Check if subresource level range is valid let device = &cmd_buf.device;
let subresource_mip_range = subresource_range.mip_range(dst_texture.full_range.mips.end); device.check_is_valid()?;
if dst_texture.full_range.mips.start > subresource_mip_range.start let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker()?;
|| dst_texture.full_range.mips.end < subresource_mip_range.end
{
return Err(ClearError::InvalidTextureLevelRange {
texture_level_range: dst_texture.full_range.mips.clone(),
subresource_base_mip_level: subresource_range.base_mip_level,
subresource_mip_level_count: subresource_range.mip_level_count,
});
}
// Check if subresource layer range is valid
let subresource_layer_range =
subresource_range.layer_range(dst_texture.full_range.layers.end);
if dst_texture.full_range.layers.start > subresource_layer_range.start
|| dst_texture.full_range.layers.end < subresource_layer_range.end
{
return Err(ClearError::InvalidTextureLayerRange {
texture_layer_range: dst_texture.full_range.layers.clone(),
subresource_base_array_layer: subresource_range.base_array_layer,
subresource_array_layer_count: subresource_range.array_layer_count,
});
}
let device = &cmd_buf.device; let snatch_guard = device.snatchable_lock.read();
device.check_is_valid()?; clear_texture(
let (encoder, tracker) = cmd_buf_data.open_encoder_and_tracker()?; &dst_texture,
TextureInitRange {
mip_range: subresource_mip_range,
layer_range: subresource_layer_range,
},
encoder,
&mut tracker.textures,
&device.alignments,
device.zero_buffer.as_ref(),
&snatch_guard,
)?;
let snatch_guard = device.snatchable_lock.read(); Ok(())
clear_texture( })
&dst_texture,
TextureInitRange {
mip_range: subresource_mip_range,
layer_range: subresource_layer_range,
},
encoder,
&mut tracker.textures,
&device.alignments,
device.zero_buffer.as_ref(),
&snatch_guard,
)?;
cmd_buf_data_guard.mark_successful();
Ok(())
} }
} }

View File

@ -37,7 +37,7 @@ use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard; use crate::snatch::SnatchGuard;
use crate::init_tracker::BufferInitTrackerAction; use crate::init_tracker::BufferInitTrackerAction;
use crate::ray_tracing::AsAction; use crate::ray_tracing::{AsAction, BuildAccelerationStructureError};
use crate::resource::{ use crate::resource::{
DestroyedResourceError, Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet, DestroyedResourceError, Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet,
}; };
@ -106,17 +106,6 @@ pub(crate) enum CommandEncoderStatus {
} }
impl CommandEncoderStatus { impl CommandEncoderStatus {
/// Checks that the encoder is in the [`Self::Recording`] state.
pub(crate) fn record(&mut self) -> Result<RecordingGuard<'_>, EncoderStateError> {
match self {
Self::Recording(_) => Ok(RecordingGuard { inner: self }),
Self::Locked(_) => Err(self.invalidate(EncoderStateError::Locked)),
Self::Finished(_) => Err(EncoderStateError::Ended),
Self::Error(_) => Err(EncoderStateError::Invalid),
Self::Transitioning => unreachable!(),
}
}
/// Record commands using the supplied closure. /// Record commands using the supplied closure.
/// ///
/// If the encoder is in the [`Self::Recording`] state, calls the closure to /// If the encoder is in the [`Self::Recording`] state, calls the closure to
@ -138,29 +127,50 @@ impl CommandEncoderStatus {
&mut self, &mut self,
f: F, f: F,
) -> Result<(), EncoderStateError> { ) -> Result<(), EncoderStateError> {
let err = match self.record() { match self {
Ok(guard) => { Self::Recording(_) => {
guard.record(f); RecordingGuard { inner: self }.record(f);
return Ok(());
}
Err(err) => err,
};
match err {
err @ EncoderStateError::Locked => {
// Invalidate the encoder and do not record anything, but do not
// return an immediate validation error.
self.invalidate(err);
Ok(()) Ok(())
} }
err @ EncoderStateError::Ended => { Self::Locked(_) => {
// Invalidate the encoder, do not record anything, and return an // Invalidate the encoder and do not record anything, but do not
// immediate validation error. // return an immediate validation error.
Err(self.invalidate(err)) self.invalidate(EncoderStateError::Locked);
Ok(())
} }
// Encoder is ended. Invalidate the encoder, do not record anything,
// and return an immediate validation error.
Self::Finished(_) => Err(self.invalidate(EncoderStateError::Ended)),
// Encoder is already invalid. Do not record anything, but do not // Encoder is already invalid. Do not record anything, but do not
// return an immediate validation error. // return an immediate validation error.
EncoderStateError::Invalid => Ok(()), Self::Error(_) => Ok(()),
EncoderStateError::Unlocked | EncoderStateError::Submitted => unreachable!(), Self::Transitioning => unreachable!(),
}
}
/// Special version of record used by `command_encoder_as_hal_mut`. This
/// differs from the regular version in two ways:
///
/// 1. The recording closure is infallible.
/// 2. The recording closure takes `Option<&mut CommandBufferMutable>`, and
/// in the case that the encoder is not in a valid state for recording, the
/// closure is still called, with `None` as its argument.
pub(crate) fn record_as_hal_mut<T, F: FnOnce(Option<&mut CommandBufferMutable>) -> T>(
&mut self,
f: F,
) -> T {
match self {
Self::Recording(_) => RecordingGuard { inner: self }.record_as_hal_mut(f),
Self::Locked(_) => {
self.invalidate(EncoderStateError::Locked);
f(None)
}
Self::Finished(_) => {
self.invalidate(EncoderStateError::Ended);
f(None)
}
Self::Error(_) => f(None),
Self::Transitioning => unreachable!(),
} }
} }
@ -295,6 +305,17 @@ impl<'a> RecordingGuard<'a> {
} }
} }
} }
/// Special version of record used by `command_encoder_as_hal_mut`. This
/// version takes an infallible recording closure.
pub(crate) fn record_as_hal_mut<T, F: FnOnce(Option<&mut CommandBufferMutable>) -> T>(
mut self,
f: F,
) -> T {
let res = f(Some(&mut self));
self.mark_successful();
res
}
} }
impl<'a> Drop for RecordingGuard<'a> { impl<'a> Drop for RecordingGuard<'a> {
@ -899,6 +920,10 @@ pub enum CommandEncoderError {
#[error(transparent)] #[error(transparent)]
Clear(#[from] ClearError), Clear(#[from] ClearError),
#[error(transparent)] #[error(transparent)]
Query(#[from] QueryError),
#[error(transparent)]
BuildAccelerationStructure(#[from] BuildAccelerationStructureError),
#[error(transparent)]
TransitionResources(#[from] TransitionResourcesError), TransitionResources(#[from] TransitionResourcesError),
#[error( #[error(
"begin and end indices of pass timestamp writes are both set to {idx}, which is not allowed" "begin and end indices of pass timestamp writes are both set to {idx}, which is not allowed"

View File

@ -317,38 +317,36 @@ impl Global {
command_encoder_id: id::CommandEncoderId, command_encoder_id: id::CommandEncoderId,
query_set_id: id::QuerySetId, query_set_id: id::QuerySetId,
query_index: u32, query_index: u32,
) -> Result<(), QueryError> { ) -> Result<(), EncoderStateError> {
let hub = &self.hub; let hub = &self.hub;
let cmd_buf = hub let cmd_buf = hub
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(|cmd_buf_data| -> Result<(), QueryError> {
let cmd_buf_data = &mut *cmd_buf_data_guard; cmd_buf
.device
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?;
cmd_buf #[cfg(feature = "trace")]
.device if let Some(ref mut list) = cmd_buf_data.commands {
.require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS)?; list.push(TraceCommand::WriteTimestamp {
query_set_id,
query_index,
});
}
#[cfg(feature = "trace")] let raw_encoder = cmd_buf_data.encoder.open()?;
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(TraceCommand::WriteTimestamp {
query_set_id,
query_index,
});
}
let raw_encoder = cmd_buf_data.encoder.open()?; let query_set = hub.query_sets.get(query_set_id).get()?;
let query_set = hub.query_sets.get(query_set_id).get()?; query_set.validate_and_write_timestamp(raw_encoder, query_index, None)?;
query_set.validate_and_write_timestamp(raw_encoder, query_index, None)?; cmd_buf_data.trackers.query_sets.insert_single(query_set);
cmd_buf_data.trackers.query_sets.insert_single(query_set); Ok(())
})
cmd_buf_data_guard.mark_successful();
Ok(())
} }
pub fn command_encoder_resolve_query_set( pub fn command_encoder_resolve_query_set(
@ -359,136 +357,135 @@ impl Global {
query_count: u32, query_count: u32,
destination: id::BufferId, destination: id::BufferId,
destination_offset: BufferAddress, destination_offset: BufferAddress,
) -> Result<(), QueryError> { ) -> Result<(), EncoderStateError> {
let hub = &self.hub; let hub = &self.hub;
let cmd_buf = hub let cmd_buf = hub
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(|cmd_buf_data| -> Result<(), QueryError> {
let cmd_buf_data = &mut *cmd_buf_data_guard; #[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
#[cfg(feature = "trace")] list.push(TraceCommand::ResolveQuerySet {
if let Some(ref mut list) = cmd_buf_data.commands { query_set_id,
list.push(TraceCommand::ResolveQuerySet { start_query,
query_set_id,
start_query,
query_count,
destination,
destination_offset,
});
}
if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
}
let query_set = hub.query_sets.get(query_set_id).get()?;
query_set.same_device_as(cmd_buf.as_ref())?;
let dst_buffer = hub.buffers.get(destination).get()?;
dst_buffer.same_device_as(cmd_buf.as_ref())?;
let snatch_guard = dst_buffer.device.snatchable_lock.read();
dst_buffer.check_destroyed(&snatch_guard)?;
let dst_pending = cmd_buf_data
.trackers
.buffers
.set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
let dst_barrier = dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard));
dst_buffer
.check_usage(wgt::BufferUsages::QUERY_RESOLVE)
.map_err(ResolveError::MissingBufferUsage)?;
let end_query = u64::from(start_query)
.checked_add(u64::from(query_count))
.expect("`u64` overflow from adding two `u32`s, should be unreachable");
if end_query > u64::from(query_set.desc.count) {
return Err(ResolveError::QueryOverrun {
start_query,
end_query,
query_set_size: query_set.desc.count,
}
.into());
}
let end_query = u32::try_from(end_query)
.expect("`u32` overflow for `end_query`, which should be `u32`");
let elements_per_query = match query_set.desc.ty {
wgt::QueryType::Occlusion => 1,
wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
wgt::QueryType::Timestamp => 1,
};
let stride = elements_per_query * wgt::QUERY_SIZE;
let bytes_used: BufferAddress = u64::from(stride)
.checked_mul(u64::from(query_count))
.expect("`stride` * `query_count` overflowed `u32`, should be unreachable");
let buffer_start_offset = destination_offset;
let buffer_end_offset = buffer_start_offset
.checked_add(bytes_used)
.filter(|buffer_end_offset| *buffer_end_offset <= dst_buffer.size)
.ok_or(ResolveError::BufferOverrun {
start_query,
end_query,
stride,
buffer_size: dst_buffer.size,
buffer_start_offset,
bytes_used,
})?;
// TODO(https://github.com/gfx-rs/wgpu/issues/3993): Need to track initialization state.
cmd_buf_data.buffer_memory_init_actions.extend(
dst_buffer.initialization_status.read().create_action(
&dst_buffer,
buffer_start_offset..buffer_end_offset,
MemoryInitKind::ImplicitlyInitialized,
),
);
let raw_dst_buffer = dst_buffer.try_raw(&snatch_guard)?;
let raw_encoder = cmd_buf_data.encoder.open()?;
unsafe {
raw_encoder.transition_buffers(dst_barrier.as_slice());
raw_encoder.copy_query_results(
query_set.raw(),
start_query..end_query,
raw_dst_buffer,
destination_offset,
wgt::BufferSize::new_unchecked(stride as u64),
);
}
if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
// Timestamp normalization is only needed for timestamps.
cmd_buf
.device
.timestamp_normalizer
.get()
.unwrap()
.normalize(
&snatch_guard,
raw_encoder,
&mut cmd_buf_data.trackers.buffers,
dst_buffer
.timestamp_normalization_bind_group
.get(&snatch_guard)
.unwrap(),
&dst_buffer,
destination_offset,
query_count, query_count,
destination,
destination_offset,
});
}
if destination_offset % wgt::QUERY_RESOLVE_BUFFER_ALIGNMENT != 0 {
return Err(QueryError::Resolve(ResolveError::BufferOffsetAlignment));
}
let query_set = hub.query_sets.get(query_set_id).get()?;
query_set.same_device_as(cmd_buf.as_ref())?;
let dst_buffer = hub.buffers.get(destination).get()?;
dst_buffer.same_device_as(cmd_buf.as_ref())?;
let snatch_guard = dst_buffer.device.snatchable_lock.read();
dst_buffer.check_destroyed(&snatch_guard)?;
let dst_pending = cmd_buf_data
.trackers
.buffers
.set_single(&dst_buffer, wgt::BufferUses::COPY_DST);
let dst_barrier =
dst_pending.map(|pending| pending.into_hal(&dst_buffer, &snatch_guard));
dst_buffer
.check_usage(wgt::BufferUsages::QUERY_RESOLVE)
.map_err(ResolveError::MissingBufferUsage)?;
let end_query = u64::from(start_query)
.checked_add(u64::from(query_count))
.expect("`u64` overflow from adding two `u32`s, should be unreachable");
if end_query > u64::from(query_set.desc.count) {
return Err(ResolveError::QueryOverrun {
start_query,
end_query,
query_set_size: query_set.desc.count,
}
.into());
}
let end_query = u32::try_from(end_query)
.expect("`u32` overflow for `end_query`, which should be `u32`");
let elements_per_query = match query_set.desc.ty {
wgt::QueryType::Occlusion => 1,
wgt::QueryType::PipelineStatistics(ps) => ps.bits().count_ones(),
wgt::QueryType::Timestamp => 1,
};
let stride = elements_per_query * wgt::QUERY_SIZE;
let bytes_used: BufferAddress = u64::from(stride)
.checked_mul(u64::from(query_count))
.expect("`stride` * `query_count` overflowed `u32`, should be unreachable");
let buffer_start_offset = destination_offset;
let buffer_end_offset = buffer_start_offset
.checked_add(bytes_used)
.filter(|buffer_end_offset| *buffer_end_offset <= dst_buffer.size)
.ok_or(ResolveError::BufferOverrun {
start_query,
end_query,
stride,
buffer_size: dst_buffer.size,
buffer_start_offset,
bytes_used,
})?;
// TODO(https://github.com/gfx-rs/wgpu/issues/3993): Need to track initialization state.
cmd_buf_data.buffer_memory_init_actions.extend(
dst_buffer.initialization_status.read().create_action(
&dst_buffer,
buffer_start_offset..buffer_end_offset,
MemoryInitKind::ImplicitlyInitialized,
),
);
let raw_dst_buffer = dst_buffer.try_raw(&snatch_guard)?;
let raw_encoder = cmd_buf_data.encoder.open()?;
unsafe {
raw_encoder.transition_buffers(dst_barrier.as_slice());
raw_encoder.copy_query_results(
query_set.raw(),
start_query..end_query,
raw_dst_buffer,
destination_offset,
wgt::BufferSize::new_unchecked(stride as u64),
); );
} }
cmd_buf_data.trackers.query_sets.insert_single(query_set); if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
// Timestamp normalization is only needed for timestamps.
cmd_buf
.device
.timestamp_normalizer
.get()
.unwrap()
.normalize(
&snatch_guard,
raw_encoder,
&mut cmd_buf_data.trackers.buffers,
dst_buffer
.timestamp_normalization_bind_group
.get(&snatch_guard)
.unwrap(),
&dst_buffer,
destination_offset,
query_count,
);
}
cmd_buf_data_guard.mark_successful(); cmd_buf_data.trackers.query_sets.insert_single(query_set);
Ok(())
Ok(())
})
} }
} }

View File

@ -7,7 +7,6 @@ use core::{
use wgt::{math::align_to, BufferUsages, BufferUses, Features}; use wgt::{math::align_to, BufferUsages, BufferUses, Features};
use crate::device::resource::CommandIndices;
use crate::lock::RwLockWriteGuard; use crate::lock::RwLockWriteGuard;
use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError}; use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
use crate::{ use crate::{
@ -29,6 +28,7 @@ use crate::{
snatch::SnatchGuard, snatch::SnatchGuard,
track::PendingTransition, track::PendingTransition,
}; };
use crate::{command::EncoderStateError, device::resource::CommandIndices};
use crate::id::{BlasId, TlasId}; use crate::id::{BlasId, TlasId};
@ -64,7 +64,7 @@ impl Global {
command_encoder_id: CommandEncoderId, command_encoder_id: CommandEncoderId,
blas_ids: &[BlasId], blas_ids: &[BlasId],
tlas_ids: &[TlasId], tlas_ids: &[TlasId],
) -> Result<(), BuildAccelerationStructureError> { ) -> Result<(), EncoderStateError> {
profiling::scope!("CommandEncoder::mark_acceleration_structures_built"); profiling::scope!("CommandEncoder::mark_acceleration_structures_built");
let hub = &self.hub; let hub = &self.hub;
@ -73,34 +73,32 @@ impl Global {
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let device = &cmd_buf.device;
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let mut build_command = AsBuild::default();
for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
build_command.blas_s_built.push(blas);
}
for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
build_command.tlas_s_built.push(TlasBuild {
tlas,
dependencies: Vec::new(),
});
}
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(
let cmd_buf_data = &mut *cmd_buf_data_guard; |cmd_buf_data| -> Result<(), BuildAccelerationStructureError> {
let device = &cmd_buf.device;
device
.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
cmd_buf_data.as_actions.push(AsAction::Build(build_command)); let mut build_command = AsBuild::default();
cmd_buf_data_guard.mark_successful(); for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
build_command.blas_s_built.push(blas);
}
Ok(()) for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
build_command.tlas_s_built.push(TlasBuild {
tlas,
dependencies: Vec::new(),
});
}
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
Ok(())
},
)
} }
pub fn command_encoder_build_acceleration_structures<'a>( pub fn command_encoder_build_acceleration_structures<'a>(
@ -108,7 +106,7 @@ impl Global {
command_encoder_id: CommandEncoderId, command_encoder_id: CommandEncoderId,
blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>, blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
tlas_iter: impl Iterator<Item = TlasPackage<'a>>, tlas_iter: impl Iterator<Item = TlasPackage<'a>>,
) -> Result<(), BuildAccelerationStructureError> { ) -> Result<(), EncoderStateError> {
profiling::scope!("CommandEncoder::build_acceleration_structures"); profiling::scope!("CommandEncoder::build_acceleration_structures");
let hub = &self.hub; let hub = &self.hub;
@ -117,10 +115,6 @@ impl Global {
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let device = &cmd_buf.device;
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let mut build_command = AsBuild::default(); let mut build_command = AsBuild::default();
let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
@ -171,14 +165,6 @@ impl Global {
}) })
.collect(); .collect();
#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf.data.lock().get_inner().commands {
list.push(crate::device::trace::Command::BuildAccelerationStructures {
blas: trace_blas.clone(),
tlas: trace_tlas.clone(),
});
}
let blas_iter = trace_blas.iter().map(|blas_entry| { let blas_iter = trace_blas.iter().map(|blas_entry| {
let geometries = match &blas_entry.geometries { let geometries = match &blas_entry.geometries {
TraceBlasGeometries::TriangleGeometries(triangle_geometries) => { TraceBlasGeometries::TriangleGeometries(triangle_geometries) => {
@ -217,298 +203,305 @@ impl Global {
} }
}); });
let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
let mut buf_storage = Vec::new();
let mut scratch_buffer_blas_size = 0;
let mut blas_storage = Vec::new();
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(|cmd_buf_data| {
let cmd_buf_data = &mut *cmd_buf_data_guard; #[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
list.push(crate::device::trace::Command::BuildAccelerationStructures {
blas: trace_blas.clone(),
tlas: trace_tlas.clone(),
});
}
iter_blas( let device = &cmd_buf.device;
blas_iter, device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
cmd_buf_data,
&mut build_command,
&mut buf_storage,
hub,
)?;
let snatch_guard = device.snatchable_lock.read(); let mut buf_storage = Vec::new();
iter_buffers( iter_blas(
&mut buf_storage, blas_iter,
&snatch_guard, cmd_buf_data,
&mut input_barriers, &mut build_command,
cmd_buf_data, &mut buf_storage,
&mut scratch_buffer_blas_size, hub,
&mut blas_storage, )?;
hub,
device.alignments.ray_tracing_scratch_buffer_alignment,
)?;
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
for package in tlas_iter { let snatch_guard = device.snatchable_lock.read();
let tlas = hub.tlas_s.get(package.tlas_id).get()?; let mut input_barriers = Vec::<hal::BufferBarrier<dyn hal::DynBuffer>>::new();
let mut scratch_buffer_blas_size = 0;
cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone()); let mut blas_storage = Vec::new();
iter_buffers(
tlas_lock_store.push((Some(package), tlas)) &mut buf_storage,
} &snatch_guard,
&mut input_barriers,
let mut scratch_buffer_tlas_size = 0; cmd_buf_data,
let mut tlas_storage = Vec::<TlasStore>::new(); &mut scratch_buffer_blas_size,
let mut instance_buffer_staging_source = Vec::<u8>::new(); &mut blas_storage,
hub,
for (package, tlas) in &mut tlas_lock_store {
let package = package.take().unwrap();
let scratch_buffer_offset = scratch_buffer_tlas_size;
scratch_buffer_tlas_size += align_to(
tlas.size_info.build_scratch_size as u32,
device.alignments.ray_tracing_scratch_buffer_alignment, device.alignments.ray_tracing_scratch_buffer_alignment,
) as u64; )?;
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
let first_byte_index = instance_buffer_staging_source.len(); for package in tlas_iter {
let tlas = hub.tlas_s.get(package.tlas_id).get()?;
let mut dependencies = Vec::new(); cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
let mut instance_count = 0; tlas_lock_store.push((Some(package), tlas))
for instance in package.instances.flatten() { }
if instance.custom_data >= (1u32 << 24u32) {
return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex( let mut scratch_buffer_tlas_size = 0;
let mut tlas_storage = Vec::<TlasStore>::new();
let mut instance_buffer_staging_source = Vec::<u8>::new();
for (package, tlas) in &mut tlas_lock_store {
let package = package.take().unwrap();
let scratch_buffer_offset = scratch_buffer_tlas_size;
scratch_buffer_tlas_size += align_to(
tlas.size_info.build_scratch_size as u32,
device.alignments.ray_tracing_scratch_buffer_alignment,
) as u64;
let first_byte_index = instance_buffer_staging_source.len();
let mut dependencies = Vec::new();
let mut instance_count = 0;
for instance in package.instances.flatten() {
if instance.custom_data >= (1u32 << 24u32) {
return Err(BuildAccelerationStructureError::TlasInvalidCustomIndex(
tlas.error_ident(),
));
}
let blas = hub.blas_s.get(instance.blas_id).get()?;
cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes(
hal::TlasInstance {
transform: *instance.transform,
custom_data: instance.custom_data,
mask: instance.mask,
blas_address: blas.handle,
},
));
if tlas.flags.contains(
wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
) && !blas.flags.contains(
wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
) {
return Err(
BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
tlas.error_ident(),
blas.error_ident(),
),
);
}
instance_count += 1;
dependencies.push(blas.clone());
}
build_command.tlas_s_built.push(TlasBuild {
tlas: tlas.clone(),
dependencies,
});
if instance_count > tlas.max_instance_count {
return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
tlas.error_ident(), tlas.error_ident(),
instance_count,
tlas.max_instance_count,
)); ));
} }
let blas = hub.blas_s.get(instance.blas_id).get()?;
cmd_buf_data.trackers.blas_s.insert_single(blas.clone()); tlas_storage.push(TlasStore {
internal: UnsafeTlasStore {
instance_buffer_staging_source.extend(device.raw().tlas_instance_to_bytes( tlas: tlas.clone(),
hal::TlasInstance { entries: hal::AccelerationStructureEntries::Instances(
transform: *instance.transform, hal::AccelerationStructureInstances {
custom_data: instance.custom_data, buffer: Some(tlas.instance_buffer.as_ref()),
mask: instance.mask, offset: 0,
blas_address: blas.handle, count: instance_count,
}, },
));
if tlas
.flags
.contains(wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN)
&& !blas.flags.contains(
wgpu_types::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN,
)
{
return Err(
BuildAccelerationStructureError::TlasDependentMissingVertexReturn(
tlas.error_ident(),
blas.error_ident(),
), ),
); scratch_buffer_offset,
} },
range: first_byte_index..instance_buffer_staging_source.len(),
instance_count += 1; });
dependencies.push(blas.clone());
} }
build_command.tlas_s_built.push(TlasBuild { let Some(scratch_size) =
tlas: tlas.clone(), wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size))
dependencies, else {
});
if instance_count > tlas.max_instance_count {
return Err(BuildAccelerationStructureError::TlasInstanceCountExceeded(
tlas.error_ident(),
instance_count,
tlas.max_instance_count,
));
}
tlas_storage.push(TlasStore {
internal: UnsafeTlasStore {
tlas: tlas.clone(),
entries: hal::AccelerationStructureEntries::Instances(
hal::AccelerationStructureInstances {
buffer: Some(tlas.instance_buffer.as_ref()),
offset: 0,
count: instance_count,
},
),
scratch_buffer_offset,
},
range: first_byte_index..instance_buffer_staging_source.len(),
});
}
let scratch_size =
match wgt::BufferSize::new(max(scratch_buffer_blas_size, scratch_buffer_tlas_size)) {
// if the size is zero there is nothing to build // if the size is zero there is nothing to build
None => { return Ok(());
cmd_buf_data_guard.mark_successful();
return Ok(());
}
Some(size) => size,
}; };
let scratch_buffer = ScratchBuffer::new(device, scratch_size)?; let scratch_buffer = ScratchBuffer::new(device, scratch_size)?;
let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> { let scratch_buffer_barrier = hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: scratch_buffer.raw(), buffer: scratch_buffer.raw(),
usage: hal::StateTransition { usage: hal::StateTransition {
from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH, from: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH, to: BufferUses::ACCELERATION_STRUCTURE_SCRATCH,
},
};
let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
for &TlasStore {
internal:
UnsafeTlasStore {
ref tlas,
ref entries,
ref scratch_buffer_offset,
}, },
..
} in &tlas_storage
{
if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
log::info!("only rebuild implemented")
}
tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
entries,
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
}
let blas_present = !blas_storage.is_empty();
let tlas_present = !tlas_storage.is_empty();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
let mut blas_s_compactable = Vec::new();
let mut descriptors = Vec::new();
for storage in &blas_storage {
descriptors.push(map_blas(
storage,
scratch_buffer.raw(),
&snatch_guard,
&mut blas_s_compactable,
)?);
}
build_blas(
cmd_buf_raw,
blas_present,
tlas_present,
input_barriers,
&descriptors,
scratch_buffer_barrier,
blas_s_compactable,
);
if tlas_present {
let staging_buffer = if !instance_buffer_staging_source.is_empty() {
let mut staging_buffer = StagingBuffer::new(
device,
wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
)?;
staging_buffer.write(&instance_buffer_staging_source);
let flushed = staging_buffer.flush();
Some(flushed)
} else {
None
}; };
unsafe { let mut tlas_descriptors = Vec::with_capacity(tlas_storage.len());
if let Some(ref staging_buffer) = staging_buffer {
cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: staging_buffer.raw(),
usage: hal::StateTransition {
from: BufferUses::MAP_WRITE,
to: BufferUses::COPY_SRC,
},
}]);
}
}
let mut instance_buffer_barriers = Vec::new();
for &TlasStore { for &TlasStore {
internal: UnsafeTlasStore { ref tlas, .. }, internal:
ref range, UnsafeTlasStore {
ref tlas,
ref entries,
ref scratch_buffer_offset,
},
..
} in &tlas_storage } in &tlas_storage
{ {
let size = match wgt::BufferSize::new((range.end - range.start) as u64) { if tlas.update_mode == wgt::AccelerationStructureUpdateMode::PreferUpdate {
None => continue, log::info!("only rebuild implemented")
Some(size) => size, }
tlas_descriptors.push(hal::BuildAccelerationStructureDescriptor {
entries,
mode: hal::AccelerationStructureBuildMode::Build,
flags: tlas.flags,
source_acceleration_structure: None,
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
scratch_buffer: scratch_buffer.raw(),
scratch_buffer_offset: *scratch_buffer_offset,
})
}
let blas_present = !blas_storage.is_empty();
let tlas_present = !tlas_storage.is_empty();
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
let mut blas_s_compactable = Vec::new();
let mut descriptors = Vec::new();
for storage in &blas_storage {
descriptors.push(map_blas(
storage,
scratch_buffer.raw(),
&snatch_guard,
&mut blas_s_compactable,
)?);
}
build_blas(
cmd_buf_raw,
blas_present,
tlas_present,
input_barriers,
&descriptors,
scratch_buffer_barrier,
blas_s_compactable,
);
if tlas_present {
let staging_buffer = if !instance_buffer_staging_source.is_empty() {
let mut staging_buffer = StagingBuffer::new(
device,
wgt::BufferSize::new(instance_buffer_staging_source.len() as u64).unwrap(),
)?;
staging_buffer.write(&instance_buffer_staging_source);
let flushed = staging_buffer.flush();
Some(flushed)
} else {
None
}; };
instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: tlas.instance_buffer.as_ref(),
usage: hal::StateTransition {
from: BufferUses::COPY_DST,
to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
},
});
unsafe { unsafe {
cmd_buf_raw.transition_buffers(&[hal::BufferBarrier::<dyn hal::DynBuffer> { if let Some(ref staging_buffer) = staging_buffer {
cmd_buf_raw.transition_buffers(&[
hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: staging_buffer.raw(),
usage: hal::StateTransition {
from: BufferUses::MAP_WRITE,
to: BufferUses::COPY_SRC,
},
},
]);
}
}
let mut instance_buffer_barriers = Vec::new();
for &TlasStore {
internal: UnsafeTlasStore { ref tlas, .. },
ref range,
} in &tlas_storage
{
let size = match wgt::BufferSize::new((range.end - range.start) as u64) {
None => continue,
Some(size) => size,
};
instance_buffer_barriers.push(hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: tlas.instance_buffer.as_ref(), buffer: tlas.instance_buffer.as_ref(),
usage: hal::StateTransition { usage: hal::StateTransition {
from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT, from: BufferUses::COPY_DST,
to: BufferUses::COPY_DST, to: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
},
});
unsafe {
cmd_buf_raw.transition_buffers(&[
hal::BufferBarrier::<dyn hal::DynBuffer> {
buffer: tlas.instance_buffer.as_ref(),
usage: hal::StateTransition {
from: BufferUses::TOP_LEVEL_ACCELERATION_STRUCTURE_INPUT,
to: BufferUses::COPY_DST,
},
},
]);
let temp = hal::BufferCopy {
src_offset: range.start as u64,
dst_offset: 0,
size,
};
cmd_buf_raw.copy_buffer_to_buffer(
// the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
// and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
staging_buffer.as_ref().unwrap().raw(),
tlas.instance_buffer.as_ref(),
&[temp],
);
}
}
unsafe {
cmd_buf_raw.transition_buffers(&instance_buffer_barriers);
cmd_buf_raw.build_acceleration_structures(&tlas_descriptors);
cmd_buf_raw.place_acceleration_structure_barrier(
hal::AccelerationStructureBarrier {
usage: hal::StateTransition {
from: hal::AccelerationStructureUses::BUILD_OUTPUT,
to: hal::AccelerationStructureUses::SHADER_INPUT,
},
}, },
}]);
let temp = hal::BufferCopy {
src_offset: range.start as u64,
dst_offset: 0,
size,
};
cmd_buf_raw.copy_buffer_to_buffer(
// the range whose size we just checked end is at (at that point in time) instance_buffer_staging_source.len()
// and since instance_buffer_staging_source doesn't shrink we can un wrap this without a panic
staging_buffer.as_ref().unwrap().raw(),
tlas.instance_buffer.as_ref(),
&[temp],
); );
} }
if let Some(staging_buffer) = staging_buffer {
cmd_buf_data
.temp_resources
.push(TempResource::StagingBuffer(staging_buffer));
}
} }
unsafe { cmd_buf_data
cmd_buf_raw.transition_buffers(&instance_buffer_barriers); .temp_resources
.push(TempResource::ScratchBuffer(scratch_buffer));
cmd_buf_raw.build_acceleration_structures(&tlas_descriptors); cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_raw.place_acceleration_structure_barrier( Ok(())
hal::AccelerationStructureBarrier { })
usage: hal::StateTransition {
from: hal::AccelerationStructureUses::BUILD_OUTPUT,
to: hal::AccelerationStructureUses::SHADER_INPUT,
},
},
);
}
if let Some(staging_buffer) = staging_buffer {
cmd_buf_data
.temp_resources
.push(TempResource::StagingBuffer(staging_buffer));
}
}
cmd_buf_data
.temp_resources
.push(TempResource::ScratchBuffer(scratch_buffer));
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data_guard.mark_successful();
Ok(())
} }
} }

View File

@ -1,7 +1,7 @@
use thiserror::Error; use thiserror::Error;
use crate::{ use crate::{
command::{CommandBuffer, EncoderStateError}, command::{CommandBuffer, CommandEncoderError, EncoderStateError},
device::DeviceError, device::DeviceError,
global::Global, global::Global,
id::{BufferId, CommandEncoderId, TextureId}, id::{BufferId, CommandEncoderId, TextureId},
@ -15,7 +15,7 @@ impl Global {
command_encoder_id: CommandEncoderId, command_encoder_id: CommandEncoderId,
buffer_transitions: impl Iterator<Item = wgt::BufferTransition<BufferId>>, buffer_transitions: impl Iterator<Item = wgt::BufferTransition<BufferId>>,
texture_transitions: impl Iterator<Item = wgt::TextureTransition<TextureId>>, texture_transitions: impl Iterator<Item = wgt::TextureTransition<TextureId>>,
) -> Result<(), TransitionResourcesError> { ) -> Result<(), EncoderStateError> {
profiling::scope!("CommandEncoder::transition_resources"); profiling::scope!("CommandEncoder::transition_resources");
let hub = &self.hub; let hub = &self.hub;
@ -25,54 +25,51 @@ impl Global {
.command_buffers .command_buffers
.get(command_encoder_id.into_command_buffer_id()); .get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; cmd_buf_data.record_with(|cmd_buf_data| -> Result<(), CommandEncoderError> {
let cmd_buf_data = &mut *cmd_buf_data_guard; // Get and lock device
let device = &cmd_buf.device;
device.check_is_valid()?;
let snatch_guard = &device.snatchable_lock.read();
// Get and lock device let mut usage_scope = device.new_usage_scope();
let device = &cmd_buf.device; let indices = &device.tracker_indices;
device.check_is_valid()?; usage_scope.buffers.set_size(indices.buffers.size());
let snatch_guard = &device.snatchable_lock.read(); usage_scope.textures.set_size(indices.textures.size());
let mut usage_scope = device.new_usage_scope(); // Process buffer transitions
let indices = &device.tracker_indices; for buffer_transition in buffer_transitions {
usage_scope.buffers.set_size(indices.buffers.size()); let buffer = hub.buffers.get(buffer_transition.buffer).get()?;
usage_scope.textures.set_size(indices.textures.size()); buffer.same_device_as(cmd_buf.as_ref())?;
// Process buffer transitions usage_scope
for buffer_transition in buffer_transitions { .buffers
let buffer = hub.buffers.get(buffer_transition.buffer).get()?; .merge_single(&buffer, buffer_transition.state)?;
buffer.same_device_as(cmd_buf.as_ref())?; }
usage_scope // Process texture transitions
.buffers for texture_transition in texture_transitions {
.merge_single(&buffer, buffer_transition.state)?; let texture = hub.textures.get(texture_transition.texture).get()?;
} texture.same_device_as(cmd_buf.as_ref())?;
// Process texture transitions unsafe {
for texture_transition in texture_transitions { usage_scope.textures.merge_single(
let texture = hub.textures.get(texture_transition.texture).get()?; &texture,
texture.same_device_as(cmd_buf.as_ref())?; texture_transition.selector,
texture_transition.state,
)
}?;
}
unsafe { // Record any needed barriers based on tracker data
usage_scope.textures.merge_single( let cmd_buf_raw = cmd_buf_data.encoder.open()?;
&texture, CommandBuffer::insert_barriers_from_scope(
texture_transition.selector, cmd_buf_raw,
texture_transition.state, &mut cmd_buf_data.trackers,
) &usage_scope,
}?; snatch_guard,
} );
Ok(())
// Record any needed barriers based on tracker data })
let cmd_buf_raw = cmd_buf_data.encoder.open()?;
CommandBuffer::insert_barriers_from_scope(
cmd_buf_raw,
&mut cmd_buf_data.trackers,
&usage_scope,
snatch_guard,
);
cmd_buf_data_guard.mark_successful();
Ok(())
} }
} }

View File

@ -1383,20 +1383,15 @@ impl Global {
let cmd_buf = hub.command_buffers.get(id.into_command_buffer_id()); let cmd_buf = hub.command_buffers.get(id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data_guard = cmd_buf_data.record(); cmd_buf_data.record_as_hal_mut(|opt_cmd_buf| -> R {
hal_command_encoder_callback(opt_cmd_buf.and_then(|cmd_buf| {
if let Ok(mut cmd_buf_data_guard) = cmd_buf_data_guard { cmd_buf
let cmd_buf_raw = cmd_buf_data_guard .encoder
.encoder .open()
.open() .ok()
.ok() .and_then(|encoder| encoder.as_any_mut().downcast_mut())
.and_then(|encoder| encoder.as_any_mut().downcast_mut()); }))
let ret = hal_command_encoder_callback(cmd_buf_raw); })
cmd_buf_data_guard.mark_successful();
ret
} else {
hal_command_encoder_callback(None)
}
} }
/// # Safety /// # Safety