Potentially optimize dot4{I,U}8Packed on Metal (#7653)

* Potentially optimize `dot4{I,U}8Packed` on Metal

This might allow the Metal compiler to emit faster code (but that's not
confirmed). See
<https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226>
for the optimization. The limitation to Metal 2.1+ is discussed here:
<https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.

* [naga] Factor out new part of `put_block` on msl

CI on test failed because the latest changes to `put_block` made its
stack too big. Factoring out the new code into a separate method fixes
this issue.
This commit is contained in:
Robert Bamler 2025-05-21 18:47:59 +02:00 committed by GitHub
parent bc0a023cd1
commit d7e6a0e1fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 220 additions and 43 deletions

View File

@ -65,7 +65,7 @@ Bottom level categories:
Naga now infers the correct binding layout when a resource appears only in an assignment to `_`. By @andyleiserson in [#7540](https://github.com/gfx-rs/wgpu/pull/7540).
- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V and HSLS if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494) and [#7574](https://github.com/gfx-rs/wgpu/pull/7574).
- Implement `dot4U8Packed` and `dot4I8Packed` for all backends, using specialized intrinsics on SPIR-V, HSLS, and Metal if available, and polyfills everywhere else. By @robamler in [#7494](https://github.com/gfx-rs/wgpu/pull/7494), [#7574](https://github.com/gfx-rs/wgpu/pull/7574), and [#7653](https://github.com/gfx-rs/wgpu/pull/7653).
- Add polyfilled `pack4x{I,U}8Clamped` built-ins to all backends and WGSL frontend. By @ErichDonGubler in [#7546](https://github.com/gfx-rs/wgpu/pull/7546).
- Allow textureLoad's sample index arg to be unsigned. By @jimblandy in [#7625](https://github.com/gfx-rs/wgpu/pull/7625).
- Properly convert arguments to atomic operations. By @jimblandy in [#7573](https://github.com/gfx-rs/wgpu/pull/7573).

View File

@ -121,6 +121,9 @@ const fn scalar_is_int(scalar: crate::Scalar) -> bool {
/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
/// Prefix for reinterpreted expressions using `as_type<T>(...)`.
const REINTERPRET_PREFIX: &str = "reinterpreted_";
/// Wrapper for identifier names for clamped level-of-detail values
///
/// Values of this type implement [`core::fmt::Display`], formatting as
@ -156,6 +159,30 @@ impl Display for ArraySizeMember {
}
}
/// Wrapper for reinterpreted variables using `as_type<target_type>(orig)`.
///
/// Implements [`core::fmt::Display`], formatting as a name derived from
/// `target_type` and the variable name of `orig`.
#[derive(Clone, Copy)]
struct Reinterpreted<'a> {
target_type: &'a str,
orig: Handle<crate::Expression>,
}
impl<'a> Reinterpreted<'a> {
const fn new(target_type: &'a str, orig: Handle<crate::Expression>) -> Self {
Self { target_type, orig }
}
}
impl Display for Reinterpreted<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_str(REINTERPRET_PREFIX)?;
f.write_str(self.target_type)?;
self.orig.write_prefixed(f, "_e")
}
}
struct TypeContext<'a> {
handle: Handle<crate::Type>,
gctx: proc::GlobalCtx<'a>,
@ -1470,14 +1497,14 @@ impl<W: Write> Writer<W> {
/// Emit code for the arithmetic expression of the dot product.
///
/// The argument `extractor` is a function that accepts a `Writer`, a handle to a vector,
/// and an index. writes out the expression for the component at that index.
fn put_dot_product(
/// The argument `extractor` is a function that accepts a `Writer`, a vector, and
/// an index. It writes out the expression for the vector component at that index.
fn put_dot_product<T: Copy>(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
arg: T,
arg1: T,
size: usize,
extractor: impl Fn(&mut Self, Handle<crate::Expression>, usize) -> BackendResult,
extractor: impl Fn(&mut Self, T, usize) -> BackendResult,
) -> BackendResult {
// Write parentheses around the dot product expression to prevent operators
// with different precedences from applying earlier.
@ -2206,27 +2233,53 @@ impl<W: Write> Writer<W> {
),
},
fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed) => {
let conversion = match fun {
Mf::Dot4I8Packed => "int",
Mf::Dot4U8Packed => "",
_ => unreachable!(),
};
if context.lang_version >= (2, 1) {
// Write potentially optimizable code using `packed_(u?)char4`.
// The two function arguments were already reinterpreted as packed (signed
// or unsigned) chars in `Self::put_block`.
let packed_type = match fun {
Mf::Dot4I8Packed => "packed_char4",
Mf::Dot4U8Packed => "packed_uchar4",
_ => unreachable!(),
};
return self.put_dot_product(
arg,
arg1.unwrap(),
4,
|writer, arg, index| {
write!(writer.out, "({}(", conversion)?;
writer.put_expression(arg, context, true)?;
if index == 3 {
write!(writer.out, ") >> 24)")?;
} else {
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
}
Ok(())
},
);
return self.put_dot_product(
Reinterpreted::new(packed_type, arg),
Reinterpreted::new(packed_type, arg1.unwrap()),
4,
|writer, arg, index| {
// MSL implicitly promotes these (signed or unsigned) chars to
// `int` or `uint` in the multiplication, so no overflow can occur.
write!(writer.out, "{arg}[{index}]")?;
Ok(())
},
);
} else {
// Fall back to a polyfill since MSL < 2.1 doesn't seem to support
// bitcasting from uint to `packed_char4` or `packed_uchar4`.
// See <https://github.com/gfx-rs/wgpu/pull/7574#issuecomment-2835464472>.
let conversion = match fun {
Mf::Dot4I8Packed => "int",
Mf::Dot4U8Packed => "",
_ => unreachable!(),
};
return self.put_dot_product(
arg,
arg1.unwrap(),
4,
|writer, arg, index| {
write!(writer.out, "({}(", conversion)?;
writer.put_expression(arg, context, true)?;
if index == 3 {
write!(writer.out, ") >> 24)")?;
} else {
write!(writer.out, ") << {} >> 24)", (3 - index) * 8)?;
}
Ok(())
},
);
}
}
Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
Mf::Cross => "cross",
@ -3346,6 +3399,38 @@ impl<W: Write> Writer<W> {
Ok(())
}
/// Convert the arguments of `Dot4{I, U}Packed` to `packed_(u?)char4`.
///
/// Caches the results in temporary variables (whose names are derived from
/// the original variable names). This caching avoids the need to redo the
/// casting for each vector component when emitting the dot product.
fn put_casting_to_packed_chars(
&mut self,
fun: crate::MathFunction,
arg0: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
indent: back::Level,
context: &StatementContext<'_>,
) -> Result<(), Error> {
let packed_type = match fun {
crate::MathFunction::Dot4I8Packed => "packed_char4",
crate::MathFunction::Dot4U8Packed => "packed_uchar4",
_ => unreachable!(),
};
for arg in [arg0, arg1] {
write!(
self.out,
"{indent}{packed_type} {0} = as_type<{packed_type}>(",
Reinterpreted::new(packed_type, arg)
)?;
self.put_expression(arg, &context.expression, true)?;
writeln!(self.out, ");")?;
}
Ok(())
}
fn put_block(
&mut self,
level: back::Level,
@ -3362,17 +3447,45 @@ impl<W: Write> Writer<W> {
match *statement {
crate::Statement::Emit(ref range) => {
for handle in range.clone() {
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
// may need to cache a clamped version of their level-of-detail argument.
if let crate::Expression::ImageLoad {
image,
level: mip_level,
..
} = context.expression.function.expressions[handle]
{
self.put_cache_restricted_level(
handle, image, mip_level, level, context,
)?;
use crate::MathFunction as Mf;
match context.expression.function.expressions[handle] {
// `ImageLoad` expressions covered by the `Restrict` bounds check policy
// may need to cache a clamped version of their level-of-detail argument.
crate::Expression::ImageLoad {
image,
level: mip_level,
..
} => {
self.put_cache_restricted_level(
handle, image, mip_level, level, context,
)?;
}
// If we are going to write a `Dot4I8Packed` or `Dot4U8Packed` on Metal
// 2.1+ then we introduce two intermediate variables that recast the two
// arguments as packed (signed or unsigned) chars. The actual dot product
// is implemented in `Self::put_expression`, and it uses both of these
// intermediate variables multiple times. There's no danger that the
// original arguments get modified between the definition of these
// intermediate variables and the implementation of the actual dot
// product since we require the inputs of `Dot4{I, U}Packed` to be baked.
crate::Expression::Math {
fun: fun @ (Mf::Dot4I8Packed | Mf::Dot4U8Packed),
arg,
arg1,
..
} if context.expression.lang_version >= (2, 1) => {
self.put_casting_to_packed_chars(
fun,
arg,
arg1.unwrap(),
level,
context,
)?;
}
_ => (),
}
let ptr_class = context.expression.resolve_type(handle).pointer_space();

View File

@ -1,7 +1,7 @@
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V and HLSL by
# using a version of SPIR-V / shader model that supports these without any extensions.
# Turn on optimizations for `dot4I8Packed` and `dot4U8Packed` on SPIR-V, HLSL, and Metal
# by using a language version / shader model that supports these (without any extensions).
targets = "SPIRV | HLSL"
targets = "SPIRV | HLSL | METAL"
[spv]
# We also need to provide the corresponding capabilities (which are part of SPIR-V >= 1.6).
@ -10,3 +10,6 @@ version = [1, 6]
[hlsl]
shader_model = "V6_4"
[msl]
lang_version = [2, 1]

View File

@ -1,7 +1,7 @@
# Explicitly turn off optimizations for `dot4I8Packed` and `dot4U8Packed`
# on SPIRV and HLSL.
# on SPIRV, HLSL, and Metal.
targets = "SPIRV | HLSL"
targets = "SPIRV | HLSL | METAL"
[spv]
# Provide some unrelated capability because an empty list of capabilities would
@ -11,3 +11,6 @@ capabilities = ["Matrix"]
[hlsl]
shader_model = "V6_3"
[msl]
lang_version = [2, 0]

View File

@ -0,0 +1,33 @@
// language: metal2.1
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
uint test_packed_integer_dot_product(
) {
packed_char4 reinterpreted_packed_char4_e0 = as_type<packed_char4>(1u);
packed_char4 reinterpreted_packed_char4_e1 = as_type<packed_char4>(2u);
int c_5_ = ( + reinterpreted_packed_char4_e0[0] * reinterpreted_packed_char4_e1[0] + reinterpreted_packed_char4_e0[1] * reinterpreted_packed_char4_e1[1] + reinterpreted_packed_char4_e0[2] * reinterpreted_packed_char4_e1[2] + reinterpreted_packed_char4_e0[3] * reinterpreted_packed_char4_e1[3]);
packed_uchar4 reinterpreted_packed_uchar4_e3 = as_type<packed_uchar4>(3u);
packed_uchar4 reinterpreted_packed_uchar4_e4 = as_type<packed_uchar4>(4u);
uint c_6_ = ( + reinterpreted_packed_uchar4_e3[0] * reinterpreted_packed_uchar4_e4[0] + reinterpreted_packed_uchar4_e3[1] * reinterpreted_packed_uchar4_e4[1] + reinterpreted_packed_uchar4_e3[2] * reinterpreted_packed_uchar4_e4[2] + reinterpreted_packed_uchar4_e3[3] * reinterpreted_packed_uchar4_e4[3]);
uint _e7 = 5u + c_6_;
uint _e9 = 6u + c_6_;
packed_char4 reinterpreted_packed_char4_e7 = as_type<packed_char4>(_e7);
packed_char4 reinterpreted_packed_char4_e9 = as_type<packed_char4>(_e9);
int c_7_ = ( + reinterpreted_packed_char4_e7[0] * reinterpreted_packed_char4_e9[0] + reinterpreted_packed_char4_e7[1] * reinterpreted_packed_char4_e9[1] + reinterpreted_packed_char4_e7[2] * reinterpreted_packed_char4_e9[2] + reinterpreted_packed_char4_e7[3] * reinterpreted_packed_char4_e9[3]);
uint _e12 = 7u + c_6_;
uint _e14 = 8u + c_6_;
packed_uchar4 reinterpreted_packed_uchar4_e12 = as_type<packed_uchar4>(_e12);
packed_uchar4 reinterpreted_packed_uchar4_e14 = as_type<packed_uchar4>(_e14);
uint c_8_ = ( + reinterpreted_packed_uchar4_e12[0] * reinterpreted_packed_uchar4_e14[0] + reinterpreted_packed_uchar4_e12[1] * reinterpreted_packed_uchar4_e14[1] + reinterpreted_packed_uchar4_e12[2] * reinterpreted_packed_uchar4_e14[2] + reinterpreted_packed_uchar4_e12[3] * reinterpreted_packed_uchar4_e14[3]);
return c_8_;
}
kernel void main_(
) {
uint _e0 = test_packed_integer_dot_product();
return;
}

View File

@ -0,0 +1,25 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
uint test_packed_integer_dot_product(
) {
int c_5_ = ( + (int(1u) << 24 >> 24) * (int(2u) << 24 >> 24) + (int(1u) << 16 >> 24) * (int(2u) << 16 >> 24) + (int(1u) << 8 >> 24) * (int(2u) << 8 >> 24) + (int(1u) >> 24) * (int(2u) >> 24));
uint c_6_ = ( + ((3u) << 24 >> 24) * ((4u) << 24 >> 24) + ((3u) << 16 >> 24) * ((4u) << 16 >> 24) + ((3u) << 8 >> 24) * ((4u) << 8 >> 24) + ((3u) >> 24) * ((4u) >> 24));
uint _e7 = 5u + c_6_;
uint _e9 = 6u + c_6_;
int c_7_ = ( + (int(_e7) << 24 >> 24) * (int(_e9) << 24 >> 24) + (int(_e7) << 16 >> 24) * (int(_e9) << 16 >> 24) + (int(_e7) << 8 >> 24) * (int(_e9) << 8 >> 24) + (int(_e7) >> 24) * (int(_e9) >> 24));
uint _e12 = 7u + c_6_;
uint _e14 = 8u + c_6_;
uint c_8_ = ( + ((_e12) << 24 >> 24) * ((_e14) << 24 >> 24) + ((_e12) << 16 >> 24) * ((_e14) << 16 >> 24) + ((_e12) << 8 >> 24) * ((_e14) << 8 >> 24) + ((_e12) >> 24) * ((_e14) >> 24));
return c_8_;
}
kernel void main_(
) {
uint _e0 = test_packed_integer_dot_product();
return;
}