[spirv] Stop naga causing undefined behavior in rayQueryGet*Intersection (#6752)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
Vecvec 2025-01-17 00:04:33 +00:00 committed by GitHub
parent d5d5157b5d
commit bdef8c0407
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 914 additions and 376 deletions

View File

@ -40,6 +40,12 @@ Bottom level categories:
## Unreleased
### Bug Fixes
#### Vulkan
- Stop naga causing undefined behavior when a ray query misses. By @Vecvec in [#6752](https://github.com/gfx-rs/wgpu/pull/6752).
### Changes
#### Refactored internal trace path parameter

View File

@ -1736,7 +1736,20 @@ impl BlockContext<'_> {
}
crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
crate::Expression::RayQueryGetIntersection { query, committed } => {
self.write_ray_query_get_intersection(query, block, committed)
let query_id = self.cached[query];
let func_id = self
.writer
.write_ray_query_get_intersection_function(committed, self.ir_module);
let ray_intersection = self.ir_module.special_types.ray_intersection.unwrap();
let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection));
let id = self.gen_id();
block.body.push(Instruction::function_call(
intersection_type_id,
id,
func_id,
&[query_id],
));
id
}
};

View File

@ -766,6 +766,8 @@ pub struct Writer {
// Just a temporary list of SPIR-V ids
temp_list: Vec<Word>,
ray_get_intersection_function: Option<Word>,
}
bitflags::bitflags! {

View File

@ -2,8 +2,519 @@
Generating SPIR-V for ray query operations.
*/
use super::{Block, BlockContext, Instruction, LocalType, LookupType, NumericType};
use super::{
Block, BlockContext, Function, FunctionArgument, Instruction, LocalType, LookupFunctionType,
LookupType, NumericType, Writer,
};
use crate::arena::Handle;
use crate::{Type, TypeInner};
impl Writer {
pub(super) fn write_ray_query_get_intersection_function(
&mut self,
is_committed: bool,
ir_module: &crate::Module,
) -> spirv::Word {
if let Some(func_id) = self.ray_get_intersection_function {
return func_id;
}
let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
let intersection_type_id = self.get_type_id(LookupType::Handle(ray_intersection));
let intersection_pointer_type_id =
self.get_type_id(LookupType::Local(LocalType::Pointer {
base: ray_intersection,
class: spirv::StorageClass::Function,
}));
let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
let flag_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Scalar(crate::Scalar::U32),
})
.unwrap();
let flag_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer {
base: flag_type,
class: spirv::StorageClass::Function,
}));
let transform_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
})));
let transform_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
},
})
.unwrap();
let transform_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer {
base: transform_type,
class: spirv::StorageClass::Function,
}));
let barycentrics_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
})));
let barycentrics_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
},
})
.unwrap();
let barycentrics_pointer_type_id =
self.get_type_id(LookupType::Local(LocalType::Pointer {
base: barycentrics_type,
class: spirv::StorageClass::Function,
}));
let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::BOOL),
)));
let bool_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Scalar(crate::Scalar::BOOL),
})
.unwrap();
let bool_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer {
base: bool_type,
class: spirv::StorageClass::Function,
}));
let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::F32),
)));
let float_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Scalar(crate::Scalar::F32),
})
.unwrap();
let float_pointer_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer {
base: float_type,
class: spirv::StorageClass::Function,
}));
let rq_ty = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::RayQuery,
})
.expect("ray_query type should have been populated by the variable passed into this!");
let argument_type_id = self.get_type_id(LookupType::Local(LocalType::Pointer {
base: rq_ty,
class: spirv::StorageClass::Function,
}));
let func_ty = self.get_function_type(LookupFunctionType {
parameter_type_ids: vec![argument_type_id],
return_type_id: intersection_type_id,
});
let mut function = Function::default();
let func_id = self.id_gen.next();
function.signature = Some(Instruction::function(
intersection_type_id,
func_id,
spirv::FunctionControl::empty(),
func_ty,
));
let blank_intersection = self.get_constant_null(intersection_type_id);
let query_id = self.id_gen.next();
let instruction = Instruction::function_parameter(argument_type_id, query_id);
function.parameters.push(FunctionArgument {
instruction,
handle_id: 0,
});
let label_id = self.id_gen.next();
let mut block = Block::new(label_id);
let blank_intersection_id = self.id_gen.next();
block.body.push(Instruction::variable(
intersection_pointer_type_id,
blank_intersection_id,
spirv::StorageClass::Function,
Some(blank_intersection),
));
let intersection_id = self.get_constant_scalar(crate::Literal::U32(if is_committed {
spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
} else {
spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
} as _));
let raw_kind_id = self.id_gen.next();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionTypeKHR,
flag_type_id,
raw_kind_id,
query_id,
intersection_id,
));
let kind_id = if is_committed {
// Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
raw_kind_id
} else {
// Remap from the candidate kind to IR
let condition_id = self.id_gen.next();
let committed_triangle_kind_id = self.get_constant_scalar(crate::Literal::U32(
spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
as _,
));
block.body.push(Instruction::binary(
spirv::Op::IEqual,
self.get_bool_type_id(),
condition_id,
raw_kind_id,
committed_triangle_kind_id,
));
let kind_id = self.id_gen.next();
block.body.push(Instruction::select(
flag_type_id,
kind_id,
condition_id,
self.get_constant_scalar(crate::Literal::U32(
crate::RayQueryIntersection::Triangle as _,
)),
self.get_constant_scalar(crate::Literal::U32(
crate::RayQueryIntersection::Aabb as _,
)),
));
kind_id
};
let idx_id = self.get_index_constant(0);
let access_idx = self.id_gen.next();
block.body.push(Instruction::access_chain(
flag_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
block
.body
.push(Instruction::store(access_idx, kind_id, None));
let not_none_comp_id = self.id_gen.next();
let none_id =
self.get_constant_scalar(crate::Literal::U32(crate::RayQueryIntersection::None as _));
block.body.push(Instruction::binary(
spirv::Op::INotEqual,
self.get_bool_type_id(),
not_none_comp_id,
kind_id,
none_id,
));
let not_none_label_id = self.id_gen.next();
let mut not_none_block = Block::new(not_none_label_id);
let final_label_id = self.id_gen.next();
let mut final_block = Block::new(final_label_id);
block.body.push(Instruction::selection_merge(
final_label_id,
spirv::SelectionControl::NONE,
));
function.consume(
block,
Instruction::branch_conditional(not_none_comp_id, not_none_label_id, final_label_id),
);
let instance_custom_index_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
flag_type_id,
instance_custom_index_id,
query_id,
intersection_id,
));
let instance_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
flag_type_id,
instance_id,
query_id,
intersection_id,
));
let sbt_record_offset_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
flag_type_id,
sbt_record_offset_id,
query_id,
intersection_id,
));
let geometry_index_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
flag_type_id,
geometry_index_id,
query_id,
intersection_id,
));
let primitive_index_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
flag_type_id,
primitive_index_id,
query_id,
intersection_id,
));
//Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
// but it's not a property of an intersection.
let object_to_world_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
transform_type_id,
object_to_world_id,
query_id,
intersection_id,
));
let world_to_object_id = self.id_gen.next();
not_none_block
.body
.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
transform_type_id,
world_to_object_id,
query_id,
intersection_id,
));
// instance custom index
let idx_id = self.get_index_constant(2);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
flag_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block.body.push(Instruction::store(
access_idx,
instance_custom_index_id,
None,
));
// instance
let idx_id = self.get_index_constant(3);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
flag_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block
.body
.push(Instruction::store(access_idx, instance_id, None));
let idx_id = self.get_index_constant(4);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
flag_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block
.body
.push(Instruction::store(access_idx, sbt_record_offset_id, None));
let idx_id = self.get_index_constant(5);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
flag_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block
.body
.push(Instruction::store(access_idx, geometry_index_id, None));
let idx_id = self.get_index_constant(6);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
flag_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block
.body
.push(Instruction::store(access_idx, primitive_index_id, None));
let idx_id = self.get_index_constant(9);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
transform_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block
.body
.push(Instruction::store(access_idx, object_to_world_id, None));
let idx_id = self.get_index_constant(10);
let access_idx = self.id_gen.next();
not_none_block.body.push(Instruction::access_chain(
transform_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
not_none_block
.body
.push(Instruction::store(access_idx, world_to_object_id, None));
let tri_comp_id = self.id_gen.next();
let tri_id = self.get_constant_scalar(crate::Literal::U32(
crate::RayQueryIntersection::Triangle as _,
));
not_none_block.body.push(Instruction::binary(
spirv::Op::IEqual,
self.get_bool_type_id(),
tri_comp_id,
kind_id,
tri_id,
));
let tri_label_id = self.id_gen.next();
let mut tri_block = Block::new(tri_label_id);
let merge_label_id = self.id_gen.next();
let merge_block = Block::new(merge_label_id);
// t
{
let block = if is_committed {
&mut not_none_block
} else {
&mut tri_block
};
let t_id = self.id_gen.next();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionTKHR,
scalar_type_id,
t_id,
query_id,
intersection_id,
));
let idx_id = self.get_index_constant(1);
let access_idx = self.id_gen.next();
block.body.push(Instruction::access_chain(
float_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
block.body.push(Instruction::store(access_idx, t_id, None));
}
not_none_block.body.push(Instruction::selection_merge(
merge_label_id,
spirv::SelectionControl::NONE,
));
function.consume(
not_none_block,
Instruction::branch_conditional(not_none_comp_id, tri_label_id, merge_label_id),
);
let barycentrics_id = self.id_gen.next();
tri_block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
barycentrics_type_id,
barycentrics_id,
query_id,
intersection_id,
));
let front_face_id = self.id_gen.next();
tri_block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
bool_type_id,
front_face_id,
query_id,
intersection_id,
));
let idx_id = self.get_index_constant(7);
let access_idx = self.id_gen.next();
tri_block.body.push(Instruction::access_chain(
barycentrics_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
tri_block
.body
.push(Instruction::store(access_idx, barycentrics_id, None));
let idx_id = self.get_index_constant(8);
let access_idx = self.id_gen.next();
tri_block.body.push(Instruction::access_chain(
bool_pointer_type_id,
access_idx,
blank_intersection_id,
&[idx_id],
));
tri_block
.body
.push(Instruction::store(access_idx, front_face_id, None));
function.consume(tri_block, Instruction::branch(merge_label_id));
function.consume(merge_block, Instruction::branch(final_label_id));
let loaded_blank_intersection_id = self.id_gen.next();
final_block.body.push(Instruction::load(
intersection_type_id,
loaded_blank_intersection_id,
blank_intersection_id,
None,
));
function.consume(
final_block,
Instruction::return_value(loaded_blank_intersection_id),
);
function.to_words(&mut self.logical_layout.function_definitions);
Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
self.ray_get_intersection_function = Some(func_id);
func_id
}
}
impl BlockContext<'_> {
pub(super) fn write_ray_query_function(
@ -101,191 +612,4 @@ impl BlockContext<'_> {
crate::RayQueryFunction::Terminate => {}
}
}
pub(super) fn write_ray_query_get_intersection(
&mut self,
query: Handle<crate::Expression>,
block: &mut Block,
is_committed: bool,
) -> spirv::Word {
let query_id = self.cached[query];
let intersection_id =
self.writer
.get_constant_scalar(crate::Literal::U32(if is_committed {
spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR
} else {
spirv::RayQueryIntersection::RayQueryCandidateIntersectionKHR
} as _));
let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::U32),
)));
let raw_kind_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionTypeKHR,
flag_type_id,
raw_kind_id,
query_id,
intersection_id,
));
let kind_id = if is_committed {
// Nothing to do: the IR value matches `spirv::RayQueryCommittedIntersectionType`
raw_kind_id
} else {
// Remap from the candidate kind to IR
let condition_id = self.gen_id();
let committed_triangle_kind_id = self.writer.get_constant_scalar(crate::Literal::U32(
spirv::RayQueryCandidateIntersectionType::RayQueryCandidateIntersectionTriangleKHR
as _,
));
block.body.push(Instruction::binary(
spirv::Op::IEqual,
self.writer.get_bool_type_id(),
condition_id,
raw_kind_id,
committed_triangle_kind_id,
));
let kind_id = self.gen_id();
block.body.push(Instruction::select(
flag_type_id,
kind_id,
condition_id,
self.writer.get_constant_scalar(crate::Literal::U32(
crate::RayQueryIntersection::Triangle as _,
)),
self.writer.get_constant_scalar(crate::Literal::U32(
crate::RayQueryIntersection::Aabb as _,
)),
));
kind_id
};
let instance_custom_index_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
flag_type_id,
instance_custom_index_id,
query_id,
intersection_id,
));
let instance_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
flag_type_id,
instance_id,
query_id,
intersection_id,
));
let sbt_record_offset_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
flag_type_id,
sbt_record_offset_id,
query_id,
intersection_id,
));
let geometry_index_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
flag_type_id,
geometry_index_id,
query_id,
intersection_id,
));
let primitive_index_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
flag_type_id,
primitive_index_id,
query_id,
intersection_id,
));
let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::F32),
)));
let t_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionTKHR,
scalar_type_id,
t_id,
query_id,
intersection_id,
));
let barycentrics_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
})));
let barycentrics_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
barycentrics_type_id,
barycentrics_id,
query_id,
intersection_id,
));
let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(
NumericType::Scalar(crate::Scalar::BOOL),
)));
let front_face_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
bool_type_id,
front_face_id,
query_id,
intersection_id,
));
//Note: there is also `OpRayQueryGetIntersectionCandidateAABBOpaqueKHR`,
// but it's not a property of an intersection.
let transform_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
})));
let object_to_world_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
transform_type_id,
object_to_world_id,
query_id,
intersection_id,
));
let world_to_object_id = self.gen_id();
block.body.push(Instruction::ray_query_get_intersection(
spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
transform_type_id,
world_to_object_id,
query_id,
intersection_id,
));
let id = self.gen_id();
let intersection_type_id = self.get_type_id(LookupType::Handle(
self.ir_module.special_types.ray_intersection.unwrap(),
));
//Note: the arguments must match `generate_ray_intersection_type` layout
block.body.push(Instruction::composite_construct(
intersection_type_id,
id,
&[
kind_id,
t_id,
instance_custom_index_id,
instance_id,
sbt_record_offset_id,
geometry_index_id,
primitive_index_id,
barycentrics_id,
front_face_id,
object_to_world_id,
world_to_object_id,
],
));
id
}
}

