fix(naga): properly impl. auto. type conv. for select

This commit is contained in:
Erich Gubler 2025-04-14 14:49:08 -04:00
parent 3c0803d1cc
commit db3c35db90
19 changed files with 1026 additions and 621 deletions

View File

@ -77,6 +77,7 @@ Naga now infers the correct binding layout when a resource appears only in an as
- Properly apply WGSL's automatic conversions to the arguments to texture sampling functions. By @jimblandy in [#7548](https://github.com/gfx-rs/wgpu/pull/7548).
- Properly evaluate `abs(most negative abstract int)`. By @jimblandy in [#7507](https://github.com/gfx-rs/wgpu/pull/7507).
- Generate vectorized code for `[un]pack4x{I,U}8[Clamp]` on SPIR-V and MSL 2.1+. By @robamler in [#7664](https://github.com/gfx-rs/wgpu/pull/7664).
- Fix typing for `select`, which had issues particularly with a lack of automatic type conversion. By @ErichDonGubler in [#7572](https://github.com/gfx-rs/wgpu/pull/7572).
#### DX12

View File

@ -18,6 +18,9 @@ webgpu:api,operation,rendering,depth:*
webgpu:api,operation,rendering,draw:*
webgpu:api,operation,shader_module,compilation_info:*
webgpu:api,operation,uncapturederror:iff_uncaptured:*
//FAIL: webgpu:shader,execution,expression,call,builtin,select:*
// - Fails with `const`/abstract int cases on all platforms because of <https://github.com/gfx-rs/wgpu/issues/4507>.
// - Fails with `vec3` & `f16` cases on macOS because of <https://github.com/gfx-rs/wgpu/issues/5262>.
//FAIL: webgpu:api,operation,uncapturederror:onuncapturederror_order_wrt_addEventListener
// There are also two unimplemented SKIPs in uncapturederror not enumerated here.
webgpu:api,validation,encoding,queries,general:occlusion_query,query_type:*

View File

@ -399,6 +399,16 @@ pub(crate) enum Error<'a> {
on_what: DiagnosticAttributeNotSupportedPosition,
spans: Vec<Span>,
},
SelectUnexpectedArgumentType {
arg_span: Span,
arg_type: String,
},
SelectRejectAndAcceptHaveNoCommonType {
reject_span: Span,
reject_type: String,
accept_span: Span,
accept_type: String,
},
}
impl From<ConflictingDiagnosticRuleError> for Error<'_> {
@ -1342,6 +1352,24 @@ impl<'a> Error<'a> {
],
}
}
Error::SelectUnexpectedArgumentType { arg_span, ref arg_type } => ParseError {
message: "unexpected argument type for `select` call".into(),
labels: vec![(arg_span, format!("this value of type {arg_type}").into())],
notes: vec!["expected a scalar or a `vecN` of scalars".into()],
},
Error::SelectRejectAndAcceptHaveNoCommonType {
reject_span,
ref reject_type,
accept_span,
ref accept_type,
} => ParseError {
message: "type mismatch for reject and accept values in `select` call".into(),
labels: vec![
(reject_span, format!("reject value of type {reject_type}").into()),
(accept_span, format!("accept value of type {accept_type}").into()),
],
notes: vec![],
},
}
}
}

View File

