mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga spv-out] Factor out wrapped divide/module generation.
Move the code to generate the definition of an overflow-safe divide/modulo SPIR-V function into its own Rust function, to reduce indentation and clarify influences. This commit isn't intended to cause any change in behavior.
This commit is contained in:
parent
cb9666c6a7
commit
f90f19c7e8
@ -303,6 +303,14 @@ impl NumericType {
|
||||
}
|
||||
}
|
||||
|
||||
const fn scalar(self) -> crate::Scalar {
|
||||
match self {
|
||||
NumericType::Scalar(scalar)
|
||||
| NumericType::Vector { scalar, .. }
|
||||
| NumericType::Matrix { scalar, .. } => scalar,
|
||||
}
|
||||
}
|
||||
|
||||
const fn with_scalar(self, scalar: crate::Scalar) -> Self {
|
||||
match self {
|
||||
NumericType::Scalar(_) => NumericType::Scalar(scalar),
|
||||
|
||||
@ -223,6 +223,10 @@ impl Writer {
|
||||
self.get_type_id(lookup_ty)
|
||||
}
|
||||
|
||||
pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
|
||||
self.get_type_id(LookupType::Local(local))
|
||||
}
|
||||
|
||||
pub(super) fn get_pointer_id(
|
||||
&mut self,
|
||||
handle: Handle<crate::Type>,
|
||||
@ -320,199 +324,27 @@ impl Writer {
|
||||
for (expr_handle, expr) in ir_function.expressions.iter() {
|
||||
match *expr {
|
||||
crate::Expression::Binary { op, left, right } => {
|
||||
let expr_ty = info[expr_handle].ty.inner_with(&ir_module.types);
|
||||
let Some(numeric_type) = NumericType::from_inner(expr_ty) else {
|
||||
continue;
|
||||
};
|
||||
match (op, expr_ty.scalar()) {
|
||||
// Division and modulo are undefined behaviour when the dividend is the
|
||||
// minimum representable value and the divisor is negative one, or when
|
||||
// the divisor is zero. These wrapped functions override the divisor to
|
||||
// one in these cases, matching the WGSL spec.
|
||||
(
|
||||
crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
|
||||
Some(
|
||||
scalar @ crate::Scalar {
|
||||
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
|
||||
..
|
||||
},
|
||||
),
|
||||
) => {
|
||||
let return_type_id = self.get_expression_type_id(&info[expr_handle].ty);
|
||||
let left_type_id = self.get_expression_type_id(&info[left].ty);
|
||||
let right_type_id = self.get_expression_type_id(&info[right].ty);
|
||||
let wrapped = WrappedFunction::BinaryOp {
|
||||
op,
|
||||
left_type_id,
|
||||
right_type_id,
|
||||
};
|
||||
let function_id = *match self.wrapped_functions.entry(wrapped) {
|
||||
Entry::Occupied(_) => continue,
|
||||
Entry::Vacant(e) => e.insert(self.id_gen.next()),
|
||||
};
|
||||
if self.flags.contains(WriterFlags::DEBUG) {
|
||||
let function_name = match op {
|
||||
crate::BinaryOperator::Divide => "naga_div",
|
||||
crate::BinaryOperator::Modulo => "naga_mod",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.debugs
|
||||
.push(Instruction::name(function_id, function_name));
|
||||
let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
|
||||
if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
|
||||
match (op, expr_ty.scalar().kind) {
|
||||
// Division and modulo are undefined behaviour when the
|
||||
// dividend is the minimum representable value and the divisor
|
||||
// is negative one, or when the divisor is zero. These wrapped
|
||||
// functions override the divisor to one in these cases,
|
||||
// matching the WGSL spec.
|
||||
(
|
||||
crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
|
||||
crate::ScalarKind::Sint | crate::ScalarKind::Uint,
|
||||
) => {
|
||||
self.write_wrapped_binary_op(
|
||||
op,
|
||||
expr_ty,
|
||||
&info[left].ty,
|
||||
&info[right].ty,
|
||||
)?;
|
||||
}
|
||||
let mut function = Function::default();
|
||||
|
||||
let function_type_id = self.get_function_type(LookupFunctionType {
|
||||
parameter_type_ids: vec![left_type_id, right_type_id],
|
||||
return_type_id,
|
||||
});
|
||||
function.signature = Some(Instruction::function(
|
||||
return_type_id,
|
||||
function_id,
|
||||
spirv::FunctionControl::empty(),
|
||||
function_type_id,
|
||||
));
|
||||
|
||||
let lhs_id = self.id_gen.next();
|
||||
let rhs_id = self.id_gen.next();
|
||||
if self.flags.contains(WriterFlags::DEBUG) {
|
||||
self.debugs.push(Instruction::name(lhs_id, "lhs"));
|
||||
self.debugs.push(Instruction::name(rhs_id, "rhs"));
|
||||
}
|
||||
let left_par = Instruction::function_parameter(left_type_id, lhs_id);
|
||||
let right_par = Instruction::function_parameter(right_type_id, rhs_id);
|
||||
for instruction in [left_par, right_par] {
|
||||
function.parameters.push(FunctionArgument {
|
||||
instruction,
|
||||
handle_id: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let label_id = self.id_gen.next();
|
||||
let mut block = Block::new(label_id);
|
||||
|
||||
let bool_type = numeric_type.with_scalar(crate::Scalar::BOOL);
|
||||
let bool_type_id =
|
||||
self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type)));
|
||||
|
||||
let maybe_splat_const = |writer: &mut Self, const_id| match numeric_type
|
||||
{
|
||||
NumericType::Scalar(_) => const_id,
|
||||
NumericType::Vector { size, .. } => {
|
||||
let constituent_ids = [const_id; crate::VectorSize::MAX];
|
||||
writer.get_constant_composite(
|
||||
LookupType::Local(LocalType::Numeric(numeric_type)),
|
||||
&constituent_ids[..size as usize],
|
||||
)
|
||||
}
|
||||
NumericType::Matrix { .. } => unreachable!(),
|
||||
};
|
||||
|
||||
let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
|
||||
let composite_zero_id = maybe_splat_const(self, const_zero_id);
|
||||
let rhs_eq_zero_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IEqual,
|
||||
bool_type_id,
|
||||
rhs_eq_zero_id,
|
||||
rhs_id,
|
||||
composite_zero_id,
|
||||
));
|
||||
let divisor_selector_id = match scalar.kind {
|
||||
crate::ScalarKind::Sint => {
|
||||
let (const_min_id, const_neg_one_id) = match scalar.width {
|
||||
4 => Ok((
|
||||
self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
|
||||
self.get_constant_scalar(crate::Literal::I32(-1i32)),
|
||||
)),
|
||||
8 => Ok((
|
||||
self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
|
||||
self.get_constant_scalar(crate::Literal::I64(-1i64)),
|
||||
)),
|
||||
_ => Err(Error::Validation("Unexpected scalar width")),
|
||||
}?;
|
||||
let composite_min_id = maybe_splat_const(self, const_min_id);
|
||||
let composite_neg_one_id =
|
||||
maybe_splat_const(self, const_neg_one_id);
|
||||
|
||||
let lhs_eq_int_min_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IEqual,
|
||||
bool_type_id,
|
||||
lhs_eq_int_min_id,
|
||||
lhs_id,
|
||||
composite_min_id,
|
||||
));
|
||||
let rhs_eq_neg_one_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IEqual,
|
||||
bool_type_id,
|
||||
rhs_eq_neg_one_id,
|
||||
rhs_id,
|
||||
composite_neg_one_id,
|
||||
));
|
||||
let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::LogicalAnd,
|
||||
bool_type_id,
|
||||
lhs_eq_int_min_and_rhs_eq_neg_one_id,
|
||||
lhs_eq_int_min_id,
|
||||
rhs_eq_neg_one_id,
|
||||
));
|
||||
let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id =
|
||||
self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::LogicalOr,
|
||||
bool_type_id,
|
||||
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
|
||||
rhs_eq_zero_id,
|
||||
lhs_eq_int_min_and_rhs_eq_neg_one_id,
|
||||
));
|
||||
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
|
||||
}
|
||||
crate::ScalarKind::Uint => rhs_eq_zero_id,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let const_one_id = self.get_constant_scalar_with(1, scalar)?;
|
||||
let composite_one_id = maybe_splat_const(self, const_one_id);
|
||||
let divisor_id = self.id_gen.next();
|
||||
block.body.push(Instruction::select(
|
||||
right_type_id,
|
||||
divisor_id,
|
||||
divisor_selector_id,
|
||||
composite_one_id,
|
||||
rhs_id,
|
||||
));
|
||||
let op = match (op, scalar.kind) {
|
||||
(crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => {
|
||||
spirv::Op::SDiv
|
||||
}
|
||||
(crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => {
|
||||
spirv::Op::UDiv
|
||||
}
|
||||
(crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => {
|
||||
spirv::Op::SRem
|
||||
}
|
||||
(crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => {
|
||||
spirv::Op::UMod
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let return_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
op,
|
||||
return_type_id,
|
||||
return_id,
|
||||
lhs_id,
|
||||
divisor_id,
|
||||
));
|
||||
|
||||
function.consume(block, Instruction::return_value(return_id));
|
||||
function.to_words(&mut self.logical_layout.function_definitions);
|
||||
Instruction::function_end()
|
||||
.to_words(&mut self.logical_layout.function_definitions);
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@ -522,6 +354,200 @@ impl Writer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
|
||||
///
|
||||
/// Define a function that performs an integer division or modulo operation,
|
||||
/// except that using a divisor of zero or causing signed overflow with a
|
||||
/// divisor of -1 returns the numerator unchanged, rather than exhibiting
|
||||
/// undefined behavior.
|
||||
///
|
||||
/// Store the generated function's id in the [`wrapped_functions`] table.
|
||||
///
|
||||
/// The operator `op` must be either [`Divide`] or [`Modulo`].
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// The `return_type`, `left_type` or `right_type` arguments must all be
|
||||
/// integer scalars or vectors. If not, this function panics.
|
||||
///
|
||||
/// [`wrapped_functions`]: Writer::wrapped_functions
|
||||
/// [`Divide`]: crate::BinaryOperator::Divide
|
||||
/// [`Modulo`]: crate::BinaryOperator::Modulo
|
||||
fn write_wrapped_binary_op(
|
||||
&mut self,
|
||||
op: crate::BinaryOperator,
|
||||
return_type: NumericType,
|
||||
left_type: &TypeResolution,
|
||||
right_type: &TypeResolution,
|
||||
) -> Result<(), Error> {
|
||||
let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
|
||||
let left_type_id = self.get_expression_type_id(left_type);
|
||||
let right_type_id = self.get_expression_type_id(right_type);
|
||||
|
||||
// Check if we've already emitted this function.
|
||||
let wrapped = WrappedFunction::BinaryOp {
|
||||
op,
|
||||
left_type_id,
|
||||
right_type_id,
|
||||
};
|
||||
let function_id = match self.wrapped_functions.entry(wrapped) {
|
||||
Entry::Occupied(_) => return Ok(()),
|
||||
Entry::Vacant(e) => *e.insert(self.id_gen.next()),
|
||||
};
|
||||
|
||||
let scalar = return_type.scalar();
|
||||
|
||||
if self.flags.contains(WriterFlags::DEBUG) {
|
||||
let function_name = match op {
|
||||
crate::BinaryOperator::Divide => "naga_div",
|
||||
crate::BinaryOperator::Modulo => "naga_mod",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.debugs
|
||||
.push(Instruction::name(function_id, function_name));
|
||||
}
|
||||
let mut function = Function::default();
|
||||
|
||||
let function_type_id = self.get_function_type(LookupFunctionType {
|
||||
parameter_type_ids: vec![left_type_id, right_type_id],
|
||||
return_type_id,
|
||||
});
|
||||
function.signature = Some(Instruction::function(
|
||||
return_type_id,
|
||||
function_id,
|
||||
spirv::FunctionControl::empty(),
|
||||
function_type_id,
|
||||
));
|
||||
|
||||
let lhs_id = self.id_gen.next();
|
||||
let rhs_id = self.id_gen.next();
|
||||
if self.flags.contains(WriterFlags::DEBUG) {
|
||||
self.debugs.push(Instruction::name(lhs_id, "lhs"));
|
||||
self.debugs.push(Instruction::name(rhs_id, "rhs"));
|
||||
}
|
||||
let left_par = Instruction::function_parameter(left_type_id, lhs_id);
|
||||
let right_par = Instruction::function_parameter(right_type_id, rhs_id);
|
||||
for instruction in [left_par, right_par] {
|
||||
function.parameters.push(FunctionArgument {
|
||||
instruction,
|
||||
handle_id: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let label_id = self.id_gen.next();
|
||||
let mut block = Block::new(label_id);
|
||||
|
||||
let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
|
||||
let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type)));
|
||||
|
||||
let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
|
||||
NumericType::Scalar(_) => const_id,
|
||||
NumericType::Vector { size, .. } => {
|
||||
let constituent_ids = [const_id; crate::VectorSize::MAX];
|
||||
writer.get_constant_composite(
|
||||
LookupType::Local(LocalType::Numeric(return_type)),
|
||||
&constituent_ids[..size as usize],
|
||||
)
|
||||
}
|
||||
NumericType::Matrix { .. } => unreachable!(),
|
||||
};
|
||||
|
||||
let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
|
||||
let composite_zero_id = maybe_splat_const(self, const_zero_id);
|
||||
let rhs_eq_zero_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IEqual,
|
||||
bool_type_id,
|
||||
rhs_eq_zero_id,
|
||||
rhs_id,
|
||||
composite_zero_id,
|
||||
));
|
||||
let divisor_selector_id = match scalar.kind {
|
||||
crate::ScalarKind::Sint => {
|
||||
let (const_min_id, const_neg_one_id) = match scalar.width {
|
||||
4 => Ok((
|
||||
self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
|
||||
self.get_constant_scalar(crate::Literal::I32(-1i32)),
|
||||
)),
|
||||
8 => Ok((
|
||||
self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
|
||||
self.get_constant_scalar(crate::Literal::I64(-1i64)),
|
||||
)),
|
||||
_ => Err(Error::Validation("Unexpected scalar width")),
|
||||
}?;
|
||||
let composite_min_id = maybe_splat_const(self, const_min_id);
|
||||
let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
|
||||
|
||||
let lhs_eq_int_min_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IEqual,
|
||||
bool_type_id,
|
||||
lhs_eq_int_min_id,
|
||||
lhs_id,
|
||||
composite_min_id,
|
||||
));
|
||||
let rhs_eq_neg_one_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::IEqual,
|
||||
bool_type_id,
|
||||
rhs_eq_neg_one_id,
|
||||
rhs_id,
|
||||
composite_neg_one_id,
|
||||
));
|
||||
let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::LogicalAnd,
|
||||
bool_type_id,
|
||||
lhs_eq_int_min_and_rhs_eq_neg_one_id,
|
||||
lhs_eq_int_min_id,
|
||||
rhs_eq_neg_one_id,
|
||||
));
|
||||
let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
spirv::Op::LogicalOr,
|
||||
bool_type_id,
|
||||
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
|
||||
rhs_eq_zero_id,
|
||||
lhs_eq_int_min_and_rhs_eq_neg_one_id,
|
||||
));
|
||||
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
|
||||
}
|
||||
crate::ScalarKind::Uint => rhs_eq_zero_id,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let const_one_id = self.get_constant_scalar_with(1, scalar)?;
|
||||
let composite_one_id = maybe_splat_const(self, const_one_id);
|
||||
let divisor_id = self.id_gen.next();
|
||||
block.body.push(Instruction::select(
|
||||
right_type_id,
|
||||
divisor_id,
|
||||
divisor_selector_id,
|
||||
composite_one_id,
|
||||
rhs_id,
|
||||
));
|
||||
let op = match (op, scalar.kind) {
|
||||
(crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
|
||||
(crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
|
||||
(crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
|
||||
(crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let return_id = self.id_gen.next();
|
||||
block.body.push(Instruction::binary(
|
||||
op,
|
||||
return_type_id,
|
||||
return_id,
|
||||
lhs_id,
|
||||
divisor_id,
|
||||
));
|
||||
|
||||
function.consume(block, Instruction::return_value(return_id));
|
||||
function.to_words(&mut self.logical_layout.function_definitions);
|
||||
Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_function(
|
||||
&mut self,
|
||||
ir_function: &crate::Function,
|
||||
@ -1138,7 +1164,7 @@ impl Writer {
|
||||
}
|
||||
LocalType::Image(image) => {
|
||||
let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
|
||||
let type_id = self.get_type_id(LookupType::Local(local_type));
|
||||
let type_id = self.get_localtype_id(local_type);
|
||||
Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
|
||||
}
|
||||
LocalType::Sampler => Instruction::type_sampler(id),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user