mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-12-08 21:26:17 +00:00
[naga hlsl-out] Factor out some repetitive code
This commit is contained in:
parent
bfa7ee8de5
commit
1b4eca97cf
@ -45,6 +45,11 @@ pub(crate) const F2U64_FUNCTION: &str = "naga_f2u64";
|
||||
pub(crate) const IMAGE_SAMPLE_BASE_CLAMP_TO_EDGE_FUNCTION: &str =
|
||||
"nagaTextureSampleBaseClampToEdge";
|
||||
|
||||
enum Index {
|
||||
Expression(Handle<crate::Expression>),
|
||||
Static(u32),
|
||||
}
|
||||
|
||||
struct EpStructMember {
|
||||
name: String,
|
||||
ty: Handle<crate::Type>,
|
||||
@ -1797,6 +1802,23 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_index(
|
||||
&mut self,
|
||||
module: &Module,
|
||||
index: Index,
|
||||
func_ctx: &back::FunctionCtx<'_>,
|
||||
) -> BackendResult {
|
||||
match index {
|
||||
Index::Static(index) => {
|
||||
write!(self.out, "{index}")?;
|
||||
}
|
||||
Index::Expression(index) => {
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper method used to write statements
|
||||
///
|
||||
/// # Notes
|
||||
@ -1953,13 +1975,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
//
|
||||
// We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
|
||||
// Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
|
||||
struct MatrixAccess {
|
||||
base: Handle<crate::Expression>,
|
||||
index: u32,
|
||||
}
|
||||
enum Index {
|
||||
Expression(Handle<crate::Expression>),
|
||||
Static(u32),
|
||||
enum MatrixAccess {
|
||||
Direct {
|
||||
base: Handle<crate::Expression>,
|
||||
index: u32,
|
||||
},
|
||||
Struct {
|
||||
columns: crate::VectorSize,
|
||||
base: Handle<crate::Expression>,
|
||||
},
|
||||
}
|
||||
|
||||
let get_members = |expr: Handle<crate::Expression>| {
|
||||
@ -1973,187 +1997,28 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
}
|
||||
};
|
||||
|
||||
let mut matrix = None;
|
||||
let mut vector = None;
|
||||
let mut scalar = None;
|
||||
|
||||
let mut current_expr = pointer;
|
||||
for _ in 0..3 {
|
||||
let resolved = func_ctx.resolve_type(current_expr, &module.types);
|
||||
|
||||
match (resolved, &func_ctx.expressions[current_expr]) {
|
||||
(
|
||||
&TypeInner::Pointer { base: ty, .. },
|
||||
&crate::Expression::AccessIndex { base, index },
|
||||
) if matches!(
|
||||
module.types[ty].inner,
|
||||
TypeInner::Matrix {
|
||||
rows: crate::VectorSize::Bi,
|
||||
..
|
||||
}
|
||||
) && get_members(base)
|
||||
.map(|members| members[index as usize].binding.is_none())
|
||||
== Some(true) =>
|
||||
{
|
||||
matrix = Some(MatrixAccess { base, index });
|
||||
break;
|
||||
}
|
||||
(
|
||||
&TypeInner::ValuePointer {
|
||||
size: Some(crate::VectorSize::Bi),
|
||||
..
|
||||
},
|
||||
&crate::Expression::Access { base, index },
|
||||
) => {
|
||||
vector = Some(Index::Expression(index));
|
||||
current_expr = base;
|
||||
}
|
||||
(
|
||||
&TypeInner::ValuePointer {
|
||||
size: Some(crate::VectorSize::Bi),
|
||||
..
|
||||
},
|
||||
&crate::Expression::AccessIndex { base, index },
|
||||
) => {
|
||||
vector = Some(Index::Static(index));
|
||||
current_expr = base;
|
||||
}
|
||||
(
|
||||
&TypeInner::ValuePointer { size: None, .. },
|
||||
&crate::Expression::Access { base, index },
|
||||
) => {
|
||||
scalar = Some(Index::Expression(index));
|
||||
current_expr = base;
|
||||
}
|
||||
(
|
||||
&TypeInner::ValuePointer { size: None, .. },
|
||||
&crate::Expression::AccessIndex { base, index },
|
||||
) => {
|
||||
scalar = Some(Index::Static(index));
|
||||
current_expr = base;
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
write!(self.out, "{level}")?;
|
||||
|
||||
if let Some(MatrixAccess { index, base }) = matrix {
|
||||
let base_ty_res = &func_ctx.info[base].ty;
|
||||
let resolved = base_ty_res.inner_with(&module.types);
|
||||
let ty = match *resolved {
|
||||
TypeInner::Pointer { base, .. } => base,
|
||||
_ => base_ty_res.handle().unwrap(),
|
||||
};
|
||||
|
||||
if let Some(Index::Static(vec_index)) = vector {
|
||||
self.write_expr(module, base, func_ctx)?;
|
||||
write!(
|
||||
self.out,
|
||||
".{}_{}",
|
||||
&self.names[&NameKey::StructMember(ty, index)],
|
||||
vec_index
|
||||
)?;
|
||||
|
||||
if let Some(scalar_index) = scalar {
|
||||
write!(self.out, "[")?;
|
||||
match scalar_index {
|
||||
Index::Static(index) => {
|
||||
write!(self.out, "{index}")?;
|
||||
}
|
||||
Index::Expression(index) => {
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
write!(self.out, "]")?;
|
||||
}
|
||||
|
||||
write!(self.out, " = ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
writeln!(self.out, ";")?;
|
||||
} else {
|
||||
let access = WrappedStructMatrixAccess { ty, index };
|
||||
match (&vector, &scalar) {
|
||||
(&Some(_), &Some(_)) => {
|
||||
self.write_wrapped_struct_matrix_set_scalar_function_name(
|
||||
access,
|
||||
)?;
|
||||
}
|
||||
(&Some(_), &None) => {
|
||||
self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
|
||||
}
|
||||
(&None, _) => {
|
||||
self.write_wrapped_struct_matrix_set_function_name(access)?;
|
||||
}
|
||||
}
|
||||
|
||||
write!(self.out, "(")?;
|
||||
self.write_expr(module, base, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
|
||||
if let Some(Index::Expression(vec_index)) = vector {
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, vec_index, func_ctx)?;
|
||||
|
||||
if let Some(scalar_index) = scalar {
|
||||
write!(self.out, ", ")?;
|
||||
match scalar_index {
|
||||
Index::Static(index) => {
|
||||
write!(self.out, "{index}")?;
|
||||
}
|
||||
Index::Expression(index) => {
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
} else {
|
||||
// We handle `Store`s to __matCx2 column vectors and scalar elements via
|
||||
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
|
||||
struct MatrixData {
|
||||
columns: crate::VectorSize,
|
||||
base: Handle<crate::Expression>,
|
||||
}
|
||||
|
||||
enum Index {
|
||||
Expression(Handle<crate::Expression>),
|
||||
Static(u32),
|
||||
}
|
||||
|
||||
let mut matrix = None;
|
||||
let mut vector = None;
|
||||
let mut scalar = None;
|
||||
|
||||
let mut current_expr = pointer;
|
||||
for _ in 0..3 {
|
||||
let resolved = func_ctx.resolve_type(current_expr, &module.types);
|
||||
match (resolved, &func_ctx.expressions[current_expr]) {
|
||||
let matrix_access_on_lhs =
|
||||
find_matrix_in_access_chain(module, pointer, func_ctx).and_then(
|
||||
|(matrix_expr, vector, scalar)| match (
|
||||
func_ctx.resolve_type(matrix_expr, &module.types),
|
||||
&func_ctx.expressions[matrix_expr],
|
||||
) {
|
||||
(
|
||||
&TypeInner::ValuePointer {
|
||||
size: Some(crate::VectorSize::Bi),
|
||||
..
|
||||
},
|
||||
&crate::Expression::Access { base, index },
|
||||
) => {
|
||||
vector = Some(index);
|
||||
current_expr = base;
|
||||
}
|
||||
(
|
||||
&TypeInner::ValuePointer { size: None, .. },
|
||||
&crate::Expression::Access { base, index },
|
||||
) => {
|
||||
scalar = Some(Index::Expression(index));
|
||||
current_expr = base;
|
||||
}
|
||||
(
|
||||
&TypeInner::ValuePointer { size: None, .. },
|
||||
&TypeInner::Pointer { base: ty, .. },
|
||||
&crate::Expression::AccessIndex { base, index },
|
||||
) => {
|
||||
scalar = Some(Index::Static(index));
|
||||
current_expr = base;
|
||||
) if matches!(
|
||||
module.types[ty].inner,
|
||||
TypeInner::Matrix {
|
||||
rows: crate::VectorSize::Bi,
|
||||
..
|
||||
}
|
||||
) && get_members(base)
|
||||
.map(|members| members[index as usize].binding.is_none())
|
||||
== Some(true) =>
|
||||
{
|
||||
Some((MatrixAccess::Direct { base, index }, vector, scalar))
|
||||
}
|
||||
_ => {
|
||||
if let Some(MatrixType {
|
||||
@ -2162,24 +2027,95 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
width: 4,
|
||||
}) = get_inner_matrix_of_struct_array_member(
|
||||
module,
|
||||
current_expr,
|
||||
matrix_expr,
|
||||
func_ctx,
|
||||
true,
|
||||
) {
|
||||
matrix = Some(MatrixData {
|
||||
columns,
|
||||
base: current_expr,
|
||||
});
|
||||
Some((
|
||||
MatrixAccess::Struct {
|
||||
columns,
|
||||
base: matrix_expr,
|
||||
},
|
||||
vector,
|
||||
scalar,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
match matrix_access_on_lhs {
|
||||
Some((MatrixAccess::Direct { index, base }, vector, scalar)) => {
|
||||
let base_ty_res = &func_ctx.info[base].ty;
|
||||
let resolved = base_ty_res.inner_with(&module.types);
|
||||
let ty = match *resolved {
|
||||
TypeInner::Pointer { base, .. } => base,
|
||||
_ => base_ty_res.handle().unwrap(),
|
||||
};
|
||||
|
||||
if let Some(Index::Static(vec_index)) = vector {
|
||||
self.write_expr(module, base, func_ctx)?;
|
||||
write!(
|
||||
self.out,
|
||||
".{}_{}",
|
||||
&self.names[&NameKey::StructMember(ty, index)],
|
||||
vec_index
|
||||
)?;
|
||||
|
||||
if let Some(scalar_index) = scalar {
|
||||
write!(self.out, "[")?;
|
||||
self.write_index(module, scalar_index, func_ctx)?;
|
||||
write!(self.out, "]")?;
|
||||
}
|
||||
|
||||
write!(self.out, " = ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
writeln!(self.out, ";")?;
|
||||
} else {
|
||||
let access = WrappedStructMatrixAccess { ty, index };
|
||||
match (&vector, &scalar) {
|
||||
(&Some(_), &Some(_)) => {
|
||||
self.write_wrapped_struct_matrix_set_scalar_function_name(
|
||||
access,
|
||||
)?;
|
||||
}
|
||||
(&Some(_), &None) => {
|
||||
self.write_wrapped_struct_matrix_set_vec_function_name(
|
||||
access,
|
||||
)?;
|
||||
}
|
||||
(&None, _) => {
|
||||
self.write_wrapped_struct_matrix_set_function_name(access)?;
|
||||
}
|
||||
}
|
||||
|
||||
write!(self.out, "(")?;
|
||||
self.write_expr(module, base, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
|
||||
if let Some(Index::Expression(vec_index)) = vector {
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, vec_index, func_ctx)?;
|
||||
|
||||
if let Some(scalar_index) = scalar {
|
||||
write!(self.out, ", ")?;
|
||||
self.write_index(module, scalar_index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?;
|
||||
}
|
||||
}
|
||||
Some((
|
||||
MatrixAccess::Struct { columns, base },
|
||||
Some(Index::Expression(vec_index)),
|
||||
scalar,
|
||||
)) => {
|
||||
// We handle `Store`s to __matCx2 column vectors and scalar elements via
|
||||
// the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
|
||||
|
||||
if let (Some(MatrixData { columns, base }), Some(vec_index)) =
|
||||
(matrix, vector)
|
||||
{
|
||||
if scalar.is_some() {
|
||||
write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
|
||||
} else {
|
||||
@ -2192,21 +2128,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
|
||||
if let Some(scalar_index) = scalar {
|
||||
write!(self.out, ", ")?;
|
||||
match scalar_index {
|
||||
Index::Static(index) => {
|
||||
write!(self.out, "{index}")?;
|
||||
}
|
||||
Index::Expression(index) => {
|
||||
self.write_expr(module, index, func_ctx)?;
|
||||
}
|
||||
}
|
||||
self.write_index(module, scalar_index, func_ctx)?;
|
||||
}
|
||||
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
|
||||
writeln!(self.out, ");")?;
|
||||
} else {
|
||||
}
|
||||
Some((MatrixAccess::Struct { .. }, Some(Index::Static(_)), _))
|
||||
| Some((MatrixAccess::Struct { .. }, None, _))
|
||||
| None => {
|
||||
self.write_expr(module, pointer, func_ctx)?;
|
||||
write!(self.out, " = ")?;
|
||||
|
||||
@ -4423,12 +4355,17 @@ pub(super) fn get_inner_matrix_data(
|
||||
}
|
||||
}
|
||||
|
||||
/// If `base` is an access chain of the form `mat`, `mat[col]`, or `mat[col][row]`,
|
||||
/// returns a tuple of the matrix, the column (vector) index (if present), and
|
||||
/// the row (scalar) index (if present).
|
||||
fn find_matrix_in_access_chain(
|
||||
module: &Module,
|
||||
base: Handle<crate::Expression>,
|
||||
func_ctx: &back::FunctionCtx<'_>,
|
||||
) -> Option<Handle<crate::Expression>> {
|
||||
) -> Option<(Handle<crate::Expression>, Option<Index>, Option<Index>)> {
|
||||
let mut current_base = base;
|
||||
let mut vector = None;
|
||||
let mut scalar = None;
|
||||
loop {
|
||||
let resolved_tr = func_ctx
|
||||
.resolve_type(current_base, &module.types)
|
||||
@ -4436,15 +4373,22 @@ fn find_matrix_in_access_chain(
|
||||
let resolved = resolved_tr.as_ref()?.inner_with(&module.types);
|
||||
|
||||
match *resolved {
|
||||
TypeInner::Matrix { .. } => return Some((current_base, vector, scalar)),
|
||||
TypeInner::Scalar(_) | TypeInner::Vector { .. } => {}
|
||||
TypeInner::Matrix { .. } => return Some(current_base),
|
||||
_ => return None,
|
||||
}
|
||||
|
||||
current_base = match func_ctx.expressions[current_base] {
|
||||
crate::Expression::Access { base, .. } => base,
|
||||
crate::Expression::AccessIndex { base, .. } => base,
|
||||
let index;
|
||||
(current_base, index) = match func_ctx.expressions[current_base] {
|
||||
crate::Expression::Access { base, index } => (base, Index::Expression(index)),
|
||||
crate::Expression::AccessIndex { base, index } => (base, Index::Static(index)),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
match *resolved {
|
||||
TypeInner::Scalar(_) => scalar = Some(index),
|
||||
TypeInner::Vector { .. } => vector = Some(index),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user