[naga const_eval] Ensure eval_zero_value_and_splat() lowers a Splat of a ZeroValue correctly

eval_zero_value_and_splat() is called to lower ZeroValue and Splat
expressions into Literal and Compose expressions. However, in its
current form it either calls splat() *or* eval_zero_value_impl()
depending on the expression type.

splat() will lower a Splat of a scalar ZeroValue to a vector
ZeroValue, which means eval_zero_value_and_splat() can still return a
ZeroValue. Its callers, such as binary_op(), are unable to handle this
ZeroValue, so cannot proceed with const evaluation.

This patch makes it so that eval_zero_value_and_splat() will first
call splat(), *and then* call eval_zero_value_impl(), which will lower
the vector ZeroValue returned by splat() into a Compose of Literals.
Callers such as binary_op() are perfectly able to handle this Compose,
so can now proceed with const evaluation.
This commit is contained in:
Jamie Nicol 2025-01-29 11:45:07 +00:00 committed by Jim Blandy
parent 2a456f5c7b
commit d9777355c9

View File

@ -1398,14 +1398,19 @@ impl<'a> ConstantEvaluator<'a> {
/// [`Compose`]: Expression::Compose /// [`Compose`]: Expression::Compose
fn eval_zero_value_and_splat( fn eval_zero_value_and_splat(
&mut self, &mut self,
expr: Handle<Expression>, mut expr: Handle<Expression>,
span: Span, span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> { ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
match self.expressions[expr] { // The result of the splat() for a Splat of a scalar ZeroValue is a
Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span), // vector ZeroValue, so we must call eval_zero_value_impl() after
Expression::Splat { size, value } => self.splat(value, size, span), // splat() in order to ensure we have no ZeroValues remaining.
_ => Ok(expr), if let Expression::Splat { size, value } = self.expressions[expr] {
expr = self.splat(value, size, span)?;
} }
if let Expression::ZeroValue(ty) = self.expressions[expr] {
expr = self.eval_zero_value_impl(ty, span)?;
}
Ok(expr)
} }
/// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions. /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
@ -2976,4 +2981,84 @@ mod tests {
panic!("unexpected evaluation result") panic!("unexpected evaluation result")
} }
} }
#[test]
fn splat_of_zero_value() {
let mut types = UniqueArena::new();
let constants = Arena::new();
let overrides = Arena::new();
let mut global_expressions = Arena::new();
let f32_ty = types.insert(
Type {
name: None,
inner: TypeInner::Scalar(crate::Scalar::F32),
},
Default::default(),
);
let vec2_f32_ty = types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: VectorSize::Bi,
scalar: crate::Scalar::F32,
},
},
Default::default(),
);
let five =
global_expressions.append(Expression::Literal(Literal::F32(5.0)), Default::default());
let five_splat = global_expressions.append(
Expression::Splat {
size: VectorSize::Bi,
value: five,
},
Default::default(),
);
let zero = global_expressions.append(Expression::ZeroValue(f32_ty), Default::default());
let zero_splat = global_expressions.append(
Expression::Splat {
size: VectorSize::Bi,
value: zero,
},
Default::default(),
);
let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl(WgslRestrictions::Const(None)),
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut global_expressions,
expression_kind_tracker,
};
let solved_add = solver
.try_eval_and_append(
Expression::Binary {
op: crate::BinaryOperator::Add,
left: zero_splat,
right: five_splat,
},
Default::default(),
)
.unwrap();
let pass = match global_expressions[solved_add] {
Expression::Compose { ty, ref components } => {
ty == vec2_f32_ty
&& components.iter().all(|&component| {
let component = &global_expressions[component];
matches!(*component, Expression::Literal(Literal::F32(5.0)))
})
}
_ => false,
};
if !pass {
panic!("unexpected evaluation result")
}
}
} }