Rework acceleration structure build tracking.

This commit is contained in:
Vecvec 2025-04-10 13:31:00 +12:00 committed by Connor Fitzgerald
parent 382a1e3c9b
commit 8010203281
8 changed files with 187 additions and 169 deletions

View File

@ -295,6 +295,7 @@ By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
- Reduce downlevel `max_color_attachments` limit from 8 to 4 for better GLES compatibility. By @adrian17 in [#6994](https://github.com/gfx-rs/wgpu/pull/6994). - Reduce downlevel `max_color_attachments` limit from 8 to 4 for better GLES compatibility. By @adrian17 in [#6994](https://github.com/gfx-rs/wgpu/pull/6994).
- Fix building a BLAS with a transform buffer by adding a flag to indicate usage of the transform buffer. By @Vecvec in - Fix building a BLAS with a transform buffer by adding a flag to indicate usage of the transform buffer. By @Vecvec in
[#7062](https://github.com/gfx-rs/wgpu/pull/7062). [#7062](https://github.com/gfx-rs/wgpu/pull/7062).
- Move incrementation of `Device::last_acceleration_structure_build_command_index` into queue submit. By @Vecvec in [#7462](https://github.com/gfx-rs/wgpu/pull/7462).
#### Vulkan #### Vulkan

View File

@ -234,6 +234,80 @@ fn out_of_order_as_build_use(ctx: TestingContext) {
}, },
None, None,
); );
let as_ctx = AsBuildContext::new(
&ctx,
AccelerationStructureFlags::empty(),
AccelerationStructureFlags::empty(),
);
//
// Build in the right order, then rebuild the BLAS so the TLAS is invalid, then use the TLAS.
//
let mut encoder_blas = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("BLAS 3"),
});
encoder_blas.build_acceleration_structures([&as_ctx.blas_build_entry()], []);
let mut encoder_blas2 = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("BLAS 4"),
});
encoder_blas2.build_acceleration_structures([&as_ctx.blas_build_entry()], []);
let mut encoder_tlas = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("TLAS 2"),
});
encoder_tlas.build_acceleration_structures([], [&as_ctx.tlas_package]);
ctx.queue.submit([
encoder_blas.finish(),
encoder_tlas.finish(),
encoder_blas2.finish(),
]);
let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &compute_pipeline.get_bind_group_layout(0),
entries: &[BindGroupEntry {
binding: 0,
resource: BindingResource::AccelerationStructure(as_ctx.tlas_package.tlas()),
}],
});
//
// Use TLAS
//
let mut encoder_compute = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor::default());
{
let mut pass = encoder_compute.begin_compute_pass(&ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
pass.set_pipeline(&compute_pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1)
}
fail(
&ctx.device,
|| {
ctx.queue.submit(Some(encoder_compute.finish()));
},
None,
);
} }
#[gpu_test] #[gpu_test]

View File

