[naga spv-out] Factor out wrapped divide/module generation.

Move the code to generate the definition of an overflow-safe
divide/modulo SPIR-V function into its own Rust function, to reduce
indentation and clarify influences. This commit isn't intended to
cause any change in behavior.
This commit is contained in:
Jim Blandy 2025-02-13 14:55:07 -08:00
parent cb9666c6a7
commit f90f19c7e8
2 changed files with 226 additions and 192 deletions

View File

@ -303,6 +303,14 @@ impl NumericType {
}
}
const fn scalar(self) -> crate::Scalar {
match self {
NumericType::Scalar(scalar)
| NumericType::Vector { scalar, .. }
| NumericType::Matrix { scalar, .. } => scalar,
}
}
const fn with_scalar(self, scalar: crate::Scalar) -> Self {
match self {
NumericType::Scalar(_) => NumericType::Scalar(scalar),

View File

@ -223,6 +223,10 @@ impl Writer {
self.get_type_id(lookup_ty)
}
pub(super) fn get_localtype_id(&mut self, local: LocalType) -> Word {
self.get_type_id(LookupType::Local(local))
}
pub(super) fn get_pointer_id(
&mut self,
handle: Handle<crate::Type>,
@ -320,199 +324,27 @@ impl Writer {
for (expr_handle, expr) in ir_function.expressions.iter() {
match *expr {
crate::Expression::Binary { op, left, right } => {
let expr_ty = info[expr_handle].ty.inner_with(&ir_module.types);
let Some(numeric_type) = NumericType::from_inner(expr_ty) else {
continue;
};
match (op, expr_ty.scalar()) {
// Division and modulo are undefined behaviour when the dividend is the
// minimum representable value and the divisor is negative one, or when
// the divisor is zero. These wrapped functions override the divisor to
// one in these cases, matching the WGSL spec.
(
crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
Some(
scalar @ crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
),
) => {
let return_type_id = self.get_expression_type_id(&info[expr_handle].ty);
let left_type_id = self.get_expression_type_id(&info[left].ty);
let right_type_id = self.get_expression_type_id(&info[right].ty);
let wrapped = WrappedFunction::BinaryOp {
op,
left_type_id,
right_type_id,
};
let function_id = *match self.wrapped_functions.entry(wrapped) {
Entry::Occupied(_) => continue,
Entry::Vacant(e) => e.insert(self.id_gen.next()),
};
if self.flags.contains(WriterFlags::DEBUG) {
let function_name = match op {
crate::BinaryOperator::Divide => "naga_div",
crate::BinaryOperator::Modulo => "naga_mod",
_ => unreachable!(),
};
self.debugs
.push(Instruction::name(function_id, function_name));
let expr_ty_inner = info[expr_handle].ty.inner_with(&ir_module.types);
if let Some(expr_ty) = NumericType::from_inner(expr_ty_inner) {
match (op, expr_ty.scalar().kind) {
// Division and modulo are undefined behaviour when the
// dividend is the minimum representable value and the divisor
// is negative one, or when the divisor is zero. These wrapped
// functions override the divisor to one in these cases,
// matching the WGSL spec.
(
crate::BinaryOperator::Divide | crate::BinaryOperator::Modulo,
crate::ScalarKind::Sint | crate::ScalarKind::Uint,
) => {
self.write_wrapped_binary_op(
op,
expr_ty,
&info[left].ty,
&info[right].ty,
)?;
}
let mut function = Function::default();
let function_type_id = self.get_function_type(LookupFunctionType {
parameter_type_ids: vec![left_type_id, right_type_id],
return_type_id,
});
function.signature = Some(Instruction::function(
return_type_id,
function_id,
spirv::FunctionControl::empty(),
function_type_id,
));
let lhs_id = self.id_gen.next();
let rhs_id = self.id_gen.next();
if self.flags.contains(WriterFlags::DEBUG) {
self.debugs.push(Instruction::name(lhs_id, "lhs"));
self.debugs.push(Instruction::name(rhs_id, "rhs"));
}
let left_par = Instruction::function_parameter(left_type_id, lhs_id);
let right_par = Instruction::function_parameter(right_type_id, rhs_id);
for instruction in [left_par, right_par] {
function.parameters.push(FunctionArgument {
instruction,
handle_id: 0,
});
}
let label_id = self.id_gen.next();
let mut block = Block::new(label_id);
let bool_type = numeric_type.with_scalar(crate::Scalar::BOOL);
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type)));
let maybe_splat_const = |writer: &mut Self, const_id| match numeric_type
{
NumericType::Scalar(_) => const_id,
NumericType::Vector { size, .. } => {
let constituent_ids = [const_id; crate::VectorSize::MAX];
writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(numeric_type)),
&constituent_ids[..size as usize],
)
}
NumericType::Matrix { .. } => unreachable!(),
};
let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
let composite_zero_id = maybe_splat_const(self, const_zero_id);
let rhs_eq_zero_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
rhs_eq_zero_id,
rhs_id,
composite_zero_id,
));
let divisor_selector_id = match scalar.kind {
crate::ScalarKind::Sint => {
let (const_min_id, const_neg_one_id) = match scalar.width {
4 => Ok((
self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
self.get_constant_scalar(crate::Literal::I32(-1i32)),
)),
8 => Ok((
self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
self.get_constant_scalar(crate::Literal::I64(-1i64)),
)),
_ => Err(Error::Validation("Unexpected scalar width")),
}?;
let composite_min_id = maybe_splat_const(self, const_min_id);
let composite_neg_one_id =
maybe_splat_const(self, const_neg_one_id);
let lhs_eq_int_min_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
lhs_eq_int_min_id,
lhs_id,
composite_min_id,
));
let rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
rhs_eq_neg_one_id,
rhs_id,
composite_neg_one_id,
));
let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::LogicalAnd,
bool_type_id,
lhs_eq_int_min_and_rhs_eq_neg_one_id,
lhs_eq_int_min_id,
rhs_eq_neg_one_id,
));
let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id =
self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::LogicalOr,
bool_type_id,
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
rhs_eq_zero_id,
lhs_eq_int_min_and_rhs_eq_neg_one_id,
));
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
}
crate::ScalarKind::Uint => rhs_eq_zero_id,
_ => unreachable!(),
};
let const_one_id = self.get_constant_scalar_with(1, scalar)?;
let composite_one_id = maybe_splat_const(self, const_one_id);
let divisor_id = self.id_gen.next();
block.body.push(Instruction::select(
right_type_id,
divisor_id,
divisor_selector_id,
composite_one_id,
rhs_id,
));
let op = match (op, scalar.kind) {
(crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => {
spirv::Op::SDiv
}
(crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => {
spirv::Op::UDiv
}
(crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => {
spirv::Op::SRem
}
(crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => {
spirv::Op::UMod
}
_ => unreachable!(),
};
let return_id = self.id_gen.next();
block.body.push(Instruction::binary(
op,
return_type_id,
return_id,
lhs_id,
divisor_id,
));
function.consume(block, Instruction::return_value(return_id));
function.to_words(&mut self.logical_layout.function_definitions);
Instruction::function_end()
.to_words(&mut self.logical_layout.function_definitions);
_ => {}
}
_ => {}
}
}
_ => {}
@ -522,6 +354,200 @@ impl Writer {
Ok(())
}
/// Write a SPIR-V function that performs the operator `op` with Naga IR semantics.
///
/// Define a function that performs an integer division or modulo operation,
/// except that using a divisor of zero or causing signed overflow with a
/// divisor of -1 returns the numerator unchanged, rather than exhibiting
/// undefined behavior.
///
/// Store the generated function's id in the [`wrapped_functions`] table.
///
/// The operator `op` must be either [`Divide`] or [`Modulo`].
///
/// # Panics
///
/// The `return_type`, `left_type` or `right_type` arguments must all be
/// integer scalars or vectors. If not, this function panics.
///
/// [`wrapped_functions`]: Writer::wrapped_functions
/// [`Divide`]: crate::BinaryOperator::Divide
/// [`Modulo`]: crate::BinaryOperator::Modulo
fn write_wrapped_binary_op(
&mut self,
op: crate::BinaryOperator,
return_type: NumericType,
left_type: &TypeResolution,
right_type: &TypeResolution,
) -> Result<(), Error> {
let return_type_id = self.get_localtype_id(LocalType::Numeric(return_type));
let left_type_id = self.get_expression_type_id(left_type);
let right_type_id = self.get_expression_type_id(right_type);
// Check if we've already emitted this function.
let wrapped = WrappedFunction::BinaryOp {
op,
left_type_id,
right_type_id,
};
let function_id = match self.wrapped_functions.entry(wrapped) {
Entry::Occupied(_) => return Ok(()),
Entry::Vacant(e) => *e.insert(self.id_gen.next()),
};
let scalar = return_type.scalar();
if self.flags.contains(WriterFlags::DEBUG) {
let function_name = match op {
crate::BinaryOperator::Divide => "naga_div",
crate::BinaryOperator::Modulo => "naga_mod",
_ => unreachable!(),
};
self.debugs
.push(Instruction::name(function_id, function_name));
}
let mut function = Function::default();
let function_type_id = self.get_function_type(LookupFunctionType {
parameter_type_ids: vec![left_type_id, right_type_id],
return_type_id,
});
function.signature = Some(Instruction::function(
return_type_id,
function_id,
spirv::FunctionControl::empty(),
function_type_id,
));
let lhs_id = self.id_gen.next();
let rhs_id = self.id_gen.next();
if self.flags.contains(WriterFlags::DEBUG) {
self.debugs.push(Instruction::name(lhs_id, "lhs"));
self.debugs.push(Instruction::name(rhs_id, "rhs"));
}
let left_par = Instruction::function_parameter(left_type_id, lhs_id);
let right_par = Instruction::function_parameter(right_type_id, rhs_id);
for instruction in [left_par, right_par] {
function.parameters.push(FunctionArgument {
instruction,
handle_id: 0,
});
}
let label_id = self.id_gen.next();
let mut block = Block::new(label_id);
let bool_type = return_type.with_scalar(crate::Scalar::BOOL);
let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(bool_type)));
let maybe_splat_const = |writer: &mut Self, const_id| match return_type {
NumericType::Scalar(_) => const_id,
NumericType::Vector { size, .. } => {
let constituent_ids = [const_id; crate::VectorSize::MAX];
writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(return_type)),
&constituent_ids[..size as usize],
)
}
NumericType::Matrix { .. } => unreachable!(),
};
let const_zero_id = self.get_constant_scalar_with(0, scalar)?;
let composite_zero_id = maybe_splat_const(self, const_zero_id);
let rhs_eq_zero_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
rhs_eq_zero_id,
rhs_id,
composite_zero_id,
));
let divisor_selector_id = match scalar.kind {
crate::ScalarKind::Sint => {
let (const_min_id, const_neg_one_id) = match scalar.width {
4 => Ok((
self.get_constant_scalar(crate::Literal::I32(i32::MIN)),
self.get_constant_scalar(crate::Literal::I32(-1i32)),
)),
8 => Ok((
self.get_constant_scalar(crate::Literal::I64(i64::MIN)),
self.get_constant_scalar(crate::Literal::I64(-1i64)),
)),
_ => Err(Error::Validation("Unexpected scalar width")),
}?;
let composite_min_id = maybe_splat_const(self, const_min_id);
let composite_neg_one_id = maybe_splat_const(self, const_neg_one_id);
let lhs_eq_int_min_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
lhs_eq_int_min_id,
lhs_id,
composite_min_id,
));
let rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
rhs_eq_neg_one_id,
rhs_id,
composite_neg_one_id,
));
let lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::LogicalAnd,
bool_type_id,
lhs_eq_int_min_and_rhs_eq_neg_one_id,
lhs_eq_int_min_id,
rhs_eq_neg_one_id,
));
let rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id = self.id_gen.next();
block.body.push(Instruction::binary(
spirv::Op::LogicalOr,
bool_type_id,
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id,
rhs_eq_zero_id,
lhs_eq_int_min_and_rhs_eq_neg_one_id,
));
rhs_eq_zero_or_lhs_eq_int_min_and_rhs_eq_neg_one_id
}
crate::ScalarKind::Uint => rhs_eq_zero_id,
_ => unreachable!(),
};
let const_one_id = self.get_constant_scalar_with(1, scalar)?;
let composite_one_id = maybe_splat_const(self, const_one_id);
let divisor_id = self.id_gen.next();
block.body.push(Instruction::select(
right_type_id,
divisor_id,
divisor_selector_id,
composite_one_id,
rhs_id,
));
let op = match (op, scalar.kind) {
(crate::BinaryOperator::Divide, crate::ScalarKind::Sint) => spirv::Op::SDiv,
(crate::BinaryOperator::Divide, crate::ScalarKind::Uint) => spirv::Op::UDiv,
(crate::BinaryOperator::Modulo, crate::ScalarKind::Sint) => spirv::Op::SRem,
(crate::BinaryOperator::Modulo, crate::ScalarKind::Uint) => spirv::Op::UMod,
_ => unreachable!(),
};
let return_id = self.id_gen.next();
block.body.push(Instruction::binary(
op,
return_type_id,
return_id,
lhs_id,
divisor_id,
));
function.consume(block, Instruction::return_value(return_id));
function.to_words(&mut self.logical_layout.function_definitions);
Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
Ok(())
}
fn write_function(
&mut self,
ir_function: &crate::Function,
@ -1138,7 +1164,7 @@ impl Writer {
}
LocalType::Image(image) => {
let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type));
let type_id = self.get_type_id(LookupType::Local(local_type));
let type_id = self.get_localtype_id(local_type);
Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
}
LocalType::Sampler => Instruction::type_sampler(id),