Prevent naga crashing on an aliased ray query. (#7759)

This commit is contained in:
Vecvec 2025-06-16 20:45:39 +12:00 committed by GitHub
parent bbb7cc79ef
commit 03775c54fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 377 additions and 20 deletions

View File

@ -82,6 +82,7 @@ Bottom level categories:
- Fix typing for `select`, which had issues particularly with a lack of automatic type conversion. By @ErichDonGubler in [#7572](https://github.com/gfx-rs/wgpu/pull/7572).
- Allow scalars as the first argument of the `distance` built-in function. By @bernhl in [#7530](https://github.com/gfx-rs/wgpu/pull/7530).
- Don't panic when handling `f16` for pipeline constants, i.e., `override`s in WGSL. By @ErichDonGubler in [#7801](https://github.com/gfx-rs/wgpu/pull/7801).
- Prevent aliased ray queries crashing naga when writing SPIR-V out. By @Vecvec in [#7759](https://github.com/gfx-rs/wgpu/pull/7759).
#### DX12

View File

@ -54,7 +54,7 @@ impl Writer {
let scalar_type_id = self.get_f32_type_id();
let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
let argument_type_id = self.get_ray_query_pointer_id(ir_module);
let argument_type_id = self.get_ray_query_pointer_id();
let func_ty = self.get_function_type(LookupFunctionType {
parameter_type_ids: vec![argument_type_id],

View File

@ -301,25 +301,9 @@ impl Writer {
self.get_pointer_type_id(base_id, class)
}
pub(super) fn get_ray_query_pointer_id(&mut self, module: &crate::Module) -> Word {
let rq_ty = module
.types
.get(&crate::Type {
name: None,
inner: crate::TypeInner::RayQuery {
vertex_return: false,
},
})
.or_else(|| {
module.types.get(&crate::Type {
name: None,
inner: crate::TypeInner::RayQuery {
vertex_return: true,
},
})
})
.expect("ray_query type should have been populated by the variable passed into this!");
self.get_handle_pointer_type_id(rq_ty, spirv::StorageClass::Function)
pub(super) fn get_ray_query_pointer_id(&mut self) -> Word {
let rq_id = self.get_type_id(LookupType::Local(LocalType::RayQuery));
self.get_pointer_type_id(rq_id, spirv::StorageClass::Function)
}
/// Return a SPIR-V type for a pointer to `resolution`.

View File

@ -0,0 +1,16 @@
god_mode = true
targets = "SPIRV | METAL | HLSL"
[msl]
fake_missing_bindings = true
lang_version = [2, 4]
spirv_cross_compatibility = false
zero_initialize_workgroup_memory = false
[hlsl]
shader_model = "V6_5"
fake_missing_bindings = true
zero_initialize_workgroup_memory = true
[spv]
version = [1, 4]

View File

@ -0,0 +1,21 @@
alias rq = ray_query;
@group(0) @binding(0)
var acc_struct: acceleration_structure;
@compute @workgroup_size(1)
fn main_candidate() {
let pos = vec3<f32>(0.0);
let dir = vec3<f32>(0.0, 1.0, 0.0);
var rq: rq;
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, pos, dir));
let intersection = rayQueryGetCandidateIntersection(&rq);
if (intersection.kind == RAY_QUERY_INTERSECTION_AABB) {
rayQueryGenerateIntersection(&rq, 10.0);
} else if (intersection.kind == RAY_QUERY_INTERSECTION_TRIANGLE) {
rayQueryConfirmIntersection(&rq);
} else {
rayQueryTerminate(&rq);
}
}

View File

@ -0,0 +1,94 @@
struct RayDesc_ {
uint flags;
uint cull_mask;
float tmin;
float tmax;
float3 origin;
int _pad5_0;
float3 dir;
int _end_pad_0;
};
struct RayIntersection {
uint kind;
float t;
uint instance_custom_data;
uint instance_index;
uint sbt_record_offset;
uint geometry_index;
uint primitive_index;
float2 barycentrics;
bool front_face;
int _pad9_0;
int _pad9_1;
row_major float4x3 object_to_world;
int _pad10_0;
row_major float4x3 world_to_object;
int _end_pad_0;
};
RayDesc RayDescFromRayDesc_(RayDesc_ arg0) {
RayDesc ret = (RayDesc)0;
ret.Origin = arg0.origin;
ret.TMin = arg0.tmin;
ret.Direction = arg0.dir;
ret.TMax = arg0.tmax;
return ret;
}
RaytracingAccelerationStructure acc_struct : register(t0);
RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 arg4, float3 arg5) {
RayDesc_ ret = (RayDesc_)0;
ret.flags = arg0;
ret.cull_mask = arg1;
ret.tmin = arg2;
ret.tmax = arg3;
ret.origin = arg4;
ret.dir = arg5;
return ret;
}
RayIntersection GetCandidateIntersection(RayQuery<RAY_FLAG_NONE> rq) {
RayIntersection ret = (RayIntersection)0;
CANDIDATE_TYPE kind = rq.CandidateType();
if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {
ret.kind = 1;
ret.t = rq.CandidateTriangleRayT();
ret.barycentrics = rq.CandidateTriangleBarycentrics();
ret.front_face = rq.CandidateTriangleFrontFace();
} else {
ret.kind = 3;
}
ret.instance_custom_data = rq.CandidateInstanceID();
ret.instance_index = rq.CandidateInstanceIndex();
ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();
ret.geometry_index = rq.CandidateGeometryIndex();
ret.primitive_index = rq.CandidatePrimitiveIndex();
ret.object_to_world = rq.CandidateObjectToWorld4x3();
ret.world_to_object = rq.CandidateWorldToObject4x3();
return ret;
}
[numthreads(1, 1, 1)]
void main_candidate()
{
RayQuery<RAY_FLAG_NONE> rq_1;
float3 pos = (0.0).xxx;
float3 dir = float3(0.0, 1.0, 0.0);
rq_1.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir)));
RayIntersection intersection = GetCandidateIntersection(rq_1);
if ((intersection.kind == 3u)) {
rq_1.CommitProceduralPrimitiveHit(10.0);
return;
} else {
if ((intersection.kind == 1u)) {
rq_1.CommitNonOpaqueTriangleHit();
return;
} else {
rq_1.Abort();
return;
}
}
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main_candidate",
target_profile:"cs_6_5",
),
],
)

