[naga] Refactor BlockContext type resolution methods

Change `resolve_type` and `resolve_type_impl` to return
`TypeResolution`s. Add a new method `resolve_type_inner` that returns a
`TypeInner` (i.e. what `resolve_type` used to do).
This commit is contained in:
Andy Leiserson 2025-03-28 12:43:09 -07:00 committed by Teodor Tanasoaia
parent 14b5838a00
commit 19429a1dc9

View File

@ -280,11 +280,11 @@ impl<'a> BlockContext<'a> {
&self, &self,
handle: Handle<crate::Expression>, handle: Handle<crate::Expression>,
valid_expressions: &HandleSet<crate::Expression>, valid_expressions: &HandleSet<crate::Expression>,
) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> { ) -> Result<&TypeResolution, WithSpan<ExpressionError>> {
if !valid_expressions.contains(handle) { if !valid_expressions.contains(handle) {
Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions)) Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
} else { } else {
Ok(self.info[handle].ty.inner_with(self.types)) Ok(&self.info[handle].ty)
} }
} }
@ -292,11 +292,20 @@ impl<'a> BlockContext<'a> {
&self, &self,
handle: Handle<crate::Expression>, handle: Handle<crate::Expression>,
valid_expressions: &HandleSet<crate::Expression>, valid_expressions: &HandleSet<crate::Expression>,
) -> Result<&crate::TypeInner, WithSpan<FunctionError>> { ) -> Result<&TypeResolution, WithSpan<FunctionError>> {
self.resolve_type_impl(handle, valid_expressions) self.resolve_type_impl(handle, valid_expressions)
.map_err_inner(|source| FunctionError::Expression { handle, source }.with_span()) .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
} }
fn resolve_type_inner(
&self,
handle: Handle<crate::Expression>,
valid_expressions: &HandleSet<crate::Expression>,
) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
self.resolve_type(handle, valid_expressions)
.map(|tr| tr.inner_with(self.types))
}
fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner { fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
self.info[handle].ty.inner_with(self.types) self.info[handle].ty.inner_with(self.types)
} }
@ -330,7 +339,7 @@ impl super::Validator {
.with_span_handle(expr, context.expressions) .with_span_handle(expr, context.expressions)
})?; })?;
let arg_inner = &context.types[arg.ty].inner; let arg_inner = &context.types[arg.ty].inner;
if !ty.non_struct_equivalent(arg_inner, context.types) { if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) {
return Err(CallError::ArgumentType { return Err(CallError::ArgumentType {
index, index,
required: arg.ty, required: arg.ty,
@ -393,7 +402,7 @@ impl super::Validator {
context: &BlockContext, context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> { ) -> Result<(), WithSpan<FunctionError>> {
// The `pointer` operand must be a pointer to an atomic value. // The `pointer` operand must be a pointer to an atomic value.
let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?; let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?;
let crate::TypeInner::Pointer { let crate::TypeInner::Pointer {
base: pointer_base, base: pointer_base,
space: pointer_space, space: pointer_space,
@ -415,7 +424,7 @@ impl super::Validator {
}; };
// The `value` operand must be a scalar of the same type as the atomic. // The `value` operand must be a scalar of the same type as the atomic.
let value_inner = context.resolve_type(value, &self.valid_expression_set)?; let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?;
let crate::TypeInner::Scalar(value_scalar) = *value_inner else { let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
log::error!("Atomic operand type {:?}", *value_inner); log::error!("Atomic operand type {:?}", *value_inner);
return Err(AtomicError::InvalidOperand(value) return Err(AtomicError::InvalidOperand(value)
@ -543,7 +552,7 @@ impl super::Validator {
// The comparison value must be a scalar of the same type as the // The comparison value must be a scalar of the same type as the
// atomic we're operating on. // atomic we're operating on.
let compare_inner = let compare_inner =
context.resolve_type(compare, &self.valid_expression_set)?; context.resolve_type_inner(compare, &self.valid_expression_set)?;
if !compare_inner.non_struct_equivalent(value_inner, context.types) { if !compare_inner.non_struct_equivalent(value_inner, context.types) {
log::error!( log::error!(
"Atomic exchange comparison has a different type from the value" "Atomic exchange comparison has a different type from the value"
@ -620,7 +629,7 @@ impl super::Validator {
result: Handle<crate::Expression>, result: Handle<crate::Expression>,
context: &BlockContext, context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> { ) -> Result<(), WithSpan<FunctionError>> {
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
let (is_scalar, scalar) = match *argument_inner { let (is_scalar, scalar) = match *argument_inner {
crate::TypeInner::Scalar(scalar) => (true, scalar), crate::TypeInner::Scalar(scalar) => (true, scalar),
@ -695,7 +704,7 @@ impl super::Validator {
| crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index) | crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => { | crate::GatherMode::ShuffleXor(index) => {
let index_ty = context.resolve_type(index, &self.valid_expression_set)?; let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
match *index_ty { match *index_ty {
crate::TypeInner::Scalar(crate::Scalar::U32) => {} crate::TypeInner::Scalar(crate::Scalar::U32) => {}
_ => { _ => {
@ -710,7 +719,7 @@ impl super::Validator {
} }
} }
} }
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
if !matches!(*argument_inner, if !matches!(*argument_inner,
crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
@ -802,7 +811,7 @@ impl super::Validator {
ref accept, ref accept,
ref reject, ref reject,
} => { } => {
match *context.resolve_type(condition, &self.valid_expression_set)? { match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
Ti::Scalar(crate::Scalar { Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Bool, kind: crate::ScalarKind::Bool,
width: _, width: _,
@ -820,7 +829,7 @@ impl super::Validator {
ref cases, ref cases,
} => { } => {
let uint = match context let uint = match context
.resolve_type(selector, &self.valid_expression_set)? .resolve_type_inner(selector, &self.valid_expression_set)?
.scalar_kind() .scalar_kind()
{ {
Some(crate::ScalarKind::Uint) => true, Some(crate::ScalarKind::Uint) => true,
@ -917,7 +926,7 @@ impl super::Validator {
.stages; .stages;
if let Some(condition) = break_if { if let Some(condition) = break_if {
match *context.resolve_type(condition, &self.valid_expression_set)? { match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
Ti::Scalar(crate::Scalar { Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Bool, kind: crate::ScalarKind::Bool,
width: _, width: _,
@ -961,7 +970,7 @@ impl super::Validator {
let okay = match (value_ty, expected_ty) { let okay = match (value_ty, expected_ty) {
(None, None) => true, (None, None) => true,
(Some(value_inner), Some(expected_inner)) => { (Some(value_inner), Some(expected_inner)) => {
value_inner.non_struct_equivalent(expected_inner, context.types) value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types)
} }
(_, _) => false, (_, _) => false,
}; };
@ -1027,7 +1036,7 @@ impl super::Validator {
} }
} }
let value_ty = context.resolve_type(value, &self.valid_expression_set)?; let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?;
match *value_ty { match *value_ty {
Ti::Image { .. } | Ti::Sampler { .. } => { Ti::Image { .. } | Ti::Sampler { .. } => {
return Err(FunctionError::InvalidStoreTexture { return Err(FunctionError::InvalidStoreTexture {
@ -1145,7 +1154,7 @@ impl super::Validator {
// The `coordinate` operand must be a vector of the appropriate size. // The `coordinate` operand must be a vector of the appropriate size.
if context if context
.resolve_type(coordinate, &self.valid_expression_set)? .resolve_type_inner(coordinate, &self.valid_expression_set)?
.image_storage_coordinates() .image_storage_coordinates()
.is_none_or(|coord_dim| coord_dim != dim) .is_none_or(|coord_dim| coord_dim != dim)
{ {
@ -1167,7 +1176,7 @@ impl super::Validator {
// If present, `array_index` must be a scalar integer type. // If present, `array_index` must be a scalar integer type.
if let Some(expr) = array_index { if let Some(expr) = array_index {
if !matches!( if !matches!(
*context.resolve_type(expr, &self.valid_expression_set)?, *context.resolve_type_inner(expr, &self.valid_expression_set)?,
Ti::Scalar(crate::Scalar { Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _, width: _,
@ -1188,7 +1197,7 @@ impl super::Validator {
// The value we're writing had better match the scalar type // The value we're writing had better match the scalar type
// for `image`'s format. // for `image`'s format.
let actual_value_ty = let actual_value_ty =
context.resolve_type(value, &self.valid_expression_set)?; context.resolve_type_inner(value, &self.valid_expression_set)?;
if actual_value_ty != &value_ty { if actual_value_ty != &value_ty {
return Err(FunctionError::InvalidStoreValue { return Err(FunctionError::InvalidStoreValue {
actual: value, actual: value,
@ -1273,7 +1282,7 @@ impl super::Validator {
dim, dim,
} => { } => {
match context match context
.resolve_type(coordinate, &self.valid_expression_set)? .resolve_type_inner(coordinate, &self.valid_expression_set)?
.image_storage_coordinates() .image_storage_coordinates()
{ {
Some(coord_dim) if coord_dim == dim => {} Some(coord_dim) if coord_dim == dim => {}
@ -1293,7 +1302,9 @@ impl super::Validator {
.with_span_handle(coordinate, context.expressions)); .with_span_handle(coordinate, context.expressions));
} }
if let Some(expr) = array_index { if let Some(expr) = array_index {
match *context.resolve_type(expr, &self.valid_expression_set)? { match *context
.resolve_type_inner(expr, &self.valid_expression_set)?
{
Ti::Scalar(crate::Scalar { Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _, width: _,
@ -1404,7 +1415,7 @@ impl super::Validator {
} }
}; };
if *context.resolve_type(value, &self.valid_expression_set)? != value_ty { if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty {
return Err(FunctionError::InvalidImageAtomicValue(value) return Err(FunctionError::InvalidImageAtomicValue(value)
.with_span_handle(value, context.expressions)); .with_span_handle(value, context.expressions));
} }
@ -1412,7 +1423,7 @@ impl super::Validator {
S::WorkGroupUniformLoad { pointer, result } => { S::WorkGroupUniformLoad { pointer, result } => {
stages &= super::ShaderStages::COMPUTE; stages &= super::ShaderStages::COMPUTE;
let pointer_inner = let pointer_inner =
context.resolve_type(pointer, &self.valid_expression_set)?; context.resolve_type_inner(pointer, &self.valid_expression_set)?;
match *pointer_inner { match *pointer_inner {
Ti::Pointer { Ti::Pointer {
space: AddressSpace::WorkGroup, space: AddressSpace::WorkGroup,
@ -1468,9 +1479,10 @@ impl super::Validator {
acceleration_structure, acceleration_structure,
descriptor, descriptor,
} => { } => {
match *context match *context.resolve_type_inner(
.resolve_type(acceleration_structure, &self.valid_expression_set)? acceleration_structure,
{ &self.valid_expression_set,
)? {
Ti::AccelerationStructure { vertex_return } => { Ti::AccelerationStructure { vertex_return } => {
if (!vertex_return) && rq_vertex_return { if (!vertex_return) && rq_vertex_return {
return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure")); return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
@ -1483,8 +1495,8 @@ impl super::Validator {
.with_span_static(span, "invalid acceleration structure")) .with_span_static(span, "invalid acceleration structure"))
} }
} }
let desc_ty_given = let desc_ty_given = context
context.resolve_type(descriptor, &self.valid_expression_set)?; .resolve_type_inner(descriptor, &self.valid_expression_set)?;
let desc_ty_expected = context let desc_ty_expected = context
.special_types .special_types
.ray_desc .ray_desc
@ -1498,7 +1510,7 @@ impl super::Validator {
self.emit_expression(result, context)?; self.emit_expression(result, context)?;
} }
crate::RayQueryFunction::GenerateIntersection { hit_t } => { crate::RayQueryFunction::GenerateIntersection { hit_t } => {
match *context.resolve_type(hit_t, &self.valid_expression_set)? { match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? {
Ti::Scalar(crate::Scalar { Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
width: _, width: _,
@ -1534,7 +1546,7 @@ impl super::Validator {
} }
if let Some(predicate) = predicate { if let Some(predicate) = predicate {
let predicate_inner = let predicate_inner =
context.resolve_type(predicate, &self.valid_expression_set)?; context.resolve_type_inner(predicate, &self.valid_expression_set)?;
if !matches!( if !matches!(
*predicate_inner, *predicate_inner,
crate::TypeInner::Scalar(crate::Scalar::BOOL,) crate::TypeInner::Scalar(crate::Scalar::BOOL,)