MemoryInitTracker tests and interface adjustments

This commit is contained in:
Andreas Reich 2021-01-26 22:17:04 +01:00
parent 492027fe6e
commit 32b4e32ac6
3 changed files with 112 additions and 42 deletions

View File

@ -208,23 +208,20 @@ fn map_buffer<B: hal::Backend>(
//
// If this is a write mapping zeroing out the memory here is the only reasonable way as all data is pushed to GPU anyways.
let zero_init_needs_flush_now = !block.is_coherent() && buffer.sync_mapped_writes.is_none(); // No need to flush if it is flushed later anyways.
if let Some(uninitialized_ranges) = buffer
for uninitialized_range in buffer
.initialization_status
.drain_uninitialized_ranges(&(offset..(size + offset)))
{
for uninitialized_range in uninitialized_ranges {
let num_bytes = uninitialized_range.end - uninitialized_range.start;
unsafe {
ptr::write_bytes(
ptr.as_ptr().offset(uninitialized_range.start as isize),
0,
num_bytes as usize,
)
};
if zero_init_needs_flush_now {
block.flush_range(raw, uninitialized_range.start, Some(num_bytes))?;
}
let num_bytes = uninitialized_range.end - uninitialized_range.start;
unsafe {
ptr::write_bytes(
ptr.as_ptr().offset(uninitialized_range.start as isize),
0,
num_bytes as usize,
)
};
if zero_init_needs_flush_now {
block.flush_range(raw, uninitialized_range.start, Some(num_bytes))?;
}
}
@ -2612,12 +2609,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
buffer
.initialization_status
.drain_uninitialized_ranges(&(0..buffer.size))
.unwrap()
.for_each(drop);
stage
.initialization_status
.drain_uninitialized_ranges(&(0..buffer.size))
.unwrap()
.for_each(drop);
buffer.map_state = resource::BufferMapState::Init {

View File

@ -275,12 +275,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
// Ensure the overwritten bytes are marked as initialized so they don't need to be nulled prior to mapping or binding.
{
let dst = buffer_guard.get_mut(buffer_id).unwrap();
if let Some(uninitialized_ranges) = dst
.initialization_status
dst.initialization_status
.drain_uninitialized_ranges(&(buffer_offset..(buffer_offset + data_size)))
{
uninitialized_ranges.for_each(drop);
}
.for_each(drop);
}
Ok(())
@ -503,20 +500,18 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.get_mut(buffer_use.id)
.map_err(|_| QueueSubmitError::DestroyedBuffer(buffer_use.id))?;
if let Some(uninitialized_ranges) = buffer
let uninitialized_ranges = buffer
.initialization_status
.drain_uninitialized_ranges(&buffer_use.range)
{
match buffer_use.kind {
MemoryInitKind::ImplicitlyInitialized => {
uninitialized_ranges.for_each(drop);
}
MemoryInitKind::NeedsInitializedMemory => {
required_buffer_inits
.entry(buffer_use.id)
.or_default()
.extend(uninitialized_ranges);
}
.drain_uninitialized_ranges(&buffer_use.range);
match buffer_use.kind {
MemoryInitKind::ImplicitlyInitialized => {
uninitialized_ranges.for_each(drop);
}
MemoryInitKind::NeedsInitializedMemory => {
required_buffer_inits
.entry(buffer_use.id)
.or_default()
.extend(uninitialized_ranges);
}
}
}

View File

@ -44,7 +44,7 @@ impl MemoryInitTracker {
pub(crate) fn drain_uninitialized_ranges<'a>(
&'a mut self,
range: &Range<wgt::BufferAddress>,
) -> Option<impl Iterator<Item = Range<wgt::BufferAddress>> + 'a> {
) -> impl Iterator<Item = Range<wgt::BufferAddress>> + 'a {
let mut uninitialized_ranges: Vec<Range<wgt::BufferAddress>> = self
.uninitialized_ranges
.allocated_ranges()
@ -60,11 +60,7 @@ impl MemoryInitTracker {
})
.collect();
if uninitialized_ranges.is_empty() {
return None;
}
Some(std::iter::from_fn(move || {
std::iter::from_fn(move || {
let range: Option<Range<wgt::BufferAddress>> =
uninitialized_ranges.last().map(|r| r.clone());
match range {
@ -76,8 +72,92 @@ impl MemoryInitTracker {
}
None => None,
}
}))
})
}
}
// TODO: Add some unit tests for this construct
#[cfg(test)]
mod test {
use std::ops::Range;
use super::MemoryInitTracker;
#[test]
fn is_initialized_for_empty_tracker() {
let tracker = MemoryInitTracker::new(10);
assert!(!tracker.is_initialized(&(0..10)));
assert!(!tracker.is_initialized(&(0..3)));
assert!(!tracker.is_initialized(&(3..4)));
assert!(!tracker.is_initialized(&(4..10)));
}
#[test]
fn is_initialized_for_filled_tracker() {
let mut tracker = MemoryInitTracker::new(10);
tracker.drain_uninitialized_ranges(&(0..10)).for_each(drop);
assert!(tracker.is_initialized(&(0..10)));
assert!(tracker.is_initialized(&(0..3)));
assert!(tracker.is_initialized(&(3..4)));
assert!(tracker.is_initialized(&(4..10)));
}
#[test]
fn is_initialized_for_partially_filled_tracker() {
let mut tracker = MemoryInitTracker::new(10);
tracker.drain_uninitialized_ranges(&(4..6)).for_each(drop);
assert!(!tracker.is_initialized(&(0..10))); // entire range
assert!(!tracker.is_initialized(&(0..4))); // left non-overlapping
assert!(!tracker.is_initialized(&(3..5))); // left overlapping
assert!(tracker.is_initialized(&(4..6))); // entire initialized range
assert!(tracker.is_initialized(&(4..5))); // left part
assert!(tracker.is_initialized(&(5..6))); // right part
assert!(!tracker.is_initialized(&(5..7))); // right overlapping
assert!(!tracker.is_initialized(&(7..10))); // right non-overlapping
}
#[test]
fn drain_uninitialized_ranges_never_returns_ranges_twice_for_same_range() {
let mut tracker = MemoryInitTracker::new(19);
assert_eq!(tracker.drain_uninitialized_ranges(&(0..19)).count(), 1);
assert_eq!(tracker.drain_uninitialized_ranges(&(0..19)).count(), 0);
let mut tracker = MemoryInitTracker::new(17);
assert_eq!(tracker.drain_uninitialized_ranges(&(5..8)).count(), 1);
assert_eq!(tracker.drain_uninitialized_ranges(&(5..8)).count(), 0);
assert_eq!(tracker.drain_uninitialized_ranges(&(1..3)).count(), 1);
assert_eq!(tracker.drain_uninitialized_ranges(&(1..3)).count(), 0);
assert_eq!(tracker.drain_uninitialized_ranges(&(7..13)).count(), 1);
assert_eq!(tracker.drain_uninitialized_ranges(&(7..13)).count(), 0);
}
#[test]
fn drain_uninitialized_ranges_splits_ranges_correctly() {
let mut tracker = MemoryInitTracker::new(1337);
assert_eq!(
tracker
.drain_uninitialized_ranges(&(21..42))
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![21..42]
);
assert_eq!(
tracker
.drain_uninitialized_ranges(&(900..1000))
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![900..1000]
);
// Splitted ranges.
assert_eq!(
tracker
.drain_uninitialized_ranges(&(5..1003))
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![1000..1003, 42..900, 5..21]
);
assert_eq!(
tracker
.drain_uninitialized_ranges(&(0..1337))
.collect::<Vec<Range<wgt::BufferAddress>>>(),
vec![1003..1337, 0..5]
);
}
}