mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga msl-out] Annotate dot product functions as wrapped functions
This commit is contained in:
parent
e620027f95
commit
1f99103be8
@ -1,12 +1,13 @@
|
|||||||
use crate::proc::KeywordSet;
|
use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet};
|
||||||
use crate::racy_lock::RacyLock;
|
use crate::racy_lock::RacyLock;
|
||||||
|
use alloc::{format, string::String, vec::Vec};
|
||||||
|
|
||||||
// MSLS - Metal Shading Language Specification:
|
// MSLS - Metal Shading Language Specification:
|
||||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||||
//
|
//
|
||||||
// C++ - Standard for Programming Language C++ (N4431)
|
// C++ - Standard for Programming Language C++ (N4431)
|
||||||
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
|
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
|
||||||
pub const RESERVED: &[&str] = &[
|
const RESERVED: &[&str] = &[
|
||||||
// Undocumented
|
// Undocumented
|
||||||
"assert", // found in https://github.com/gfx-rs/wgpu/issues/5347
|
"assert", // found in https://github.com/gfx-rs/wgpu/issues/5347
|
||||||
// Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
|
// Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
|
||||||
@ -346,6 +347,7 @@ pub const RESERVED: &[&str] = &[
|
|||||||
super::writer::MODF_FUNCTION,
|
super::writer::MODF_FUNCTION,
|
||||||
super::writer::ABS_FUNCTION,
|
super::writer::ABS_FUNCTION,
|
||||||
super::writer::DIV_FUNCTION,
|
super::writer::DIV_FUNCTION,
|
||||||
|
// DOT_FUNCTION_PREFIX variants are added dynamically below
|
||||||
super::writer::MOD_FUNCTION,
|
super::writer::MOD_FUNCTION,
|
||||||
super::writer::NEG_FUNCTION,
|
super::writer::NEG_FUNCTION,
|
||||||
super::writer::F2I32_FUNCTION,
|
super::writer::F2I32_FUNCTION,
|
||||||
@ -359,8 +361,31 @@ pub const RESERVED: &[&str] = &[
|
|||||||
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
|
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// The set of concrete integer dot product function variants.
|
||||||
|
// This must match the set of names that could be produced by
|
||||||
|
// `Writer::get_dot_wrapper_function_helper_name`.
|
||||||
|
static DOT_FUNCTION_NAMES: RacyLock<Vec<String>> = RacyLock::new(|| {
|
||||||
|
let mut names = Vec::new();
|
||||||
|
for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) {
|
||||||
|
for size_suffix in vector_sizes().map(vector_size_str) {
|
||||||
|
let fun_name = format!(
|
||||||
|
"{}_{}{}",
|
||||||
|
super::writer::DOT_FUNCTION_PREFIX,
|
||||||
|
scalar,
|
||||||
|
size_suffix
|
||||||
|
);
|
||||||
|
names.push(fun_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
names
|
||||||
|
});
|
||||||
|
|
||||||
/// The above set of reserved keywords, turned into a cached HashSet. This saves
|
/// The above set of reserved keywords, turned into a cached HashSet. This saves
|
||||||
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
|
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
|
||||||
///
|
///
|
||||||
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
|
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
|
||||||
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| KeywordSet::from_iter(RESERVED));
|
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| {
|
||||||
|
let mut set = KeywordSet::from_iter(RESERVED);
|
||||||
|
set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str));
|
||||||
|
set
|
||||||
|
});
|
||||||
|
|||||||
@ -19,7 +19,7 @@ use crate::{
|
|||||||
back::{self, get_entry_points, Baked},
|
back::{self, get_entry_points, Baked},
|
||||||
common,
|
common,
|
||||||
proc::{
|
proc::{
|
||||||
self,
|
self, concrete_int_scalars,
|
||||||
index::{self, BoundsCheck},
|
index::{self, BoundsCheck},
|
||||||
ExternalTextureNameKey, NameKey, TypeResolution,
|
ExternalTextureNameKey, NameKey, TypeResolution,
|
||||||
},
|
},
|
||||||
@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
|
|||||||
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
|
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
|
||||||
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
|
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
|
||||||
pub(crate) const DIV_FUNCTION: &str = "naga_div";
|
pub(crate) const DIV_FUNCTION: &str = "naga_div";
|
||||||
|
pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
|
||||||
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
|
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
|
||||||
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
|
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
|
||||||
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
|
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
|
||||||
@ -488,7 +489,7 @@ pub struct Writer<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl crate::Scalar {
|
impl crate::Scalar {
|
||||||
fn to_msl_name(self) -> &'static str {
|
pub(super) fn to_msl_name(self) -> &'static str {
|
||||||
use crate::ScalarKind as Sk;
|
use crate::ScalarKind as Sk;
|
||||||
match self {
|
match self {
|
||||||
Self {
|
Self {
|
||||||
@ -2334,26 +2335,28 @@ impl<W: Write> Writer<W> {
|
|||||||
crate::TypeInner::Vector {
|
crate::TypeInner::Vector {
|
||||||
scalar:
|
scalar:
|
||||||
crate::Scalar {
|
crate::Scalar {
|
||||||
|
// Resolve float values to MSL's builtin dot function.
|
||||||
kind: crate::ScalarKind::Float,
|
kind: crate::ScalarKind::Float,
|
||||||
..
|
..
|
||||||
},
|
},
|
||||||
..
|
..
|
||||||
} => "dot",
|
} => "dot",
|
||||||
crate::TypeInner::Vector { size, .. } => {
|
crate::TypeInner::Vector {
|
||||||
return self.put_dot_product(
|
size,
|
||||||
arg,
|
scalar:
|
||||||
arg1.unwrap(),
|
scalar @ crate::Scalar {
|
||||||
size as usize,
|
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
|
||||||
|writer, arg, index| {
|
..
|
||||||
// Write the vector expression; this expression is marked to be
|
|
||||||
// cached so unless it can't be cached (for example, it's a Constant)
|
|
||||||
// it shouldn't produce large expressions.
|
|
||||||
writer.put_expression(arg, context, true)?;
|
|
||||||
// Access the current component on the vector.
|
|
||||||
write!(writer.out, ".{}", back::COMPONENTS[index])?;
|
|
||||||
Ok(())
|
|
||||||
},
|
},
|
||||||
);
|
} => {
|
||||||
|
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
|
||||||
|
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
|
||||||
|
write!(self.out, "{fun_name}(")?;
|
||||||
|
self.put_expression(arg, context, true)?;
|
||||||
|
write!(self.out, ", ")?;
|
||||||
|
self.put_expression(arg1.unwrap(), context, true)?;
|
||||||
|
write!(self.out, ")")?;
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
_ => unreachable!(
|
_ => unreachable!(
|
||||||
"Correct TypeInner for dot product should be already validated"
|
"Correct TypeInner for dot product should be already validated"
|
||||||
@ -3370,26 +3373,15 @@ impl<W: Write> Writer<W> {
|
|||||||
} = *expr
|
} = *expr
|
||||||
{
|
{
|
||||||
match fun {
|
match fun {
|
||||||
crate::MathFunction::Dot => {
|
|
||||||
// WGSL's `dot` function works on any `vecN` type, but Metal's only
|
// WGSL's `dot` function works on any `vecN` type, but Metal's only
|
||||||
// works on floating-point vectors, so we emit inline code for
|
// works on floating-point vectors, so we emit inline code for
|
||||||
// integer vector `dot` calls. But that code uses each argument `N`
|
// integer vector `dot` calls. But that code uses each argument `N`
|
||||||
// times, once for each component (see `put_dot_product`), so to
|
// times, once for each component (see `put_dot_product`), so to
|
||||||
// avoid duplicated evaluation, we must bake integer operands.
|
// avoid duplicated evaluation, we must bake integer operands.
|
||||||
|
// This applies both when using the polyfill (because of the duplicate
|
||||||
// check what kind of product this is depending
|
// evaluation issue) and when we don't use the polyfill (because we
|
||||||
// on the resolve type of the Dot function itself
|
// need them to be emitted before casting to packed chars -- see the
|
||||||
let inner = context.resolve_type(expr_handle);
|
// comment at the call to `put_casting_to_packed_chars`).
|
||||||
if let crate::TypeInner::Scalar(scalar) = *inner {
|
|
||||||
match scalar.kind {
|
|
||||||
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
|
|
||||||
self.need_bake_expressions.insert(arg);
|
|
||||||
self.need_bake_expressions.insert(arg1.unwrap());
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
|
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
|
||||||
self.need_bake_expressions.insert(arg);
|
self.need_bake_expressions.insert(arg);
|
||||||
self.need_bake_expressions.insert(arg1.unwrap());
|
self.need_bake_expressions.insert(arg1.unwrap());
|
||||||
@ -5806,6 +5798,24 @@ template <typename A>
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the mangled helper name for integer vector dot products.
|
||||||
|
///
|
||||||
|
/// `scalar` must be a concrete integer scalar type.
|
||||||
|
///
|
||||||
|
/// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
|
||||||
|
fn get_dot_wrapper_function_helper_name(
|
||||||
|
&self,
|
||||||
|
scalar: crate::Scalar,
|
||||||
|
size: crate::VectorSize,
|
||||||
|
) -> String {
|
||||||
|
// Check for consistency with [`super::keywords::RESERVED_SET`]
|
||||||
|
debug_assert!(concrete_int_scalars().any(|s| s == scalar));
|
||||||
|
|
||||||
|
let type_name = scalar.to_msl_name();
|
||||||
|
let size_suffix = common::vector_size_str(size);
|
||||||
|
format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn write_wrapped_math_function(
|
fn write_wrapped_math_function(
|
||||||
&mut self,
|
&mut self,
|
||||||
@ -5861,6 +5871,45 @@ template <typename A>
|
|||||||
writeln!(self.out, "}}")?;
|
writeln!(self.out, "}}")?;
|
||||||
writeln!(self.out)?;
|
writeln!(self.out)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
crate::MathFunction::Dot => match *arg_ty {
|
||||||
|
crate::TypeInner::Vector { size, scalar }
|
||||||
|
if matches!(
|
||||||
|
scalar.kind,
|
||||||
|
crate::ScalarKind::Sint | crate::ScalarKind::Uint
|
||||||
|
) =>
|
||||||
|
{
|
||||||
|
// De-duplicate per (fun, arg type) like other wrapped math functions
|
||||||
|
let wrapped = WrappedFunction::Math {
|
||||||
|
fun,
|
||||||
|
arg_ty: (Some(size), scalar),
|
||||||
|
};
|
||||||
|
if !self.wrapped_functions.insert(wrapped) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut vec_ty = String::new();
|
||||||
|
put_numeric_type(&mut vec_ty, scalar, &[size])?;
|
||||||
|
let mut ret_ty = String::new();
|
||||||
|
put_numeric_type(&mut ret_ty, scalar, &[])?;
|
||||||
|
|
||||||
|
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
|
||||||
|
|
||||||
|
// Emit function signature and body using put_dot_product for the expression
|
||||||
|
writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
|
||||||
|
let level = back::Level(1);
|
||||||
|
write!(self.out, "{level}return ")?;
|
||||||
|
self.put_dot_product("a", "b", size as usize, |writer, name, index| {
|
||||||
|
write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
writeln!(self.out, ";")?;
|
||||||
|
writeln!(self.out, "}}")?;
|
||||||
|
writeln!(self.out)?;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
|
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@ -13,14 +13,22 @@ metal::float2 test_fma(
|
|||||||
return metal::fma(a, b, c);
|
return metal::fma(a, b, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int naga_dot_int2(metal::int2 a, metal::int2 b) {
|
||||||
|
return ( + a.x * b.x + a.y * b.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
|
||||||
|
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
|
||||||
|
}
|
||||||
|
|
||||||
int test_integer_dot_product(
|
int test_integer_dot_product(
|
||||||
) {
|
) {
|
||||||
metal::int2 a_2_ = metal::int2(1);
|
metal::int2 a_2_ = metal::int2(1);
|
||||||
metal::int2 b_2_ = metal::int2(1);
|
metal::int2 b_2_ = metal::int2(1);
|
||||||
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
|
int c_2_ = naga_dot_int2(a_2_, b_2_);
|
||||||
metal::uint3 a_3_ = metal::uint3(1u);
|
metal::uint3 a_3_ = metal::uint3(1u);
|
||||||
metal::uint3 b_3_ = metal::uint3(1u);
|
metal::uint3 b_3_ = metal::uint3(1u);
|
||||||
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
|
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
|
||||||
return 32;
|
return 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,10 @@ long naga_abs(long val) {
|
|||||||
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
|
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
long naga_dot_long2(metal::long2 a, metal::long2 b) {
|
||||||
|
return ( + a.x * b.x + a.y * b.y);
|
||||||
|
}
|
||||||
|
|
||||||
long int64_function(
|
long int64_function(
|
||||||
long x,
|
long x,
|
||||||
thread long& private_variable,
|
thread long& private_variable,
|
||||||
@ -111,11 +115,9 @@ long int64_function(
|
|||||||
long _e130 = val;
|
long _e130 = val;
|
||||||
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
|
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
|
||||||
long _e132 = val;
|
long _e132 = val;
|
||||||
metal::long2 _e133 = metal::long2(_e132);
|
|
||||||
long _e134 = val;
|
long _e134 = val;
|
||||||
metal::long2 _e135 = metal::long2(_e134);
|
|
||||||
long _e137 = val;
|
long _e137 = val;
|
||||||
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(( + _e133.x * _e135.x + _e133.y * _e135.y)));
|
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(naga_dot_long2(metal::long2(_e132), metal::long2(_e134))));
|
||||||
long _e139 = val;
|
long _e139 = val;
|
||||||
long _e140 = val;
|
long _e140 = val;
|
||||||
long _e142 = val;
|
long _e142 = val;
|
||||||
@ -135,6 +137,10 @@ ulong naga_f2u64(float value) {
|
|||||||
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
|
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) {
|
||||||
|
return ( + a.x * b.x + a.y * b.y);
|
||||||
|
}
|
||||||
|
|
||||||
ulong uint64_function(
|
ulong uint64_function(
|
||||||
ulong x_1,
|
ulong x_1,
|
||||||
constant UniformCompatible& input_uniform,
|
constant UniformCompatible& input_uniform,
|
||||||
@ -199,11 +205,9 @@ ulong uint64_function(
|
|||||||
ulong _e125 = val_1;
|
ulong _e125 = val_1;
|
||||||
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
|
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
|
||||||
ulong _e127 = val_1;
|
ulong _e127 = val_1;
|
||||||
metal::ulong2 _e128 = metal::ulong2(_e127);
|
|
||||||
ulong _e129 = val_1;
|
ulong _e129 = val_1;
|
||||||
metal::ulong2 _e130 = metal::ulong2(_e129);
|
|
||||||
ulong _e132 = val_1;
|
ulong _e132 = val_1;
|
||||||
val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y);
|
val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129));
|
||||||
ulong _e134 = val_1;
|
ulong _e134 = val_1;
|
||||||
ulong _e135 = val_1;
|
ulong _e135 = val_1;
|
||||||
ulong _e137 = val_1;
|
ulong _e137 = val_1;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user