[naga] Vectorize [un]pack4x{I, U}8[Clamp] on msl

Implements more direct conversions between 32-bit integers and 4x8-bit
integer vectors using bit casting to/from `packed_[u]char4` when on
MSL 2.1+ (older versions of MSL don't seem to support these bit casts).

- `unpack4x{I, U}8(x)` becomes `[u]int4(as_type<packed_[u]char4>(x))`;
- `pack4x{I, U}8(x)` becomes `as_type<uint>(packed_[u]char4(x))`; and
- `pack4x{I, U}8Clamp(x)` becomes
  `as_type<uint>(packed_uchar4(metal::clamp(x, 0, 255)))`.

These bit casts match the WGSL spec for these functions because Metal
runs on little-endian machines.
This commit is contained in:
Robert Bamler 2025-05-03 20:56:26 +02:00 committed by Teodor Tanasoaia
parent b32eb4a120
commit 8969965978
4 changed files with 303 additions and 47 deletions

View File

@ -1524,6 +1524,58 @@ impl<W: Write> Writer<W> {
Ok(())
}
/// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`.
fn put_pack4x8(
&mut self,
arg: Handle<crate::Expression>,
context: &ExpressionContext<'_>,
was_signed: bool,
clamp_bounds: Option<(&str, &str)>,
) -> Result<(), Error> {
let write_arg = |this: &mut Self| -> BackendResult {
if let Some((min, max)) = clamp_bounds {
// Clamping with scalar bounds works (component-wise) even for packed_[u]char4.
write!(this.out, "{NAMESPACE}::clamp(")?;
this.put_expression(arg, context, true)?;
write!(this.out, ", {min}, {max})")?;
} else {
this.put_expression(arg, context, true)?;
}
Ok(())
};
if context.lang_version >= (2, 1) {
let packed_type = if was_signed {
"packed_char4"
} else {
"packed_uchar4"
};
// Metal uses little endian byte order, which matches what WGSL expects here.
write!(self.out, "as_type<uint>({packed_type}(")?;
write_arg(self)?;
write!(self.out, "))")?;
} else {
// MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
if was_signed {
write!(self.out, "uint(")?;
}
write!(self.out, "(")?;
write_arg(self)?;
write!(self.out, "[0] & 0xFF) | ((")?;
write_arg(self)?;
write!(self.out, "[1] & 0xFF) << 8) | ((")?;
write_arg(self)?;
write!(self.out, "[2] & 0xFF) << 16) | ((")?;
write_arg(self)?;
write!(self.out, "[3] & 0xFF) << 24)")?;
if was_signed {
write!(self.out, ")")?;
}
}
Ok(())
}
/// Emit code for the isign expression.
///
fn put_isign(
@ -2490,53 +2542,41 @@ impl<W: Write> Writer<W> {
write!(self.out, "{fun_name}")?;
self.put_call_parameters(iter::once(arg), context)?;
}
fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => {
let was_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp);
let clamp_bounds = match fun {
Mf::Pack4xI8Clamp => Some(("-128", "127")),
Mf::Pack4xU8Clamp => Some(("0", "255")),
_ => None,
};
if was_signed {
write!(self.out, "uint(")?;
}
let write_arg = |this: &mut Self| -> BackendResult {
if let Some((min, max)) = clamp_bounds {
write!(this.out, "{NAMESPACE}::clamp(")?;
this.put_expression(arg, context, true)?;
write!(this.out, ", {min}, {max})")?;
} else {
this.put_expression(arg, context, true)?;
}
Ok(())
};
write!(self.out, "(")?;
write_arg(self)?;
write!(self.out, "[0] & 0xFF) | ((")?;
write_arg(self)?;
write!(self.out, "[1] & 0xFF) << 8) | ((")?;
write_arg(self)?;
write!(self.out, "[2] & 0xFF) << 16) | ((")?;
write_arg(self)?;
write!(self.out, "[3] & 0xFF) << 24)")?;
if was_signed {
write!(self.out, ")")?;
}
Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?,
Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?,
Mf::Pack4xI8Clamp => {
self.put_pack4x8(arg, context, true, Some(("-128", "127")))?
}
Mf::Pack4xU8Clamp => {
self.put_pack4x8(arg, context, false, Some(("0", "255")))?
}
fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => {
write!(self.out, "(")?;
if matches!(fun, Mf::Unpack4xU8) {
write!(self.out, "u")?;
let sign_prefix = if matches!(fun, Mf::Unpack4xU8) {
"u"
} else {
""
};
if context.lang_version >= (2, 1) {
// Metal uses little endian byte order, which matches what WGSL expects here.
write!(
self.out,
"{sign_prefix}int4(as_type<packed_{sign_prefix}char4>("
)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))")?;
} else {
// MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars.
write!(self.out, "({sign_prefix}int4(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " >> 8, ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " >> 16, ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
write!(self.out, "int4(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " >> 8, ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " >> 16, ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " >> 24) << 24 >> 24)")?;
}
Mf::QuantizeToF16 => {
match *context.resolve_type(arg) {
@ -3279,14 +3319,20 @@ impl<W: Write> Writer<W> {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
crate::MathFunction::FirstLeadingBit
| crate::MathFunction::Pack4xI8
crate::MathFunction::FirstLeadingBit => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8
| crate::MathFunction::Pack4xI8Clamp
| crate::MathFunction::Pack4xU8Clamp
| crate::MathFunction::Unpack4xI8
| crate::MathFunction::Unpack4xU8 => {
self.need_bake_expressions.insert(arg);
// On MSL < 2.1, we emit a polyfill for these functions that uses the
// argument multiple times. This is no longer necessary on MSL >= 2.1.
if context.lang_version < (2, 1) {
self.need_bake_expressions.insert(arg);
}
}
crate::MathFunction::ExtractBits => {
// Only argument 1 is re-used.

View File

@ -0,0 +1,4 @@
targets = "METAL"
[msl]
lang_version = [2, 1]

View File

@ -0,0 +1,69 @@
// Keep in sync with `bits_downlevel` and `bits_downlevel_webgl`
@compute @workgroup_size(1)
fn main() {
var i = 0;
var i2 = vec2<i32>(0);
var i3 = vec3<i32>(0);
var i4 = vec4<i32>(0);
var u = 0u;
var u2 = vec2<u32>(0u);
var u3 = vec3<u32>(0u);
var u4 = vec4<u32>(0u);
var f2 = vec2<f32>(0.0);
var f4 = vec4<f32>(0.0);
u = pack4x8snorm(f4);
u = pack4x8unorm(f4);
u = pack2x16snorm(f2);
u = pack2x16unorm(f2);
u = pack2x16float(f2);
u = pack4xI8(i4);
u = pack4xU8(u4);
u = pack4xI8Clamp(i4);
u = pack4xU8Clamp(u4);
f4 = unpack4x8snorm(u);
f4 = unpack4x8unorm(u);
f2 = unpack2x16snorm(u);
f2 = unpack2x16unorm(u);
f2 = unpack2x16float(u);
i4 = unpack4xI8(u);
u4 = unpack4xU8(u);
i = insertBits(i, i, 5u, 10u);
i2 = insertBits(i2, i2, 5u, 10u);
i3 = insertBits(i3, i3, 5u, 10u);
i4 = insertBits(i4, i4, 5u, 10u);
u = insertBits(u, u, 5u, 10u);
u2 = insertBits(u2, u2, 5u, 10u);
u3 = insertBits(u3, u3, 5u, 10u);
u4 = insertBits(u4, u4, 5u, 10u);
i = extractBits(i, 5u, 10u);
i2 = extractBits(i2, 5u, 10u);
i3 = extractBits(i3, 5u, 10u);
i4 = extractBits(i4, 5u, 10u);
u = extractBits(u, 5u, 10u);
u2 = extractBits(u2, 5u, 10u);
u3 = extractBits(u3, 5u, 10u);
u4 = extractBits(u4, 5u, 10u);
i = firstTrailingBit(i);
u2 = firstTrailingBit(u2);
i3 = firstLeadingBit(i3);
u3 = firstLeadingBit(u3);
i = firstLeadingBit(i);
u = firstLeadingBit(u);
i = countOneBits(i);
i2 = countOneBits(i2);
i3 = countOneBits(i3);
i4 = countOneBits(i4);
u = countOneBits(u);
u2 = countOneBits(u2);
u3 = countOneBits(u3);
u4 = countOneBits(u4);
i = reverseBits(i);
i2 = reverseBits(i2);
i3 = reverseBits(i3);
i4 = reverseBits(i4);
u = reverseBits(u);
u2 = reverseBits(u2);
u3 = reverseBits(u3);
u4 = reverseBits(u4);
}

View File

@ -0,0 +1,137 @@
// language: metal2.1
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
kernel void main_(
) {
int i = 0;
metal::int2 i2_ = metal::int2(0);
metal::int3 i3_ = metal::int3(0);
metal::int4 i4_ = metal::int4(0);
uint u = 0u;
metal::uint2 u2_ = metal::uint2(0u);
metal::uint3 u3_ = metal::uint3(0u);
metal::uint4 u4_ = metal::uint4(0u);
metal::float2 f2_ = metal::float2(0.0);
metal::float4 f4_ = metal::float4(0.0);
metal::float4 _e28 = f4_;
u = metal::pack_float_to_snorm4x8(_e28);
metal::float4 _e30 = f4_;
u = metal::pack_float_to_unorm4x8(_e30);
metal::float2 _e32 = f2_;
u = metal::pack_float_to_snorm2x16(_e32);
metal::float2 _e34 = f2_;
u = metal::pack_float_to_unorm2x16(_e34);
metal::float2 _e36 = f2_;
u = as_type<uint>(half2(_e36));
metal::int4 _e38 = i4_;
u = as_type<uint>(packed_char4(_e38));
metal::uint4 _e40 = u4_;
u = as_type<uint>(packed_uchar4(_e40));
metal::int4 _e42 = i4_;
u = as_type<uint>(packed_char4(metal::clamp(_e42, -128, 127)));
metal::uint4 _e44 = u4_;
u = as_type<uint>(packed_uchar4(metal::clamp(_e44, 0, 255)));
uint _e46 = u;
f4_ = metal::unpack_snorm4x8_to_float(_e46);
uint _e48 = u;
f4_ = metal::unpack_unorm4x8_to_float(_e48);
uint _e50 = u;
f2_ = metal::unpack_snorm2x16_to_float(_e50);
uint _e52 = u;
f2_ = metal::unpack_unorm2x16_to_float(_e52);
uint _e54 = u;
f2_ = float2(as_type<half2>(_e54));
uint _e56 = u;
i4_ = int4(as_type<packed_char4>(_e56));
uint _e58 = u;
u4_ = uint4(as_type<packed_uchar4>(_e58));
int _e60 = i;
int _e61 = i;
i = metal::insert_bits(_e60, _e61, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::int2 _e65 = i2_;
metal::int2 _e66 = i2_;
i2_ = metal::insert_bits(_e65, _e66, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::int3 _e70 = i3_;
metal::int3 _e71 = i3_;
i3_ = metal::insert_bits(_e70, _e71, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::int4 _e75 = i4_;
metal::int4 _e76 = i4_;
i4_ = metal::insert_bits(_e75, _e76, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
uint _e80 = u;
uint _e81 = u;
u = metal::insert_bits(_e80, _e81, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::uint2 _e85 = u2_;
metal::uint2 _e86 = u2_;
u2_ = metal::insert_bits(_e85, _e86, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::uint3 _e90 = u3_;
metal::uint3 _e91 = u3_;
u3_ = metal::insert_bits(_e90, _e91, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::uint4 _e95 = u4_;
metal::uint4 _e96 = u4_;
u4_ = metal::insert_bits(_e95, _e96, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
int _e100 = i;
i = metal::extract_bits(_e100, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::int2 _e104 = i2_;
i2_ = metal::extract_bits(_e104, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::int3 _e108 = i3_;
i3_ = metal::extract_bits(_e108, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::int4 _e112 = i4_;
i4_ = metal::extract_bits(_e112, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
uint _e116 = u;
u = metal::extract_bits(_e116, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::uint2 _e120 = u2_;
u2_ = metal::extract_bits(_e120, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::uint3 _e124 = u3_;
u3_ = metal::extract_bits(_e124, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
metal::uint4 _e128 = u4_;
u4_ = metal::extract_bits(_e128, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u)));
int _e132 = i;
i = (((metal::ctz(_e132) + 1) % 33) - 1);
metal::uint2 _e134 = u2_;
u2_ = (((metal::ctz(_e134) + 1) % 33) - 1);
metal::int3 _e136 = i3_;
i3_ = metal::select(31 - metal::clz(metal::select(_e136, ~_e136, _e136 < 0)), int3(-1), _e136 == 0 || _e136 == -1);
metal::uint3 _e138 = u3_;
u3_ = metal::select(31 - metal::clz(_e138), uint3(-1), _e138 == 0 || _e138 == -1);
int _e140 = i;
i = metal::select(31 - metal::clz(metal::select(_e140, ~_e140, _e140 < 0)), int(-1), _e140 == 0 || _e140 == -1);
uint _e142 = u;
u = metal::select(31 - metal::clz(_e142), uint(-1), _e142 == 0 || _e142 == -1);
int _e144 = i;
i = metal::popcount(_e144);
metal::int2 _e146 = i2_;
i2_ = metal::popcount(_e146);
metal::int3 _e148 = i3_;
i3_ = metal::popcount(_e148);
metal::int4 _e150 = i4_;
i4_ = metal::popcount(_e150);
uint _e152 = u;
u = metal::popcount(_e152);
metal::uint2 _e154 = u2_;
u2_ = metal::popcount(_e154);
metal::uint3 _e156 = u3_;
u3_ = metal::popcount(_e156);
metal::uint4 _e158 = u4_;
u4_ = metal::popcount(_e158);
int _e160 = i;
i = metal::reverse_bits(_e160);
metal::int2 _e162 = i2_;
i2_ = metal::reverse_bits(_e162);
metal::int3 _e164 = i3_;
i3_ = metal::reverse_bits(_e164);
metal::int4 _e166 = i4_;
i4_ = metal::reverse_bits(_e166);
uint _e168 = u;
u = metal::reverse_bits(_e168);
metal::uint2 _e170 = u2_;
u2_ = metal::reverse_bits(_e170);
metal::uint3 _e172 = u3_;
u3_ = metal::reverse_bits(_e172);
metal::uint4 _e174 = u4_;
u4_ = metal::reverse_bits(_e174);
return;
}