View File

@ -21,7 +21,7 @@ struct FunctionInterface<'a> {
}
impl Function {
fn to_words(&self, sink: &mut impl Extend<Word>) {
pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
self.signature.as_ref().unwrap().to_words(sink);
for argument in self.parameters.iter() {
argument.instruction.to_words(sink);
@ -81,6 +81,7 @@ impl Writer {
saved_cached: CachedExpressions::default(),
gl450_ext_inst_id,
temp_list: Vec::new(),
ray_get_intersection_function: None,
})
}
@ -131,6 +132,7 @@ impl Writer {
global_variables: take(&mut self.global_variables).recycle(),
saved_cached: take(&mut self.saved_cached).recycle(),
temp_list: take(&mut self.temp_list).recycle(),
ray_get_intersection_function: None,
};
*self = fresh;
@ -1846,7 +1848,7 @@ impl Writer {
Ok(())
}
fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
pub(super) fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
match self
.lookup_function_type
.entry(lookup_function_type.clone())

View File

@ -1,16 +1,16 @@
; SPIR-V
; Version: 1.4
; Generator: rspirv
; Bound: 136
; Bound: 160
OpCapability Shader
OpCapability RayQueryKHR
OpExtension "SPV_KHR_ray_query"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %84 "main" %15 %17
OpEntryPoint GLCompute %105 "main_candidate" %15 %17
OpExecutionMode %84 LocalSize 1 1 1
OpExecutionMode %105 LocalSize 1 1 1
OpEntryPoint GLCompute %123 "main" %15 %17
OpEntryPoint GLCompute %143 "main_candidate" %15 %17
OpExecutionMode %123 LocalSize 1 1 1
OpExecutionMode %143 LocalSize 1 1 1
OpMemberDecorate %10 0 Offset 0
OpMemberDecorate %10 1 Offset 4
OpMemberDecorate %10 2 Offset 8
@ -64,20 +64,87 @@ OpMemberDecorate %18 0 Offset 0
%29 = OpConstant %3 0.1
%30 = OpConstant %3 100.0
%32 = OpTypePointer Function %11
%50 = OpConstant %6 1
%67 = OpTypeFunction %4 %4 %10
%68 = OpConstant %3 1.0
%69 = OpConstant %3 2.4
%70 = OpConstant %3 0.0
%85 = OpTypeFunction %2
%87 = OpTypePointer StorageBuffer %13
%88 = OpConstant %6 0
%90 = OpConstantComposite %4 %70 %70 %70
%91 = OpConstantComposite %4 %70 %68 %70
%94 = OpTypePointer StorageBuffer %6
%99 = OpTypePointer StorageBuffer %4
%108 = OpConstantComposite %12 %27 %28 %29 %30 %90 %91
%109 = OpConstant %6 3
%50 = OpTypePointer Function %10
%51 = OpTypePointer Function %6
%52 = OpTypePointer Function %9
%53 = OpTypePointer Function %7
%54 = OpTypePointer Function %8
%55 = OpTypePointer Function %3
%56 = OpTypeFunction %10 %32
%58 = OpConstantNull %10
%62 = OpConstant %6 1
%64 = OpConstant %6 0
%76 = OpConstant %6 2
%78 = OpConstant %6 3
%81 = OpConstant %6 5
%83 = OpConstant %6 6
%85 = OpConstant %6 9
%87 = OpConstant %6 10
%96 = OpConstant %6 7
%98 = OpConstant %6 8
%106 = OpTypeFunction %4 %4 %10
%107 = OpConstant %3 1.0
%108 = OpConstant %3 2.4
%109 = OpConstant %3 0.0
%124 = OpTypeFunction %2
%126 = OpTypePointer StorageBuffer %13
%128 = OpConstantComposite %4 %109 %109 %109
%129 = OpConstantComposite %4 %109 %107 %109
%132 = OpTypePointer StorageBuffer %6
%137 = OpTypePointer StorageBuffer %4
%146 = OpConstantComposite %12 %27 %28 %29 %30 %128 %129
%57 = OpFunction %10 None %56
%59 = OpFunctionParameter %32
%60 = OpLabel
%61 = OpVariable %50 Function %58
%63 = OpRayQueryGetIntersectionTypeKHR %6 %59 %62
%65 = OpAccessChain %51 %61 %64
OpStore %65 %63
%66 = OpINotEqual %8 %63 %64
OpSelectionMerge %68 None
OpBranchConditional %66 %67 %68
%67 = OpLabel
%69 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %59 %62
%70 = OpRayQueryGetIntersectionInstanceIdKHR %6 %59 %62
%71 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %59 %62
%72 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %59 %62
%73 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %59 %62
%74 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %59 %62
%75 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %59 %62
%77 = OpAccessChain %51 %61 %76
OpStore %77 %69
%79 = OpAccessChain %51 %61 %78
OpStore %79 %70
%80 = OpAccessChain %51 %61 %27
OpStore %80 %71
%82 = OpAccessChain %51 %61 %81
OpStore %82 %72
%84 = OpAccessChain %51 %61 %83
OpStore %84 %73
%86 = OpAccessChain %52 %61 %85
OpStore %86 %74
%88 = OpAccessChain %52 %61 %87
OpStore %88 %75
%89 = OpIEqual %8 %63 %62
%92 = OpRayQueryGetIntersectionTKHR %3 %59 %62
%93 = OpAccessChain %55 %61 %62
OpStore %93 %92
OpSelectionMerge %91 None
OpBranchConditional %66 %90 %91
%90 = OpLabel
%94 = OpRayQueryGetIntersectionBarycentricsKHR %7 %59 %62
%95 = OpRayQueryGetIntersectionFrontFaceKHR %8 %59 %62
%97 = OpAccessChain %53 %61 %96
OpStore %97 %94
%99 = OpAccessChain %54 %61 %98
OpStore %99 %95
OpBranch %91
%91 = OpLabel
OpBranch %68
%68 = OpLabel
%100 = OpLoad %10 %61
OpReturnValue %100
OpFunctionEnd
%25 = OpFunction %10 None %26
%21 = OpFunctionParameter %4
%22 = OpFunctionParameter %4
@ -114,90 +181,66 @@ OpBranch %44
%44 = OpLabel
OpBranch %41
%42 = OpLabel
%51 = OpRayQueryGetIntersectionTypeKHR %6 %31 %50
%52 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %31 %50
%53 = OpRayQueryGetIntersectionInstanceIdKHR %6 %31 %50
%54 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %31 %50
%55 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %31 %50
%56 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %31 %50
%57 = OpRayQueryGetIntersectionTKHR %3 %31 %50
%58 = OpRayQueryGetIntersectionBarycentricsKHR %7 %31 %50
%59 = OpRayQueryGetIntersectionFrontFaceKHR %8 %31 %50
%60 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %31 %50
%61 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %31 %50
%62 = OpCompositeConstruct %10 %51 %57 %52 %53 %54 %55 %56 %58 %59 %60 %61
OpReturnValue %62
%101 = OpFunctionCall %10 %57 %31
OpReturnValue %101
OpFunctionEnd
%66 = OpFunction %4 None %67
%64 = OpFunctionParameter %4
%65 = OpFunctionParameter %10
%63 = OpLabel
OpBranch %71
%71 = OpLabel
%72 = OpCompositeExtract %9 %65 10
%73 = OpCompositeConstruct %14 %64 %68
%74 = OpMatrixTimesVector %4 %72 %73
%75 = OpVectorShuffle %7 %74 %74 0 1
%76 = OpExtInst %7 %1 Normalize %75
%77 = OpVectorTimesScalar %7 %76 %69
%78 = OpCompositeExtract %9 %65 9
%79 = OpCompositeConstruct %14 %77 %70 %68
%80 = OpMatrixTimesVector %4 %78 %79
%81 = OpFSub %4 %64 %80
%82 = OpExtInst %4 %1 Normalize %81
OpReturnValue %82
%105 = OpFunction %4 None %106
%103 = OpFunctionParameter %4
%104 = OpFunctionParameter %10
%102 = OpLabel
OpBranch %110
%110 = OpLabel
%111 = OpCompositeExtract %9 %104 10
%112 = OpCompositeConstruct %14 %103 %107
%113 = OpMatrixTimesVector %4 %111 %112
%114 = OpVectorShuffle %7 %113 %113 0 1
%115 = OpExtInst %7 %1 Normalize %114
%116 = OpVectorTimesScalar %7 %115 %108
%117 = OpCompositeExtract %9 %104 9
%118 = OpCompositeConstruct %14 %116 %109 %107
%119 = OpMatrixTimesVector %4 %117 %118
%120 = OpFSub %4 %103 %119
%121 = OpExtInst %4 %1 Normalize %120
OpReturnValue %121
OpFunctionEnd
%84 = OpFunction %2 None %85
%83 = OpLabel
%86 = OpLoad %5 %15
%89 = OpAccessChain %87 %17 %88
OpBranch %92
%92 = OpLabel
%93 = OpFunctionCall %10 %25 %90 %91 %15
%95 = OpCompositeExtract %6 %93 0
%96 = OpIEqual %8 %95 %88
%97 = OpSelect %6 %96 %50 %88
%98 = OpAccessChain %94 %89 %88
OpStore %98 %97
%100 = OpCompositeExtract %3 %93 1
%101 = OpVectorTimesScalar %4 %91 %100
%102 = OpFunctionCall %4 %66 %101 %93
%103 = OpAccessChain %99 %89 %50
OpStore %103 %102
%123 = OpFunction %2 None %124
%122 = OpLabel
%125 = OpLoad %5 %15
%127 = OpAccessChain %126 %17 %64
OpBranch %130
%130 = OpLabel
%131 = OpFunctionCall %10 %25 %128 %129 %15
%133 = OpCompositeExtract %6 %131 0
%134 = OpIEqual %8 %133 %64
%135 = OpSelect %6 %134 %62 %64
%136 = OpAccessChain %132 %127 %64
OpStore %136 %135
%138 = OpCompositeExtract %3 %131 1
%139 = OpVectorTimesScalar %4 %129 %138
%140 = OpFunctionCall %4 %105 %139 %131
%141 = OpAccessChain %137 %127 %62
OpStore %141 %140
OpReturn
OpFunctionEnd
%105 = OpFunction %2 None %85
%104 = OpLabel
%110 = OpVariable %32 Function
%106 = OpLoad %5 %15
%107 = OpAccessChain %87 %17 %88
OpBranch %111
%111 = OpLabel
%112 = OpCompositeExtract %6 %108 0
%113 = OpCompositeExtract %6 %108 1
%114 = OpCompositeExtract %3 %108 2
%115 = OpCompositeExtract %3 %108 3
%116 = OpCompositeExtract %4 %108 4
%117 = OpCompositeExtract %4 %108 5
OpRayQueryInitializeKHR %110 %106 %112 %113 %116 %114 %117 %115
%118 = OpRayQueryGetIntersectionTypeKHR %6 %110 %88
%119 = OpIEqual %8 %118 %88
%120 = OpSelect %6 %119 %50 %109
%121 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %6 %110 %88
%122 = OpRayQueryGetIntersectionInstanceIdKHR %6 %110 %88
%123 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %6 %110 %88
%124 = OpRayQueryGetIntersectionGeometryIndexKHR %6 %110 %88
%125 = OpRayQueryGetIntersectionPrimitiveIndexKHR %6 %110 %88
%126 = OpRayQueryGetIntersectionTKHR %3 %110 %88
%127 = OpRayQueryGetIntersectionBarycentricsKHR %7 %110 %88
%128 = OpRayQueryGetIntersectionFrontFaceKHR %8 %110 %88
%129 = OpRayQueryGetIntersectionObjectToWorldKHR %9 %110 %88
%130 = OpRayQueryGetIntersectionWorldToObjectKHR %9 %110 %88
%131 = OpCompositeConstruct %10 %120 %126 %121 %122 %123 %124 %125 %127 %128 %129 %130
%132 = OpCompositeExtract %6 %131 0
%133 = OpIEqual %8 %132 %109
%134 = OpSelect %6 %133 %50 %88
%135 = OpAccessChain %94 %107 %88
OpStore %135 %134
%143 = OpFunction %2 None %124
%142 = OpLabel
%147 = OpVariable %32 Function
%144 = OpLoad %5 %15
%145 = OpAccessChain %126 %17 %64
OpBranch %148
%148 = OpLabel
%149 = OpCompositeExtract %6 %146 0
%150 = OpCompositeExtract %6 %146 1
%151 = OpCompositeExtract %3 %146 2
%152 = OpCompositeExtract %3 %146 3
%153 = OpCompositeExtract %4 %146 4
%154 = OpCompositeExtract %4 %146 5
OpRayQueryInitializeKHR %147 %144 %149 %150 %153 %151 %154 %152
%155 = OpFunctionCall %10 %57 %147
%156 = OpCompositeExtract %6 %155 0
%157 = OpIEqual %8 %156 %78
%158 = OpSelect %6 %157 %62 %64
%159 = OpAccessChain %132 %145 %64
OpStore %159 %158
OpReturn
OpFunctionEnd

View File

@ -1,88 +1,12 @@
use std::{iter, mem};
use std::iter;
use wgpu::{
util::{BufferInitDescriptor, DeviceExt},
*,
};
use crate::ray_tracing::AsBuildContext;
use wgpu::util::{BufferInitDescriptor, DeviceExt};
use wgpu::*;
use wgpu_test::{
fail, gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext,
};
struct AsBuildContext {
vertices: Buffer,
blas_size: BlasTriangleGeometrySizeDescriptor,
blas: Blas,
// Putting this last, forces the BLAS to die before the TLAS.
tlas_package: TlasPackage,
}
impl AsBuildContext {
fn new(ctx: &TestingContext) -> Self {
let vertices = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &[0; mem::size_of::<[[f32; 3]; 3]>()],
usage: BufferUsages::BLAS_INPUT,
});
let blas_size = BlasTriangleGeometrySizeDescriptor {
vertex_format: VertexFormat::Float32x3,
vertex_count: 3,
index_format: None,
index_count: None,
flags: AccelerationStructureGeometryFlags::empty(),
};
let blas = ctx.device.create_blas(
&CreateBlasDescriptor {
label: Some("BLAS"),
flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: AccelerationStructureUpdateMode::Build,
},
BlasGeometrySizeDescriptors::Triangles {
descriptors: vec![blas_size.clone()],
},
);
let tlas = ctx.device.create_tlas(&CreateTlasDescriptor {
label: Some("TLAS"),
max_instances: 1,
flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: AccelerationStructureUpdateMode::Build,
});
let mut tlas_package = TlasPackage::new(tlas);
tlas_package[0] = Some(TlasInstance::new(
&blas,
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
0,
0xFF,
));
Self {
vertices,
blas_size,
blas,
tlas_package,
}
}
fn blas_build_entry(&self) -> BlasBuildEntry {
BlasBuildEntry {
blas: &self.blas,
geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry {
size: &self.blas_size,
vertex_buffer: &self.vertices,
first_vertex: 0,
vertex_stride: mem::size_of::<[f32; 3]>() as BufferAddress,
index_buffer: None,
first_index: None,
transform_buffer: None,
transform_buffer_offset: None,
}]),
}
}
}
#[gpu_test]
static UNBUILT_BLAS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
@ -256,7 +180,7 @@ fn out_of_order_as_build_use(ctx: TestingContext) {
label: None,
layout: None,
module: &shader,
entry_point: Some("comp_main"),
entry_point: Some("basic_usage"),
compilation_options: Default::default(),
cache: None,
});
@ -343,7 +267,7 @@ static BUILD_WITH_TRANSFORM: GpuTestConfiguration = GpuTestConfiguration::new()
fn build_with_transform(ctx: TestingContext) {
let vertices = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &[0; mem::size_of::<[[f32; 3]; 3]>()],
contents: &[0; size_of::<[[f32; 3]; 3]>()],
usage: BufferUsages::BLAS_INPUT,
});
@ -404,7 +328,7 @@ fn build_with_transform(ctx: TestingContext) {
size: &blas_size,
vertex_buffer: &vertices,
first_vertex: 0,
vertex_stride: mem::size_of::<[f32; 3]>() as BufferAddress,
vertex_stride: size_of::<[f32; 3]>() as BufferAddress,
index_buffer: None,
first_index: None,
transform_buffer: Some(&transform),

View File

@ -108,7 +108,7 @@ fn acceleration_structure_use_after_free(ctx: TestingContext) {
label: None,
layout: None,
module: &shader,
entry_point: Some("comp_main"),
entry_point: Some("basic_usage"),
compilation_options: Default::default(),
cache: None,
});

