[wgsl-in] Unify ConcreteConstructor and ConcreteConstructorHandle. (#2577)

Replace the `ConcreteConstructor` and `ConcreteConstructorHandle`
types in `front::wgsl::lower::construction` with a single type
`Constructor` with a type parameter that determines how it refers to
Naga types.
This commit is contained in:
Jim Blandy 2023-10-24 13:49:30 -07:00 committed by GitHub
parent 86b6db6f76
commit ada3cd85bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,45 +6,57 @@ use crate::{Handle, Span};
use crate::front::wgsl::error::Error;
use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
enum ConcreteConstructorHandle {
PartialVector {
size: crate::VectorSize,
},
/// A cooked form of `ast::ConstructorType` that uses Naga types whenever
/// possible.
enum Constructor<T> {
/// A vector construction whose component type is inferred from the
/// argument: `vec3(1.0)`.
PartialVector { size: crate::VectorSize },
/// A matrix construction whose component type is inferred from the
/// argument: `mat2x2(1,2,3,4)`.
PartialMatrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
},
/// An array whose component type and size are inferred from the arguments:
/// `array(3,4,5)`.
PartialArray,
Type(Handle<crate::Type>),
/// A known Naga type.
///
/// When we match on this type, we need to see the `TypeInner` here, but at
/// the point that we build this value we'll still need mutable access to
/// the module later. To avoid borrowing from the module, the type parameter
/// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a
/// version holding a tuple `(Handle<Type>, &TypeInner)`.
Type(T),
}
impl ConcreteConstructorHandle {
fn borrow<'a>(&self, module: &'a crate::Module) -> ConcreteConstructor<'a> {
match *self {
Self::PartialVector { size } => ConcreteConstructor::PartialVector { size },
Self::PartialMatrix { columns, rows } => {
ConcreteConstructor::PartialMatrix { columns, rows }
impl Constructor<Handle<crate::Type>> {
/// Return an equivalent `Constructor` value that includes borrowed
/// `TypeInner` values alongside any type handles.
///
/// The returned form is more convenient to match on, since the patterns
/// can actually see what the handle refers to.
fn borrow_inner(
self,
module: &crate::Module,
) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
match self {
Constructor::PartialVector { size } => Constructor::PartialVector { size },
Constructor::PartialMatrix { columns, rows } => {
Constructor::PartialMatrix { columns, rows }
}
Self::PartialArray => ConcreteConstructor::PartialArray,
Self::Type(handle) => ConcreteConstructor::Type(handle, &module.types[handle].inner),
Constructor::PartialArray => Constructor::PartialArray,
Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
}
}
}
enum ConcreteConstructor<'a> {
PartialVector {
size: crate::VectorSize,
},
PartialMatrix {
columns: crate::VectorSize,
rows: crate::VectorSize,
},
PartialArray,
Type(Handle<crate::Type>, &'a crate::TypeInner),
}
impl ConcreteConstructorHandle {
fn to_error_string(&self, ctx: &mut ExpressionContext) -> String {
impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
fn to_error_string(&self, ctx: &ExpressionContext) -> String {
match *self {
Self::PartialVector { size } => {
format!("vec{}<?>", size as u32,)
@ -53,7 +65,7 @@ impl ConcreteConstructorHandle {
format!("mat{}x{}<?>", columns as u32, rows as u32,)
}
Self::PartialArray => "array<?, ?>".to_string(),
Self::Type(ty) => ctx.format_type(ty),
Self::Type((handle, _inner)) => ctx.format_type(handle),
}
}
}
@ -146,15 +158,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};
let constructor = constructor_h.borrow(ctx.module);
// Even though we computed `constructor` above, wait until now to borrow
// a reference to the `TypeInner`, so that the component-handling code
// above can have mutable access to the type arena.
let constructor = constructor_h.borrow_inner(ctx.module);
let expr = match (components, constructor) {
// Empty constructor
(Components::None, dst_ty) => match dst_ty {
ConcreteConstructor::Type(ty, _) => {
return ctx.append_expression(crate::Expression::ZeroValue(ty), span)
Constructor::Type((result_ty, _)) => {
return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span)
}
Constructor::PartialVector { .. }
| Constructor::PartialMatrix { .. }
| Constructor::PartialArray => {
// We have no arguments from which to infer the result type, so
// partial constructors aren't acceptable here.
return Err(Error::TypeNotInferrable(ty_span));
}
_ => return Err(Error::TypeNotInferrable(ty_span)),
},
// Scalar constructor & conversion (scalar -> scalar)
@ -164,7 +185,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Scalar { .. },
..
},
ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }),
Constructor::Type((_, &crate::TypeInner::Scalar { kind, width })),
) => crate::Expression::As {
expr: component,
kind,
@ -178,14 +199,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
..
},
ConcreteConstructor::Type(
Constructor::Type((
_,
&crate::TypeInner::Vector {
size: dst_size,
kind: dst_kind,
width: dst_width,
},
),
)),
) if dst_size == src_size => crate::Expression::As {
expr: component,
kind: dst_kind,
@ -199,7 +220,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
..
},
ConcreteConstructor::PartialVector { size: dst_size },
Constructor::PartialVector { size: dst_size },
) if dst_size == src_size => {
// This is a trivial conversion: the sizes match, and a Partial
// constructor doesn't specify a scalar type, so nothing can
@ -219,14 +240,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
},
..
},
ConcreteConstructor::Type(
Constructor::Type((
_,
&crate::TypeInner::Matrix {
columns: dst_columns,
rows: dst_rows,
width: dst_width,
},
),
)),
) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
expr: component,
kind: crate::ScalarKind::Float,
@ -245,7 +266,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
},
..
},
ConcreteConstructor::PartialMatrix {
Constructor::PartialMatrix {
columns: dst_columns,
rows: dst_rows,
},
@ -263,7 +284,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty_inner: &crate::TypeInner::Scalar { .. },
..
},
ConcreteConstructor::PartialVector { size },
Constructor::PartialVector { size },
) => crate::Expression::Splat {
size,
value: component,
@ -281,14 +302,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
},
..
},
ConcreteConstructor::Type(
Constructor::Type((
_,
&crate::TypeInner::Vector {
size,
kind: dst_kind,
width: dst_width,
},
),
)),
) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat {
size,
value: component,
@ -303,7 +324,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
| &crate::TypeInner::Vector { kind, width, .. },
..
},
ConcreteConstructor::PartialVector { size },
Constructor::PartialVector { size },
)
| (
Components::Many {
@ -312,7 +333,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. },
..
},
ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }),
Constructor::Type((_, &crate::TypeInner::Vector { size, width, kind })),
) => {
let inner = crate::TypeInner::Vector { size, kind, width };
let ty = ctx.ensure_type_exists(inner);
@ -326,7 +347,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Scalar { width, .. },
..
},
ConcreteConstructor::PartialMatrix { columns, rows },
Constructor::PartialMatrix { columns, rows },
)
| (
Components::Many {
@ -334,14 +355,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Scalar { .. },
..
},
ConcreteConstructor::Type(
Constructor::Type((
_,
&crate::TypeInner::Matrix {
columns,
rows,
width,
},
),
)),
) => {
let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector {
width,
@ -377,7 +398,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Vector { width, .. },
..
},
ConcreteConstructor::PartialMatrix { columns, rows },
Constructor::PartialMatrix { columns, rows },
)
| (
Components::Many {
@ -385,14 +406,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
first_component_ty_inner: &crate::TypeInner::Vector { .. },
..
},
ConcreteConstructor::Type(
Constructor::Type((
_,
&crate::TypeInner::Matrix {
columns,
rows,
width,
},
),
)),
) => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
@ -403,7 +424,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
// Array constructor - infer type
(components, ConcreteConstructor::PartialArray) => {
(components, Constructor::PartialArray) => {
let components = components.into_components_vec();
let base = ctx.register_type(components[0])?;
@ -426,10 +447,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// Array or Struct constructor
(
components,
ConcreteConstructor::Type(
Constructor::Type((
ty,
&crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. },
),
)),
) => {
let components = components.into_components_vec();
crate::Expression::Compose { ty, components }
@ -438,19 +459,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// ERRORS
// Bad conversion (type cast)
(Components::One { span, ty_inner, .. }, _) => {
(Components::One { span, ty_inner, .. }, constructor) => {
let from_type = ctx.format_typeinner(ty_inner);
return Err(Error::BadTypeCast {
span,
from_type,
to_type: constructor_h.to_error_string(ctx),
to_type: constructor.to_error_string(ctx),
});
}
// Too many parameters for scalar constructor
(
Components::Many { spans, .. },
ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }),
Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
) => {
let span = spans[1].until(spans.last().unwrap());
return Err(Error::UnexpectedComponents(span));
@ -459,12 +480,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// Parameters are of the wrong type for vector or matrix constructor
(
Components::Many { spans, .. },
ConcreteConstructor::Type(
Constructor::Type((
_,
&crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. },
)
| ConcreteConstructor::PartialVector { .. }
| ConcreteConstructor::PartialMatrix { .. },
))
| Constructor::PartialVector { .. }
| Constructor::PartialMatrix { .. },
) => {
return Err(Error::InvalidConstructorComponentType(spans[0], 0));
}
@ -477,17 +498,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(expr)
}
/// Build a Naga IR [`Type`] for `constructor` if there is enough
/// information to do so.
/// Build a [`Constructor`] for a WGSL construction expression.
///
/// For `Partial` variants of [`ast::ConstructorType`], we don't know the
/// component type, so in that case we return the appropriate `Partial`
/// variant of [`ConcreteConstructorHandle`].
/// If `constructor` conveys enough information to determine which Naga [`Type`]
/// we're actually building (i.e., it's not a partial constructor), then
/// ensure the `Type` exists in [`ctx.module`], and return
/// [`Constructor::Type`].
///
/// But for the other `ConstructorType` variants, we have everything we need
/// to know to actually produce a Naga IR type. In this case we add to/find
/// in [`ctx.module`] a suitable Naga `Type` and return a
/// [`ConcreteConstructorHandle::Type`] value holding its handle.
/// Otherwise, return the [`Constructor`] partial variant corresponding to
/// `constructor`.
///
/// [`Type`]: crate::Type
/// [`ctx.module`]: ExpressionContext::module
@ -495,21 +514,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&mut self,
constructor: &ast::ConstructorType<'source>,
ctx: &mut ExpressionContext<'source, '_, 'out>,
) -> Result<ConcreteConstructorHandle, Error<'source>> {
let c = match *constructor {
) -> Result<Constructor<Handle<crate::Type>>, Error<'source>> {
let handle = match *constructor {
ast::ConstructorType::Scalar { width, kind } => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
ConcreteConstructorHandle::Type(ty)
}
ast::ConstructorType::PartialVector { size } => {
ConcreteConstructorHandle::PartialVector { size }
Constructor::Type(ty)
}
ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
ast::ConstructorType::Vector { size, kind, width } => {
let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width });
ConcreteConstructorHandle::Type(ty)
Constructor::Type(ty)
}
ast::ConstructorType::PartialMatrix { rows, columns } => {
ConcreteConstructorHandle::PartialMatrix { rows, columns }
ast::ConstructorType::PartialMatrix { columns, rows } => {
Constructor::PartialMatrix { columns, rows }
}
ast::ConstructorType::Matrix {
rows,
@ -521,9 +538,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rows,
width,
});
ConcreteConstructorHandle::Type(ty)
Constructor::Type(ty)
}
ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray,
ast::ConstructorType::PartialArray => Constructor::PartialArray,
ast::ConstructorType::Array { base, size } => {
let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
let size = self.array_size(size, &mut ctx.as_global())?;
@ -532,11 +549,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let stride = self.layouter[base].to_stride();
let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
ConcreteConstructorHandle::Type(ty)
Constructor::Type(ty)
}
ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty),
ast::ConstructorType::Type(ty) => Constructor::Type(ty),
};
Ok(c)
Ok(handle)
}
}