[naga msl-out] Split up write_wrapped_functions()

It was getting unwieldy, and upcoming commits are going to add
additional functions that will be wrapped.
This commit is contained in:
Jamie Nicol 2025-08-14 11:11:56 +01:00 committed by Jamie Nicol
parent 4b5e38ab49
commit bb21da3014

View File

@ -5513,14 +5513,13 @@ template <typename A>
}
}
pub(super) fn write_wrapped_functions(
fn write_wrapped_unary_op(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
op: crate::UnaryOperator,
operand: Handle<crate::Expression>,
) -> BackendResult {
for (expr_handle, expr) in func_ctx.expressions.iter() {
match *expr {
crate::Expression::Unary { op, expr: operand } => {
let operand_ty = func_ctx.resolve_type(operand, &module.types);
match op {
// Negating the TYPE_MIN of a two's complement signed integer
@ -5532,16 +5531,15 @@ template <typename A>
crate::UnaryOperator::Negate
if operand_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
{
let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar()
else {
continue;
let Some((vector_size, scalar)) = operand_ty.vector_size_and_scalar() else {
return Ok(());
};
let wrapped = WrappedFunction::UnaryOp {
op,
ty: (vector_size, scalar),
};
if !self.wrapped_functions.insert(wrapped) {
continue;
return Ok(());
}
let unsigned_scalar = crate::Scalar {
@ -5557,25 +5555,34 @@ template <typename A>
}
Some(size) => {
put_numeric_type(&mut type_name, scalar, &[size])?;
put_numeric_type(
&mut unsigned_type_name,
unsigned_scalar,
&[size],
)?;
put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
}
};
writeln!(self.out, "{type_name} {NEG_FUNCTION}({type_name} val) {{")?;
let level = back::Level(1);
writeln!(self.out, "{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));")?;
writeln!(
self.out,
"{level}return as_type<{type_name}>(-as_type<{unsigned_type_name}>(val));"
)?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
_ => {}
}
Ok(())
}
crate::Expression::Binary { op, left, right } => {
let expr_ty = func_ctx.resolve_type(expr_handle, &module.types);
fn write_wrapped_binary_op(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
expr: Handle<crate::Expression>,
op: crate::BinaryOperator,
left: Handle<crate::Expression>,
right: Handle<crate::Expression>,
) -> BackendResult {
let expr_ty = func_ctx.resolve_type(expr, &module.types);
let left_ty = func_ctx.resolve_type(left, &module.types);
let right_ty = func_ctx.resolve_type(right, &module.types);
match (op, expr_ty.scalar_kind()) {
@ -5590,10 +5597,10 @@ template <typename A>
Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
) => {
let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
continue;
return Ok(());
};
let Some(right_wrapped_ty) = right_ty.vector_size_and_scalar() else {
continue;
return Ok(());
};
let wrapped = WrappedFunction::BinaryOp {
op,
@ -5601,12 +5608,11 @@ template <typename A>
right_ty: right_wrapped_ty,
};
if !self.wrapped_functions.insert(wrapped) {
continue;
return Ok(());
}
let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar()
else {
continue;
let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
return Ok(());
};
let mut type_name = String::new();
match vector_size {
@ -5662,12 +5668,11 @@ template <typename A>
Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint),
) => {
let Some(left_wrapped_ty) = left_ty.vector_size_and_scalar() else {
continue;
return Ok(());
};
let Some((right_vector_size, right_scalar)) =
right_ty.vector_size_and_scalar()
let Some((right_vector_size, right_scalar)) = right_ty.vector_size_and_scalar()
else {
continue;
return Ok(());
};
let wrapped = WrappedFunction::BinaryOp {
op,
@ -5675,12 +5680,11 @@ template <typename A>
right_ty: (right_vector_size, right_scalar),
};
if !self.wrapped_functions.insert(wrapped) {
continue;
return Ok(());
}
let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar()
else {
continue;
let Some((vector_size, scalar)) = expr_ty.vector_size_and_scalar() else {
return Ok(());
};
let mut type_name = String::new();
match vector_size {
@ -5690,9 +5694,7 @@ template <typename A>
let mut rhs_type_name = String::new();
match right_vector_size {
None => put_numeric_type(&mut rhs_type_name, right_scalar, &[])?,
Some(size) => {
put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?
}
Some(size) => put_numeric_type(&mut rhs_type_name, right_scalar, &[size])?,
};
writeln!(
@ -5711,13 +5713,13 @@ template <typename A>
)));
}
};
write!(self.out, "{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == ")?;
write!(
self.out,
"{level}{rhs_type_name} divisor = metal::select(rhs, 1, (lhs == "
)?;
self.put_literal(min_val)?;
writeln!(self.out, " & rhs == -1) | (rhs == 0));")?;
writeln!(
self.out,
"{level}return lhs - (lhs / divisor) * divisor;"
)?
writeln!(self.out, "{level}return lhs - (lhs / divisor) * divisor;")?
}
crate::ScalarKind::Uint => writeln!(
self.out,
@ -5730,14 +5732,20 @@ template <typename A>
}
_ => {}
}
Ok(())
}
crate::Expression::Math {
fun,
arg,
arg1: _,
arg2: _,
arg3: _,
} => {
#[allow(clippy::too_many_arguments)]
fn write_wrapped_math_function(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
fun: crate::MathFunction,
arg: Handle<crate::Expression>,
_arg1: Option<Handle<crate::Expression>>,
_arg2: Option<Handle<crate::Expression>>,
_arg3: Option<Handle<crate::Expression>>,
) -> BackendResult {
let arg_ty = func_ctx.resolve_type(arg, &module.types);
match fun {
// Taking the absolute value of the TYPE_MIN of a two's
@ -5747,19 +5755,16 @@ template <typename A>
// bitcast back to signed.
// This adheres to the WGSL spec in that the absolute of the
// type's minimum value should equal to the minimum value.
crate::MathFunction::Abs
if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) =>
{
let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar()
else {
continue;
crate::MathFunction::Abs if arg_ty.scalar_kind() == Some(crate::ScalarKind::Sint) => {
let Some((vector_size, scalar)) = arg_ty.vector_size_and_scalar() else {
return Ok(());
};
let wrapped = WrappedFunction::Math {
fun,
arg_ty: (vector_size, scalar),
};
if !self.wrapped_functions.insert(wrapped) {
continue;
return Ok(());
}
let unsigned_scalar = crate::Scalar {
@ -5775,11 +5780,7 @@ template <typename A>
}
Some(size) => {
put_numeric_type(&mut type_name, scalar, &[size])?;
put_numeric_type(
&mut unsigned_type_name,
unsigned_scalar,
&[size],
)?;
put_numeric_type(&mut unsigned_type_name, unsigned_scalar, &[size])?;
}
};
@ -5791,12 +5792,17 @@ template <typename A>
}
_ => {}
}
Ok(())
}
crate::Expression::As {
expr,
kind,
convert: Some(width),
} => {
fn write_wrapped_cast(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
expr: Handle<crate::Expression>,
kind: crate::ScalarKind,
convert: Option<crate::Bytes>,
) -> BackendResult {
// Avoid undefined behaviour when casting from a float to integer
// when the value is out of range for the target type. Additionally
// ensure we clamp to the correct value as per the WGSL spec.
@ -5808,15 +5814,18 @@ template <typename A>
// truncate(X) and also exactly representable in the original
// floating point type.
let src_ty = func_ctx.resolve_type(expr, &module.types);
let Some(width) = convert else {
return Ok(());
};
let Some((vector_size, src_scalar)) = src_ty.vector_size_and_scalar() else {
continue;
return Ok(());
};
let dst_scalar = crate::Scalar { kind, width };
if src_scalar.kind != crate::ScalarKind::Float
|| (dst_scalar.kind != crate::ScalarKind::Sint
&& dst_scalar.kind != crate::ScalarKind::Uint)
{
continue;
return Ok(());
}
let wrapped = WrappedFunction::Cast {
src_scalar,
@ -5824,7 +5833,7 @@ template <typename A>
dst_scalar,
};
if !self.wrapped_functions.insert(wrapped) {
continue;
return Ok(());
}
let (min, max) = proc::min_max_float_representable_by(src_scalar, dst_scalar);
@ -5861,16 +5870,32 @@ template <typename A>
writeln!(self.out, "));")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn write_wrapped_image_sample(
&mut self,
_module: &crate::Module,
_func_ctx: &back::FunctionCtx,
_image: Handle<crate::Expression>,
_sampler: Handle<crate::Expression>,
_gather: Option<crate::SwizzleComponent>,
_coordinate: Handle<crate::Expression>,
_array_index: Option<Handle<crate::Expression>>,
_offset: Option<Handle<crate::Expression>>,
_level: crate::SampleLevel,
_depth_ref: Option<Handle<crate::Expression>>,
clamp_to_edge: bool,
) -> BackendResult {
if !clamp_to_edge {
return Ok(());
}
crate::Expression::ImageSample {
clamp_to_edge: true,
..
} => {
let wrapped = WrappedFunction::ImageSample {
clamp_to_edge: true,
};
if !self.wrapped_functions.insert(wrapped) {
continue;
return Ok(());
}
writeln!(self.out, "{NAMESPACE}::float4 {IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION}({NAMESPACE}::texture2d<float, {NAMESPACE}::access::sample> tex, {NAMESPACE}::sampler samp, {NAMESPACE}::float2 coords) {{")?;
@ -5882,6 +5907,62 @@ template <typename A>
)?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
Ok(())
}
pub(super) fn write_wrapped_functions(
&mut self,
module: &crate::Module,
func_ctx: &back::FunctionCtx,
) -> BackendResult {
for (expr_handle, expr) in func_ctx.expressions.iter() {
match *expr {
crate::Expression::Unary { op, expr: operand } => {
self.write_wrapped_unary_op(module, func_ctx, op, operand)?;
}
crate::Expression::Binary { op, left, right } => {
self.write_wrapped_binary_op(module, func_ctx, expr_handle, op, left, right)?;
}
crate::Expression::Math {
fun,
arg,
arg1,
arg2,
arg3,
} => {
self.write_wrapped_math_function(module, func_ctx, fun, arg, arg1, arg2, arg3)?;
}
crate::Expression::As {
expr,
kind,
convert,
} => {
self.write_wrapped_cast(module, func_ctx, expr, kind, convert)?;
}
crate::Expression::ImageSample {
image,
sampler,
gather,
coordinate,
array_index,
offset,
level,
depth_ref,
clamp_to_edge,
} => {
self.write_wrapped_image_sample(
module,
func_ctx,
image,
sampler,
gather,
coordinate,
array_index,
offset,
level,
depth_ref,
clamp_to_edge,
)?;
}
_ => {}
}