[naga msl-out] Annotate dot product functions as wrapped functions

This commit is contained in:
David Rivera 2025-10-26 17:02:51 -04:00 committed by Andy Leiserson
parent e620027f95
commit 1f99103be8
4 changed files with 133 additions and 47 deletions

View File

@ -1,12 +1,13 @@
use crate::proc::KeywordSet;
use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet};
use crate::racy_lock::RacyLock;
use alloc::{format, string::String, vec::Vec};
// MSLS - Metal Shading Language Specification:
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// C++ - Standard for Programming Language C++ (N4431)
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
pub const RESERVED: &[&str] = &[
const RESERVED: &[&str] = &[
// Undocumented
"assert", // found in https://github.com/gfx-rs/wgpu/issues/5347
// Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
@ -346,6 +347,7 @@ pub const RESERVED: &[&str] = &[
super::writer::MODF_FUNCTION,
super::writer::ABS_FUNCTION,
super::writer::DIV_FUNCTION,
// DOT_FUNCTION_PREFIX variants are added dynamically below
super::writer::MOD_FUNCTION,
super::writer::NEG_FUNCTION,
super::writer::F2I32_FUNCTION,
@ -359,8 +361,31 @@ pub const RESERVED: &[&str] = &[
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
];
// The set of concrete integer dot product function variants.
// This must match the set of names that could be produced by
// `Writer::get_dot_wrapper_function_helper_name`.
static DOT_FUNCTION_NAMES: RacyLock<Vec<String>> = RacyLock::new(|| {
let mut names = Vec::new();
for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) {
for size_suffix in vector_sizes().map(vector_size_str) {
let fun_name = format!(
"{}_{}{}",
super::writer::DOT_FUNCTION_PREFIX,
scalar,
size_suffix
);
names.push(fun_name);
}
}
names
});
/// The above set of reserved keywords, turned into a cached HashSet. This saves
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
///
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| KeywordSet::from_iter(RESERVED));
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| {
let mut set = KeywordSet::from_iter(RESERVED);
set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str));
set
});

View File

@ -19,7 +19,7 @@ use crate::{
back::{self, get_entry_points, Baked},
common,
proc::{
self,
self, concrete_int_scalars,
index::{self, BoundsCheck},
ExternalTextureNameKey, NameKey, TypeResolution,
},
@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
pub(crate) const DIV_FUNCTION: &str = "naga_div";
pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
@ -488,7 +489,7 @@ pub struct Writer<W> {
}
impl crate::Scalar {
fn to_msl_name(self) -> &'static str {
pub(super) fn to_msl_name(self) -> &'static str {
use crate::ScalarKind as Sk;
match self {
Self {
@ -2334,26 +2335,28 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Vector {
scalar:
crate::Scalar {
// Resolve float values to MSL's builtin dot function.
kind: crate::ScalarKind::Float,
..
},
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.put_dot_product(
arg,
arg1.unwrap(),
size as usize,
|writer, arg, index| {
// Write the vector expression; this expression is marked to be
// cached so unless it can't be cached (for example, it's a Constant)
// it shouldn't produce large expressions.
writer.put_expression(arg, context, true)?;
// Access the current component on the vector.
write!(writer.out, ".{}", back::COMPONENTS[index])?;
Ok(())
crate::TypeInner::Vector {
size,
scalar:
scalar @ crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
);
} => {
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
write!(self.out, "{fun_name}(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ")")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
@ -3370,26 +3373,15 @@ impl<W: Write> Writer<W> {
} = *expr
{
match fun {
crate::MathFunction::Dot => {
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.
// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(expr_handle);
if let crate::TypeInner::Scalar(scalar) = *inner {
match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.
// This applies both when using the polyfill (because of the duplicate
// evaluation issue) and when we don't use the polyfill (because we
// need them to be emitted before casting to packed chars -- see the
// comment at the call to `put_casting_to_packed_chars`).
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
@ -5806,6 +5798,24 @@ template <typename A>
Ok(())
}
/// Build the mangled helper name for integer vector dot products.
///
/// `scalar` must be a concrete integer scalar type.
///
/// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
fn get_dot_wrapper_function_helper_name(
&self,
scalar: crate::Scalar,
size: crate::VectorSize,
) -> String {
// Check for consistency with [`super::keywords::RESERVED_SET`]
debug_assert!(concrete_int_scalars().any(|s| s == scalar));
let type_name = scalar.to_msl_name();
let size_suffix = common::vector_size_str(size);
format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
}
#[allow(clippy::too_many_arguments)]
fn write_wrapped_math_function(
&mut self,
@ -5861,6 +5871,45 @@ template <typename A>
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
crate::MathFunction::Dot => match *arg_ty {
crate::TypeInner::Vector { size, scalar }
if matches!(
scalar.kind,
crate::ScalarKind::Sint | crate::ScalarKind::Uint
) =>
{
// De-duplicate per (fun, arg type) like other wrapped math functions
let wrapped = WrappedFunction::Math {
fun,
arg_ty: (Some(size), scalar),
};
if !self.wrapped_functions.insert(wrapped) {
return Ok(());
}
let mut vec_ty = String::new();
put_numeric_type(&mut vec_ty, scalar, &[size])?;
let mut ret_ty = String::new();
put_numeric_type(&mut ret_ty, scalar, &[])?;
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
// Emit function signature and body using put_dot_product for the expression
writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
let level = back::Level(1);
write!(self.out, "{level}return ")?;
self.put_dot_product("a", "b", size as usize, |writer, name, index| {
write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
Ok(())
})?;
writeln!(self.out, ";")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
_ => {}
},
_ => {}
}
Ok(())

View File

@ -13,14 +13,22 @@ metal::float2 test_fma(
return metal::fma(a, b, c);
}
int naga_dot_int2(metal::int2 a, metal::int2 b) {
return ( + a.x * b.x + a.y * b.y);
}
uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
}
int test_integer_dot_product(
) {
metal::int2 a_2_ = metal::int2(1);
metal::int2 b_2_ = metal::int2(1);
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
int c_2_ = naga_dot_int2(a_2_, b_2_);
metal::uint3 a_3_ = metal::uint3(1u);
metal::uint3 b_3_ = metal::uint3(1u);
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
return 32;
}

View File

@ -43,6 +43,10 @@ long naga_abs(long val) {
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
}
long naga_dot_long2(metal::long2 a, metal::long2 b) {
return ( + a.x * b.x + a.y * b.y);
}
long int64_function(
long x,
thread long& private_variable,
@ -111,11 +115,9 @@ long int64_function(
long _e130 = val;
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
long _e132 = val;
metal::long2 _e133 = metal::long2(_e132);
long _e134 = val;
metal::long2 _e135 = metal::long2(_e134);
long _e137 = val;
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(( + _e133.x * _e135.x + _e133.y * _e135.y)));
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(naga_dot_long2(metal::long2(_e132), metal::long2(_e134))));
long _e139 = val;
long _e140 = val;
long _e142 = val;
@ -135,6 +137,10 @@ ulong naga_f2u64(float value) {
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
}
ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) {
return ( + a.x * b.x + a.y * b.y);
}
ulong uint64_function(
ulong x_1,
constant UniformCompatible& input_uniform,
@ -199,11 +205,9 @@ ulong uint64_function(
ulong _e125 = val_1;
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
ulong _e127 = val_1;
metal::ulong2 _e128 = metal::ulong2(_e127);
ulong _e129 = val_1;
metal::ulong2 _e130 = metal::ulong2(_e129);
ulong _e132 = val_1;
val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y);
val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129));
ulong _e134 = val_1;
ulong _e135 = val_1;
ulong _e137 = val_1;