View File

@ -0,0 +1,62 @@
// language: metal2.4
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct _RayQuery {
metal::raytracing::intersector<metal::raytracing::instancing, metal::raytracing::triangle_data, metal::raytracing::world_space_data> intersector;
metal::raytracing::intersector<metal::raytracing::instancing, metal::raytracing::triangle_data, metal::raytracing::world_space_data>::result_type intersection;
bool ready = false;
};
constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) {
return ty==metal::raytracing::intersection_type::triangle ? 1 :
ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0;
}
struct RayDesc {
uint flags;
uint cull_mask;
float tmin;
float tmax;
metal::float3 origin;
metal::float3 dir;
};
struct RayIntersection {
uint kind;
float t;
uint instance_custom_data;
uint instance_index;
uint sbt_record_offset;
uint geometry_index;
uint primitive_index;
metal::float2 barycentrics;
bool front_face;
char _pad9[11];
metal::float4x3 object_to_world;
metal::float4x3 world_to_object;
};
kernel void main_candidate(
metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]]
) {
_RayQuery rq_1 = {};
metal::float3 pos = metal::float3(0.0);
metal::float3 dir = metal::float3(0.0, 1.0, 0.0);
RayDesc _e12 = RayDesc {4u, 255u, 0.1, 100.0, pos, dir};
rq_1.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle);
rq_1.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none);
rq_1.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none);
rq_1.intersector.accept_any_intersection((_e12.flags & 4) != 0);
rq_1.intersection = rq_1.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq_1.ready = true;
RayIntersection intersection = RayIntersection {_map_intersection_type(rq_1.intersection.type), rq_1.intersection.distance, rq_1.intersection.user_instance_id, rq_1.intersection.instance_id, {}, rq_1.intersection.geometry_id, rq_1.intersection.primitive_id, rq_1.intersection.triangle_barycentric_coord, rq_1.intersection.triangle_front_facing, {}, rq_1.intersection.object_to_world_transform, rq_1.intersection.world_to_object_transform};
if (intersection.kind == 3u) {
return;
} else {
if (intersection.kind == 1u) {
return;
} else {
rq_1.ready = false;
return;
}
}
}