@ -1,6 +1,7 @@
use alloc::{
borrow::ToOwned,
boxed::Box,
format,
string::{String, ToString},
vec::Vec,
};
@ -2541,12 +2542,68 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
"select" => {
let mut args = ctx.prepare_args(arguments, 3, span);
let reject = self.expression(args.next()?, ctx)?;
let accept = self.expression(args.next()?, ctx)?;
let reject_orig = args.next()?;
let accept_orig = args.next()?;
let mut values = [
self.expression_for_abstract(reject_orig, ctx)?,
self.expression_for_abstract(accept_orig, ctx)?,
];
let condition = self.expression(args.next()?, ctx)?;
args.finish()?;
let diagnostic_details =
|ctx: &ExpressionContext<'_, '_, '_>,
ty_res: &proc::TypeResolution,
orig_expr| {
(
ctx.ast_expressions.get_span(orig_expr),
format!("`{}`", ctx.as_diagnostic_display(ty_res)),
)
};
for (&value, orig_value) in
values.iter().zip([reject_orig, accept_orig])
{
let value_ty_res = resolve!(ctx, value);
if value_ty_res
.inner_with(&ctx.module.types)
.vector_size_and_scalar()
.is_none()
{
let (arg_span, arg_type) =
diagnostic_details(ctx, value_ty_res, orig_value);
return Err(Box::new(Error::SelectUnexpectedArgumentType {
arg_span,
arg_type,
}));
}
}
let mut consensus_scalar = ctx
.automatic_conversion_consensus(&values)
.map_err(|_idx| {
let [reject, accept] = values;
let [(reject_span, reject_type), (accept_span, accept_type)] =
[(reject_orig, reject), (accept_orig, accept)].map(
|(orig_expr, expr)| {
let ty_res = &ctx.typifier()[expr];
diagnostic_details(ctx, ty_res, orig_expr)
},
);
Error::SelectRejectAndAcceptHaveNoCommonType {
reject_span,
reject_type,
accept_span,
accept_type,
}
})?;
if !ctx.is_const(condition) {
consensus_scalar = consensus_scalar.concretize();
}
ctx.convert_slice_to_common_leaf_scalar(&mut values, consensus_scalar)?;
let [reject, accept] = values;
ir::Expression::Select {
reject,
accept,

View File

@ -563,6 +563,27 @@ pub enum ConstantEvaluatorError {
RuntimeExpr,
#[error("Unexpected override-expression")]
OverrideExpr,
#[error("Expected boolean expression for condition argument of `select`, got something else")]
SelectScalarConditionNotABool,
#[error(
"Expected vectors of the same size for reject and accept args., got {:?} and {:?}",
reject,
accept
)]
SelectVecRejectAcceptSizeMismatch {
reject: crate::VectorSize,
accept: crate::VectorSize,
},
#[error("Expected boolean vector for condition arg., got something else")]
SelectConditionNotAVecBool,
#[error(
"Expected same number of vector components between condition, accept, and reject args., got something else",
)]
SelectConditionVecSizeMismatch,
#[error(
"Expected reject and accept args. to be scalars of vectors of the same type, got something else",
)]
SelectAcceptRejectTypeMismatch,
}
impl<'a> ConstantEvaluator<'a> {
@ -904,9 +925,19 @@ impl<'a> ConstantEvaluator<'a> {
)),
}
}
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
"select built-in function".into(),
)),
Expression::Select {
reject,
accept,
condition,
} => {
let mut arg = |expr| self.check_and_get(expr);
let reject = arg(reject)?;
let accept = arg(accept)?;
let condition = arg(condition)?;
self.select(reject, accept, condition, span)
}
Expression::Relational { fun, argument } => {
let argument = self.check_and_get(argument)?;
self.relational(fun, argument, span)
@ -2501,6 +2532,116 @@ impl<'a> ConstantEvaluator<'a> {
Ok(resolution)
}
fn select(
&mut self,
reject: Handle<Expression>,
accept: Handle<Expression>,
condition: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let mut arg = |arg| self.eval_zero_value_and_splat(arg, span);
let reject = arg(reject)?;
let accept = arg(accept)?;
let condition = arg(condition)?;
let select_single_component =
|this: &mut Self, reject_scalar, reject, accept, condition| {
let accept = this.cast(accept, reject_scalar, span)?;
if condition {
Ok(accept)
} else {
Ok(reject)
}
};
match (&self.expressions[reject], &self.expressions[accept]) {
(&Expression::Literal(reject_lit), &Expression::Literal(_accept_lit)) => {
let reject_scalar = reject_lit.scalar();
let &Expression::Literal(Literal::Bool(condition)) = &self.expressions[condition]
else {
return Err(ConstantEvaluatorError::SelectScalarConditionNotABool);
};
select_single_component(self, reject_scalar, reject, accept, condition)
}
(
&Expression::Compose {
ty: reject_ty,
components: ref reject_components,
},
&Expression::Compose {
ty: accept_ty,
components: ref accept_components,
},
) => {
let ty_deets = |ty| {
let (size, scalar) = self.types[ty].inner.vector_size_and_scalar().unwrap();
(size.unwrap(), scalar)
};
let expected_vec_size = {
let [(reject_vec_size, _), (accept_vec_size, _)] =
[reject_ty, accept_ty].map(ty_deets);
if reject_vec_size != accept_vec_size {
return Err(ConstantEvaluatorError::SelectVecRejectAcceptSizeMismatch {
reject: reject_vec_size,
accept: accept_vec_size,
});
}
reject_vec_size
};
let condition_components = match self.expressions[condition] {
Expression::Literal(Literal::Bool(condition)) => {
vec![condition; (expected_vec_size as u8).into()]
}
Expression::Compose {
ty: condition_ty,
components: ref condition_components,
} => {
let (condition_vec_size, condition_scalar) = ty_deets(condition_ty);
if condition_scalar.kind != ScalarKind::Bool {
return Err(ConstantEvaluatorError::SelectConditionNotAVecBool);
}
if condition_vec_size != expected_vec_size {
return Err(ConstantEvaluatorError::SelectConditionVecSizeMismatch);
}
condition_components
.iter()
.copied()
.map(|component| match &self.expressions[component] {
&Expression::Literal(Literal::Bool(condition)) => condition,
_ => unreachable!(),
})
.collect()
}
_ => return Err(ConstantEvaluatorError::SelectConditionNotAVecBool),
};
let evaluated = Expression::Compose {
ty: reject_ty,
components: reject_components
.clone()
.into_iter()
.zip(accept_components.clone().into_iter())
.zip(condition_components.into_iter())
.map(|((reject, accept), condition)| {
let reject_scalar = match &self.expressions[reject] {
&Expression::Literal(lit) => lit.scalar(),
_ => unreachable!(),
};
select_single_component(self, reject_scalar, reject, accept, condition)
})
.collect::<Result<_, _>>()?,
};
self.register_evaluated_expr(evaluated, span)
}
_ => Err(ConstantEvaluatorError::SelectAcceptRejectTypeMismatch),
}
}
}
fn first_trailing_bit(concrete_int: ConcreteInt<1>) -> ConcreteInt<1> {

View File

@ -0,0 +1,41 @@
const_assert select(0xdeadbeef, 42f, false) == 0xdeadbeef;
const_assert select(0xdeadbeefu, 42, false) == 0xdeadbeefu;
const_assert select(0xdeadi, 42, false) == 0xdeadi;
const_assert select(42f, 0xdeadbeef, true) == 0xdeadbeef;
const_assert select(42, 0xdeadbeefu, true) == 0xdeadbeefu;
const_assert select(42, 0xdeadi, true) == 0xdeadi;
const_assert select(42f, 9001, true) == 9001;
const_assert select(42f, 9001, true) == 9001f;
const_assert select(42, 9001i, true) == 9001;
const_assert select(42, 9001u, true) == 9001;
const_assert select(9001, 42f, false) == 9001;
const_assert select(9001, 42f, false) == 9001f;
const_assert select(9001i, 42, false) == 9001;
const_assert select(9001u, 42, false) == 9001;
const_assert !select(false, true, false);
const_assert select(false, true, true);
const_assert select(true, false, false);
const_assert !select(true, false, true);
const_assert all(select(vec2(2f), vec2(), true) == vec2(0));
const_assert all(select(vec2(1), vec2(2f), false) == vec2(1));
const_assert all(select(vec2(1), vec2(2f), false) == vec2(1));
const_assert all(select(vec2(1), vec2(2f), vec2(false, false)) == vec2(1));
const_assert all(select(vec2(1), vec2(2f), vec2(true)) == vec2(2));
const_assert all(select(vec2(1), vec2(2f), vec2(true)) == vec2(2));
const_assert all(select(vec2(1), vec2(2f), vec2(true, false)) == vec2(2, 1));
const_assert all(select(vec3(1), vec3(2f), vec3(true)) == vec3(2));
const_assert all(select(vec4(1), vec4(2f), vec4(true)) == vec4(2));
@compute @workgroup_size(1, 1)
fn main() {
_ = select(1, 2f, false);
var x0 = vec2(1, 2);
var i1: vec2<f32> = select(vec2<f32>(1., 0.), vec2<f32>(0., 1.), (x0.x < x0.y));
}

View File

@ -648,9 +648,8 @@ fn binding_arrays_cannot_hold_scalars() {
#[cfg(feature = "wgsl-in")]
#[test]
fn validation_error_messages() {
let cases = [
(
r#"@group(0) @binding(0) var my_sampler: sampler;
let cases = [(
r#"@group(0) @binding(0) var my_sampler: sampler;
fn foo(tex: texture_2d<f32>) -> vec4<f32> {
return textureSampleLevel(tex, my_sampler, vec2f(0, 0), 0.0);
@ -660,7 +659,7 @@ fn validation_error_messages() {
foo();
}
"#,
"\
"\
error: Function [1] 'main' is invalid
wgsl:7:17
\n7 fn main() {
@ -671,48 +670,7 @@ error: Function [1] 'main' is invalid
= Requires 1 arguments, but 0 are provided
",
),
(
"\
@compute @workgroup_size(1, 1)
fn main() {
// Bad: `9001` isn't a `bool`.
_ = select(1, 2, 9001);
}
",
"\
error: Entry point main at Compute is invalid
wgsl:4:9
4 _ = select(1, 2, 9001);
^^^^^^ naga::ir::Expression [3]
= Expression [3] is invalid
= Expected selection condition to be a boolean value, got Scalar(Scalar { kind: Sint, width: 4 })
",
),
(
"\
@compute @workgroup_size(1, 1)
fn main() {
// Bad: `bool` and abstract int args. don't match.
_ = select(true, 1, false);
}
",
"\
error: Entry point main at Compute is invalid
wgsl:4:9
4 _ = select(true, 1, false);
^^^^^^ naga::ir::Expression [3]
= Expression [3] is invalid
= Expected selection argument types to match, but reject value of type Scalar(Scalar { kind: Bool, width: 1 }) does not match accept value of value Scalar(Scalar { kind: Sint, width: 4 })
",
),
];
)];
for (source, expected_err) in cases {
let module = naga::front::wgsl::parse_str(source).unwrap();

View File

@ -2012,8 +2012,9 @@ fn invalid_runtime_sized_arrays() {
#[test]
fn select() {
check_validation! {
"
let snapshots = [
(
"
fn select_pointers(which: bool) -> i32 {
var x: i32 = 1;
var y: i32 = 2;
@ -2021,7 +2022,19 @@ fn select() {
return *p;
}
",
"
"\
error: unexpected argument type for `select` call
wgsl:5:28
5 let p = select(&x, &y, which);
^^ this value of type `ptr<function, i32>`
= note: expected a scalar or a `vecN` of scalars
",
),
(
"
fn select_arrays(which: bool) -> i32 {
var x: array<i32, 4>;
var y: array<i32, 4>;
@ -2029,7 +2042,19 @@ fn select() {
return s[0];
}
",
"
"\
error: unexpected argument type for `select` call
wgsl:5:28
5 let s = select(x, y, which);
^ this value of type `array<i32, 4>`
= note: expected a scalar or a `vecN` of scalars
",
),
(
"
struct S { member: i32 }
fn select_structs(which: bool) -> S {
var x: S = S(1);
@ -2037,18 +2062,58 @@ fn select() {
let s = select(x, y, which);
return s;
}
":
Err(
naga::valid::ValidationError::Function {
name,
source: naga::valid::FunctionError::Expression {
source: naga::valid::ExpressionError::SelectConditionNotABool { .. },
..
},
..
},
)
if name.starts_with("select_")
",
"\
error: unexpected argument type for `select` call
wgsl:6:28
6 let s = select(x, y, which);
^ this value of type `S`
= note: expected a scalar or a `vecN` of scalars
",
),
(
"
@compute @workgroup_size(1, 1)
fn main() {
// Bad: `9001` isn't a `bool`.
_ = select(1, 2, 9001);
}
",
"\
error: Expected boolean expression for condition argument of `select`, got something else
wgsl:5:17
5 _ = select(1, 2, 9001);
^^^^^^ see msg
",
),
(
"
@compute @workgroup_size(1, 1)
fn main() {
// Bad: `bool` and abstract int args. don't match.
_ = select(true, 1, false);
}
",
"\
error: type mismatch for reject and accept values in `select` call
wgsl:5:24
5 _ = select(true, 1, false);
^^^^ ^ accept value of type `{AbstractInt}`
reject value of type `bool`
",
),
];
for (input, snapshot) in snapshots {
check(input, snapshot);
}
}

View File

@ -14,7 +14,7 @@ const ivec4 v_i32_one = ivec4(1, 1, 1, 1);
vec4 builtins() {
int s1_ = (true ? 1 : 0);
vec4 s2_ = (true ? v_f32_one : v_f32_zero);
vec4 s3_ = mix(v_f32_one, v_f32_zero, bvec4(false, false, false, false));
vec4 s3_ = vec4(1.0, 1.0, 1.0, 1.0);
vec4 m1_ = mix(v_f32_zero, v_f32_one, v_f32_half);
vec4 m2_ = mix(v_f32_zero, v_f32_one, 0.1);
float b1_ = intBitsToFloat(1);

View File

@ -0,0 +1,17 @@
#version 310 es
precision highp float;
precision highp int;
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void main() {
ivec2 x0_ = ivec2(1, 2);
vec2 i1_ = vec2(0.0);
int _e12 = x0_.x;
int _e14 = x0_.y;
i1_ = ((_e12 < _e14) ? vec2(0.0, 1.0) : vec2(1.0, 0.0));
return;
}

View File

@ -7,7 +7,7 @@ float4 builtins()
{
int s1_ = (true ? int(1) : int(0));
float4 s2_ = (true ? v_f32_one : v_f32_zero);
float4 s3_ = (bool4(false, false, false, false) ? v_f32_zero : v_f32_one);
float4 s3_ = float4(1.0, 1.0, 1.0, 1.0);
float4 m1_ = lerp(v_f32_zero, v_f32_one, v_f32_half);
float4 m2_ = lerp(v_f32_zero, v_f32_one, 0.1);
float b1_ = asfloat(int(1));

View File

@ -0,0 +1,11 @@
[numthreads(1, 1, 1)]
void main()
{
int2 x0_ = int2(int(1), int(2));
float2 i1_ = (float2)0;
int _e12 = x0_.x;
int _e14 = x0_.y;
i1_ = ((_e12 < _e14) ? float2(0.0, 1.0) : float2(1.0, 0.0));
return;
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)

View File

@ -13,7 +13,7 @@ metal::float4 builtins(
) {
int s1_ = true ? 1 : 0;
metal::float4 s2_ = true ? v_f32_one : v_f32_zero;
metal::float4 s3_ = metal::select(v_f32_one, v_f32_zero, metal::bool4(false, false, false, false));
metal::float4 s3_ = metal::float4(1.0, 1.0, 1.0, 1.0);
metal::float4 m1_ = metal::mix(v_f32_zero, v_f32_one, v_f32_half);
metal::float4 m2_ = metal::mix(v_f32_zero, v_f32_one, 0.1);
float b1_ = as_type<float>(1);

View File

@ -0,0 +1,16 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
kernel void main_(
) {
metal::int2 x0_ = metal::int2(1, 2);
metal::float2 i1_ = {};
int _e12 = x0_.x;
int _e14 = x0_.y;
i1_ = (_e12 < _e14) ? metal::float2(0.0, 1.0) : metal::float2(1.0, 0.0);
return;
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,47 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 36
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %8 "main"
OpExecutionMode %8 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpTypeVector %4 2
%6 = OpTypeInt 32 1
%5 = OpTypeVector %6 2
%9 = OpTypeFunction %2
%10 = OpConstant %4 1.0
%11 = OpConstant %6 1
%12 = OpConstant %6 2
%13 = OpConstantComposite %5 %11 %12
%14 = OpConstant %4 0.0
%15 = OpConstantComposite %3 %10 %14
%16 = OpConstantComposite %3 %14 %10
%18 = OpTypePointer Function %5
%20 = OpTypePointer Function %3
%21 = OpConstantNull %3
%23 = OpTypePointer Function %6
%25 = OpTypeInt 32 0
%24 = OpConstant %25 0
%28 = OpConstant %25 1
%31 = OpTypeBool
%34 = OpTypeVector %31 2
%8 = OpFunction %2 None %9
%7 = OpLabel
%17 = OpVariable %18 Function %13
%19 = OpVariable %20 Function %21
OpBranch %22
%22 = OpLabel
%26 = OpAccessChain %23 %17 %24
%27 = OpLoad %6 %26
%29 = OpAccessChain %23 %17 %28
%30 = OpLoad %6 %29
%32 = OpSLessThan %31 %27 %30
%35 = OpCompositeConstruct %34 %32 %32
%33 = OpSelect %3 %35 %16 %15
OpStore %19 %33
OpReturn
OpFunctionEnd

View File

@ -6,7 +6,7 @@ const v_i32_one: vec4<i32> = vec4<i32>(1i, 1i, 1i, 1i);
fn builtins() -> vec4<f32> {
let s1_ = select(0i, 1i, true);
let s2_ = select(v_f32_zero, v_f32_one, true);
let s3_ = select(v_f32_one, v_f32_zero, vec4<bool>(false, false, false, false));
let s3_ = vec4<f32>(1f, 1f, 1f, 1f);
let m1_ = mix(v_f32_zero, v_f32_one, v_f32_half);
let m2_ = mix(v_f32_zero, v_f32_one, 0.1f);
let b1_ = bitcast<f32>(1i);

View File

@ -0,0 +1,10 @@
@compute @workgroup_size(1, 1, 1)
fn main() {
var x0_: vec2<i32> = vec2<i32>(1i, 2i);
var i1_: vec2<f32>;
let _e12 = x0_.x;
let _e14 = x0_.y;
i1_ = select(vec2<f32>(1f, 0f), vec2<f32>(0f, 1f), (_e12 < _e14));
return;
}