memory init tracker check (previous is_initialized) clamps range now

and is O(log n)!
This commit is contained in:
Andreas Reich 2021-01-30 20:17:15 +01:00
parent 31d292b169
commit 018ad05f56
4 changed files with 220 additions and 168 deletions

View File

@ -333,16 +333,19 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.map_pass_err(scope)?;
cmd_buf.buffer_memory_init_actions.extend(
bind_group
.used_buffer_ranges
.iter()
.filter(|action| match buffer_guard.get(action.id) {
Ok(buffer) => {
!buffer.initialization_status.is_initialized(&action.range)
}
Err(_) => false,
})
.cloned(),
bind_group.used_buffer_ranges.iter().filter_map(
|action| match buffer_guard.get(action.id) {
Ok(buffer) => buffer
.initialization_status
.check(action.range.clone())
.map(|range| MemoryInitTrackerAction {
id: action.id,
range,
kind: action.kind,
}),
Err(_) => None,
},
),
);
if let Some((pipeline_layout_id, follow_ups)) = state.binder.provide_entry(
@ -531,19 +534,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let stride = 3 * 4; // 3 integers, x/y/z group size
let used_buffer_range = offset..(offset + stride);
if !indirect_buffer
.initialization_status
.is_initialized(&used_buffer_range)
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
indirect_buffer
.initialization_status
.check(offset..(offset + stride))
.map(|range| MemoryInitTrackerAction {
id: buffer_id,
range: used_buffer_range,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
state
.flush_states(

View File

@ -1119,16 +1119,19 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.map_pass_err(scope)?;
cmd_buf.buffer_memory_init_actions.extend(
bind_group
.used_buffer_ranges
.iter()
.filter(|action| match buffer_guard.get(action.id) {
Ok(buffer) => {
!buffer.initialization_status.is_initialized(&action.range)
}
Err(_) => false,
})
.cloned(),
bind_group.used_buffer_ranges.iter().filter_map(|action| {
match buffer_guard.get(action.id) {
Ok(buffer) => buffer
.initialization_status
.check(action.range.clone())
.map(|range| MemoryInitTrackerAction {
id: action.id,
range,
kind: action.kind,
}),
Err(_) => None,
}
}),
);
if let Some((pipeline_layout_id, follow_ups)) = state.binder.provide_entry(
@ -1307,15 +1310,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
state.index.format = Some(index_format);
state.index.update_limit();
if !buffer.initialization_status.is_initialized(&(offset..end)) {
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
buffer
.initialization_status
.check(offset..end)
.map(|range| MemoryInitTrackerAction {
id: buffer_id,
range: offset..end,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
let range = hal::buffer::SubRange {
offset,
@ -1360,16 +1364,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
};
vertex_state.bound = true;
let used_range = offset..(offset + vertex_state.total_size);
if !buffer.initialization_status.is_initialized(&used_range) {
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
buffer
.initialization_status
.check(offset..(offset + vertex_state.total_size))
.map(|range| MemoryInitTrackerAction {
id: buffer_id,
range: used_range,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
let range = hal::buffer::SubRange {
offset,
@ -1623,18 +1627,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.map_pass_err(scope);
}
if !indirect_buffer
.initialization_status
.is_initialized(&(offset..end_offset))
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
indirect_buffer
.initialization_status
.check(offset..end_offset)
.map(|range| MemoryInitTrackerAction {
id: buffer_id,
range: offset..end_offset,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
match indexed {
false => unsafe {
@ -1719,18 +1721,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
})
.map_pass_err(scope);
}
if !indirect_buffer
.initialization_status
.is_initialized(&(offset..end_offset))
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
indirect_buffer
.initialization_status
.check(offset..end_offset)
.map(|range| MemoryInitTrackerAction {
id: buffer_id,
range: offset..end_offset,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
let begin_count_offset = count_buffer_offset;
let end_count_offset = count_buffer_offset + 4;
@ -1742,18 +1742,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
})
.map_pass_err(scope);
}
if !count_buffer
.initialization_status
.is_initialized(&(count_buffer_offset..end_count_offset))
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
count_buffer
.initialization_status
.check(count_buffer_offset..end_count_offset)
.map(|range| MemoryInitTrackerAction {
id: count_buffer_id,
range: count_buffer_offset..end_count_offset,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
match indexed {
false => unsafe {
@ -1891,13 +1889,17 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
bundle
.buffer_memory_init_actions
.iter()
.filter(|action| match buffer_guard.get(action.id) {
Ok(buffer) => {
!buffer.initialization_status.is_initialized(&action.range)
}
Err(_) => false,
})
.cloned(),
.filter_map(|action| match buffer_guard.get(action.id) {
Ok(buffer) => buffer
.initialization_status
.check(action.range.clone())
.map(|range| MemoryInitTrackerAction {
id: action.id,
range,
kind: action.kind,
}),
Err(_) => None,
}),
);
unsafe {

View File

@ -404,32 +404,26 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
// Make sure source is initialized memory and mark dest as initialized.
let used_dst_buffer_range = destination_offset..(destination_offset + size);
if !dst_buffer
.initialization_status
.is_initialized(&used_dst_buffer_range)
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
dst_buffer
.initialization_status
.check(destination_offset..(destination_offset + size))
.map(|range| MemoryInitTrackerAction {
id: destination,
range: used_dst_buffer_range,
range,
kind: MemoryInitKind::ImplicitlyInitialized,
});
}
let used_src_buffer_range = source_offset..(source_offset + size);
if !src_buffer
.initialization_status
.is_initialized(&used_src_buffer_range)
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
}),
);
cmd_buf.buffer_memory_init_actions.extend(
src_buffer
.initialization_status
.check(source_offset..(source_offset + size))
.map(|range| MemoryInitTrackerAction {
id: source,
range: used_src_buffer_range,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
let region = hal::command::BufferCopy {
src: source_offset,
@ -544,20 +538,16 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
copy_size,
)?;
let used_src_buffer_range =
source.layout.offset..(source.layout.offset + required_buffer_bytes_in_copy);
if !src_buffer
.initialization_status
.is_initialized(&used_src_buffer_range)
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
src_buffer
.initialization_status
.check(source.layout.offset..(source.layout.offset + required_buffer_bytes_in_copy))
.map(|range| MemoryInitTrackerAction {
id: source.buffer,
range: used_src_buffer_range,
range,
kind: MemoryInitKind::NeedsInitializedMemory,
});
}
}),
);
let (block_width, _) = dst_texture.format.describe().block_dimensions;
if !conv::is_valid_copy_dst_texture_format(dst_texture.format) {
@ -706,20 +696,20 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
))?
}
let used_dst_buffer_range =
destination.layout.offset..(destination.layout.offset + required_buffer_bytes_in_copy);
if !dst_buffer
.initialization_status
.is_initialized(&used_dst_buffer_range)
{
cmd_buf
.buffer_memory_init_actions
.push(MemoryInitTrackerAction {
cmd_buf.buffer_memory_init_actions.extend(
dst_buffer
.initialization_status
.check(
destination.layout.offset
..(destination.layout.offset + required_buffer_bytes_in_copy),
)
.map(|range| MemoryInitTrackerAction {
id: destination.buffer,
range: used_dst_buffer_range,
range,
kind: MemoryInitKind::ImplicitlyInitialized,
});
}
}),
);
// WebGPU uses the physical size of the texture for copies whereas vulkan uses
// the virtual size. We have passed validation, so it's safe to use the
// image extent data directly. We want the provided copy size to be no larger than

View File

@ -1,6 +1,6 @@
use std::ops::Range;
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Copy)]
pub(crate) enum MemoryInitKind {
// The memory range is going to be written by an already initialized source, thus doesn't need extra attention other than marking as initialized.
ImplicitlyInitialized,
@ -18,6 +18,7 @@ pub(crate) struct MemoryInitTrackerAction<ResourceId> {
/// Tracks initialization status of a linear range from 0..size
#[derive(Debug)]
pub(crate) struct MemoryInitTracker {
// Ordered, non overlapping list of all uninitialized ranges.
uninitialized_ranges: Vec<Range<wgt::BufferAddress>>,
}
@ -80,15 +81,59 @@ impl MemoryInitTracker {
}
}
pub(crate) fn is_initialized(&self, query_range: &Range<wgt::BufferAddress>) -> bool {
match self
.uninitialized_ranges
.iter()
.find(|r| r.end > query_range.start)
{
Some(r) => r.start >= query_range.end,
None => true,
// Search smallest range.end which is bigger than bound in O(log n) (with n being number of uninitialized ranges)
fn lower_bound(&self, bound: wgt::BufferAddress) -> usize {
// This is equivalent to, except that it may return an out of bounds index instead of
//self.uninitialized_ranges.iter().position(|r| r.end > bound)
// In future Rust versions this operation can be done with partition_point
// See https://github.com/rust-lang/rust/pull/73577/
let mut left = 0;
let mut right = self.uninitialized_ranges.len();
while left != right {
let mid = left + (right - left) / 2;
let value = unsafe { self.uninitialized_ranges.get_unchecked(mid) };
if value.end <= bound {
left = mid + 1;
} else {
right = mid;
}
}
left
}
// Checks if there's any uninitialized ranges within a query.
// If there are any, the range returned a the subrange of the query_range that contains all these uninitialized regions.
// Returned range may be larger than necessary (tradeoff for making this function O(log n))
pub(crate) fn check(
&self,
query_range: Range<wgt::BufferAddress>,
) -> Option<Range<wgt::BufferAddress>> {
let index = self.lower_bound(query_range.start);
self.uninitialized_ranges
.get(index)
.map(|start_range| {
if start_range.start < query_range.end {
let start = start_range.start.max(query_range.start);
match self.uninitialized_ranges.get(index + 1) {
Some(next_range) => {
if next_range.start < query_range.end {
// Would need to keep iterating for more accurate upper bound. Don't do that here.
Some(start..query_range.end)
} else {
Some(start..start_range.end.min(query_range.end))
}
}
None => Some(start..start_range.end.min(query_range.end)),
}
} else {
None
}
})
.flatten()
}
// Drains uninitialized ranges in a query range.
@ -97,22 +142,16 @@ impl MemoryInitTracker {
&'a mut self,
drain_range: Range<wgt::BufferAddress>,
) -> MemoryInitTrackerDrain<'a> {
let next_index = self
.uninitialized_ranges
.iter()
.position(|r| r.end > drain_range.start)
.unwrap_or(std::usize::MAX);
MemoryInitTrackerDrain {
next_index,
next_index: self.lower_bound(drain_range.start),
drain_range,
uninitialized_ranges: &mut self.uninitialized_ranges,
}
}
// Clears uninitialized ranges in a query range.
pub(crate) fn clear(&mut self, drain_range: Range<wgt::BufferAddress>) {
self.drain(drain_range).for_each(drop);
pub(crate) fn clear(&mut self, range: Range<wgt::BufferAddress>) {
self.drain(range).for_each(drop);
}
}
@ -122,36 +161,57 @@ mod test {
use std::ops::Range;
#[test]
fn is_initialized_for_empty_tracker() {
fn check_for_newly_created_tracker() {
let tracker = MemoryInitTracker::new(10);
assert!(!tracker.is_initialized(&(0..10)));
assert!(!tracker.is_initialized(&(0..3)));
assert!(!tracker.is_initialized(&(3..4)));
assert!(!tracker.is_initialized(&(4..10)));
assert_eq!(tracker.check(0..10), Some(0..10));
assert_eq!(tracker.check(0..3), Some(0..3));
assert_eq!(tracker.check(3..4), Some(3..4));
assert_eq!(tracker.check(4..10), Some(4..10));
}
#[test]
fn is_initialized_for_filled_tracker() {
fn check_for_cleared_tracker() {
let mut tracker = MemoryInitTracker::new(10);
tracker.clear(0..10);
assert!(tracker.is_initialized(&(0..10)));
assert!(tracker.is_initialized(&(0..3)));
assert!(tracker.is_initialized(&(3..4)));
assert!(tracker.is_initialized(&(4..10)));
assert_eq!(tracker.check(0..10), None);
assert_eq!(tracker.check(0..3), None);
assert_eq!(tracker.check(3..4), None);
assert_eq!(tracker.check(4..10), None);
}
#[test]
fn is_initialized_for_partially_filled_tracker() {
let mut tracker = MemoryInitTracker::new(10);
tracker.clear(4..6);
assert!(!tracker.is_initialized(&(0..10))); // entire range
assert!(!tracker.is_initialized(&(0..4))); // left non-overlapping
assert!(!tracker.is_initialized(&(3..5))); // left overlapping
assert!(tracker.is_initialized(&(4..6))); // entire initialized range
assert!(tracker.is_initialized(&(4..5))); // left part
assert!(tracker.is_initialized(&(5..6))); // right part
assert!(!tracker.is_initialized(&(5..7))); // right overlapping
assert!(!tracker.is_initialized(&(7..10))); // right non-overlapping
fn check_for_partially_filled_tracker() {
let mut tracker = MemoryInitTracker::new(25);
// Two regions of uninitialized memory
tracker.clear(0..5);
tracker.clear(10..15);
tracker.clear(20..25);
assert_eq!(tracker.check(0..25), Some(5..25)); // entire range
assert_eq!(tracker.check(0..5), None); // left non-overlapping
assert_eq!(tracker.check(3..8), Some(5..8)); // left overlapping region
assert_eq!(tracker.check(3..17), Some(5..17)); // left overlapping region + contained region
assert_eq!(tracker.check(8..22), Some(8..22)); // right overlapping region + contained region (yes, doesn't fix range end!)
assert_eq!(tracker.check(17..22), Some(17..20)); // right overlapping region
assert_eq!(tracker.check(20..25), None); // right non-overlapping
}
#[test]
fn clear_already_cleared() {
let mut tracker = MemoryInitTracker::new(30);
tracker.clear(10..20);
// Overlapping with non-cleared
tracker.clear(5..15); // Left overlap
tracker.clear(15..25); // Right overlap
tracker.clear(0..30); // Inner overlap
// Clear fully cleared
tracker.clear(0..30);
assert_eq!(tracker.check(0..30), None);
}
#[test]