[naga] Ensure that FooResult expressions are correctly populated.

Make Naga module validation require that `CallResult` and
`AtomicResult` expressions are indeed visited by exactly one `Call` /
`Atomic` statement.
This commit is contained in:
Jim Blandy 2024-06-03 07:59:25 -07:00 committed by Teodor Tanasoaia
parent 9a27ba53ca
commit 583cc6ab04
3 changed files with 145 additions and 55 deletions

View File

@ -22,6 +22,8 @@ pub enum CallError {
}, },
#[error("Result expression {0:?} has already been introduced earlier")] #[error("Result expression {0:?} has already been introduced earlier")]
ResultAlreadyInScope(Handle<crate::Expression>), ResultAlreadyInScope(Handle<crate::Expression>),
#[error("Result expression {0:?} is populated by multiple `Call` statements")]
ResultAlreadyPopulated(Handle<crate::Expression>),
#[error("Result value is invalid")] #[error("Result value is invalid")]
ResultValue(#[source] ExpressionError), ResultValue(#[source] ExpressionError),
#[error("Requires {required} arguments, but {seen} are provided")] #[error("Requires {required} arguments, but {seen} are provided")]
@ -45,6 +47,8 @@ pub enum AtomicError {
InvalidOperand(Handle<crate::Expression>), InvalidOperand(Handle<crate::Expression>),
#[error("Result type for {0:?} doesn't match the statement")] #[error("Result type for {0:?} doesn't match the statement")]
ResultTypeMismatch(Handle<crate::Expression>), ResultTypeMismatch(Handle<crate::Expression>),
#[error("Result expression {0:?} is populated by multiple `Atomic` statements")]
ResultAlreadyPopulated(Handle<crate::Expression>),
} }
#[derive(Clone, Debug, thiserror::Error)] #[derive(Clone, Debug, thiserror::Error)]
@ -174,6 +178,8 @@ pub enum FunctionError {
InvalidSubgroup(#[from] SubgroupError), InvalidSubgroup(#[from] SubgroupError),
#[error("Emit statement should not cover \"result\" expressions like {0:?}")] #[error("Emit statement should not cover \"result\" expressions like {0:?}")]
EmitResult(Handle<crate::Expression>), EmitResult(Handle<crate::Expression>),
#[error("Expression not visited by the appropriate statement")]
UnvisitedExpression(Handle<crate::Expression>),
} }
bitflags::bitflags! { bitflags::bitflags! {
@ -305,7 +311,13 @@ impl super::Validator {
} }
match context.expressions[expr] { match context.expressions[expr] {
crate::Expression::CallResult(callee) crate::Expression::CallResult(callee)
if fun.result.is_some() && callee == function => {} if fun.result.is_some() && callee == function =>
{
if !self.needs_visit.remove(expr.index()) {
return Err(CallError::ResultAlreadyPopulated(expr)
.with_span_handle(expr, context.expressions));
}
}
_ => { _ => {
return Err(CallError::ExpressionMismatch(result) return Err(CallError::ExpressionMismatch(result)
.with_span_handle(expr, context.expressions)) .with_span_handle(expr, context.expressions))
@ -397,7 +409,14 @@ impl super::Validator {
} }
_ => false, _ => false,
} }
} => {} } =>
{
if !self.needs_visit.remove(result.index()) {
return Err(AtomicError::ResultAlreadyPopulated(result)
.with_span_handle(result, context.expressions)
.into_other());
}
}
_ => { _ => {
return Err(AtomicError::ResultTypeMismatch(result) return Err(AtomicError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions) .with_span_handle(result, context.expressions)
@ -1290,11 +1309,20 @@ impl super::Validator {
self.valid_expression_set.clear(); self.valid_expression_set.clear();
self.valid_expression_list.clear(); self.valid_expression_list.clear();
self.needs_visit.clear();
for (handle, expr) in fun.expressions.iter() { for (handle, expr) in fun.expressions.iter() {
if expr.needs_pre_emit() { if expr.needs_pre_emit() {
self.valid_expression_set.insert(handle.index()); self.valid_expression_set.insert(handle.index());
} }
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
// Mark expressions that need to be visited by a particular kind of
// statement.
if let crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } =
*expr
{
self.needs_visit.insert(handle.index());
}
match self.validate_expression( match self.validate_expression(
handle, handle,
expr, expr,
@ -1321,6 +1349,15 @@ impl super::Validator {
)? )?
.stages; .stages;
info.available_stages &= stages; info.available_stages &= stages;
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
if let Some(unvisited) = self.needs_visit.iter().next() {
let index = std::num::NonZeroU32::new(unvisited as u32 + 1).unwrap();
let handle = Handle::new(index);
return Err(FunctionError::UnvisitedExpression(handle)
.with_span_handle(handle, &fun.expressions));
}
}
} }
Ok(info) Ok(info)
} }

View File

@ -246,6 +246,26 @@ pub struct Validator {
valid_expression_set: BitSet, valid_expression_set: BitSet,
override_ids: FastHashSet<u16>, override_ids: FastHashSet<u16>,
allow_overrides: bool, allow_overrides: bool,
/// A checklist of expressions that must be visited by a specific kind of
/// statement.
///
/// For example:
///
/// - [`CallResult`] expressions must be visited by a [`Call`] statement.
/// - [`AtomicResult`] expressions must be visited by an [`Atomic`] statement.
///
/// Be sure not to remove any [`Expression`] handle from this set unless
/// you've explicitly checked that it is the right kind of expression for
/// the visiting [`Statement`].
///
/// [`CallResult`]: crate::Expression::CallResult
/// [`Call`]: crate::Statement::Call
/// [`AtomicResult`]: crate::Expression::AtomicResult
/// [`Atomic`]: crate::Statement::Atomic
/// [`Expression`]: crate::Expression
/// [`Statement`]: crate::Statement
needs_visit: BitSet,
} }
#[derive(Clone, Debug, thiserror::Error)] #[derive(Clone, Debug, thiserror::Error)]
@ -398,6 +418,7 @@ impl Validator {
valid_expression_set: BitSet::new(), valid_expression_set: BitSet::new(),
override_ids: FastHashSet::default(), override_ids: FastHashSet::default(),
allow_overrides: true, allow_overrides: true,
needs_visit: BitSet::new(),
} }
} }

