mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga msl-out] Annotate dot product functions as wrapped functions
This commit is contained in:
parent
e620027f95
commit
1f99103be8
@ -1,12 +1,13 @@
|
||||
use crate::proc::KeywordSet;
|
||||
use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet};
|
||||
use crate::racy_lock::RacyLock;
|
||||
use alloc::{format, string::String, vec::Vec};
|
||||
|
||||
// MSLS - Metal Shading Language Specification:
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
//
|
||||
// C++ - Standard for Programming Language C++ (N4431)
|
||||
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
|
||||
pub const RESERVED: &[&str] = &[
|
||||
const RESERVED: &[&str] = &[
|
||||
// Undocumented
|
||||
"assert", // found in https://github.com/gfx-rs/wgpu/issues/5347
|
||||
// Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
|
||||
@ -346,6 +347,7 @@ pub const RESERVED: &[&str] = &[
|
||||
super::writer::MODF_FUNCTION,
|
||||
super::writer::ABS_FUNCTION,
|
||||
super::writer::DIV_FUNCTION,
|
||||
// DOT_FUNCTION_PREFIX variants are added dynamically below
|
||||
super::writer::MOD_FUNCTION,
|
||||
super::writer::NEG_FUNCTION,
|
||||
super::writer::F2I32_FUNCTION,
|
||||
@ -359,8 +361,31 @@ pub const RESERVED: &[&str] = &[
|
||||
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
|
||||
];
|
||||
|
||||
// The set of concrete integer dot product function variants.
|
||||
// This must match the set of names that could be produced by
|
||||
// `Writer::get_dot_wrapper_function_helper_name`.
|
||||
static DOT_FUNCTION_NAMES: RacyLock<Vec<String>> = RacyLock::new(|| {
|
||||
let mut names = Vec::new();
|
||||
for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) {
|
||||
for size_suffix in vector_sizes().map(vector_size_str) {
|
||||
let fun_name = format!(
|
||||
"{}_{}{}",
|
||||
super::writer::DOT_FUNCTION_PREFIX,
|
||||
scalar,
|
||||
size_suffix
|
||||
);
|
||||
names.push(fun_name);
|
||||
}
|
||||
}
|
||||
names
|
||||
});
|
||||
|
||||
/// The above set of reserved keywords, turned into a cached HashSet. This saves
|
||||
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
|
||||
///
|
||||
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
|
||||
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| KeywordSet::from_iter(RESERVED));
|
||||
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| {
|
||||
let mut set = KeywordSet::from_iter(RESERVED);
|
||||
set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str));
|
||||
set
|
||||
});
|
||||
|
||||
@ -19,7 +19,7 @@ use crate::{
|
||||
back::{self, get_entry_points, Baked},
|
||||
common,
|
||||
proc::{
|
||||
self,
|
||||
self, concrete_int_scalars,
|
||||
index::{self, BoundsCheck},
|
||||
ExternalTextureNameKey, NameKey, TypeResolution,
|
||||
},
|
||||
@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
|
||||
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
|
||||
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
|
||||
pub(crate) const DIV_FUNCTION: &str = "naga_div";
|
||||
pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
|
||||
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
|
||||
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
|
||||
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
|
||||
@ -488,7 +489,7 @@ pub struct Writer<W> {
|
||||
}
|
||||
|
||||
impl crate::Scalar {
|
||||
fn to_msl_name(self) -> &'static str {
|
||||
pub(super) fn to_msl_name(self) -> &'static str {
|
||||
use crate::ScalarKind as Sk;
|
||||
match self {
|
||||
Self {
|
||||
@ -2334,26 +2335,28 @@ impl<W: Write> Writer<W> {
|
||||
crate::TypeInner::Vector {
|
||||
scalar:
|
||||
crate::Scalar {
|
||||
// Resolve float values to MSL's builtin dot function.
|
||||
kind: crate::ScalarKind::Float,
|
||||
..
|
||||
},
|
||||
..
|
||||
} => "dot",
|
||||
crate::TypeInner::Vector { size, .. } => {
|
||||
return self.put_dot_product(
|
||||
arg,
|
||||
arg1.unwrap(),
|
||||
size as usize,
|
||||
|writer, arg, index| {
|
||||
// Write the vector expression; this expression is marked to be
|
||||
// cached so unless it can't be cached (for example, it's a Constant)
|
||||
// it shouldn't produce large expressions.
|
||||
writer.put_expression(arg, context, true)?;
|
||||
// Access the current component on the vector.
|
||||
write!(writer.out, ".{}", back::COMPONENTS[index])?;
|
||||
Ok(())
|
||||
crate::TypeInner::Vector {
|
||||
size,
|
||||
scalar:
|
||||
scalar @ crate::Scalar {
|
||||
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
|
||||
..
|
||||
},
|
||||
);
|
||||
} => {
|
||||
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
|
||||
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
|
||||
write!(self.out, "{fun_name}(")?;
|
||||
self.put_expression(arg, context, true)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.put_expression(arg1.unwrap(), context, true)?;
|
||||
write!(self.out, ")")?;
|
||||
return Ok(());
|
||||
}
|
||||
_ => unreachable!(
|
||||
"Correct TypeInner for dot product should be already validated"
|
||||
@ -3370,26 +3373,15 @@ impl<W: Write> Writer<W> {
|
||||
} = *expr
|
||||
{
|
||||
match fun {
|
||||
crate::MathFunction::Dot => {
|
||||
// WGSL's `dot` function works on any `vecN` type, but Metal's only
|
||||
// works on floating-point vectors, so we emit inline code for
|
||||
// integer vector `dot` calls. But that code uses each argument `N`
|
||||
// times, once for each component (see `put_dot_product`), so to
|
||||
// avoid duplicated evaluation, we must bake integer operands.
|
||||
|
||||
// check what kind of product this is depending
|
||||
// on the resolve type of the Dot function itself
|
||||
let inner = context.resolve_type(expr_handle);
|
||||
if let crate::TypeInner::Scalar(scalar) = *inner {
|
||||
match scalar.kind {
|
||||
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
|
||||
self.need_bake_expressions.insert(arg);
|
||||
self.need_bake_expressions.insert(arg1.unwrap());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
// WGSL's `dot` function works on any `vecN` type, but Metal's only
|
||||
// works on floating-point vectors, so we emit inline code for
|
||||
// integer vector `dot` calls. But that code uses each argument `N`
|
||||
// times, once for each component (see `put_dot_product`), so to
|
||||
// avoid duplicated evaluation, we must bake integer operands.
|
||||
// This applies both when using the polyfill (because of the duplicate
|
||||
// evaluation issue) and when we don't use the polyfill (because we
|
||||
// need them to be emitted before casting to packed chars -- see the
|
||||
// comment at the call to `put_casting_to_packed_chars`).
|
||||
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
|
||||
self.need_bake_expressions.insert(arg);
|
||||
self.need_bake_expressions.insert(arg1.unwrap());
|
||||
@ -5806,6 +5798,24 @@ template <typename A>
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build the mangled helper name for integer vector dot products.
|
||||
///
|
||||
/// `scalar` must be a concrete integer scalar type.
|
||||
///
|
||||
/// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
|
||||
fn get_dot_wrapper_function_helper_name(
|
||||
&self,
|
||||
scalar: crate::Scalar,
|
||||
size: crate::VectorSize,
|
||||
) -> String {
|
||||
// Check for consistency with [`super::keywords::RESERVED_SET`]
|
||||
debug_assert!(concrete_int_scalars().any(|s| s == scalar));
|
||||
|
||||
let type_name = scalar.to_msl_name();
|
||||
let size_suffix = common::vector_size_str(size);
|
||||
format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn write_wrapped_math_function(
|
||||
&mut self,
|
||||
@ -5861,6 +5871,45 @@ template <typename A>
|
||||
writeln!(self.out, "}}")?;
|
||||
writeln!(self.out)?;
|
||||
}
|
||||
|
||||
crate::MathFunction::Dot => match *arg_ty {
|
||||
crate::TypeInner::Vector { size, scalar }
|
||||
if matches!(
|
||||
scalar.kind,
|
||||
crate::ScalarKind::Sint | crate::ScalarKind::Uint
|
||||
) =>
|
||||
{
|
||||
// De-duplicate per (fun, arg type) like other wrapped math functions
|
||||
let wrapped = WrappedFunction::Math {
|
||||
fun,
|
||||
arg_ty: (Some(size), scalar),
|
||||
};
|
||||
if !self.wrapped_functions.insert(wrapped) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut vec_ty = String::new();
|
||||
put_numeric_type(&mut vec_ty, scalar, &[size])?;
|
||||
let mut ret_ty = String::new();
|
||||
put_numeric_type(&mut ret_ty, scalar, &[])?;
|
||||
|
||||
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
|
||||
|
||||
// Emit function signature and body using put_dot_product for the expression
|
||||
writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
|
||||
let level = back::Level(1);
|
||||
write!(self.out, "{level}return ")?;
|
||||
self.put_dot_product("a", "b", size as usize, |writer, name, index| {
|
||||
write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
|
||||
Ok(())
|
||||
})?;
|
||||
writeln!(self.out, ";")?;
|
||||
writeln!(self.out, "}}")?;
|
||||
writeln!(self.out)?;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@ -13,14 +13,22 @@ metal::float2 test_fma(
|
||||
return metal::fma(a, b, c);
|
||||
}
|
||||
|
||||
int naga_dot_int2(metal::int2 a, metal::int2 b) {
|
||||
return ( + a.x * b.x + a.y * b.y);
|
||||
}
|
||||
|
||||
uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
|
||||
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
|
||||
}
|
||||
|
||||
int test_integer_dot_product(
|
||||
) {
|
||||
metal::int2 a_2_ = metal::int2(1);
|
||||
metal::int2 b_2_ = metal::int2(1);
|
||||
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
|
||||
int c_2_ = naga_dot_int2(a_2_, b_2_);
|
||||
metal::uint3 a_3_ = metal::uint3(1u);
|
||||
metal::uint3 b_3_ = metal::uint3(1u);
|
||||
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
|
||||
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
|
||||
return 32;
|
||||
}
|
||||
|
||||
|
||||
@ -43,6 +43,10 @@ long naga_abs(long val) {
|
||||
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
|
||||
}
|
||||
|
||||
long naga_dot_long2(metal::long2 a, metal::long2 b) {
|
||||
return ( + a.x * b.x + a.y * b.y);
|
||||
}
|
||||
|
||||
long int64_function(
|
||||
long x,
|
||||
thread long& private_variable,
|
||||
@ -111,11 +115,9 @@ long int64_function(
|
||||
long _e130 = val;
|
||||
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
|
||||
long _e132 = val;
|
||||
metal::long2 _e133 = metal::long2(_e132);
|
||||
long _e134 = val;
|
||||
metal::long2 _e135 = metal::long2(_e134);
|
||||
long _e137 = val;
|
||||
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(( + _e133.x * _e135.x + _e133.y * _e135.y)));
|
||||
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(naga_dot_long2(metal::long2(_e132), metal::long2(_e134))));
|
||||
long _e139 = val;
|
||||
long _e140 = val;
|
||||
long _e142 = val;
|
||||
@ -135,6 +137,10 @@ ulong naga_f2u64(float value) {
|
||||
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
|
||||
}
|
||||
|
||||
ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) {
|
||||
return ( + a.x * b.x + a.y * b.y);
|
||||
}
|
||||
|
||||
ulong uint64_function(
|
||||
ulong x_1,
|
||||
constant UniformCompatible& input_uniform,
|
||||
@ -199,11 +205,9 @@ ulong uint64_function(
|
||||
ulong _e125 = val_1;
|
||||
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
|
||||
ulong _e127 = val_1;
|
||||
metal::ulong2 _e128 = metal::ulong2(_e127);
|
||||
ulong _e129 = val_1;
|
||||
metal::ulong2 _e130 = metal::ulong2(_e129);
|
||||
ulong _e132 = val_1;
|
||||
val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y);
|
||||
val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129));
|
||||
ulong _e134 = val_1;
|
||||
ulong _e135 = val_1;
|
||||
ulong _e137 = val_1;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user