Disallow taking the address of a vector component (#7284)

This commit is contained in:
Andy Leiserson 2025-03-28 03:32:06 -07:00 committed by GitHub
parent efbfa36ded
commit c7c79a0dc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 182 additions and 112 deletions

View File

@ -218,6 +218,8 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)
The original names (e.g. `naga::Module`) remain present for compatibility. The original names (e.g. `naga::Module`) remain present for compatibility.
By @kpreid in [#7365](https://github.com/gfx-rs/wgpu/pull/7365). By @kpreid in [#7365](https://github.com/gfx-rs/wgpu/pull/7365).
- Refactored `use` statements to simplify future `no_std` support. By @bushrat011899 in [#7256](https://github.com/gfx-rs/wgpu/pull/7256) - Refactored `use` statements to simplify future `no_std` support. By @bushrat011899 in [#7256](https://github.com/gfx-rs/wgpu/pull/7256)
- Naga's WGSL frontend no longer allows using the `&` operator to take the address of a component of a vector,
which is not permitted by the WGSL specification. By @andyleiserson in [#7284](https://github.com/gfx-rs/wgpu/pull/7284)
#### Vulkan #### Vulkan

View File

@ -215,6 +215,7 @@ pub(crate) enum Error<'a> {
}, },
DeclMissingTypeAndInit(Span), DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span), MissingAttribute(&'static str, Span),
InvalidAddrOfOperand(Span),
InvalidAtomicPointer(Span), InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span), InvalidAtomicOperandType(Span),
InvalidRayQueryPointer(Span), InvalidRayQueryPointer(Span),
@ -676,6 +677,11 @@ impl<'a> Error<'a> {
)], )],
notes: vec![], notes: vec![],
}, },
Error::InvalidAddrOfOperand(span) => ParseError {
message: "cannot take the address of a vector component".to_string(),
labels: vec![(span, "invalid operand for address-of".into())],
notes: vec![],
},
Error::InvalidAtomicPointer(span) => ParseError { Error::InvalidAtomicPointer(span) => ParseError {
message: "atomic operation is done on a pointer to a non-atomic".to_string(), message: "atomic operation is done on a pointer to a non-atomic".to_string(),
labels: vec![(span, "atomic pointer is invalid".into())], labels: vec![(span, "atomic pointer is invalid".into())],

View File

@ -2077,6 +2077,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// reference is required, the Load Rule is not applied. // reference is required, the Load Rule is not applied.
match self.expression_for_reference(expr, ctx)? { match self.expression_for_reference(expr, ctx)? {
Typed::Reference(handle) => { Typed::Reference(handle) => {
let expr = &ctx.runtime_expression_ctx(span)?.function.expressions[handle];
if let &crate::Expression::Access { base, .. }
| &crate::Expression::AccessIndex { base, .. } = expr
{
if let Some(ty) = resolve_inner!(ctx, base).pointer_base_type() {
if matches!(
*ty.inner_with(&ctx.module.types),
crate::TypeInner::Vector { .. },
) {
return Err(Box::new(Error::InvalidAddrOfOperand(
ctx.get_expression_span(handle),
)));
}
}
}
// No code is generated. We just declare the reference a pointer now. // No code is generated. We just declare the reference a pointer now.
return Ok(Typed::Plain(handle)); return Ok(Typed::Plain(handle));
} }
@ -2149,30 +2164,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
} }
} }
let temp_inner; let temp_ty;
let composite_type: &crate::TypeInner = match lowered_base { let composite_type: &crate::TypeInner = match lowered_base {
Typed::Reference(handle) => { Typed::Reference(handle) => {
let inner = resolve_inner!(ctx, handle); temp_ty = resolve_inner!(ctx, handle)
match *inner { .pointer_base_type()
crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner, .expect("In Typed::Reference(handle), handle must be a Naga pointer");
crate::TypeInner::ValuePointer { temp_ty.inner_with(&ctx.module.types)
size: None, scalar, ..
} => {
temp_inner = crate::TypeInner::Scalar(scalar);
&temp_inner
}
crate::TypeInner::ValuePointer {
size: Some(size),
scalar,
..
} => {
temp_inner = crate::TypeInner::Vector { size, scalar };
&temp_inner
}
_ => unreachable!(
"In Typed::Reference(handle), handle must be a Naga pointer"
),
}
} }
Typed::Plain(handle) => { Typed::Plain(handle) => {

View File

@ -128,6 +128,25 @@ impl crate::TypeInner {
} }
} }
/// If `self` is a pointer type, return its base type.
pub const fn pointer_base_type(&self) -> Option<TypeResolution> {
match *self {
crate::TypeInner::Pointer { base, .. } => Some(TypeResolution::Handle(base)),
crate::TypeInner::ValuePointer {
size: None, scalar, ..
} => Some(TypeResolution::Value(crate::TypeInner::Scalar(scalar))),
crate::TypeInner::ValuePointer {
size: Some(size),
scalar,
..
} => Some(TypeResolution::Value(crate::TypeInner::Vector {
size,
scalar,
})),
_ => None,
}
}
pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool { pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
match *self { match *self {
crate::TypeInner::Pointer { base, .. } => match types[base].inner { crate::TypeInner::Pointer { base, .. } => match types[base].inner {

View File

@ -7,6 +7,7 @@ use super::{
}; };
use crate::arena::{Arena, UniqueArena}; use crate::arena::{Arena, UniqueArena};
use crate::arena::{Handle, HandleSet}; use crate::arena::{Handle, HandleSet};
use crate::proc::TypeResolution;
use crate::span::WithSpan; use crate::span::WithSpan;
use crate::span::{AddSpan as _, MapErrWithSpan as _}; use crate::span::{AddSpan as _, MapErrWithSpan as _};
@ -299,6 +300,10 @@ impl<'a> BlockContext<'a> {
fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner { fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
self.info[handle].ty.inner_with(self.types) self.info[handle].ty.inner_with(self.types)
} }
fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
ty.inner_with(self.types)
}
} }
impl super::Validator { impl super::Validator {
@ -1039,23 +1044,15 @@ impl super::Validator {
} }
let pointer_ty = context.resolve_pointer_type(pointer); let pointer_ty = context.resolve_pointer_type(pointer);
let good = match pointer_ty
let good = match *pointer_ty { .pointer_base_type()
Ti::Pointer { base, space: _ } => match context.types[base].inner { .as_ref()
Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar), .map(|ty| context.inner_type(ty))
ref other => value_ty == other, {
}, // The Naga IR allows storing a scalar to an atomic.
Ti::ValuePointer { Some(&Ti::Atomic(ref scalar)) => *value_ty == Ti::Scalar(*scalar),
size: Some(size), Some(other) => *value_ty == *other,
scalar, None => false,
space: _,
} => *value_ty == Ti::Vector { size, scalar },
Ti::ValuePointer {
size: None,
scalar,
space: _,
} => *value_ty == Ti::Scalar(scalar),
_ => false,
}; };
if !good { if !good {
return Err(FunctionError::InvalidStoreTypes { pointer, value } return Err(FunctionError::InvalidStoreTypes { pointer, value }

View File

@ -1,7 +1,7 @@
fn f() { fn f() {
var v: vec2<i32>; var v: mat2x2<f32>;
let px = &v.x; let px = &v[0];
*px = 10; *px = vec2<f32>(10.0);
} }
struct DynamicArray { struct DynamicArray {

View File

@ -1,7 +1,7 @@
; SPIR-V ; SPIR-V
; Version: 1.2 ; Version: 1.2
; Generator: rspirv ; Generator: rspirv
; Bound: 43 ; Bound: 46
OpCapability Shader OpCapability Shader
OpCapability Linkage OpCapability Linkage
OpExtension "SPV_KHR_storage_buffer_storage_class" OpExtension "SPV_KHR_storage_buffer_storage_class"
@ -9,9 +9,9 @@ OpExtension "SPV_KHR_storage_buffer_storage_class"
OpMemoryModel Logical GLSL450 OpMemoryModel Logical GLSL450
%3 = OpString "pointers.wgsl" %3 = OpString "pointers.wgsl"
OpSource Unknown 0 %3 "fn f() { OpSource Unknown 0 %3 "fn f() {
var v: vec2<i32>; var v: mat2x2<f32>;
let px = &v.x; let px = &v[0];
*px = 10; *px = vec2<f32>(10.0);
} }
struct DynamicArray { struct DynamicArray {
@ -35,80 +35,84 @@ fn index_dynamic_array(i: i32, v: u32) {
(*p)[i] = val + v; (*p)[i] = val + v;
} }
" "
OpMemberName %8 0 "arr" OpMemberName %9 0 "arr"
OpName %8 "DynamicArray" OpName %9 "DynamicArray"
OpName %9 "dynamic_array" OpName %11 "dynamic_array"
OpName %12 "f" OpName %14 "f"
OpName %15 "v" OpName %18 "v"
OpName %23 "i" OpName %26 "i"
OpName %24 "v" OpName %27 "v"
OpName %25 "index_unsized" OpName %28 "index_unsized"
OpName %35 "i" OpName %38 "i"
OpName %36 "v" OpName %39 "v"
OpName %37 "index_dynamic_array" OpName %40 "index_dynamic_array"
OpDecorate %7 ArrayStride 4 OpDecorate %8 ArrayStride 4
OpMemberDecorate %8 0 Offset 0 OpMemberDecorate %9 0 Offset 0
OpDecorate %8 Block OpDecorate %9 Block
OpDecorate %9 DescriptorSet 0 OpDecorate %11 DescriptorSet 0
OpDecorate %9 Binding 0 OpDecorate %11 Binding 0
%2 = OpTypeVoid %2 = OpTypeVoid
%4 = OpTypeInt 32 1 %6 = OpTypeFloat 32
%5 = OpTypeVector %4 2 %5 = OpTypeVector %6 2
%6 = OpTypeInt 32 0 %4 = OpTypeMatrix %5 2
%7 = OpTypeRuntimeArray %6 %7 = OpTypeInt 32 0
%8 = OpTypeStruct %7 %8 = OpTypeRuntimeArray %7
%10 = OpTypePointer StorageBuffer %8 %9 = OpTypeStruct %8
%9 = OpVariable %10 StorageBuffer %10 = OpTypeInt 32 1
%13 = OpTypeFunction %2 %12 = OpTypePointer StorageBuffer %9
%14 = OpConstant %4 10 %11 = OpVariable %12 StorageBuffer
%16 = OpTypePointer Function %5 %15 = OpTypeFunction %2
%17 = OpConstantNull %5 %16 = OpConstant %6 10.0
%17 = OpConstantComposite %5 %16 %16
%19 = OpTypePointer Function %4 %19 = OpTypePointer Function %4
%20 = OpConstant %6 0 %20 = OpConstantNull %4
%26 = OpTypeFunction %2 %4 %6 %22 = OpTypePointer Function %5
%28 = OpTypePointer StorageBuffer %7 %23 = OpConstant %7 0
%29 = OpTypePointer StorageBuffer %6 %29 = OpTypeFunction %2 %10 %7
%12 = OpFunction %2 None %13 %31 = OpTypePointer StorageBuffer %8
%11 = OpLabel %32 = OpTypePointer StorageBuffer %7
%15 = OpVariable %16 Function %17 %14 = OpFunction %2 None %15
OpBranch %18 %13 = OpLabel
%18 = OpLabel %18 = OpVariable %19 Function %20
OpBranch %21
%21 = OpLabel
OpLine %3 3 14 OpLine %3 3 14
OpLine %3 4 10
OpLine %3 4 4 OpLine %3 4 4
%21 = OpAccessChain %19 %15 %20 %24 = OpAccessChain %22 %18 %23
OpStore %21 %14 OpStore %24 %17
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%25 = OpFunction %2 None %26 %28 = OpFunction %2 None %29
%23 = OpFunctionParameter %4 %26 = OpFunctionParameter %10
%24 = OpFunctionParameter %6 %27 = OpFunctionParameter %7
%22 = OpLabel %25 = OpLabel
OpBranch %27 OpBranch %30
%27 = OpLabel %30 = OpLabel
OpLine %3 17 14 OpLine %3 17 14
%30 = OpAccessChain %29 %9 %20 %23 %33 = OpAccessChain %32 %11 %23 %26
%31 = OpLoad %6 %30 %34 = OpLoad %7 %33
OpLine %3 18 4 OpLine %3 18 4
%32 = OpIAdd %6 %31 %24 %35 = OpIAdd %7 %34 %27
OpLine %3 18 4 OpLine %3 18 4
%33 = OpAccessChain %29 %9 %20 %23 %36 = OpAccessChain %32 %11 %23 %26
OpStore %33 %32 OpStore %36 %35
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%37 = OpFunction %2 None %26 %40 = OpFunction %2 None %29
%35 = OpFunctionParameter %4 %38 = OpFunctionParameter %10
%36 = OpFunctionParameter %6 %39 = OpFunctionParameter %7
%34 = OpLabel %37 = OpLabel
OpBranch %38 OpBranch %41
%38 = OpLabel %41 = OpLabel
OpLine %3 22 51 OpLine %3 22 51
OpLine %3 24 14 OpLine %3 24 14
%39 = OpAccessChain %29 %9 %20 %35 %42 = OpAccessChain %32 %11 %23 %38
%40 = OpLoad %6 %39 %43 = OpLoad %7 %42
OpLine %3 25 4 OpLine %3 25 4
%41 = OpIAdd %6 %40 %36 %44 = OpIAdd %7 %43 %39
OpLine %3 25 4 OpLine %3 25 4
%42 = OpAccessChain %29 %9 %20 %35 %45 = OpAccessChain %32 %11 %23 %38
OpStore %42 %41 OpStore %45 %44
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd

View File

@ -6,10 +6,10 @@ struct DynamicArray {
var<storage, read_write> dynamic_array: DynamicArray; var<storage, read_write> dynamic_array: DynamicArray;
fn f() { fn f() {
var v: vec2<i32>; var v: mat2x2<f32>;
let px = (&v.x); let px = (&v[0]);
(*px) = 10i; (*px) = vec2(10f);
return; return;
} }

View File

@ -1368,14 +1368,13 @@ fn invalid_return_type() {
fn pointer_type_equivalence() { fn pointer_type_equivalence() {
check_validation! { check_validation! {
r#" r#"
fn f(pv: ptr<function, vec2<f32>>, pf: ptr<function, f32>) { } fn f(pv: ptr<function, vec2<f32>>) { }
fn g() { fn g() {
var m: mat2x2<f32>; var m: mat2x2<f32>;
let pv: ptr<function, vec2<f32>> = &m[0]; let pv: ptr<function, vec2<f32>> = &m[0];
let pf: ptr<function, f32> = &m[0].x;
f(pv, pf); f(pv);
} }
"#: "#:
Ok(_) Ok(_)
@ -3061,6 +3060,51 @@ fn reject_utf8_bom() {
); );
} }
#[test]
fn matrix_vector_pointers() {
check(
"fn foo() {
var v: vec2<f32>;
let p = &v[0];
}",
r#"error: cannot take the address of a vector component
wgsl:3:22
3 let p = &v[0];
^^^^ invalid operand for address-of
"#,
);
check(
"fn foo() {
var v: vec2<f32>;
let p = &v.x;
}",
r#"error: cannot take the address of a vector component
wgsl:3:22
3 let p = &v.x;
^^^ invalid operand for address-of
"#,
);
check(
"fn foo() {
var m: mat2x2<f32>;
let p = &m[0][0];
}",
r#"error: cannot take the address of a vector component
wgsl:3:22
3 let p = &m[0][0];
^^^^^^^ invalid operand for address-of
"#,
);
}
#[test] #[test]
fn issue7165() { fn issue7165() {
// Regression test for https://github.com/gfx-rs/wgpu/issues/7165 // Regression test for https://github.com/gfx-rs/wgpu/issues/7165