[naga wgsl-in] Short-circuiting of && and || operators (#7339)

Addresses parts of #4394 and #6302
This commit is contained in:
Andy Leiserson 2025-11-19 17:06:49 -08:00 committed by GitHub
parent 1f99103be8
commit 119b4efada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1092 additions and 565 deletions

View File

@ -152,6 +152,10 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206
- Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454). - Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454).
- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370). - Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370).
#### naga
- The `||` and `&&` operators now "short circuit", i.e., do not evaluate the RHS if the result can be determined from just the LHS. By @andyleiserson in [#7339](https://github.com/gfx-rs/wgpu/pull/7339).
#### DX12 #### DX12
- Align copies b/w textures and buffers via a single intermediate buffer per copy when `D3D12_FEATURE_DATA_D3D12_OPTIONS13.UnrestrictedBufferTextureCopyPitchSupported` is `false`. By @ErichDonGubler in [#7721](https://github.com/gfx-rs/wgpu/pull/7721). - Align copies b/w textures and buffers via a single intermediate buffer per copy when `D3D12_FEATURE_DATA_D3D12_OPTIONS13.UnrestrictedBufferTextureCopyPitchSupported` is `false`. By @ErichDonGubler in [#7721](https://github.com/gfx-rs/wgpu/pull/7721).

View File

@ -426,6 +426,13 @@ impl TypeContext for ExpressionContext<'_, '_, '_> {
} }
impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const fn is_runtime(&self) -> bool {
match self.expr_type {
ExpressionContextType::Runtime(_) => true,
ExpressionContextType::Constant(_) | ExpressionContextType::Override => false,
}
}
#[allow(dead_code)] #[allow(dead_code)]
fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
ExpressionContext { ExpressionContext {
@ -588,6 +595,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
} }
} }
fn get(&self, handle: Handle<crate::Expression>) -> &crate::Expression {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx)
| ExpressionContextType::Constant(Some(ref ctx)) => &ctx.function.expressions[handle],
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
&self.module.global_expressions[handle]
}
}
}
fn local( fn local(
&mut self, &mut self,
local: &Handle<ast::Local>, local: &Handle<ast::Local>,
@ -614,6 +631,52 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
} }
} }
fn with_nested_runtime_expression_ctx<'a, F, T>(
&mut self,
span: Span,
f: F,
) -> Result<'source, (T, crate::Block)>
where
for<'t> F: FnOnce(&mut ExpressionContext<'source, 't, 't>) -> Result<'source, T>,
{
let mut block = crate::Block::new();
let rctx = match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => Ok(rctx),
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(span))
}
}?;
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
rctx.emitter.start(&rctx.function.expressions);
let nested_rctx = LocalExpressionContext {
local_table: rctx.local_table,
function: rctx.function,
block: &mut block,
emitter: rctx.emitter,
typifier: rctx.typifier,
local_expression_kind_tracker: rctx.local_expression_kind_tracker,
};
let mut nested_ctx = ExpressionContext {
expr_type: ExpressionContextType::Runtime(nested_rctx),
ast_expressions: self.ast_expressions,
types: self.types,
globals: self.globals,
module: self.module,
const_typifier: self.const_typifier,
layouter: self.layouter,
global_expression_kind_tracker: self.global_expression_kind_tracker,
};
let ret = f(&mut nested_ctx)?;
block.extend(rctx.emitter.finish(&rctx.function.expressions));
rctx.emitter.start(&rctx.function.expressions);
Ok((ret, block))
}
fn gather_component( fn gather_component(
&mut self, &mut self,
expr: Handle<ir::Expression>, expr: Handle<ir::Expression>,
@ -2375,6 +2438,130 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
expr.try_map(|handle| ctx.append_expression(handle, span)) expr.try_map(|handle| ctx.append_expression(handle, span))
} }
/// Generate IR for the short-circuiting operators `&&` and `||`.
///
/// `binary` has already lowered the LHS expression and resolved its type.
fn logical(
&mut self,
op: crate::BinaryOperator,
left: Handle<crate::Expression>,
right: Handle<ast::Expression<'source>>,
span: Span,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<'source, Typed<crate::Expression>> {
debug_assert!(
op == crate::BinaryOperator::LogicalAnd || op == crate::BinaryOperator::LogicalOr
);
if ctx.is_runtime() {
// To simulate short-circuiting behavior, we want to generate IR
// like the following for `&&`. For `||`, the condition is `!_lhs`
// and the else value is `true`.
//
// var _e0: bool;
// if _lhs {
// _e0 = _rhs;
// } else {
// _e0 = false;
// }
let (condition, else_val) = if op == crate::BinaryOperator::LogicalAnd {
let condition = left;
let else_val = ctx.append_expression(
crate::Expression::Literal(crate::Literal::Bool(false)),
span,
)?;
(condition, else_val)
} else {
let condition = ctx.append_expression(
crate::Expression::Unary {
op: crate::UnaryOperator::LogicalNot,
expr: left,
},
span,
)?;
let else_val = ctx.append_expression(
crate::Expression::Literal(crate::Literal::Bool(true)),
span,
)?;
(condition, else_val)
};
let bool_ty = ctx.ensure_type_exists(crate::TypeInner::Scalar(crate::Scalar::BOOL));
let rctx = ctx.runtime_expression_ctx(span)?;
let result_var = rctx.function.local_variables.append(
crate::LocalVariable {
name: None,
ty: bool_ty,
init: None,
},
span,
);
let pointer =
ctx.append_expression(crate::Expression::LocalVariable(result_var), span)?;
let (right, mut accept) = ctx.with_nested_runtime_expression_ctx(span, |ctx| {
let right = self.expression_for_abstract(right, ctx)?;
ctx.grow_types(right)?;
Ok(right)
})?;
accept.push(
crate::Statement::Store {
pointer,
value: right,
},
span,
);
let mut reject = crate::Block::with_capacity(1);
reject.push(
crate::Statement::Store {
pointer,
value: else_val,
},
span,
);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::If {
condition,
accept,
reject,
},
span,
);
Ok(Typed::Reference(crate::Expression::LocalVariable(
result_var,
)))
} else {
let left_expr = ctx.get(left);
// Constant or override context in either function or module scope
let &crate::Expression::Literal(crate::Literal::Bool(left_val)) = left_expr else {
return Err(Box::new(Error::NotBool(span)));
};
if op == crate::BinaryOperator::LogicalAnd && !left_val
|| op == crate::BinaryOperator::LogicalOr && left_val
{
// Short-circuit behavior: don't evaluate the RHS. Ideally we
// would do _some_ validity checks of the RHS here, but that's
// tricky, because the RHS is allowed to have things that aren't
// legal in const contexts.
Ok(Typed::Plain(left_expr.clone()))
} else {
let right = self.expression_for_abstract(right, ctx)?;
ctx.grow_types(right)?;
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
}
}
}
fn binary( fn binary(
&mut self, &mut self,
op: ir::BinaryOperator, op: ir::BinaryOperator,
@ -2383,6 +2570,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
span: Span, span: Span,
ctx: &mut ExpressionContext<'source, '_, '_>, ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<'source, Typed<ir::Expression>> { ) -> Result<'source, Typed<ir::Expression>> {
if op == ir::BinaryOperator::LogicalAnd || op == ir::BinaryOperator::LogicalOr {
let left = self.expression_for_abstract(left, ctx)?;
ctx.grow_types(left)?;
if !matches!(
resolve_inner!(ctx, left),
&ir::TypeInner::Scalar(ir::Scalar::BOOL)
) {
// Pass it through as-is, will fail validation
let right = self.expression_for_abstract(right, ctx)?;
ctx.grow_types(right)?;
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
} else {
self.logical(op, left, right, span, ctx)
}
} else {
// Load both operands. // Load both operands.
let mut left = self.expression_for_abstract(left, ctx)?; let mut left = self.expression_for_abstract(left, ctx)?;
let mut right = self.expression_for_abstract(right, ctx)?; let mut right = self.expression_for_abstract(right, ctx)?;
@ -2435,6 +2638,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(Typed::Plain(ir::Expression::Binary { op, left, right })) Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
} }
}
/// Generate Naga IR for call expressions and statements, and type /// Generate Naga IR for call expressions and statements, and type
/// constructor expressions. /// constructor expressions.

