[naga] Refactor BlockContext type resolution methods

Change `resolve_type` and `resolve_type_impl` to return
`TypeResolution`s. Add a new method `resolve_type_inner` that returns a
`TypeInner` (i.e. what `resolve_type` used to do).
This commit is contained in:
Andy Leiserson 2025-03-28 12:43:09 -07:00 committed by Teodor Tanasoaia
parent 14b5838a00
commit 19429a1dc9

View File

@ -280,11 +280,11 @@ impl<'a> BlockContext<'a> {
&self,
handle: Handle<crate::Expression>,
valid_expressions: &HandleSet<crate::Expression>,
) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
) -> Result<&TypeResolution, WithSpan<ExpressionError>> {
if !valid_expressions.contains(handle) {
Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
} else {
Ok(self.info[handle].ty.inner_with(self.types))
Ok(&self.info[handle].ty)
}
}
@ -292,11 +292,20 @@ impl<'a> BlockContext<'a> {
&self,
handle: Handle<crate::Expression>,
valid_expressions: &HandleSet<crate::Expression>,
) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
) -> Result<&TypeResolution, WithSpan<FunctionError>> {
self.resolve_type_impl(handle, valid_expressions)
.map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
}
fn resolve_type_inner(
&self,
handle: Handle<crate::Expression>,
valid_expressions: &HandleSet<crate::Expression>,
) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
self.resolve_type(handle, valid_expressions)
.map(|tr| tr.inner_with(self.types))
}
fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
self.info[handle].ty.inner_with(self.types)
}
@ -330,7 +339,7 @@ impl super::Validator {
.with_span_handle(expr, context.expressions)
})?;
let arg_inner = &context.types[arg.ty].inner;
if !ty.non_struct_equivalent(arg_inner, context.types) {
if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) {
return Err(CallError::ArgumentType {
index,
required: arg.ty,
@ -393,7 +402,7 @@ impl super::Validator {
context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> {
// The `pointer` operand must be a pointer to an atomic value.
let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?;
let crate::TypeInner::Pointer {
base: pointer_base,
space: pointer_space,
@ -415,7 +424,7 @@ impl super::Validator {
};
// The `value` operand must be a scalar of the same type as the atomic.
let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?;
let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
log::error!("Atomic operand type {:?}", *value_inner);
return Err(AtomicError::InvalidOperand(value)
@ -543,7 +552,7 @@ impl super::Validator {
// The comparison value must be a scalar of the same type as the
// atomic we're operating on.
let compare_inner =
context.resolve_type(compare, &self.valid_expression_set)?;
context.resolve_type_inner(compare, &self.valid_expression_set)?;
if !compare_inner.non_struct_equivalent(value_inner, context.types) {
log::error!(
"Atomic exchange comparison has a different type from the value"
@ -620,7 +629,7 @@ impl super::Validator {
result: Handle<crate::Expression>,
context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> {
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
let (is_scalar, scalar) = match *argument_inner {
crate::TypeInner::Scalar(scalar) => (true, scalar),
@ -695,7 +704,7 @@ impl super::Validator {
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
match *index_ty {
crate::TypeInner::Scalar(crate::Scalar::U32) => {}
_ => {
@ -710,7 +719,7 @@ impl super::Validator {
}
}
}
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
if !matches!(*argument_inner,
crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
@ -802,7 +811,7 @@ impl super::Validator {
ref accept,
ref reject,
} => {
match *context.resolve_type(condition, &self.valid_expression_set)? {
match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Bool,
width: _,
@ -820,7 +829,7 @@ impl super::Validator {
ref cases,
} => {
let uint = match context
.resolve_type(selector, &self.valid_expression_set)?
.resolve_type_inner(selector, &self.valid_expression_set)?
.scalar_kind()
{
Some(crate::ScalarKind::Uint) => true,
@ -917,7 +926,7 @@ impl super::Validator {
.stages;
if let Some(condition) = break_if {
match *context.resolve_type(condition, &self.valid_expression_set)? {
match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Bool,
width: _,
@ -961,7 +970,7 @@ impl super::Validator {
let okay = match (value_ty, expected_ty) {
(None, None) => true,
(Some(value_inner), Some(expected_inner)) => {
value_inner.non_struct_equivalent(expected_inner, context.types)
value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types)
}
(_, _) => false,
};
@ -1027,7 +1036,7 @@ impl super::Validator {
}
}
let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?;
match *value_ty {
Ti::Image { .. } | Ti::Sampler { .. } => {
return Err(FunctionError::InvalidStoreTexture {
@ -1145,7 +1154,7 @@ impl super::Validator {
// The `coordinate` operand must be a vector of the appropriate size.
if context
.resolve_type(coordinate, &self.valid_expression_set)?
.resolve_type_inner(coordinate, &self.valid_expression_set)?
.image_storage_coordinates()
.is_none_or(|coord_dim| coord_dim != dim)
{
@ -1167,7 +1176,7 @@ impl super::Validator {
// If present, `array_index` must be a scalar integer type.
if let Some(expr) = array_index {
if !matches!(
*context.resolve_type(expr, &self.valid_expression_set)?,
*context.resolve_type_inner(expr, &self.valid_expression_set)?,
Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
@ -1188,7 +1197,7 @@ impl super::Validator {
// The value we're writing had better match the scalar type
// for `image`'s format.
let actual_value_ty =
context.resolve_type(value, &self.valid_expression_set)?;
context.resolve_type_inner(value, &self.valid_expression_set)?;
if actual_value_ty != &value_ty {
return Err(FunctionError::InvalidStoreValue {
actual: value,
@ -1273,7 +1282,7 @@ impl super::Validator {
dim,
} => {
match context
.resolve_type(coordinate, &self.valid_expression_set)?
.resolve_type_inner(coordinate, &self.valid_expression_set)?
.image_storage_coordinates()
{
Some(coord_dim) if coord_dim == dim => {}
@ -1293,7 +1302,9 @@ impl super::Validator {
.with_span_handle(coordinate, context.expressions));
}
if let Some(expr) = array_index {
match *context.resolve_type(expr, &self.valid_expression_set)? {
match *context
.resolve_type_inner(expr, &self.valid_expression_set)?
{
Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _,
@ -1404,7 +1415,7 @@ impl super::Validator {
}
};
if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty {
return Err(FunctionError::InvalidImageAtomicValue(value)
.with_span_handle(value, context.expressions));
}
@ -1412,7 +1423,7 @@ impl super::Validator {
S::WorkGroupUniformLoad { pointer, result } => {
stages &= super::ShaderStages::COMPUTE;
let pointer_inner =
context.resolve_type(pointer, &self.valid_expression_set)?;
context.resolve_type_inner(pointer, &self.valid_expression_set)?;
match *pointer_inner {
Ti::Pointer {
space: AddressSpace::WorkGroup,
@ -1468,9 +1479,10 @@ impl super::Validator {
acceleration_structure,
descriptor,
} => {
match *context
.resolve_type(acceleration_structure, &self.valid_expression_set)?
{
match *context.resolve_type_inner(
acceleration_structure,
&self.valid_expression_set,
)? {
Ti::AccelerationStructure { vertex_return } => {
if (!vertex_return) && rq_vertex_return {
return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
@ -1483,8 +1495,8 @@ impl super::Validator {
.with_span_static(span, "invalid acceleration structure"))
}
}
let desc_ty_given =
context.resolve_type(descriptor, &self.valid_expression_set)?;
let desc_ty_given = context
.resolve_type_inner(descriptor, &self.valid_expression_set)?;
let desc_ty_expected = context
.special_types
.ray_desc
@ -1498,7 +1510,7 @@ impl super::Validator {
self.emit_expression(result, context)?;
}
crate::RayQueryFunction::GenerateIntersection { hit_t } => {
match *context.resolve_type(hit_t, &self.valid_expression_set)? {
match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? {
Ti::Scalar(crate::Scalar {
kind: crate::ScalarKind::Float,
width: _,
@ -1534,7 +1546,7 @@ impl super::Validator {
}
if let Some(predicate) = predicate {
let predicate_inner =
context.resolve_type(predicate, &self.valid_expression_set)?;
context.resolve_type_inner(predicate, &self.valid_expression_set)?;
if !matches!(
*predicate_inner,
crate::TypeInner::Scalar(crate::Scalar::BOOL,)