View File

@ -1,4 +1,93 @@
use std::mem;
use wgpu::util::BufferInitDescriptor;
use wgpu::{
util::DeviceExt, Blas, BlasBuildEntry, BlasGeometries, BlasGeometrySizeDescriptors,
BlasTriangleGeometry, BlasTriangleGeometrySizeDescriptor, Buffer, CreateBlasDescriptor,
CreateTlasDescriptor, TlasInstance, TlasPackage,
};
use wgpu_test::TestingContext;
use wgt::{
AccelerationStructureFlags, AccelerationStructureGeometryFlags,
AccelerationStructureUpdateMode, BufferAddress, BufferUsages, VertexFormat,
};
mod as_build;
mod as_create;
mod as_use_after_free;
mod scene;
mod shader;
pub struct AsBuildContext {
vertices: Buffer,
blas_size: BlasTriangleGeometrySizeDescriptor,
blas: Blas,
// Putting this last, forces the BLAS to die before the TLAS.
tlas_package: TlasPackage,
}
impl AsBuildContext {
pub fn new(ctx: &TestingContext) -> Self {
let vertices = ctx.device.create_buffer_init(&BufferInitDescriptor {
label: None,
contents: &[0; mem::size_of::<[[f32; 3]; 3]>()],
usage: BufferUsages::BLAS_INPUT,
});
let blas_size = BlasTriangleGeometrySizeDescriptor {
vertex_format: VertexFormat::Float32x3,
vertex_count: 3,
index_format: None,
index_count: None,
flags: AccelerationStructureGeometryFlags::empty(),
};
let blas = ctx.device.create_blas(
&CreateBlasDescriptor {
label: Some("BLAS"),
flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: AccelerationStructureUpdateMode::Build,
},
BlasGeometrySizeDescriptors::Triangles {
descriptors: vec![blas_size.clone()],
},
);
let tlas = ctx.device.create_tlas(&CreateTlasDescriptor {
label: Some("TLAS"),
max_instances: 1,
flags: AccelerationStructureFlags::PREFER_FAST_TRACE,
update_mode: AccelerationStructureUpdateMode::Build,
});
let mut tlas_package = TlasPackage::new(tlas);
tlas_package[0] = Some(TlasInstance::new(
&blas,
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
0,
0xFF,
));
Self {
vertices,
blas_size,
blas,
tlas_package,
}
}
pub fn blas_build_entry(&self) -> BlasBuildEntry {
BlasBuildEntry {
blas: &self.blas,
geometry: BlasGeometries::TriangleGeometries(vec![BlasTriangleGeometry {
size: &self.blas_size,
vertex_buffer: &self.vertices,
first_vertex: 0,
vertex_stride: mem::size_of::<[f32; 3]>() as BufferAddress,
index_buffer: None,
first_index: None,
transform_buffer: None,
transform_buffer_offset: None,
}]),
}
}
}