@ -4,6 +4,7 @@ use wgt::{BufferAddress, DynamicOffset};
use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec}; use alloc::{borrow::Cow, boxed::Box, sync::Arc, vec::Vec};
use core::{fmt, str}; use core::{fmt, str};
use crate::ray_tracing::AsAction;
use crate::{ use crate::{
binding_model::{ binding_model::{
BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
@ -24,7 +25,6 @@ use crate::{
hal_label, id, hal_label, id,
init_tracker::{BufferInitTrackerAction, MemoryInitKind}, init_tracker::{BufferInitTrackerAction, MemoryInitKind},
pipeline::ComputePipeline, pipeline::ComputePipeline,
ray_tracing::TlasAction,
resource::{ resource::{
self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled, self, Buffer, DestroyedResourceError, InvalidResourceError, Labeled,
MissingBufferUsageError, ParentDevice, MissingBufferUsageError, ParentDevice,
@ -208,7 +208,7 @@ struct State<'scope, 'snatch_guard, 'cmd_buf, 'raw_encoder> {
tracker: &'cmd_buf mut Tracker, tracker: &'cmd_buf mut Tracker,
buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>, buffer_memory_init_actions: &'cmd_buf mut Vec<BufferInitTrackerAction>,
texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions, texture_memory_actions: &'cmd_buf mut CommandBufferTextureMemoryActions,
tlas_actions: &'cmd_buf mut Vec<TlasAction>, as_actions: &'cmd_buf mut Vec<AsAction>,
temp_offsets: Vec<u32>, temp_offsets: Vec<u32>,
dynamic_offset_count: usize, dynamic_offset_count: usize,
@ -433,7 +433,7 @@ impl Global {
tracker: &mut cmd_buf_data.trackers, tracker: &mut cmd_buf_data.trackers,
buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions, buffer_memory_init_actions: &mut cmd_buf_data.buffer_memory_init_actions,
texture_memory_actions: &mut cmd_buf_data.texture_memory_actions, texture_memory_actions: &mut cmd_buf_data.texture_memory_actions,
tlas_actions: &mut cmd_buf_data.tlas_actions, as_actions: &mut cmd_buf_data.as_actions,
temp_offsets: Vec::new(), temp_offsets: Vec::new(),
dynamic_offset_count: 0, dynamic_offset_count: 0,
@ -680,12 +680,9 @@ fn set_bind_group(
.used .used
.acceleration_structures .acceleration_structures
.into_iter() .into_iter()
.map(|tlas| TlasAction { .map(|tlas| AsAction::UseTlas(tlas.clone()));
tlas: tlas.clone(),
kind: crate::ray_tracing::TlasActionKind::Use,
});
state.tlas_actions.extend(used_resource); state.as_actions.extend(used_resource);
let pipeline_layout = state.binder.pipeline_layout.clone(); let pipeline_layout = state.binder.pipeline_layout.clone();
let entries = state let entries = state

View File

@ -36,7 +36,7 @@ use crate::lock::{rank, Mutex};
use crate::snatch::SnatchGuard; use crate::snatch::SnatchGuard;
use crate::init_tracker::BufferInitTrackerAction; use crate::init_tracker::BufferInitTrackerAction;
use crate::ray_tracing::{BlasAction, TlasAction}; use crate::ray_tracing::AsAction;
use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet}; use crate::resource::{Fallible, InvalidResourceError, Labeled, ParentDevice as _, QuerySet};
use crate::storage::Storage; use crate::storage::Storage;
use crate::track::{DeviceTracker, Tracker, UsageScope}; use crate::track::{DeviceTracker, Tracker, UsageScope};
@ -463,8 +463,7 @@ pub struct CommandBufferMutable {
pub(crate) pending_query_resets: QueryResetMap, pub(crate) pending_query_resets: QueryResetMap,
blas_actions: Vec<BlasAction>, as_actions: Vec<AsAction>,
tlas_actions: Vec<TlasAction>,
temp_resources: Vec<TempResource>, temp_resources: Vec<TempResource>,
indirect_draw_validation_resources: crate::indirect_validation::DrawResources, indirect_draw_validation_resources: crate::indirect_validation::DrawResources,
@ -553,8 +552,7 @@ impl CommandBuffer {
buffer_memory_init_actions: Default::default(), buffer_memory_init_actions: Default::default(),
texture_memory_actions: Default::default(), texture_memory_actions: Default::default(),
pending_query_resets: QueryResetMap::new(), pending_query_resets: QueryResetMap::new(),
blas_actions: Default::default(), as_actions: Default::default(),
tlas_actions: Default::default(),
temp_resources: Default::default(), temp_resources: Default::default(),
indirect_draw_validation_resources: indirect_draw_validation_resources:
crate::indirect_validation::DrawResources::new(device.clone()), crate::indirect_validation::DrawResources::new(device.clone()),

View File

@ -3,11 +3,13 @@ use core::{
cmp::max, cmp::max,
num::NonZeroU64, num::NonZeroU64,
ops::{Deref, Range}, ops::{Deref, Range},
sync::atomic::Ordering,
}; };
use wgt::{math::align_to, BufferUsages, BufferUses, Features}; use wgt::{math::align_to, BufferUsages, BufferUses, Features};
use crate::device::resource::CommandIndices;
use crate::lock::RwLockWriteGuard;
use crate::ray_tracing::{AsAction, AsBuild, TlasBuild, ValidateAsActionsError};
use crate::{ use crate::{
command::CommandBufferMutable, command::CommandBufferMutable,
device::queue::TempResource, device::queue::TempResource,
@ -16,16 +18,14 @@ use crate::{
id::CommandEncoderId, id::CommandEncoderId,
init_tracker::MemoryInitKind, init_tracker::MemoryInitKind,
ray_tracing::{ ray_tracing::{
BlasAction, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BlasBuildEntry, BlasGeometries, BlasTriangleGeometry, BuildAccelerationStructureError,
BuildAccelerationStructureError, TlasAction, TlasBuildEntry, TlasInstance, TlasPackage, TlasBuildEntry, TlasInstance, TlasPackage, TraceBlasBuildEntry, TraceBlasGeometries,
TraceBlasBuildEntry, TraceBlasGeometries, TraceBlasTriangleGeometry, TraceTlasInstance, TraceBlasTriangleGeometry, TraceTlasInstance, TraceTlasPackage,
TraceTlasPackage, ValidateBlasActionsError, ValidateTlasActionsError,
}, },
resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas, Trackable}, resource::{AccelerationStructure, Blas, Buffer, Labeled, StagingBuffer, Tlas},
scratch::ScratchBuffer, scratch::ScratchBuffer,
snatch::SnatchGuard, snatch::SnatchGuard,
track::PendingTransition, track::PendingTransition,
FastHashSet,
}; };
use crate::id::{BlasId, TlasId}; use crate::id::{BlasId, TlasId};
@ -81,39 +81,28 @@ impl Global {
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?; device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new( let mut build_command = AsBuild::default();
device
.last_acceleration_structure_build_command_index for blas in blas_ids {
.fetch_add(1, Ordering::Relaxed), let blas = hub.blas_s.get(*blas).get()?;
) build_command.blas_s_built.push(blas);
.unwrap(); }
for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
build_command.tlas_s_built.push(TlasBuild {
tlas,
dependencies: Vec::new(),
});
}
let mut cmd_buf_data = cmd_buf.data.lock(); let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.record()?; let mut cmd_buf_data_guard = cmd_buf_data.record()?;
let cmd_buf_data = &mut *cmd_buf_data_guard; let cmd_buf_data = &mut *cmd_buf_data_guard;
cmd_buf_data.blas_actions.reserve(blas_ids.len()); cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data.tlas_actions.reserve(tlas_ids.len()); cmd_buf_data_guard.mark_successful();
for blas in blas_ids {
let blas = hub.blas_s.get(*blas).get()?;
cmd_buf_data.blas_actions.push(BlasAction {
blas,
kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
});
}
for tlas in tlas_ids {
let tlas = hub.tlas_s.get(*tlas).get()?;
cmd_buf_data.tlas_actions.push(TlasAction {
tlas,
kind: crate::ray_tracing::TlasActionKind::Build {
build_index: build_command_index,
dependencies: Vec::new(),
},
});
}
Ok(()) Ok(())
} }
@ -139,12 +128,7 @@ impl Global {
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?; device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new( let mut build_command = AsBuild::default();
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
@ -227,7 +211,7 @@ impl Global {
iter_blas( iter_blas(
blas_iter, blas_iter,
cmd_buf_data, cmd_buf_data,
build_command_index, &mut build_command,
&mut buf_storage, &mut buf_storage,
hub, hub,
)?; )?;
@ -281,12 +265,9 @@ impl Global {
let tlas = hub.tlas_s.get(entry.tlas_id).get()?; let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone()); cmd_buf_data.trackers.tlas_s.insert_single(tlas.clone());
cmd_buf_data.tlas_actions.push(TlasAction { build_command.tlas_s_built.push(TlasBuild {
tlas: tlas.clone(), tlas: tlas.clone(),
kind: crate::ray_tracing::TlasActionKind::Build { dependencies: Vec::new(),
build_index: build_command_index,
dependencies: Vec::new(),
},
}); });
let scratch_buffer_offset = scratch_buffer_tlas_size; let scratch_buffer_offset = scratch_buffer_tlas_size;
@ -388,6 +369,8 @@ impl Global {
.temp_resources .temp_resources
.push(TempResource::ScratchBuffer(scratch_buffer)); .push(TempResource::ScratchBuffer(scratch_buffer));
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data_guard.mark_successful(); cmd_buf_data_guard.mark_successful();
Ok(()) Ok(())
} }
@ -410,12 +393,7 @@ impl Global {
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?; device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
let build_command_index = NonZeroU64::new( let mut build_command = AsBuild::default();
device
.last_acceleration_structure_build_command_index
.fetch_add(1, Ordering::Relaxed),
)
.unwrap();
let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter let trace_blas: Vec<TraceBlasBuildEntry> = blas_iter
.map(|blas_entry| { .map(|blas_entry| {
@ -523,7 +501,7 @@ impl Global {
iter_blas( iter_blas(
blas_iter, blas_iter,
cmd_buf_data, cmd_buf_data,
build_command_index, &mut build_command,
&mut buf_storage, &mut buf_storage,
hub, hub,
)?; )?;
@ -604,19 +582,11 @@ impl Global {
instance_count += 1; instance_count += 1;
dependencies.push(blas.clone()); dependencies.push(blas.clone());
cmd_buf_data.blas_actions.push(BlasAction {
blas,
kind: crate::ray_tracing::BlasActionKind::Use,
});
} }
cmd_buf_data.tlas_actions.push(TlasAction { build_command.tlas_s_built.push(TlasBuild {
tlas: tlas.clone(), tlas: tlas.clone(),
kind: crate::ray_tracing::TlasActionKind::Build { dependencies,
build_index: build_command_index,
dependencies,
},
}); });
if instance_count > tlas.max_instance_count { if instance_count > tlas.max_instance_count {
@ -800,72 +770,69 @@ impl Global {
.temp_resources .temp_resources
.push(TempResource::ScratchBuffer(scratch_buffer)); .push(TempResource::ScratchBuffer(scratch_buffer));
cmd_buf_data.as_actions.push(AsAction::Build(build_command));
cmd_buf_data_guard.mark_successful(); cmd_buf_data_guard.mark_successful();
Ok(()) Ok(())
} }
} }
impl CommandBufferMutable { impl CommandBufferMutable {
// makes sure a blas is build before it is used pub(crate) fn validate_acceleration_structure_actions(
pub(crate) fn validate_blas_actions(&self) -> Result<(), ValidateBlasActionsError> {
profiling::scope!("CommandEncoder::[submission]::validate_blas_actions");
let mut built = FastHashSet::default();
for action in &self.blas_actions {
match &action.kind {
crate::ray_tracing::BlasActionKind::Build(id) => {
built.insert(action.blas.tracker_index());
*action.blas.built_index.write() = Some(*id);
}
crate::ray_tracing::BlasActionKind::Use => {
if !built.contains(&action.blas.tracker_index())
&& (*action.blas.built_index.read()).is_none()
{
return Err(ValidateBlasActionsError::UsedUnbuilt(
action.blas.error_ident(),
));
}
}
}
}
Ok(())
}
// makes sure a tlas is built before it is used
pub(crate) fn validate_tlas_actions(
&self, &self,
snatch_guard: &SnatchGuard, snatch_guard: &SnatchGuard,
) -> Result<(), ValidateTlasActionsError> { command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
profiling::scope!("CommandEncoder::[submission]::validate_tlas_actions"); ) -> Result<(), ValidateAsActionsError> {
for action in &self.tlas_actions { profiling::scope!("CommandEncoder::[submission]::validate_as_actions");
match &action.kind { for action in &self.as_actions {
crate::ray_tracing::TlasActionKind::Build { match action {
build_index, AsAction::Build(build) => {
dependencies, let build_command_index = NonZeroU64::new(
} => { command_index_guard.next_acceleration_structure_build_command_index,
*action.tlas.built_index.write() = Some(*build_index); )
action.tlas.dependencies.write().clone_from(dependencies); .unwrap();
command_index_guard.next_acceleration_structure_build_command_index += 1;
for blas in build.blas_s_built.iter() {
*blas.built_index.write() = Some(build_command_index);
}
for tlas_build in build.tlas_s_built.iter() {
for blas in &tlas_build.dependencies {
if blas.built_index.read().is_none() {
return Err(ValidateAsActionsError::UsedUnbuiltBlas(
blas.error_ident(),
tlas_build.tlas.error_ident(),
));
}
}
*tlas_build.tlas.built_index.write() = Some(build_command_index);
tlas_build
.tlas
.dependencies
.write()
.clone_from(&tlas_build.dependencies)
}
} }
crate::ray_tracing::TlasActionKind::Use => { AsAction::UseTlas(tlas) => {
let tlas_build_index = action.tlas.built_index.read(); let tlas_build_index = tlas.built_index.read();
let dependencies = action.tlas.dependencies.read(); let dependencies = tlas.dependencies.read();
if (*tlas_build_index).is_none() { if (*tlas_build_index).is_none() {
return Err(ValidateTlasActionsError::UsedUnbuilt( return Err(ValidateAsActionsError::UsedUnbuiltTlas(tlas.error_ident()));
action.tlas.error_ident(),
));
} }
for blas in dependencies.deref() { for blas in dependencies.deref() {
let blas_build_index = *blas.built_index.read(); let blas_build_index = *blas.built_index.read();
if blas_build_index.is_none() { if blas_build_index.is_none() {
return Err(ValidateTlasActionsError::UsedUnbuiltBlas( return Err(ValidateAsActionsError::UsedUnbuiltBlas(
action.tlas.error_ident(), tlas.error_ident(),
blas.error_ident(), blas.error_ident(),
)); ));
} }
if blas_build_index.unwrap() > tlas_build_index.unwrap() { if blas_build_index.unwrap() > tlas_build_index.unwrap() {
return Err(ValidateTlasActionsError::BlasNewerThenTlas( return Err(ValidateAsActionsError::BlasNewerThenTlas(
blas.error_ident(), blas.error_ident(),
action.tlas.error_ident(), tlas.error_ident(),
)); ));
} }
blas.try_raw(snatch_guard)?; blas.try_raw(snatch_guard)?;
@ -881,7 +848,7 @@ impl CommandBufferMutable {
fn iter_blas<'a>( fn iter_blas<'a>(
blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>, blas_iter: impl Iterator<Item = BlasBuildEntry<'a>>,
cmd_buf_data: &mut CommandBufferMutable, cmd_buf_data: &mut CommandBufferMutable,
build_command_index: NonZeroU64, build_command: &mut AsBuild,
buf_storage: &mut Vec<TriangleBufferStore<'a>>, buf_storage: &mut Vec<TriangleBufferStore<'a>>,
hub: &Hub, hub: &Hub,
) -> Result<(), BuildAccelerationStructureError> { ) -> Result<(), BuildAccelerationStructureError> {
@ -890,10 +857,7 @@ fn iter_blas<'a>(
let blas = hub.blas_s.get(entry.blas_id).get()?; let blas = hub.blas_s.get(entry.blas_id).get()?;
cmd_buf_data.trackers.blas_s.insert_single(blas.clone()); cmd_buf_data.trackers.blas_s.insert_single(blas.clone());
cmd_buf_data.blas_actions.push(BlasAction { build_command.blas_s_built.push(blas.clone());
blas: blas.clone(),
kind: crate::ray_tracing::BlasActionKind::Build(build_command_index),
});
match entry.geometries { match entry.geometries {
BlasGeometries::TriangleGeometries(triangle_geometries) => { BlasGeometries::TriangleGeometries(triangle_geometries) => {

View File

@ -10,6 +10,7 @@ use smallvec::SmallVec;
use thiserror::Error; use thiserror::Error;
use super::{life::LifetimeTracker, Device}; use super::{life::LifetimeTracker, Device};
use crate::device::resource::CommandIndices;
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
use crate::device::trace::Action; use crate::device::trace::Action;
use crate::scratch::ScratchBuffer; use crate::scratch::ScratchBuffer;
@ -447,9 +448,7 @@ pub enum QueueSubmitError {
#[error(transparent)] #[error(transparent)]
CommandEncoder(#[from] CommandEncoderError), CommandEncoder(#[from] CommandEncoderError),
#[error(transparent)] #[error(transparent)]
ValidateBlasActionsError(#[from] crate::ray_tracing::ValidateBlasActionsError), ValidateAsActionsError(#[from] crate::ray_tracing::ValidateAsActionsError),
#[error(transparent)]
ValidateTlasActionsError(#[from] crate::ray_tracing::ValidateTlasActionsError),
} }
//TODO: move out common parts of write_xxx. //TODO: move out common parts of write_xxx.
@ -1126,6 +1125,7 @@ impl Queue {
&snatch_guard, &snatch_guard,
&mut submit_surface_textures_owned, &mut submit_surface_textures_owned,
&mut used_surface_textures, &mut used_surface_textures,
&mut command_index_guard,
); );
if let Err(err) = res { if let Err(err) = res {
first_error.get_or_insert(err); first_error.get_or_insert(err);
@ -1518,6 +1518,7 @@ fn validate_command_buffer(
snatch_guard: &SnatchGuard, snatch_guard: &SnatchGuard,
submit_surface_textures_owned: &mut FastHashMap<*const Texture, Arc<Texture>>, submit_surface_textures_owned: &mut FastHashMap<*const Texture, Arc<Texture>>,
used_surface_textures: &mut track::TextureUsageScope, used_surface_textures: &mut track::TextureUsageScope,
command_index_guard: &mut RwLockWriteGuard<CommandIndices>,
) -> Result<(), QueueSubmitError> { ) -> Result<(), QueueSubmitError> {
command_buffer.same_device_as(queue)?; command_buffer.same_device_as(queue)?;
@ -1557,10 +1558,9 @@ fn validate_command_buffer(
} }
} }
if let Err(e) = cmd_buf_data.validate_blas_actions() { if let Err(e) =
return Err(e.into()); cmd_buf_data.validate_acceleration_structure_actions(snatch_guard, command_index_guard)
} {
if let Err(e) = cmd_buf_data.validate_tlas_actions(snatch_guard) {
return Err(e.into()); return Err(e.into());
} }
} }

View File

@ -71,6 +71,7 @@ pub(crate) struct CommandIndices {
/// ///
/// [`last_successful_submission_index`]: Device::last_successful_submission_index /// [`last_successful_submission_index`]: Device::last_successful_submission_index
pub(crate) active_submission_index: hal::FenceValue, pub(crate) active_submission_index: hal::FenceValue,
pub(crate) next_acceleration_structure_build_command_index: u64,
} }
/// Structure describing a logical device. Some members are internally mutable, /// Structure describing a logical device. Some members are internally mutable,
@ -133,7 +134,6 @@ pub struct Device {
pub(crate) instance_flags: wgt::InstanceFlags, pub(crate) instance_flags: wgt::InstanceFlags,
pub(crate) deferred_destroy: Mutex<Vec<DeferredDestroy>>, pub(crate) deferred_destroy: Mutex<Vec<DeferredDestroy>>,
pub(crate) usage_scopes: UsageScopePool, pub(crate) usage_scopes: UsageScopePool,
pub(crate) last_acceleration_structure_build_command_index: AtomicU64,
pub(crate) indirect_validation: Option<crate::indirect_validation::IndirectValidation>, pub(crate) indirect_validation: Option<crate::indirect_validation::IndirectValidation>,
// Optional so that we can late-initialize this after the queue is created. // Optional so that we can late-initialize this after the queue is created.
pub(crate) timestamp_normalizer: pub(crate) timestamp_normalizer:
@ -284,6 +284,8 @@ impl Device {
rank::DEVICE_COMMAND_INDICES, rank::DEVICE_COMMAND_INDICES,
CommandIndices { CommandIndices {
active_submission_index: 0, active_submission_index: 0,
// By starting at one, we can put the result in a NonZeroU64.
next_acceleration_structure_build_command_index: 1,
}, },
), ),
last_successful_submission_index: AtomicU64::new(0), last_successful_submission_index: AtomicU64::new(0),
@ -321,8 +323,6 @@ impl Device {
instance_flags, instance_flags,
deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()), deferred_destroy: Mutex::new(rank::DEVICE_DEFERRED_DESTROY, Vec::new()),
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()), usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
// By starting at one, we can put the result in a NonZeroU64.
last_acceleration_structure_build_command_index: AtomicU64::new(1),
timestamp_normalizer: OnceCellOrLock::new(), timestamp_normalizer: OnceCellOrLock::new(),
indirect_validation, indirect_validation,
}) })

View File

@ -8,7 +8,6 @@
// - ([non performance] extract function in build (rust function extraction with guards is a pain)) // - ([non performance] extract function in build (rust function extraction with guards is a pain))
use alloc::{boxed::Box, sync::Arc, vec::Vec}; use alloc::{boxed::Box, sync::Arc, vec::Vec};
use core::num::NonZeroU64;
use thiserror::Error; use thiserror::Error;
use wgt::{AccelerationStructureGeometryFlags, BufferAddress, IndexFormat, VertexFormat}; use wgt::{AccelerationStructureGeometryFlags, BufferAddress, IndexFormat, VertexFormat};
@ -137,18 +136,12 @@ pub enum BuildAccelerationStructureError {
} }
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
pub enum ValidateBlasActionsError { pub enum ValidateAsActionsError {
#[error("Blas {0:?} is used before it is built")]
UsedUnbuilt(ResourceErrorIdent),
}
#[derive(Clone, Debug, Error)]
pub enum ValidateTlasActionsError {
#[error(transparent)] #[error(transparent)]
DestroyedResource(#[from] DestroyedResourceError), DestroyedResource(#[from] DestroyedResourceError),
#[error("Tlas {0:?} is used before it is built")] #[error("Tlas {0:?} is used before it is built")]
UsedUnbuilt(ResourceErrorIdent), UsedUnbuiltTlas(ResourceErrorIdent),
#[error("Blas {0:?} is used before it is built (in Tlas {1:?})")] #[error("Blas {0:?} is used before it is built (in Tlas {1:?})")]
UsedUnbuiltBlas(ResourceErrorIdent, ResourceErrorIdent), UsedUnbuiltBlas(ResourceErrorIdent, ResourceErrorIdent),
@ -200,31 +193,22 @@ pub struct TlasPackage<'a> {
pub lowest_unmodified: u32, pub lowest_unmodified: u32,
} }
#[derive(Debug, Copy, Clone)]
pub(crate) enum BlasActionKind {
Build(NonZeroU64),
Use,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) enum TlasActionKind { pub(crate) struct TlasBuild {
Build {
build_index: NonZeroU64,
dependencies: Vec<Arc<Blas>>,
},
Use,
}
#[derive(Debug, Clone)]
pub(crate) struct BlasAction {
pub blas: Arc<Blas>,
pub kind: BlasActionKind,
}
#[derive(Debug, Clone)]
pub(crate) struct TlasAction {
pub tlas: Arc<Tlas>, pub tlas: Arc<Tlas>,
pub kind: TlasActionKind, pub dependencies: Vec<Arc<Blas>>,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct AsBuild {
pub blas_s_built: Vec<Arc<Blas>>,
pub tlas_s_built: Vec<TlasBuild>,
}
#[derive(Debug, Clone)]
pub(crate) enum AsAction {
Build(AsBuild),
UseTlas(Arc<Tlas>),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]