View File

@ -1,18 +1,30 @@
use naga::{valid, Expression, Function, Scalar}; use naga::{valid, Expression, Function, Scalar};
/// Validation should fail if `AtomicResult` expressions are not
/// populated by `Atomic` statements.
#[test] #[test]
fn emit_atomic_result() { fn populate_atomic_result() {
use naga::{Module, Type, TypeInner}; use naga::{Module, Type, TypeInner};
// We want to ensure that the *only* problem with the code is the /// Different variants of the test case that we want to exercise.
// use of an `Emit` statement instead of an `Atomic` statement. So enum Variant {
// validate two versions of the module varying only in that /// An `AtomicResult` expression with an `Atomic` statement
// aspect. /// that populates it: valid.
// Atomic,
// Looking at uses of the `atomic` makes it easy to identify the
// differences between the two variants. /// An `AtomicResult` expression visited by an `Emit`
fn variant( /// statement: invalid.
atomic: bool, Emit,
/// An `AtomicResult` expression visited by no statement at
/// all: invalid
None,
}
// Looking at uses of `variant` should make it easy to identify
// the differences between the test cases.
fn try_variant(
variant: Variant,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> { ) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default(); let span = naga::Span::default();
let mut module = Module::default(); let mut module = Module::default();
@ -56,21 +68,25 @@ fn emit_atomic_result() {
span, span,
); );
if atomic { match variant {
fun.body.push( Variant::Atomic => {
naga::Statement::Atomic { fun.body.push(
pointer: ex_global, naga::Statement::Atomic {
fun: naga::AtomicFunction::Add, pointer: ex_global,
value: ex_42, fun: naga::AtomicFunction::Add,
result: ex_result, value: ex_42,
}, result: ex_result,
span, },
); span,
} else { );
fun.body.push( }
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)), Variant::Emit => {
span, fun.body.push(
); naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}
Variant::None => {}
} }
module.functions.append(fun, span); module.functions.append(fun, span);
@ -82,23 +98,34 @@ fn emit_atomic_result() {
.validate(&module) .validate(&module)
} }
variant(true).expect("module should validate"); try_variant(Variant::Atomic).expect("module should validate");
assert!(variant(false).is_err()); assert!(try_variant(Variant::Emit).is_err());
assert!(try_variant(Variant::None).is_err());
} }
#[test] #[test]
fn emit_call_result() { fn populate_call_result() {
use naga::{Module, Type, TypeInner}; use naga::{Module, Type, TypeInner};
// We want to ensure that the *only* problem with the code is the /// Different variants of the test case that we want to exercise.
// use of an `Emit` statement instead of a `Call` statement. So enum Variant {
// validate two versions of the module varying only in that /// A `CallResult` expression with an `Call` statement that
// aspect. /// populates it: valid.
// Call,
// Looking at uses of the `call` makes it easy to identify the
// differences between the two variants. /// A `CallResult` expression visited by an `Emit` statement:
fn variant( /// invalid.
call: bool, Emit,
/// A `CallResult` expression visited by no statement at all:
/// invalid
None,
}
// Looking at uses of `variant` should make it easy to identify
// the differences between the test cases.
fn try_variant(
variant: Variant,
) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> { ) -> Result<naga::valid::ModuleInfo, naga::WithSpan<naga::valid::ValidationError>> {
let span = naga::Span::default(); let span = naga::Span::default();
let mut module = Module::default(); let mut module = Module::default();
@ -130,20 +157,24 @@ fn emit_call_result() {
.expressions .expressions
.append(Expression::CallResult(fun_callee), span); .append(Expression::CallResult(fun_callee), span);
if call { match variant {
fun_caller.body.push( Variant::Call => {
naga::Statement::Call { fun_caller.body.push(
function: fun_callee, naga::Statement::Call {
arguments: vec![], function: fun_callee,
result: Some(ex_result), arguments: vec![],
}, result: Some(ex_result),
span, },
); span,
} else { );
fun_caller.body.push( }
naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)), Variant::Emit => {
span, fun_caller.body.push(
); naga::Statement::Emit(naga::Range::new_from_bounds(ex_result, ex_result)),
span,
);
}
Variant::None => {}
} }
module.functions.append(fun_caller, span); module.functions.append(fun_caller, span);
@ -155,8 +186,9 @@ fn emit_call_result() {
.validate(&module) .validate(&module)
} }
variant(true).expect("should validate"); try_variant(Variant::Call).expect("should validate");
assert!(variant(false).is_err()); assert!(try_variant(Variant::Emit).is_err());
assert!(try_variant(Variant::None).is_err());
} }
#[test] #[test]