[naga wgsl-in] Short-circuiting of && and || operators (#7339)

Addresses parts of #4394 and #6302
This commit is contained in:
Andy Leiserson 2025-11-19 17:06:49 -08:00 committed by GitHub
parent 1f99103be8
commit 119b4efada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1092 additions and 565 deletions

View File

@ -152,6 +152,10 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206
- Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454).
- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370).
#### naga
- The `||` and `&&` operators now "short circuit", i.e., do not evaluate the RHS if the result can be determined from just the LHS. By @andyleiserson in [#7339](https://github.com/gfx-rs/wgpu/pull/7339).
#### DX12
- Align copies b/w textures and buffers via a single intermediate buffer per copy when `D3D12_FEATURE_DATA_D3D12_OPTIONS13.UnrestrictedBufferTextureCopyPitchSupported` is `false`. By @ErichDonGubler in [#7721](https://github.com/gfx-rs/wgpu/pull/7721).

View File

@ -426,6 +426,13 @@ impl TypeContext for ExpressionContext<'_, '_, '_> {
}
impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const fn is_runtime(&self) -> bool {
match self.expr_type {
ExpressionContextType::Runtime(_) => true,
ExpressionContextType::Constant(_) | ExpressionContextType::Override => false,
}
}
#[allow(dead_code)]
fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
ExpressionContext {
@ -588,6 +595,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
}
fn get(&self, handle: Handle<crate::Expression>) -> &crate::Expression {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx)
| ExpressionContextType::Constant(Some(ref ctx)) => &ctx.function.expressions[handle],
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
&self.module.global_expressions[handle]
}
}
}
fn local(
&mut self,
local: &Handle<ast::Local>,
@ -614,6 +631,52 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
}
fn with_nested_runtime_expression_ctx<'a, F, T>(
&mut self,
span: Span,
f: F,
) -> Result<'source, (T, crate::Block)>
where
for<'t> F: FnOnce(&mut ExpressionContext<'source, 't, 't>) -> Result<'source, T>,
{
let mut block = crate::Block::new();
let rctx = match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => Ok(rctx),
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(span))
}
}?;
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
rctx.emitter.start(&rctx.function.expressions);
let nested_rctx = LocalExpressionContext {
local_table: rctx.local_table,
function: rctx.function,
block: &mut block,
emitter: rctx.emitter,
typifier: rctx.typifier,
local_expression_kind_tracker: rctx.local_expression_kind_tracker,
};
let mut nested_ctx = ExpressionContext {
expr_type: ExpressionContextType::Runtime(nested_rctx),
ast_expressions: self.ast_expressions,
types: self.types,
globals: self.globals,
module: self.module,
const_typifier: self.const_typifier,
layouter: self.layouter,
global_expression_kind_tracker: self.global_expression_kind_tracker,
};
let ret = f(&mut nested_ctx)?;
block.extend(rctx.emitter.finish(&rctx.function.expressions));
rctx.emitter.start(&rctx.function.expressions);
Ok((ret, block))
}
fn gather_component(
&mut self,
expr: Handle<ir::Expression>,
@ -2375,6 +2438,130 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
expr.try_map(|handle| ctx.append_expression(handle, span))
}
/// Generate IR for the short-circuiting operators `&&` and `||`.
///
/// `binary` has already lowered the LHS expression and resolved its type.
fn logical(
&mut self,
op: crate::BinaryOperator,
left: Handle<crate::Expression>,
right: Handle<ast::Expression<'source>>,
span: Span,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<'source, Typed<crate::Expression>> {
debug_assert!(
op == crate::BinaryOperator::LogicalAnd || op == crate::BinaryOperator::LogicalOr
);
if ctx.is_runtime() {
// To simulate short-circuiting behavior, we want to generate IR
// like the following for `&&`. For `||`, the condition is `!_lhs`
// and the else value is `true`.
//
// var _e0: bool;
// if _lhs {
// _e0 = _rhs;
// } else {
// _e0 = false;
// }
let (condition, else_val) = if op == crate::BinaryOperator::LogicalAnd {
let condition = left;
let else_val = ctx.append_expression(
crate::Expression::Literal(crate::Literal::Bool(false)),
span,
)?;
(condition, else_val)
} else {
let condition = ctx.append_expression(
crate::Expression::Unary {
op: crate::UnaryOperator::LogicalNot,
expr: left,
},
span,
)?;
let else_val = ctx.append_expression(
crate::Expression::Literal(crate::Literal::Bool(true)),
span,
)?;
(condition, else_val)
};
let bool_ty = ctx.ensure_type_exists(crate::TypeInner::Scalar(crate::Scalar::BOOL));
let rctx = ctx.runtime_expression_ctx(span)?;
let result_var = rctx.function.local_variables.append(
crate::LocalVariable {
name: None,
ty: bool_ty,
init: None,
},
span,
);
let pointer =
ctx.append_expression(crate::Expression::LocalVariable(result_var), span)?;
let (right, mut accept) = ctx.with_nested_runtime_expression_ctx(span, |ctx| {
let right = self.expression_for_abstract(right, ctx)?;
ctx.grow_types(right)?;
Ok(right)
})?;
accept.push(
crate::Statement::Store {
pointer,
value: right,
},
span,
);
let mut reject = crate::Block::with_capacity(1);
reject.push(
crate::Statement::Store {
pointer,
value: else_val,
},
span,
);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::If {
condition,
accept,
reject,
},
span,
);
Ok(Typed::Reference(crate::Expression::LocalVariable(
result_var,
)))
} else {
let left_expr = ctx.get(left);
// Constant or override context in either function or module scope
let &crate::Expression::Literal(crate::Literal::Bool(left_val)) = left_expr else {
return Err(Box::new(Error::NotBool(span)));
};
if op == crate::BinaryOperator::LogicalAnd && !left_val
|| op == crate::BinaryOperator::LogicalOr && left_val
{
// Short-circuit behavior: don't evaluate the RHS. Ideally we
// would do _some_ validity checks of the RHS here, but that's
// tricky, because the RHS is allowed to have things that aren't
// legal in const contexts.
Ok(Typed::Plain(left_expr.clone()))
} else {
let right = self.expression_for_abstract(right, ctx)?;
ctx.grow_types(right)?;
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
}
}
}
fn binary(
&mut self,
op: ir::BinaryOperator,
@ -2383,57 +2570,74 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
span: Span,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<'source, Typed<ir::Expression>> {
// Load both operands.
let mut left = self.expression_for_abstract(left, ctx)?;
let mut right = self.expression_for_abstract(right, ctx)?;
if op == ir::BinaryOperator::LogicalAnd || op == ir::BinaryOperator::LogicalOr {
let left = self.expression_for_abstract(left, ctx)?;
ctx.grow_types(left)?;
// Convert `scalar op vector` to `vector op vector` by introducing
// `Splat` expressions.
ctx.binary_op_splat(op, &mut left, &mut right)?;
// Apply automatic conversions.
match op {
ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => {
// Shift operators require the right operand to be `u32` or
// `vecN<u32>`. We can let the validator sort out vector length
// issues, but the right operand must be, or convert to, a u32 leaf
// scalar.
right =
ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?;
// Additionally, we must concretize the left operand if the right operand
// is not a const-expression.
// See https://www.w3.org/TR/WGSL/#overload-resolution-section.
//
// 2. Eliminate any candidate where one of its subexpressions resolves to
// an abstract type after feasible automatic conversions, but another of
// the candidates subexpressions is not a const-expression.
//
// We only have to explicitly do so for shifts as their operands may be
// of different types - for other binary ops this is achieved by finding
// the conversion consensus for both operands.
if !ctx.is_const(right) {
left = ctx.concretize(left)?;
}
}
// All other operators follow the same pattern: reconcile the
// scalar leaf types. If there's no reconciliation possible,
// leave the expressions as they are: validation will report the
// problem.
_ => {
ctx.grow_types(left)?;
if !matches!(
resolve_inner!(ctx, left),
&ir::TypeInner::Scalar(ir::Scalar::BOOL)
) {
// Pass it through as-is, will fail validation
let right = self.expression_for_abstract(right, ctx)?;
ctx.grow_types(right)?;
if let Ok(consensus_scalar) =
ctx.automatic_conversion_consensus([left, right].iter())
{
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
} else {
self.logical(op, left, right, span, ctx)
}
} else {
// Load both operands.
let mut left = self.expression_for_abstract(left, ctx)?;
let mut right = self.expression_for_abstract(right, ctx)?;
// Convert `scalar op vector` to `vector op vector` by introducing
// `Splat` expressions.
ctx.binary_op_splat(op, &mut left, &mut right)?;
// Apply automatic conversions.
match op {
ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => {
// Shift operators require the right operand to be `u32` or
// `vecN<u32>`. We can let the validator sort out vector length
// issues, but the right operand must be, or convert to, a u32 leaf
// scalar.
right =
ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?;
// Additionally, we must concretize the left operand if the right operand
// is not a const-expression.
// See https://www.w3.org/TR/WGSL/#overload-resolution-section.
//
// 2. Eliminate any candidate where one of its subexpressions resolves to
// an abstract type after feasible automatic conversions, but another of
// the candidates subexpressions is not a const-expression.
//
// We only have to explicitly do so for shifts as their operands may be
// of different types - for other binary ops this is achieved by finding
// the conversion consensus for both operands.
if !ctx.is_const(right) {
left = ctx.concretize(left)?;
}
}
// All other operators follow the same pattern: reconcile the
// scalar leaf types. If there's no reconciliation possible,
// leave the expressions as they are: validation will report the
// problem.
_ => {
ctx.grow_types(left)?;
ctx.grow_types(right)?;
if let Ok(consensus_scalar) =
ctx.automatic_conversion_consensus([left, right].iter())
{
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
}
}
}
}
Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
}
}
/// Generate Naga IR for call expressions and statements, and type

View File

@ -40,6 +40,11 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
return vec3<f32>(y);
}
fn p() -> bool { return true; }
fn q() -> bool { return false; }
fn r() -> bool { return true; }
fn s() -> bool { return false; }
fn logical() {
let t = true;
let f = false;
@ -55,6 +60,7 @@ fn logical() {
let bitwise_or1 = vec3(t) | vec3(f);
let bitwise_and0 = t & f;
let bitwise_and1 = vec4(t) & vec4(f);
let short_circuit = (p() || q()) && (r() || s());
}
fn arithmetic() {

View File

@ -46,15 +46,68 @@ vec3 bool_cast(vec3 x) {
return vec3(y);
}
bool p() {
return true;
}
bool q() {
return false;
}
bool r() {
return true;
}
bool s() {
return false;
}
void logical() {
bool local = false;
bool local_1 = false;
bool local_2 = false;
bool local_3 = false;
bool local_4 = false;
bool neg0_ = !(true);
bvec2 neg1_ = not(bvec2(true));
bool or = (true || false);
bool and = (true && false);
if (!(true)) {
local = false;
} else {
local = true;
}
bool or = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool and = local_1;
bool bitwise_or0_ = (true || false);
bvec3 bitwise_or1_ = bvec3(bvec3(true).x || bvec3(false).x, bvec3(true).y || bvec3(false).y, bvec3(true).z || bvec3(false).z);
bool bitwise_and0_ = (true && false);
bvec4 bitwise_and1_ = bvec4(bvec4(true).x && bvec4(false).x, bvec4(true).y && bvec4(false).y, bvec4(true).z && bvec4(false).z, bvec4(true).w && bvec4(false).w);
bool _e22 = p();
if (!(_e22)) {
bool _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
bool _e28 = local_2;
if (_e28) {
bool _e31 = r();
if (!(_e31)) {
bool _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
bool _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
bool short_circuit = local_3;
return;
}

View File

@ -48,16 +48,74 @@ float3 bool_cast(float3 x)
return float3(y);
}
bool p()
{
return true;
}
bool q()
{
return false;
}
bool r()
{
return true;
}
bool s()
{
return false;
}
void logical()
{
bool local = (bool)0;
bool local_1 = (bool)0;
bool local_2 = (bool)0;
bool local_3 = (bool)0;
bool local_4 = (bool)0;
bool neg0_ = !(true);
bool2 neg1_ = !((true).xx);
bool or_ = (true || false);
bool and_ = (true && false);
if (!(true)) {
local = false;
} else {
local = true;
}
bool or_ = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool and_ = local_1;
bool bitwise_or0_ = (true | false);
bool3 bitwise_or1_ = ((true).xxx | (false).xxx);
bool bitwise_and0_ = (true & false);
bool4 bitwise_and1_ = ((true).xxxx & (false).xxxx);
const bool _e22 = p();
if (!(_e22)) {
const bool _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
bool _e28 = local_2;
if (_e28) {
const bool _e31 = r();
if (!(_e31)) {
const bool _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
bool _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
bool short_circuit = local_3;
return;
}

View File

@ -56,16 +56,73 @@ metal::float3 bool_cast(
return static_cast<metal::float3>(y);
}
bool p(
) {
return true;
}
bool q(
) {
return false;
}
bool r(
) {
return true;
}
bool s(
) {
return false;
}
void logical(
) {
bool local = {};
bool local_1 = {};
bool local_2 = {};
bool local_3 = {};
bool local_4 = {};
bool neg0_ = !(true);
metal::bool2 neg1_ = !(metal::bool2(true));
bool or_ = true || false;
bool and_ = true && false;
if (!(true)) {
local = false;
} else {
local = true;
}
bool or_ = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool and_ = local_1;
bool bitwise_or0_ = true | false;
metal::bool3 bitwise_or1_ = metal::bool3(true) | metal::bool3(false);
bool bitwise_and0_ = true & false;
metal::bool4 bitwise_and1_ = metal::bool4(true) & metal::bool4(false);
bool _e22 = p();
if (!(_e22)) {
bool _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
bool _e28 = local_2;
if (_e28) {
bool _e31 = r();
if (!(_e31)) {
bool _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
bool _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
bool short_circuit = local_3;
return;
}

File diff suppressed because it is too large Load Diff

View File

@ -39,15 +39,69 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
return vec3<f32>(y);
}
fn p() -> bool {
return true;
}
fn q() -> bool {
return false;
}
fn r() -> bool {
return true;
}
fn s() -> bool {
return false;
}
fn logical() {
var local: bool;
var local_1: bool;
var local_2: bool;
var local_3: bool;
var local_4: bool;
let neg0_ = !(true);
let neg1_ = !(vec2(true));
let or = (true || false);
let and = (true && false);
if !(true) {
local = false;
} else {
local = true;
}
let or = local;
if true {
local_1 = false;
} else {
local_1 = false;
}
let and = local_1;
let bitwise_or0_ = (true | false);
let bitwise_or1_ = (vec3(true) | vec3(false));
let bitwise_and0_ = (true & false);
let bitwise_and1_ = (vec4(true) & vec4(false));
let _e22 = p();
if !(_e22) {
let _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
let _e28 = local_2;
if _e28 {
let _e31 = r();
if !(_e31) {
let _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
let _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
let short_circuit = local_3;
return;
}

View File

@ -107,8 +107,8 @@ fn main(
add_result_to_mask(&passed, 21u, subgroupBroadcast(subgroup_invocation_id, 1u) == 1u);
add_result_to_mask(&passed, 22u, subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id);
add_result_to_mask(&passed, 23u, subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id);
add_result_to_mask(&passed, 24u, subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u);
add_result_to_mask(&passed, 25u, subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u);
add_result_to_mask(&passed, 24u, subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u || subgroup_invocation_id == subgroup_size - 1u);
add_result_to_mask(&passed, 25u, subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u || subgroup_invocation_id == 0u);
add_result_to_mask(&passed, 26u, subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u)));
// Mac/Apple will fail this test.