View File

@ -0,0 +1,167 @@
; SPIR-V
; Version: 1.4
; Generator: rspirv
; Bound: 102
OpCapability Shader
OpCapability RayQueryKHR
OpExtension "SPV_KHR_ray_query"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %16 "main_candidate" %13
OpExecutionMode %16 LocalSize 1 1 1
OpMemberDecorate %8 0 Offset 0
OpMemberDecorate %8 1 Offset 4
OpMemberDecorate %8 2 Offset 8
OpMemberDecorate %8 3 Offset 12
OpMemberDecorate %8 4 Offset 16
OpMemberDecorate %8 5 Offset 32
OpMemberDecorate %12 0 Offset 0
OpMemberDecorate %12 1 Offset 4
OpMemberDecorate %12 2 Offset 8
OpMemberDecorate %12 3 Offset 12
OpMemberDecorate %12 4 Offset 16
OpMemberDecorate %12 5 Offset 20
OpMemberDecorate %12 6 Offset 24
OpMemberDecorate %12 7 Offset 28
OpMemberDecorate %12 8 Offset 36
OpMemberDecorate %12 9 Offset 48
OpMemberDecorate %12 9 ColMajor
OpMemberDecorate %12 9 MatrixStride 16
OpMemberDecorate %12 10 Offset 112
OpMemberDecorate %12 10 ColMajor
OpMemberDecorate %12 10 MatrixStride 16
OpDecorate %13 DescriptorSet 0
OpDecorate %13 Binding 0
%2 = OpTypeVoid
%3 = OpTypeRayQueryKHR
%4 = OpTypeAccelerationStructureNV
%5 = OpTypeFloat 32
%6 = OpTypeVector %5 3
%7 = OpTypeInt 32 0
%8 = OpTypeStruct %7 %7 %5 %5 %6 %6
%9 = OpTypeVector %5 2
%10 = OpTypeBool
%11 = OpTypeMatrix %6 4
%12 = OpTypeStruct %7 %5 %7 %7 %7 %7 %7 %9 %10 %11 %11
%14 = OpTypePointer UniformConstant %4
%13 = OpVariable %14 UniformConstant
%17 = OpTypeFunction %2
%19 = OpConstant %5 0.0
%20 = OpConstantComposite %6 %19 %19 %19
%21 = OpConstant %5 1.0
%22 = OpConstantComposite %6 %19 %21 %19
%23 = OpConstant %7 4
%24 = OpConstant %7 255
%25 = OpConstant %5 0.1
%26 = OpConstant %5 100.0
%27 = OpConstantComposite %8 %23 %24 %25 %26 %20 %22
%28 = OpConstant %7 3
%29 = OpConstant %5 10.0
%30 = OpConstant %7 1
%32 = OpTypePointer Function %3
%40 = OpTypePointer Function %12
%41 = OpTypePointer Function %7
%42 = OpTypePointer Function %11
%43 = OpTypePointer Function %9
%44 = OpTypePointer Function %10
%45 = OpTypePointer Function %5
%46 = OpTypeFunction %12 %32
%48 = OpConstantNull %12
%52 = OpConstant %7 0
%67 = OpConstant %7 2
%71 = OpConstant %7 5
%73 = OpConstant %7 6
%75 = OpConstant %7 9
%77 = OpConstant %7 10
%86 = OpConstant %7 7
%88 = OpConstant %7 8
%47 = OpFunction %12 None %46
%49 = OpFunctionParameter %32
%50 = OpLabel
%51 = OpVariable %40 Function %48
%53 = OpRayQueryGetIntersectionTypeKHR %7 %49 %52
%54 = OpIEqual %10 %53 %52
%55 = OpSelect %7 %54 %30 %28
%56 = OpAccessChain %41 %51 %52
OpStore %56 %55
%57 = OpINotEqual %10 %55 %52
OpSelectionMerge %59 None
OpBranchConditional %57 %58 %59
%58 = OpLabel
%60 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %7 %49 %52
%61 = OpRayQueryGetIntersectionInstanceIdKHR %7 %49 %52
%62 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %7 %49 %52
%63 = OpRayQueryGetIntersectionGeometryIndexKHR %7 %49 %52
%64 = OpRayQueryGetIntersectionPrimitiveIndexKHR %7 %49 %52
%65 = OpRayQueryGetIntersectionObjectToWorldKHR %11 %49 %52
%66 = OpRayQueryGetIntersectionWorldToObjectKHR %11 %49 %52
%68 = OpAccessChain %41 %51 %67
OpStore %68 %60
%69 = OpAccessChain %41 %51 %28
OpStore %69 %61
%70 = OpAccessChain %41 %51 %23
OpStore %70 %62
%72 = OpAccessChain %41 %51 %71
OpStore %72 %63
%74 = OpAccessChain %41 %51 %73
OpStore %74 %64
%76 = OpAccessChain %42 %51 %75
OpStore %76 %65
%78 = OpAccessChain %42 %51 %77
OpStore %78 %66
%79 = OpIEqual %10 %55 %30
OpSelectionMerge %81 None
OpBranchConditional %57 %80 %81
%80 = OpLabel
%82 = OpRayQueryGetIntersectionTKHR %5 %49 %52
%83 = OpAccessChain %45 %51 %30
OpStore %83 %82
%84 = OpRayQueryGetIntersectionBarycentricsKHR %9 %49 %52
%85 = OpRayQueryGetIntersectionFrontFaceKHR %10 %49 %52
%87 = OpAccessChain %43 %51 %86
OpStore %87 %84
%89 = OpAccessChain %44 %51 %88
OpStore %89 %85
OpBranch %81
%81 = OpLabel
OpBranch %59
%59 = OpLabel
%90 = OpLoad %12 %51
OpReturnValue %90
OpFunctionEnd
%16 = OpFunction %2 None %17
%15 = OpLabel
%31 = OpVariable %32 Function
%18 = OpLoad %4 %13
OpBranch %33
%33 = OpLabel
%34 = OpCompositeExtract %7 %27 0
%35 = OpCompositeExtract %7 %27 1
%36 = OpCompositeExtract %5 %27 2
%37 = OpCompositeExtract %5 %27 3
%38 = OpCompositeExtract %6 %27 4
%39 = OpCompositeExtract %6 %27 5
OpRayQueryInitializeKHR %31 %18 %34 %35 %38 %36 %39 %37
%91 = OpFunctionCall %12 %47 %31
%92 = OpCompositeExtract %7 %91 0
%93 = OpIEqual %10 %92 %28
OpSelectionMerge %94 None
OpBranchConditional %93 %95 %96
%95 = OpLabel
OpRayQueryGenerateIntersectionKHR %31 %29
OpReturn
%96 = OpLabel
%97 = OpCompositeExtract %7 %91 0
%98 = OpIEqual %10 %97 %30
OpSelectionMerge %99 None
OpBranchConditional %98 %100 %101
%100 = OpLabel
OpRayQueryConfirmIntersectionKHR %31
OpReturn
%101 = OpLabel
OpReturn
%99 = OpLabel
OpBranch %94
%94 = OpLabel
OpReturn
OpFunctionEnd