View File

@ -0,0 +1,95 @@
use crate::ray_tracing::AsBuildContext;
use wgpu::{
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferDescriptor,
CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor,
};
use wgpu_macros::gpu_test;
use wgpu_test::{GpuTestConfiguration, TestParameters, TestingContext};
use wgt::BufferUsages;
const STRUCT_SIZE: wgt::BufferAddress = 176;
#[gpu_test]
static ACCESS_ALL_STRUCT_MEMBERS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits().features(
wgpu::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE
| wgpu::Features::EXPERIMENTAL_RAY_QUERY,
))
.run_sync(access_all_struct_members);
fn access_all_struct_members(ctx: TestingContext) {
let buf = ctx.device.create_buffer(&BufferDescriptor {
label: None,
size: STRUCT_SIZE,
usage: BufferUsages::STORAGE,
mapped_at_creation: false,
});
//
// Create a clean `AsBuildContext`
//
let as_ctx = AsBuildContext::new(&ctx);
let mut encoder_build = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("Build"),
});
encoder_build
.build_acceleration_structures([&as_ctx.blas_build_entry()], [&as_ctx.tlas_package]);
ctx.queue.submit([encoder_build.finish()]);
//
// Create shader to use tlas with
//
let shader = ctx
.device
.create_shader_module(include_wgsl!("shader.wgsl"));
let compute_pipeline = ctx
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: None,
layout: None,
module: &shader,
entry_point: Some("all_of_struct"),
compilation_options: Default::default(),
cache: None,
});
let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &compute_pipeline.get_bind_group_layout(0),
entries: &[
BindGroupEntry {
binding: 0,
resource: BindingResource::AccelerationStructure(as_ctx.tlas_package.tlas()),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::Buffer(buf.as_entire_buffer_binding()),
},
],
});
//
// Submit once to check for no issues
//
let mut encoder_compute = ctx
.device
.create_command_encoder(&CommandEncoderDescriptor::default());
{
let mut pass = encoder_compute.begin_compute_pass(&ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
pass.set_pipeline(&compute_pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1)
}
ctx.queue.submit([encoder_compute.finish()]);
}