View File

@ -40,6 +40,11 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
return vec3<f32>(y); return vec3<f32>(y);
} }
fn p() -> bool { return true; }
fn q() -> bool { return false; }
fn r() -> bool { return true; }
fn s() -> bool { return false; }
fn logical() { fn logical() {
let t = true; let t = true;
let f = false; let f = false;
@ -55,6 +60,7 @@ fn logical() {
let bitwise_or1 = vec3(t) | vec3(f); let bitwise_or1 = vec3(t) | vec3(f);
let bitwise_and0 = t & f; let bitwise_and0 = t & f;
let bitwise_and1 = vec4(t) & vec4(f); let bitwise_and1 = vec4(t) & vec4(f);
let short_circuit = (p() || q()) && (r() || s());
} }
fn arithmetic() { fn arithmetic() {

View File

@ -46,15 +46,68 @@ vec3 bool_cast(vec3 x) {
return vec3(y); return vec3(y);
} }
bool p() {
return true;
}
bool q() {
return false;
}
bool r() {
return true;
}
bool s() {
return false;
}
void logical() { void logical() {
bool local = false;
bool local_1 = false;
bool local_2 = false;
bool local_3 = false;
bool local_4 = false;
bool neg0_ = !(true); bool neg0_ = !(true);
bvec2 neg1_ = not(bvec2(true)); bvec2 neg1_ = not(bvec2(true));
bool or = (true || false); if (!(true)) {
bool and = (true && false); local = false;
} else {
local = true;
}
bool or = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool and = local_1;
bool bitwise_or0_ = (true || false); bool bitwise_or0_ = (true || false);
bvec3 bitwise_or1_ = bvec3(bvec3(true).x || bvec3(false).x, bvec3(true).y || bvec3(false).y, bvec3(true).z || bvec3(false).z); bvec3 bitwise_or1_ = bvec3(bvec3(true).x || bvec3(false).x, bvec3(true).y || bvec3(false).y, bvec3(true).z || bvec3(false).z);
bool bitwise_and0_ = (true && false); bool bitwise_and0_ = (true && false);
bvec4 bitwise_and1_ = bvec4(bvec4(true).x && bvec4(false).x, bvec4(true).y && bvec4(false).y, bvec4(true).z && bvec4(false).z, bvec4(true).w && bvec4(false).w); bvec4 bitwise_and1_ = bvec4(bvec4(true).x && bvec4(false).x, bvec4(true).y && bvec4(false).y, bvec4(true).z && bvec4(false).z, bvec4(true).w && bvec4(false).w);
bool _e22 = p();
if (!(_e22)) {
bool _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
bool _e28 = local_2;
if (_e28) {
bool _e31 = r();
if (!(_e31)) {
bool _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
bool _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
bool short_circuit = local_3;
return; return;
} }

View File

@ -48,16 +48,74 @@ float3 bool_cast(float3 x)
return float3(y); return float3(y);
} }
bool p()
{
return true;
}
bool q()
{
return false;
}
bool r()
{
return true;
}
bool s()
{
return false;
}
void logical() void logical()
{ {
bool local = (bool)0;
bool local_1 = (bool)0;
bool local_2 = (bool)0;
bool local_3 = (bool)0;
bool local_4 = (bool)0;
bool neg0_ = !(true); bool neg0_ = !(true);
bool2 neg1_ = !((true).xx); bool2 neg1_ = !((true).xx);
bool or_ = (true || false); if (!(true)) {
bool and_ = (true && false); local = false;
} else {
local = true;
}
bool or_ = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool and_ = local_1;
bool bitwise_or0_ = (true | false); bool bitwise_or0_ = (true | false);
bool3 bitwise_or1_ = ((true).xxx | (false).xxx); bool3 bitwise_or1_ = ((true).xxx | (false).xxx);
bool bitwise_and0_ = (true & false); bool bitwise_and0_ = (true & false);
bool4 bitwise_and1_ = ((true).xxxx & (false).xxxx); bool4 bitwise_and1_ = ((true).xxxx & (false).xxxx);
const bool _e22 = p();
if (!(_e22)) {
const bool _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
bool _e28 = local_2;
if (_e28) {
const bool _e31 = r();
if (!(_e31)) {
const bool _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
bool _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
bool short_circuit = local_3;
return; return;
} }

View File

@ -56,16 +56,73 @@ metal::float3 bool_cast(
return static_cast<metal::float3>(y); return static_cast<metal::float3>(y);
} }
bool p(
) {
return true;
}
bool q(
) {
return false;
}
bool r(
) {
return true;
}
bool s(
) {
return false;
}
void logical( void logical(
) { ) {
bool local = {};
bool local_1 = {};
bool local_2 = {};
bool local_3 = {};
bool local_4 = {};
bool neg0_ = !(true); bool neg0_ = !(true);
metal::bool2 neg1_ = !(metal::bool2(true)); metal::bool2 neg1_ = !(metal::bool2(true));
bool or_ = true || false; if (!(true)) {
bool and_ = true && false; local = false;
} else {
local = true;
}
bool or_ = local;
if (true) {
local_1 = false;
} else {
local_1 = false;
}
bool and_ = local_1;
bool bitwise_or0_ = true | false; bool bitwise_or0_ = true | false;
metal::bool3 bitwise_or1_ = metal::bool3(true) | metal::bool3(false); metal::bool3 bitwise_or1_ = metal::bool3(true) | metal::bool3(false);
bool bitwise_and0_ = true & false; bool bitwise_and0_ = true & false;
metal::bool4 bitwise_and1_ = metal::bool4(true) & metal::bool4(false); metal::bool4 bitwise_and1_ = metal::bool4(true) & metal::bool4(false);
bool _e22 = p();
if (!(_e22)) {
bool _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
bool _e28 = local_2;
if (_e28) {
bool _e31 = r();
if (!(_e31)) {
bool _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
bool _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
bool short_circuit = local_3;
return; return;
} }

File diff suppressed because it is too large Load Diff

View File

@ -39,15 +39,69 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
return vec3<f32>(y); return vec3<f32>(y);
} }
fn p() -> bool {
return true;
}
fn q() -> bool {
return false;
}
fn r() -> bool {
return true;
}
fn s() -> bool {
return false;
}
fn logical() { fn logical() {
var local: bool;
var local_1: bool;
var local_2: bool;
var local_3: bool;
var local_4: bool;
let neg0_ = !(true); let neg0_ = !(true);
let neg1_ = !(vec2(true)); let neg1_ = !(vec2(true));
let or = (true || false); if !(true) {
let and = (true && false); local = false;
} else {
local = true;
}
let or = local;
if true {
local_1 = false;
} else {
local_1 = false;
}
let and = local_1;
let bitwise_or0_ = (true | false); let bitwise_or0_ = (true | false);
let bitwise_or1_ = (vec3(true) | vec3(false)); let bitwise_or1_ = (vec3(true) | vec3(false));
let bitwise_and0_ = (true & false); let bitwise_and0_ = (true & false);
let bitwise_and1_ = (vec4(true) & vec4(false)); let bitwise_and1_ = (vec4(true) & vec4(false));
let _e22 = p();
if !(_e22) {
let _e26 = q();
local_2 = _e26;
} else {
local_2 = true;
}
let _e28 = local_2;
if _e28 {
let _e31 = r();
if !(_e31) {
let _e35 = s();
local_4 = _e35;
} else {
local_4 = true;
}
let _e37 = local_4;
local_3 = _e37;
} else {
local_3 = false;
}
let short_circuit = local_3;
return; return;
} }

View File

@ -107,8 +107,8 @@ fn main(
add_result_to_mask(&passed, 21u, subgroupBroadcast(subgroup_invocation_id, 1u) == 1u); add_result_to_mask(&passed, 21u, subgroupBroadcast(subgroup_invocation_id, 1u) == 1u);
add_result_to_mask(&passed, 22u, subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id); add_result_to_mask(&passed, 22u, subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id);
add_result_to_mask(&passed, 23u, subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id); add_result_to_mask(&passed, 23u, subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id);
add_result_to_mask(&passed, 24u, subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u); add_result_to_mask(&passed, 24u, subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u || subgroup_invocation_id == subgroup_size - 1u);
add_result_to_mask(&passed, 25u, subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u); add_result_to_mask(&passed, 25u, subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u || subgroup_invocation_id == 0u);
add_result_to_mask(&passed, 26u, subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u))); add_result_to_mask(&passed, 26u, subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u)));
// Mac/Apple will fail this test. // Mac/Apple will fail this test.