diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 846bd3df5..22f225d1a 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1524,6 +1524,58 @@ impl Writer { Ok(()) } + /// Emit code for the WGSL functions `pack4x{I, U}8[Clamp]`. + fn put_pack4x8( + &mut self, + arg: Handle, + context: &ExpressionContext<'_>, + was_signed: bool, + clamp_bounds: Option<(&str, &str)>, + ) -> Result<(), Error> { + let write_arg = |this: &mut Self| -> BackendResult { + if let Some((min, max)) = clamp_bounds { + // Clamping with scalar bounds works (component-wise) even for packed_[u]char4. + write!(this.out, "{NAMESPACE}::clamp(")?; + this.put_expression(arg, context, true)?; + write!(this.out, ", {min}, {max})")?; + } else { + this.put_expression(arg, context, true)?; + } + Ok(()) + }; + + if context.lang_version >= (2, 1) { + let packed_type = if was_signed { + "packed_char4" + } else { + "packed_uchar4" + }; + // Metal uses little endian byte order, which matches what WGSL expects here. + write!(self.out, "as_type({packed_type}(")?; + write_arg(self)?; + write!(self.out, "))")?; + } else { + // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. + if was_signed { + write!(self.out, "uint(")?; + } + write!(self.out, "(")?; + write_arg(self)?; + write!(self.out, "[0] & 0xFF) | ((")?; + write_arg(self)?; + write!(self.out, "[1] & 0xFF) << 8) | ((")?; + write_arg(self)?; + write!(self.out, "[2] & 0xFF) << 16) | ((")?; + write_arg(self)?; + write!(self.out, "[3] & 0xFF) << 24)")?; + if was_signed { + write!(self.out, ")")?; + } + } + + Ok(()) + } + /// Emit code for the isign expression. /// fn put_isign( @@ -2490,53 +2542,41 @@ impl Writer { write!(self.out, "{fun_name}")?; self.put_call_parameters(iter::once(arg), context)?; } - fun @ (Mf::Pack4xI8 | Mf::Pack4xU8 | Mf::Pack4xI8Clamp | Mf::Pack4xU8Clamp) => { - let was_signed = matches!(fun, Mf::Pack4xI8 | Mf::Pack4xI8Clamp); - let clamp_bounds = match fun { - Mf::Pack4xI8Clamp => Some(("-128", "127")), - Mf::Pack4xU8Clamp => Some(("0", "255")), - _ => None, - }; - if was_signed { - write!(self.out, "uint(")?; - } - let write_arg = |this: &mut Self| -> BackendResult { - if let Some((min, max)) = clamp_bounds { - write!(this.out, "{NAMESPACE}::clamp(")?; - this.put_expression(arg, context, true)?; - write!(this.out, ", {min}, {max})")?; - } else { - this.put_expression(arg, context, true)?; - } - Ok(()) - }; - write!(self.out, "(")?; - write_arg(self)?; - write!(self.out, "[0] & 0xFF) | ((")?; - write_arg(self)?; - write!(self.out, "[1] & 0xFF) << 8) | ((")?; - write_arg(self)?; - write!(self.out, "[2] & 0xFF) << 16) | ((")?; - write_arg(self)?; - write!(self.out, "[3] & 0xFF) << 24)")?; - if was_signed { - write!(self.out, ")")?; - } + Mf::Pack4xI8 => self.put_pack4x8(arg, context, true, None)?, + Mf::Pack4xU8 => self.put_pack4x8(arg, context, false, None)?, + Mf::Pack4xI8Clamp => { + self.put_pack4x8(arg, context, true, Some(("-128", "127")))? + } + Mf::Pack4xU8Clamp => { + self.put_pack4x8(arg, context, false, Some(("0", "255")))? } fun @ (Mf::Unpack4xI8 | Mf::Unpack4xU8) => { - write!(self.out, "(")?; - if matches!(fun, Mf::Unpack4xU8) { - write!(self.out, "u")?; + let sign_prefix = if matches!(fun, Mf::Unpack4xU8) { + "u" + } else { + "" + }; + + if context.lang_version >= (2, 1) { + // Metal uses little endian byte order, which matches what WGSL expects here. + write!( + self.out, + "{sign_prefix}int4(as_type(" + )?; + self.put_expression(arg, context, true)?; + write!(self.out, "))")?; + } else { + // MSL < 2.1 doesn't support `as_type` casting between packed chars and scalars. + write!(self.out, "({sign_prefix}int4(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 8, ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 16, ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " >> 24) << 24 >> 24)")?; } - write!(self.out, "int4(")?; - self.put_expression(arg, context, true)?; - write!(self.out, ", ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 8, ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 16, ")?; - self.put_expression(arg, context, true)?; - write!(self.out, " >> 24) << 24 >> 24)")?; } Mf::QuantizeToF16 => { match *context.resolve_type(arg) { @@ -3279,14 +3319,20 @@ impl Writer { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); } - crate::MathFunction::FirstLeadingBit - | crate::MathFunction::Pack4xI8 + crate::MathFunction::FirstLeadingBit => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::Pack4xI8 | crate::MathFunction::Pack4xU8 | crate::MathFunction::Pack4xI8Clamp | crate::MathFunction::Pack4xU8Clamp | crate::MathFunction::Unpack4xI8 | crate::MathFunction::Unpack4xU8 => { - self.need_bake_expressions.insert(arg); + // On MSL < 2.1, we emit a polyfill for these functions that uses the + // argument multiple times. This is no longer necessary on MSL >= 2.1. + if context.lang_version < (2, 1) { + self.need_bake_expressions.insert(arg); + } } crate::MathFunction::ExtractBits => { // Only argument 1 is re-used. diff --git a/naga/tests/in/wgsl/bits-optimized-msl.toml b/naga/tests/in/wgsl/bits-optimized-msl.toml new file mode 100644 index 000000000..9409d2ac7 --- /dev/null +++ b/naga/tests/in/wgsl/bits-optimized-msl.toml @@ -0,0 +1,4 @@ +targets = "METAL" + +[msl] +lang_version = [2, 1] diff --git a/naga/tests/in/wgsl/bits-optimized-msl.wgsl b/naga/tests/in/wgsl/bits-optimized-msl.wgsl new file mode 100644 index 000000000..a77266ad3 --- /dev/null +++ b/naga/tests/in/wgsl/bits-optimized-msl.wgsl @@ -0,0 +1,69 @@ +// Keep in sync with `bits_downlevel` and `bits_downlevel_webgl` + +@compute @workgroup_size(1) +fn main() { + var i = 0; + var i2 = vec2(0); + var i3 = vec3(0); + var i4 = vec4(0); + var u = 0u; + var u2 = vec2(0u); + var u3 = vec3(0u); + var u4 = vec4(0u); + var f2 = vec2(0.0); + var f4 = vec4(0.0); + u = pack4x8snorm(f4); + u = pack4x8unorm(f4); + u = pack2x16snorm(f2); + u = pack2x16unorm(f2); + u = pack2x16float(f2); + u = pack4xI8(i4); + u = pack4xU8(u4); + u = pack4xI8Clamp(i4); + u = pack4xU8Clamp(u4); + f4 = unpack4x8snorm(u); + f4 = unpack4x8unorm(u); + f2 = unpack2x16snorm(u); + f2 = unpack2x16unorm(u); + f2 = unpack2x16float(u); + i4 = unpack4xI8(u); + u4 = unpack4xU8(u); + i = insertBits(i, i, 5u, 10u); + i2 = insertBits(i2, i2, 5u, 10u); + i3 = insertBits(i3, i3, 5u, 10u); + i4 = insertBits(i4, i4, 5u, 10u); + u = insertBits(u, u, 5u, 10u); + u2 = insertBits(u2, u2, 5u, 10u); + u3 = insertBits(u3, u3, 5u, 10u); + u4 = insertBits(u4, u4, 5u, 10u); + i = extractBits(i, 5u, 10u); + i2 = extractBits(i2, 5u, 10u); + i3 = extractBits(i3, 5u, 10u); + i4 = extractBits(i4, 5u, 10u); + u = extractBits(u, 5u, 10u); + u2 = extractBits(u2, 5u, 10u); + u3 = extractBits(u3, 5u, 10u); + u4 = extractBits(u4, 5u, 10u); + i = firstTrailingBit(i); + u2 = firstTrailingBit(u2); + i3 = firstLeadingBit(i3); + u3 = firstLeadingBit(u3); + i = firstLeadingBit(i); + u = firstLeadingBit(u); + i = countOneBits(i); + i2 = countOneBits(i2); + i3 = countOneBits(i3); + i4 = countOneBits(i4); + u = countOneBits(u); + u2 = countOneBits(u2); + u3 = countOneBits(u3); + u4 = countOneBits(u4); + i = reverseBits(i); + i2 = reverseBits(i2); + i3 = reverseBits(i3); + i4 = reverseBits(i4); + u = reverseBits(u); + u2 = reverseBits(u2); + u3 = reverseBits(u3); + u4 = reverseBits(u4); +} diff --git a/naga/tests/out/msl/wgsl-bits-optimized-msl.msl b/naga/tests/out/msl/wgsl-bits-optimized-msl.msl new file mode 100644 index 000000000..e33ed65f4 --- /dev/null +++ b/naga/tests/out/msl/wgsl-bits-optimized-msl.msl @@ -0,0 +1,137 @@ +// language: metal2.1 +#include +#include + +using metal::uint; + + +kernel void main_( +) { + int i = 0; + metal::int2 i2_ = metal::int2(0); + metal::int3 i3_ = metal::int3(0); + metal::int4 i4_ = metal::int4(0); + uint u = 0u; + metal::uint2 u2_ = metal::uint2(0u); + metal::uint3 u3_ = metal::uint3(0u); + metal::uint4 u4_ = metal::uint4(0u); + metal::float2 f2_ = metal::float2(0.0); + metal::float4 f4_ = metal::float4(0.0); + metal::float4 _e28 = f4_; + u = metal::pack_float_to_snorm4x8(_e28); + metal::float4 _e30 = f4_; + u = metal::pack_float_to_unorm4x8(_e30); + metal::float2 _e32 = f2_; + u = metal::pack_float_to_snorm2x16(_e32); + metal::float2 _e34 = f2_; + u = metal::pack_float_to_unorm2x16(_e34); + metal::float2 _e36 = f2_; + u = as_type(half2(_e36)); + metal::int4 _e38 = i4_; + u = as_type(packed_char4(_e38)); + metal::uint4 _e40 = u4_; + u = as_type(packed_uchar4(_e40)); + metal::int4 _e42 = i4_; + u = as_type(packed_char4(metal::clamp(_e42, -128, 127))); + metal::uint4 _e44 = u4_; + u = as_type(packed_uchar4(metal::clamp(_e44, 0, 255))); + uint _e46 = u; + f4_ = metal::unpack_snorm4x8_to_float(_e46); + uint _e48 = u; + f4_ = metal::unpack_unorm4x8_to_float(_e48); + uint _e50 = u; + f2_ = metal::unpack_snorm2x16_to_float(_e50); + uint _e52 = u; + f2_ = metal::unpack_unorm2x16_to_float(_e52); + uint _e54 = u; + f2_ = float2(as_type(_e54)); + uint _e56 = u; + i4_ = int4(as_type(_e56)); + uint _e58 = u; + u4_ = uint4(as_type(_e58)); + int _e60 = i; + int _e61 = i; + i = metal::insert_bits(_e60, _e61, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int2 _e65 = i2_; + metal::int2 _e66 = i2_; + i2_ = metal::insert_bits(_e65, _e66, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int3 _e70 = i3_; + metal::int3 _e71 = i3_; + i3_ = metal::insert_bits(_e70, _e71, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int4 _e75 = i4_; + metal::int4 _e76 = i4_; + i4_ = metal::insert_bits(_e75, _e76, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + uint _e80 = u; + uint _e81 = u; + u = metal::insert_bits(_e80, _e81, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint2 _e85 = u2_; + metal::uint2 _e86 = u2_; + u2_ = metal::insert_bits(_e85, _e86, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint3 _e90 = u3_; + metal::uint3 _e91 = u3_; + u3_ = metal::insert_bits(_e90, _e91, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint4 _e95 = u4_; + metal::uint4 _e96 = u4_; + u4_ = metal::insert_bits(_e95, _e96, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + int _e100 = i; + i = metal::extract_bits(_e100, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int2 _e104 = i2_; + i2_ = metal::extract_bits(_e104, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int3 _e108 = i3_; + i3_ = metal::extract_bits(_e108, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::int4 _e112 = i4_; + i4_ = metal::extract_bits(_e112, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + uint _e116 = u; + u = metal::extract_bits(_e116, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint2 _e120 = u2_; + u2_ = metal::extract_bits(_e120, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint3 _e124 = u3_; + u3_ = metal::extract_bits(_e124, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + metal::uint4 _e128 = u4_; + u4_ = metal::extract_bits(_e128, metal::min(5u, 32u), metal::min(10u, 32u - metal::min(5u, 32u))); + int _e132 = i; + i = (((metal::ctz(_e132) + 1) % 33) - 1); + metal::uint2 _e134 = u2_; + u2_ = (((metal::ctz(_e134) + 1) % 33) - 1); + metal::int3 _e136 = i3_; + i3_ = metal::select(31 - metal::clz(metal::select(_e136, ~_e136, _e136 < 0)), int3(-1), _e136 == 0 || _e136 == -1); + metal::uint3 _e138 = u3_; + u3_ = metal::select(31 - metal::clz(_e138), uint3(-1), _e138 == 0 || _e138 == -1); + int _e140 = i; + i = metal::select(31 - metal::clz(metal::select(_e140, ~_e140, _e140 < 0)), int(-1), _e140 == 0 || _e140 == -1); + uint _e142 = u; + u = metal::select(31 - metal::clz(_e142), uint(-1), _e142 == 0 || _e142 == -1); + int _e144 = i; + i = metal::popcount(_e144); + metal::int2 _e146 = i2_; + i2_ = metal::popcount(_e146); + metal::int3 _e148 = i3_; + i3_ = metal::popcount(_e148); + metal::int4 _e150 = i4_; + i4_ = metal::popcount(_e150); + uint _e152 = u; + u = metal::popcount(_e152); + metal::uint2 _e154 = u2_; + u2_ = metal::popcount(_e154); + metal::uint3 _e156 = u3_; + u3_ = metal::popcount(_e156); + metal::uint4 _e158 = u4_; + u4_ = metal::popcount(_e158); + int _e160 = i; + i = metal::reverse_bits(_e160); + metal::int2 _e162 = i2_; + i2_ = metal::reverse_bits(_e162); + metal::int3 _e164 = i3_; + i3_ = metal::reverse_bits(_e164); + metal::int4 _e166 = i4_; + i4_ = metal::reverse_bits(_e166); + uint _e168 = u; + u = metal::reverse_bits(_e168); + metal::uint2 _e170 = u2_; + u2_ = metal::reverse_bits(_e170); + metal::uint3 _e172 = u3_; + u3_ = metal::reverse_bits(_e172); + metal::uint4 _e174 = u4_; + u4_ = metal::reverse_bits(_e174); + return; +}