Disallow taking the address of a vector component (#7284)

This commit is contained in:
Andy Leiserson 2025-03-28 03:32:06 -07:00 committed by GitHub
parent efbfa36ded
commit c7c79a0dc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 182 additions and 112 deletions

View File

@ -218,6 +218,8 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)
The original names (e.g. `naga::Module`) remain present for compatibility.
By @kpreid in [#7365](https://github.com/gfx-rs/wgpu/pull/7365).
- Refactored `use` statements to simplify future `no_std` support. By @bushrat011899 in [#7256](https://github.com/gfx-rs/wgpu/pull/7256)
- Naga's WGSL frontend no longer allows using the `&` operator to take the address of a component of a vector,
which is not permitted by the WGSL specification. By @andyleiserson in [#7284](https://github.com/gfx-rs/wgpu/pull/7284)
#### Vulkan

View File

@ -215,6 +215,7 @@ pub(crate) enum Error<'a> {
},
DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAddrOfOperand(Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
InvalidRayQueryPointer(Span),
@ -676,6 +677,11 @@ impl<'a> Error<'a> {
)],
notes: vec![],
},
Error::InvalidAddrOfOperand(span) => ParseError {
message: "cannot take the address of a vector component".to_string(),
labels: vec![(span, "invalid operand for address-of".into())],
notes: vec![],
},
Error::InvalidAtomicPointer(span) => ParseError {
message: "atomic operation is done on a pointer to a non-atomic".to_string(),
labels: vec![(span, "atomic pointer is invalid".into())],

View File

@ -2077,6 +2077,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// reference is required, the Load Rule is not applied.
match self.expression_for_reference(expr, ctx)? {
Typed::Reference(handle) => {
let expr = &ctx.runtime_expression_ctx(span)?.function.expressions[handle];
if let &crate::Expression::Access { base, .. }
| &crate::Expression::AccessIndex { base, .. } = expr
{
if let Some(ty) = resolve_inner!(ctx, base).pointer_base_type() {
if matches!(
*ty.inner_with(&ctx.module.types),
crate::TypeInner::Vector { .. },
) {
return Err(Box::new(Error::InvalidAddrOfOperand(
ctx.get_expression_span(handle),
)));
}
}
}
// No code is generated. We just declare the reference a pointer now.
return Ok(Typed::Plain(handle));
}
@ -2149,30 +2164,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}
let temp_inner;
let temp_ty;
let composite_type: &crate::TypeInner = match lowered_base {
Typed::Reference(handle) => {
let inner = resolve_inner!(ctx, handle);
match *inner {
crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner,
crate::TypeInner::ValuePointer {
size: None, scalar, ..
} => {
temp_inner = crate::TypeInner::Scalar(scalar);
&temp_inner
}
crate::TypeInner::ValuePointer {
size: Some(size),
scalar,
..
} => {
temp_inner = crate::TypeInner::Vector { size, scalar };
&temp_inner
}
_ => unreachable!(
"In Typed::Reference(handle), handle must be a Naga pointer"
),
}
temp_ty = resolve_inner!(ctx, handle)
.pointer_base_type()
.expect("In Typed::Reference(handle), handle must be a Naga pointer");
temp_ty.inner_with(&ctx.module.types)
}
Typed::Plain(handle) => {

View File

@ -128,6 +128,25 @@ impl crate::TypeInner {
}
}
/// If `self` is a pointer type, return its base type.
pub const fn pointer_base_type(&self) -> Option<TypeResolution> {
match *self {
crate::TypeInner::Pointer { base, .. } => Some(TypeResolution::Handle(base)),
crate::TypeInner::ValuePointer {
size: None, scalar, ..
} => Some(TypeResolution::Value(crate::TypeInner::Scalar(scalar))),
crate::TypeInner::ValuePointer {
size: Some(size),
scalar,
..
} => Some(TypeResolution::Value(crate::TypeInner::Vector {
size,
scalar,
})),
_ => None,
}
}
pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
match *self {
crate::TypeInner::Pointer { base, .. } => match types[base].inner {

View File

@ -7,6 +7,7 @@ use super::{
};
use crate::arena::{Arena, UniqueArena};
use crate::arena::{Handle, HandleSet};
use crate::proc::TypeResolution;
use crate::span::WithSpan;
use crate::span::{AddSpan as _, MapErrWithSpan as _};
@ -299,6 +300,10 @@ impl<'a> BlockContext<'a> {
fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
self.info[handle].ty.inner_with(self.types)
}
fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
ty.inner_with(self.types)
}
}
impl super::Validator {
@ -1039,23 +1044,15 @@ impl super::Validator {
}
let pointer_ty = context.resolve_pointer_type(pointer);
let good = match *pointer_ty {
Ti::Pointer { base, space: _ } => match context.types[base].inner {
Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar),
ref other => value_ty == other,
},
Ti::ValuePointer {
size: Some(size),
scalar,
space: _,
} => *value_ty == Ti::Vector { size, scalar },
Ti::ValuePointer {
size: None,
scalar,
space: _,
} => *value_ty == Ti::Scalar(scalar),
_ => false,
let good = match pointer_ty
.pointer_base_type()
.as_ref()
.map(|ty| context.inner_type(ty))
{
// The Naga IR allows storing a scalar to an atomic.
Some(&Ti::Atomic(ref scalar)) => *value_ty == Ti::Scalar(*scalar),
Some(other) => *value_ty == *other,
None => false,
};
if !good {
return Err(FunctionError::InvalidStoreTypes { pointer, value }

View File

@ -1,7 +1,7 @@
fn f() {
var v: vec2<i32>;
let px = &v.x;
*px = 10;
var v: mat2x2<f32>;
let px = &v[0];
*px = vec2<f32>(10.0);
}
struct DynamicArray {

View File

@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.2
; Generator: rspirv
; Bound: 43
; Bound: 46
OpCapability Shader
OpCapability Linkage
OpExtension "SPV_KHR_storage_buffer_storage_class"
@ -9,9 +9,9 @@ OpExtension "SPV_KHR_storage_buffer_storage_class"
OpMemoryModel Logical GLSL450
%3 = OpString "pointers.wgsl"
OpSource Unknown 0 %3 "fn f() {
var v: vec2<i32>;
let px = &v.x;
*px = 10;
var v: mat2x2<f32>;
let px = &v[0];
*px = vec2<f32>(10.0);
}
struct DynamicArray {
@ -35,80 +35,84 @@ fn index_dynamic_array(i: i32, v: u32) {
(*p)[i] = val + v;
}
"
OpMemberName %8 0 "arr"
OpName %8 "DynamicArray"
OpName %9 "dynamic_array"
OpName %12 "f"
OpName %15 "v"
OpName %23 "i"
OpName %24 "v"
OpName %25 "index_unsized"
OpName %35 "i"
OpName %36 "v"
OpName %37 "index_dynamic_array"
OpDecorate %7 ArrayStride 4
OpMemberDecorate %8 0 Offset 0
OpDecorate %8 Block
OpDecorate %9 DescriptorSet 0
OpDecorate %9 Binding 0
OpMemberName %9 0 "arr"
OpName %9 "DynamicArray"
OpName %11 "dynamic_array"
OpName %14 "f"
OpName %18 "v"
OpName %26 "i"
OpName %27 "v"
OpName %28 "index_unsized"
OpName %38 "i"
OpName %39 "v"
OpName %40 "index_dynamic_array"
OpDecorate %8 ArrayStride 4
OpMemberDecorate %9 0 Offset 0
OpDecorate %9 Block
OpDecorate %11 DescriptorSet 0
OpDecorate %11 Binding 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%5 = OpTypeVector %4 2
%6 = OpTypeInt 32 0
%7 = OpTypeRuntimeArray %6
%8 = OpTypeStruct %7
%10 = OpTypePointer StorageBuffer %8
%9 = OpVariable %10 StorageBuffer
%13 = OpTypeFunction %2
%14 = OpConstant %4 10
%16 = OpTypePointer Function %5
%17 = OpConstantNull %5
%6 = OpTypeFloat 32
%5 = OpTypeVector %6 2
%4 = OpTypeMatrix %5 2
%7 = OpTypeInt 32 0
%8 = OpTypeRuntimeArray %7
%9 = OpTypeStruct %8
%10 = OpTypeInt 32 1
%12 = OpTypePointer StorageBuffer %9
%11 = OpVariable %12 StorageBuffer
%15 = OpTypeFunction %2
%16 = OpConstant %6 10.0
%17 = OpConstantComposite %5 %16 %16
%19 = OpTypePointer Function %4
%20 = OpConstant %6 0
%26 = OpTypeFunction %2 %4 %6
%28 = OpTypePointer StorageBuffer %7
%29 = OpTypePointer StorageBuffer %6
%12 = OpFunction %2 None %13
%11 = OpLabel
%15 = OpVariable %16 Function %17
OpBranch %18
%18 = OpLabel
%20 = OpConstantNull %4
%22 = OpTypePointer Function %5
%23 = OpConstant %7 0
%29 = OpTypeFunction %2 %10 %7
%31 = OpTypePointer StorageBuffer %8
%32 = OpTypePointer StorageBuffer %7
%14 = OpFunction %2 None %15
%13 = OpLabel
%18 = OpVariable %19 Function %20
OpBranch %21
%21 = OpLabel
OpLine %3 3 14
OpLine %3 4 10
OpLine %3 4 4
%21 = OpAccessChain %19 %15 %20
OpStore %21 %14
%24 = OpAccessChain %22 %18 %23
OpStore %24 %17
OpReturn
OpFunctionEnd
%25 = OpFunction %2 None %26
%23 = OpFunctionParameter %4
%24 = OpFunctionParameter %6
%22 = OpLabel
OpBranch %27
%27 = OpLabel
%28 = OpFunction %2 None %29
%26 = OpFunctionParameter %10
%27 = OpFunctionParameter %7
%25 = OpLabel
OpBranch %30
%30 = OpLabel
OpLine %3 17 14
%30 = OpAccessChain %29 %9 %20 %23
%31 = OpLoad %6 %30
%33 = OpAccessChain %32 %11 %23 %26
%34 = OpLoad %7 %33
OpLine %3 18 4
%32 = OpIAdd %6 %31 %24
%35 = OpIAdd %7 %34 %27
OpLine %3 18 4
%33 = OpAccessChain %29 %9 %20 %23
OpStore %33 %32
%36 = OpAccessChain %32 %11 %23 %26
OpStore %36 %35
OpReturn
OpFunctionEnd
%37 = OpFunction %2 None %26
%35 = OpFunctionParameter %4
%36 = OpFunctionParameter %6
%34 = OpLabel
OpBranch %38
%38 = OpLabel
%40 = OpFunction %2 None %29
%38 = OpFunctionParameter %10
%39 = OpFunctionParameter %7
%37 = OpLabel
OpBranch %41
%41 = OpLabel
OpLine %3 22 51
OpLine %3 24 14
%39 = OpAccessChain %29 %9 %20 %35
%40 = OpLoad %6 %39
%42 = OpAccessChain %32 %11 %23 %38
%43 = OpLoad %7 %42
OpLine %3 25 4
%41 = OpIAdd %6 %40 %36
%44 = OpIAdd %7 %43 %39
OpLine %3 25 4
%42 = OpAccessChain %29 %9 %20 %35
OpStore %42 %41
%45 = OpAccessChain %32 %11 %23 %38
OpStore %45 %44
OpReturn
OpFunctionEnd

View File

@ -6,10 +6,10 @@ struct DynamicArray {
var<storage, read_write> dynamic_array: DynamicArray;
fn f() {
var v: vec2<i32>;
var v: mat2x2<f32>;
let px = (&v.x);
(*px) = 10i;
let px = (&v[0]);
(*px) = vec2(10f);
return;
}

View File

@ -1368,14 +1368,13 @@ fn invalid_return_type() {
fn pointer_type_equivalence() {
check_validation! {
r#"
fn f(pv: ptr<function, vec2<f32>>, pf: ptr<function, f32>) { }
fn f(pv: ptr<function, vec2<f32>>) { }
fn g() {
var m: mat2x2<f32>;
let pv: ptr<function, vec2<f32>> = &m[0];
let pf: ptr<function, f32> = &m[0].x;
f(pv, pf);
f(pv);
}
"#:
Ok(_)
@ -3061,6 +3060,51 @@ fn reject_utf8_bom() {
);
}
#[test]
fn matrix_vector_pointers() {
check(
"fn foo() {
var v: vec2<f32>;
let p = &v[0];
}",
r#"error: cannot take the address of a vector component
wgsl:3:22
3 let p = &v[0];
^^^^ invalid operand for address-of
"#,
);
check(
"fn foo() {
var v: vec2<f32>;
let p = &v.x;
}",
r#"error: cannot take the address of a vector component
wgsl:3:22
3 let p = &v.x;
^^^ invalid operand for address-of
"#,
);
check(
"fn foo() {
var m: mat2x2<f32>;
let p = &m[0][0];
}",
r#"error: cannot take the address of a vector component
wgsl:3:22
3 let p = &m[0][0];
^^^^^^^ invalid operand for address-of
"#,
);
}
#[test]
fn issue7165() {
// Regression test for https://github.com/gfx-rs/wgpu/issues/7165