[wgsl-in] Unify ConcreteConstructor and ConcreteConstructorHandle. (#2577)

Replace the `ConcreteConstructor` and `ConcreteConstructorHandle`
types in `front::wgsl::lower::construction` with a single type
`Constructor` with a type parameter that determines how it refers to
Naga types.
This commit is contained in:
Jim Blandy 2023-10-24 13:49:30 -07:00 committed by GitHub
parent 86b6db6f76
commit ada3cd85bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,45 +6,57 @@ use crate::{Handle, Span};
use crate::front::wgsl::error::Error; use crate::front::wgsl::error::Error;
use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
enum ConcreteConstructorHandle { /// A cooked form of `ast::ConstructorType` that uses Naga types whenever
PartialVector { /// possible.
size: crate::VectorSize, enum Constructor<T> {
}, /// A vector construction whose component type is inferred from the
/// argument: `vec3(1.0)`.
PartialVector { size: crate::VectorSize },
/// A matrix construction whose component type is inferred from the
/// argument: `mat2x2(1,2,3,4)`.
PartialMatrix { PartialMatrix {
columns: crate::VectorSize, columns: crate::VectorSize,
rows: crate::VectorSize, rows: crate::VectorSize,
}, },
/// An array whose component type and size are inferred from the arguments:
/// `array(3,4,5)`.
PartialArray, PartialArray,
Type(Handle<crate::Type>),
/// A known Naga type.
///
/// When we match on this type, we need to see the `TypeInner` here, but at
/// the point that we build this value we'll still need mutable access to
/// the module later. To avoid borrowing from the module, the type parameter
/// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a
/// version holding a tuple `(Handle<Type>, &TypeInner)`.
Type(T),
} }
impl ConcreteConstructorHandle { impl Constructor<Handle<crate::Type>> {
fn borrow<'a>(&self, module: &'a crate::Module) -> ConcreteConstructor<'a> { /// Return an equivalent `Constructor` value that includes borrowed
match *self { /// `TypeInner` values alongside any type handles.
Self::PartialVector { size } => ConcreteConstructor::PartialVector { size }, ///
Self::PartialMatrix { columns, rows } => { /// The returned form is more convenient to match on, since the patterns
ConcreteConstructor::PartialMatrix { columns, rows } /// can actually see what the handle refers to.
fn borrow_inner(
self,
module: &crate::Module,
) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
match self {
Constructor::PartialVector { size } => Constructor::PartialVector { size },
Constructor::PartialMatrix { columns, rows } => {
Constructor::PartialMatrix { columns, rows }
} }
Self::PartialArray => ConcreteConstructor::PartialArray, Constructor::PartialArray => Constructor::PartialArray,
Self::Type(handle) => ConcreteConstructor::Type(handle, &module.types[handle].inner), Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
} }
} }
} }
enum ConcreteConstructor<'a> { impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
PartialVector { fn to_error_string(&self, ctx: &ExpressionContext) -> String {
size: crate::VectorSize,
},
PartialMatrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
},
PartialArray,
Type(Handle<crate::Type>, &'a crate::TypeInner),
}
impl ConcreteConstructorHandle {
fn to_error_string(&self, ctx: &mut ExpressionContext) -> String {
match *self { match *self {
Self::PartialVector { size } => { Self::PartialVector { size } => {
format!("vec{}<?>", size as u32,) format!("vec{}<?>", size as u32,)
@ -53,7 +65,7 @@ impl ConcreteConstructorHandle {
format!("mat{}x{}<?>", columns as u32, rows as u32,) format!("mat{}x{}<?>", columns as u32, rows as u32,)
} }
Self::PartialArray => "array<?, ?>".to_string(), Self::PartialArray => "array<?, ?>".to_string(),
Self::Type(ty) => ctx.format_type(ty), Self::Type((handle, _inner)) => ctx.format_type(handle),
} }
} }
} }
@ -146,15 +158,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
} }
}; };
let constructor = constructor_h.borrow(ctx.module); // Even though we computed `constructor` above, wait until now to borrow
// a reference to the `TypeInner`, so that the component-handling code
// above can have mutable access to the type arena.
let constructor = constructor_h.borrow_inner(ctx.module);
let expr = match (components, constructor) { let expr = match (components, constructor) {
// Empty constructor // Empty constructor
(Components::None, dst_ty) => match dst_ty { (Components::None, dst_ty) => match dst_ty {
ConcreteConstructor::Type(ty, _) => { Constructor::Type((result_ty, _)) => {
return ctx.append_expression(crate::Expression::ZeroValue(ty), span) return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span)
}
Constructor::PartialVector { .. }
| Constructor::PartialMatrix { .. }
| Constructor::PartialArray => {
// We have no arguments from which to infer the result type, so
// partial constructors aren't acceptable here.
return Err(Error::TypeNotInferrable(ty_span));
} }
_ => return Err(Error::TypeNotInferrable(ty_span)),
}, },
// Scalar constructor & conversion (scalar -> scalar) // Scalar constructor & conversion (scalar -> scalar)
@ -164,7 +185,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Scalar { .. }, ty_inner: &crate::TypeInner::Scalar { .. },
.. ..
}, },
ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }), Constructor::Type((_, &crate::TypeInner::Scalar { kind, width })),
) => crate::Expression::As { ) => crate::Expression::As {
expr: component, expr: component,
kind, kind,
@ -178,14 +199,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
.. ..
}, },
ConcreteConstructor::Type( Constructor::Type((
_, _,
&crate::TypeInner::Vector { &crate::TypeInner::Vector {
size: dst_size, size: dst_size,
kind: dst_kind, kind: dst_kind,
width: dst_width, width: dst_width,
}, },
), )),
) if dst_size == src_size => crate::Expression::As { ) if dst_size == src_size => crate::Expression::As {
expr: component, expr: component,
kind: dst_kind, kind: dst_kind,
@ -199,7 +220,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
.. ..
}, },
ConcreteConstructor::PartialVector { size: dst_size }, Constructor::PartialVector { size: dst_size },
) if dst_size == src_size => { ) if dst_size == src_size => {
// This is a trivial conversion: the sizes match, and a Partial // This is a trivial conversion: the sizes match, and a Partial
// constructor doesn't specify a scalar type, so nothing can // constructor doesn't specify a scalar type, so nothing can
@ -219,14 +240,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}, },
.. ..
}, },
ConcreteConstructor::Type( Constructor::Type((
_, _,
&crate::TypeInner::Matrix { &crate::TypeInner::Matrix {
columns: dst_columns, columns: dst_columns,
rows: dst_rows, rows: dst_rows,
width: dst_width, width: dst_width,
}, },
), )),
) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As { ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
expr: component, expr: component,
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
@ -245,7 +266,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}, },
.. ..
}, },
ConcreteConstructor::PartialMatrix { Constructor::PartialMatrix {
columns: dst_columns, columns: dst_columns,
rows: dst_rows, rows: dst_rows,
}, },
@ -263,7 +284,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Scalar { .. }, ty_inner: &crate::TypeInner::Scalar { .. },
.. ..
}, },
ConcreteConstructor::PartialVector { size }, Constructor::PartialVector { size },
) => crate::Expression::Splat { ) => crate::Expression::Splat {
size, size,
value: component, value: component,
@ -281,14 +302,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}, },
.. ..
}, },
ConcreteConstructor::Type( Constructor::Type((
_, _,
&crate::TypeInner::Vector { &crate::TypeInner::Vector {
size, size,
kind: dst_kind, kind: dst_kind,
width: dst_width, width: dst_width,
}, },
), )),
) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat { ) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat {
size, size,
value: component, value: component,
@ -303,7 +324,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
| &crate::TypeInner::Vector { kind, width, .. }, | &crate::TypeInner::Vector { kind, width, .. },
.. ..
}, },
ConcreteConstructor::PartialVector { size }, Constructor::PartialVector { size },
) )
| ( | (
Components::Many { Components::Many {
@ -312,7 +333,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }, &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. },
.. ..
}, },
ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }), Constructor::Type((_, &crate::TypeInner::Vector { size, width, kind })),
) => { ) => {
let inner = crate::TypeInner::Vector { size, kind, width }; let inner = crate::TypeInner::Vector { size, kind, width };
let ty = ctx.ensure_type_exists(inner); let ty = ctx.ensure_type_exists(inner);
@ -326,7 +347,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Scalar { width, .. }, first_component_ty_inner: &crate::TypeInner::Scalar { width, .. },
.. ..
}, },
ConcreteConstructor::PartialMatrix { columns, rows }, Constructor::PartialMatrix { columns, rows },
) )
| ( | (
Components::Many { Components::Many {
@ -334,14 +355,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Scalar { .. }, first_component_ty_inner: &crate::TypeInner::Scalar { .. },
.. ..
}, },
ConcreteConstructor::Type( Constructor::Type((
_, _,
&crate::TypeInner::Matrix { &crate::TypeInner::Matrix {
columns, columns,
rows, rows,
width, width,
}, },
), )),
) => { ) => {
let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector { let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector {
width, width,
@ -377,7 +398,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Vector { width, .. }, first_component_ty_inner: &crate::TypeInner::Vector { width, .. },
.. ..
}, },
ConcreteConstructor::PartialMatrix { columns, rows }, Constructor::PartialMatrix { columns, rows },
) )
| ( | (
Components::Many { Components::Many {
@ -385,14 +406,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Vector { .. }, first_component_ty_inner: &crate::TypeInner::Vector { .. },
.. ..
}, },
ConcreteConstructor::Type( Constructor::Type((
_, _,
&crate::TypeInner::Matrix { &crate::TypeInner::Matrix {
columns, columns,
rows, rows,
width, width,
}, },
), )),
) => { ) => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns, columns,
@ -403,7 +424,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
} }
// Array constructor - infer type // Array constructor - infer type
(components, ConcreteConstructor::PartialArray) => { (components, Constructor::PartialArray) => {
let components = components.into_components_vec(); let components = components.into_components_vec();
let base = ctx.register_type(components[0])?; let base = ctx.register_type(components[0])?;
@ -426,10 +447,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// Array or Struct constructor // Array or Struct constructor
( (
components, components,
ConcreteConstructor::Type( Constructor::Type((
ty, ty,
&crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. },
), )),
) => { ) => {
let components = components.into_components_vec(); let components = components.into_components_vec();
crate::Expression::Compose { ty, components } crate::Expression::Compose { ty, components }
@ -438,19 +459,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// ERRORS // ERRORS
// Bad conversion (type cast) // Bad conversion (type cast)
(Components::One { span, ty_inner, .. }, _) => { (Components::One { span, ty_inner, .. }, constructor) => {
let from_type = ctx.format_typeinner(ty_inner); let from_type = ctx.format_typeinner(ty_inner);
return Err(Error::BadTypeCast { return Err(Error::BadTypeCast {
span, span,
from_type, from_type,
to_type: constructor_h.to_error_string(ctx), to_type: constructor.to_error_string(ctx),
}); });
} }
// Too many parameters for scalar constructor // Too many parameters for scalar constructor
( (
Components::Many { spans, .. }, Components::Many { spans, .. },
ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }), Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
) => { ) => {
let span = spans[1].until(spans.last().unwrap()); let span = spans[1].until(spans.last().unwrap());
return Err(Error::UnexpectedComponents(span)); return Err(Error::UnexpectedComponents(span));
@ -459,12 +480,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// Parameters are of the wrong type for vector or matrix constructor // Parameters are of the wrong type for vector or matrix constructor
( (
Components::Many { spans, .. }, Components::Many { spans, .. },
ConcreteConstructor::Type( Constructor::Type((
_, _,
&crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. },
) ))
| ConcreteConstructor::PartialVector { .. } | Constructor::PartialVector { .. }
| ConcreteConstructor::PartialMatrix { .. }, | Constructor::PartialMatrix { .. },
) => { ) => {
return Err(Error::InvalidConstructorComponentType(spans[0], 0)); return Err(Error::InvalidConstructorComponentType(spans[0], 0));
} }
@ -477,17 +498,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(expr) Ok(expr)
} }
/// Build a Naga IR [`Type`] for `constructor` if there is enough /// Build a [`Constructor`] for a WGSL construction expression.
/// information to do so.
/// ///
/// For `Partial` variants of [`ast::ConstructorType`], we don't know the /// If `constructor` conveys enough information to determine which Naga [`Type`]
/// component type, so in that case we return the appropriate `Partial` /// we're actually building (i.e., it's not a partial constructor), then
/// variant of [`ConcreteConstructorHandle`]. /// ensure the `Type` exists in [`ctx.module`], and return
/// [`Constructor::Type`].
/// ///
/// But for the other `ConstructorType` variants, we have everything we need /// Otherwise, return the [`Constructor`] partial variant corresponding to
/// to know to actually produce a Naga IR type. In this case we add to/find /// `constructor`.
/// in [`ctx.module`] a suitable Naga `Type` and return a
/// [`ConcreteConstructorHandle::Type`] value holding its handle.
/// ///
/// [`Type`]: crate::Type /// [`Type`]: crate::Type
/// [`ctx.module`]: ExpressionContext::module /// [`ctx.module`]: ExpressionContext::module
@ -495,21 +514,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&mut self, &mut self,
constructor: &ast::ConstructorType<'source>, constructor: &ast::ConstructorType<'source>,
ctx: &mut ExpressionContext<'source, '_, 'out>, ctx: &mut ExpressionContext<'source, '_, 'out>,
) -> Result<ConcreteConstructorHandle, Error<'source>> { ) -> Result<Constructor<Handle<crate::Type>>, Error<'source>> {
let c = match *constructor { let handle = match *constructor {
ast::ConstructorType::Scalar { width, kind } => { ast::ConstructorType::Scalar { width, kind } => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind }); let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
ConcreteConstructorHandle::Type(ty) Constructor::Type(ty)
}
ast::ConstructorType::PartialVector { size } => {
ConcreteConstructorHandle::PartialVector { size }
} }
ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
ast::ConstructorType::Vector { size, kind, width } => { ast::ConstructorType::Vector { size, kind, width } => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width }); let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width });
ConcreteConstructorHandle::Type(ty) Constructor::Type(ty)
} }
ast::ConstructorType::PartialMatrix { rows, columns } => { ast::ConstructorType::PartialMatrix { columns, rows } => {
ConcreteConstructorHandle::PartialMatrix { rows, columns } Constructor::PartialMatrix { columns, rows }
} }
ast::ConstructorType::Matrix { ast::ConstructorType::Matrix {
rows, rows,
@ -521,9 +538,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rows, rows,
width, width,
}); });
ConcreteConstructorHandle::Type(ty) Constructor::Type(ty)
} }
ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray, ast::ConstructorType::PartialArray => Constructor::PartialArray,
ast::ConstructorType::Array { base, size } => { ast::ConstructorType::Array { base, size } => {
let base = self.resolve_ast_type(base, &mut ctx.as_global())?; let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
let size = self.array_size(size, &mut ctx.as_global())?; let size = self.array_size(size, &mut ctx.as_global())?;
@ -532,11 +549,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let stride = self.layouter[base].to_stride(); let stride = self.layouter[base].to_stride();
let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride }); let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
ConcreteConstructorHandle::Type(ty) Constructor::Type(ty)
} }
ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty), ast::ConstructorType::Type(ty) => Constructor::Type(ty),
}; };
Ok(c) Ok(handle)
} }
} }