mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga const_eval] Ensure eval_zero_value_and_splat() lowers a Splat of a ZeroValue correctly
eval_zero_value_and_splat() is called to lower ZeroValue and Splat expressions into Literal and Compose expressions. However, in its current form it either calls splat() *or* eval_zero_value_impl() depending on the expression type. splat() will lower a Splat of a scalar ZeroValue to a vector ZeroValue, which means eval_zero_value_and_splat() can still return a ZeroValue. Its callers, such as binary_op(), are unable to handle this ZeroValue, so cannot proceed with const evaluation. This patch makes it so that eval_zero_value_and_splat() will first call splat(), *and then* call eval_zero_value_impl(), which will lower the vector ZeroValue returned by splat() into a Compose of Literals. Callers such as binary_op() are perfectly able to handle this Compose, so can now proceed with const evaluation.
This commit is contained in:
parent
2a456f5c7b
commit
d9777355c9
@ -1398,14 +1398,19 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
/// [`Compose`]: Expression::Compose
|
||||
fn eval_zero_value_and_splat(
|
||||
&mut self,
|
||||
expr: Handle<Expression>,
|
||||
mut expr: Handle<Expression>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
match self.expressions[expr] {
|
||||
Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
|
||||
Expression::Splat { size, value } => self.splat(value, size, span),
|
||||
_ => Ok(expr),
|
||||
// The result of the splat() for a Splat of a scalar ZeroValue is a
|
||||
// vector ZeroValue, so we must call eval_zero_value_impl() after
|
||||
// splat() in order to ensure we have no ZeroValues remaining.
|
||||
if let Expression::Splat { size, value } = self.expressions[expr] {
|
||||
expr = self.splat(value, size, span)?;
|
||||
}
|
||||
if let Expression::ZeroValue(ty) = self.expressions[expr] {
|
||||
expr = self.eval_zero_value_impl(ty, span)?;
|
||||
}
|
||||
Ok(expr)
|
||||
}
|
||||
|
||||
/// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
|
||||
@ -2976,4 +2981,84 @@ mod tests {
|
||||
panic!("unexpected evaluation result")
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn splat_of_zero_value() {
|
||||
let mut types = UniqueArena::new();
|
||||
let constants = Arena::new();
|
||||
let overrides = Arena::new();
|
||||
let mut global_expressions = Arena::new();
|
||||
|
||||
let f32_ty = types.insert(
|
||||
Type {
|
||||
name: None,
|
||||
inner: TypeInner::Scalar(crate::Scalar::F32),
|
||||
},
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let vec2_f32_ty = types.insert(
|
||||
Type {
|
||||
name: None,
|
||||
inner: TypeInner::Vector {
|
||||
size: VectorSize::Bi,
|
||||
scalar: crate::Scalar::F32,
|
||||
},
|
||||
},
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let five =
|
||||
global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
|
||||
let five_splat = global_expressions.append(
|
||||
Expression::Splat {
|
||||
size: VectorSize::Bi,
|
||||
value: five,
|
||||
},
|
||||
Default::default(),
|
||||
);
|
||||
let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
|
||||
let zero_splat = global_expressions.append(
|
||||
Expression::Splat {
|
||||
size: VectorSize::Bi,
|
||||
value: zero,
|
||||
},
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
|
||||
let mut solver = ConstantEvaluator {
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
|
||||
types: &mut types,
|
||||
constants: &constants,
|
||||
overrides: &overrides,
|
||||
expressions: &mut global_expressions,
|
||||
expression_kind_tracker,
|
||||
};
|
||||
|
||||
let solved_add = solver
|
||||
.try_eval_and_append(
|
||||
Expression::Binary {
|
||||
op: crate::BinaryOperator::Add,
|
||||
left: zero_splat,
|
||||
right: five_splat,
|
||||
},
|
||||
Default::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let pass = match global_expressions[solved_add] {
|
||||
Expression::Compose { ty, ref components } => {
|
||||
ty == vec2_f32_ty
|
||||
&& components.iter().all(|&component| {
|
||||
let component = &global_expressions[component];
|
||||
matches!(*component, Expression::Literal(Literal::F32(5.0)))
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if !pass {
|
||||
panic!("unexpected evaluation result")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user