View File

@ -1,11 +1,51 @@
@group(0) @binding(0)
var acc_struct: acceleration_structure;
struct Intersection {
kind: u32,
t: f32,
instance_custom_index: u32,
instance_id: u32,
sbt_record_offset: u32,
geometry_index: u32,
primitive_index: u32,
barycentrics: vec2<f32>,
front_face: u32,
object_to_world: mat4x3<f32>,
world_to_object: mat4x3<f32>,
}
@group(0) @binding(1)
var<storage, read_write> out: Intersection;
@workgroup_size(1)
@compute
fn comp_main() {
fn basic_usage() {
var rq: ray_query;
rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.001, 100000.0, vec3f(0.0, 0.0, 0.0), vec3f(0.0, 0.0, 1.0)));
rayQueryProceed(&rq);
let intersection = rayQueryGetCommittedIntersection(&rq);
}
@workgroup_size(1)
@compute
fn all_of_struct() {
var rq: ray_query;
rayQueryInitialize(&rq, acc_struct, RayDesc(0u, 0xFFu, 0.0, 0.0, vec3f(0.0, 0.0, 1.0), vec3f(0.0, 0.0, 1.0)));
rayQueryProceed(&rq);
let intersection = rayQueryGetCommittedIntersection(&rq);
// this prevents optimisation as we use the fields
out = Intersection(
intersection.kind,
intersection.t,
intersection.instance_custom_index,
intersection.instance_id,
intersection.sbt_record_offset,
intersection.geometry_index,
intersection.primitive_index,
intersection.barycentrics,
u32(intersection.front_face),
intersection.world_to_object,
intersection.object_to_world,
);
}