mirror of
https://github.com/trekhleb/javascript-algorithms.git
synced 2025-12-08 19:06:00 +00:00
310 lines
7.3 KiB
JavaScript
310 lines
7.3 KiB
JavaScript
/**
|
|
* @typedef {number} Cell
|
|
* @typedef {Cell[][]|Cell[][][]} Matrix
|
|
* @typedef {number[]} Shape
|
|
* @typedef {number[]} CellIndices
|
|
*/
|
|
|
|
/**
|
|
* Gets the matrix's shape.
|
|
*
|
|
* @param {Matrix} m
|
|
* @returns {Shape}
|
|
*/
|
|
export const shape = (m) => {
|
|
const shapes = [];
|
|
let dimension = m;
|
|
while (dimension && Array.isArray(dimension)) {
|
|
shapes.push(dimension.length);
|
|
dimension = (dimension.length && [...dimension][0]) || null;
|
|
}
|
|
return shapes;
|
|
};
|
|
|
|
/**
|
|
* Checks if matrix has a correct type.
|
|
*
|
|
* @param {Matrix} m
|
|
* @throws {Error}
|
|
*/
|
|
const validateType = (m) => {
|
|
if (
|
|
!m
|
|
|| !Array.isArray(m)
|
|
|| !Array.isArray(m[0])
|
|
) {
|
|
throw new Error('Invalid matrix format');
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Checks if matrix is two dimensional.
|
|
*
|
|
* @param {Matrix} m
|
|
* @throws {Error}
|
|
*/
|
|
const validate2D = (m) => {
|
|
validateType(m);
|
|
const aShape = shape(m);
|
|
if (aShape.length !== 2) {
|
|
throw new Error('Matrix is not of 2D shape');
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Validates that matrices are of the same shape.
|
|
*
|
|
* @param {Matrix} a
|
|
* @param {Matrix} b
|
|
* @trows {Error}
|
|
*/
|
|
export const validateSameShape = (a, b) => {
|
|
validateType(a);
|
|
validateType(b);
|
|
|
|
const aShape = shape(a);
|
|
const bShape = shape(b);
|
|
|
|
if (aShape.length !== bShape.length) {
|
|
throw new Error('Matrices have different dimensions');
|
|
}
|
|
|
|
while (aShape.length && bShape.length) {
|
|
if (aShape.pop() !== bShape.pop()) {
|
|
throw new Error('Matrices have different shapes');
|
|
}
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Generates the matrix of specific shape with specific values.
|
|
*
|
|
* @param {Shape} mShape - the shape of the matrix to generate
|
|
* @param {function({CellIndex}): Cell} fill - cell values of a generated matrix.
|
|
* @returns {Matrix}
|
|
*/
|
|
export const generate = (mShape, fill) => {
|
|
/**
|
|
* Generates the matrix recursively.
|
|
*
|
|
* @param {Shape} recShape - the shape of the matrix to generate
|
|
* @param {CellIndices} recIndices
|
|
* @returns {Matrix}
|
|
*/
|
|
const generateRecursively = (recShape, recIndices) => {
|
|
if (recShape.length === 1) {
|
|
return Array(recShape[0])
|
|
.fill(null)
|
|
.map((cellValue, cellIndex) => fill([...recIndices, cellIndex]));
|
|
}
|
|
const m = [];
|
|
for (let i = 0; i < recShape[0]; i += 1) {
|
|
m.push(generateRecursively(recShape.slice(1), [...recIndices, i]));
|
|
}
|
|
return m;
|
|
};
|
|
|
|
return generateRecursively(mShape, []);
|
|
};
|
|
|
|
/**
|
|
* Generates the matrix of zeros of specified shape.
|
|
*
|
|
* @param {Shape} mShape - shape of the matrix
|
|
* @returns {Matrix}
|
|
*/
|
|
export const zeros = (mShape) => {
|
|
return generate(mShape, () => 0);
|
|
};
|
|
|
|
/**
|
|
* @param {Matrix} a
|
|
* @param {Matrix} b
|
|
* @return Matrix
|
|
* @throws {Error}
|
|
*/
|
|
export const dot = (a, b) => {
|
|
// Validate inputs.
|
|
validate2D(a);
|
|
validate2D(b);
|
|
|
|
// Check dimensions.
|
|
const aShape = shape(a);
|
|
const bShape = shape(b);
|
|
if (aShape[1] !== bShape[0]) {
|
|
throw new Error('Matrices have incompatible shape for multiplication');
|
|
}
|
|
|
|
// Perform matrix multiplication.
|
|
const outputShape = [aShape[0], bShape[1]];
|
|
const c = zeros(outputShape);
|
|
|
|
for (let bCol = 0; bCol < b[0].length; bCol += 1) {
|
|
for (let aRow = 0; aRow < a.length; aRow += 1) {
|
|
let cellSum = 0;
|
|
for (let aCol = 0; aCol < a[aRow].length; aCol += 1) {
|
|
cellSum += a[aRow][aCol] * b[aCol][bCol];
|
|
}
|
|
c[aRow][bCol] = cellSum;
|
|
}
|
|
}
|
|
|
|
return c;
|
|
};
|
|
|
|
/**
|
|
* Transposes the matrix.
|
|
*
|
|
* @param {Matrix} m
|
|
* @returns Matrix
|
|
* @throws {Error}
|
|
*/
|
|
export const t = (m) => {
|
|
validate2D(m);
|
|
const mShape = shape(m);
|
|
const transposed = zeros([mShape[1], mShape[0]]);
|
|
for (let row = 0; row < m.length; row += 1) {
|
|
for (let col = 0; col < m[0].length; col += 1) {
|
|
transposed[col][row] = m[row][col];
|
|
}
|
|
}
|
|
return transposed;
|
|
};
|
|
|
|
/**
|
|
* Traverses the matrix.
|
|
*
|
|
* @param {Matrix} m
|
|
* @param {function(indices: CellIndices, c: Cell)} visit
|
|
*/
|
|
export const walk = (m, visit) => {
|
|
/**
|
|
* Traverses the matrix recursively.
|
|
*
|
|
* @param {Matrix} recM
|
|
* @param {CellIndices} cellIndices
|
|
* @return {Matrix}
|
|
*/
|
|
const recWalk = (recM, cellIndices) => {
|
|
const recMShape = shape(recM);
|
|
|
|
if (recMShape.length === 1) {
|
|
for (let i = 0; i < recM.length; i += 1) {
|
|
visit([...cellIndices, i], recM[i]);
|
|
}
|
|
}
|
|
for (let i = 0; i < recM.length; i += 1) {
|
|
recWalk(recM[i], [...cellIndices, i]);
|
|
}
|
|
};
|
|
|
|
recWalk(m, []);
|
|
};
|
|
|
|
/**
|
|
* Gets the matrix cell value at specific index.
|
|
*
|
|
* @param {Matrix} m - Matrix that contains the cell that needs to be updated
|
|
* @param {CellIndices} cellIndices - Array of cell indices
|
|
* @return {Cell}
|
|
*/
|
|
export const getCellAtIndex = (m, cellIndices) => {
|
|
// We start from the row at specific index.
|
|
let cell = m[cellIndices[0]];
|
|
// Going deeper into the next dimensions but not to the last one to preserve
|
|
// the pointer to the last dimension array.
|
|
for (let dimIdx = 1; dimIdx < cellIndices.length - 1; dimIdx += 1) {
|
|
cell = cell[cellIndices[dimIdx]];
|
|
}
|
|
// At this moment the cell variable points to the array at the last needed dimension.
|
|
return cell[cellIndices[cellIndices.length - 1]];
|
|
};
|
|
|
|
/**
|
|
* Update the matrix cell at specific index.
|
|
*
|
|
* @param {Matrix} m - Matrix that contains the cell that needs to be updated
|
|
* @param {CellIndices} cellIndices - Array of cell indices
|
|
* @param {Cell} cellValue - New cell value
|
|
*/
|
|
export const updateCellAtIndex = (m, cellIndices, cellValue) => {
|
|
// We start from the row at specific index.
|
|
let cell = m[cellIndices[0]];
|
|
// Going deeper into the next dimensions but not to the last one to preserve
|
|
// the pointer to the last dimension array.
|
|
for (let dimIdx = 1; dimIdx < cellIndices.length - 1; dimIdx += 1) {
|
|
cell = cell[cellIndices[dimIdx]];
|
|
}
|
|
// At this moment the cell variable points to the array at the last needed dimension.
|
|
cell[cellIndices[cellIndices.length - 1]] = cellValue;
|
|
};
|
|
|
|
/**
|
|
* Adds two matrices element-wise.
|
|
*
|
|
* @param {Matrix} a
|
|
* @param {Matrix} b
|
|
* @return {Matrix}
|
|
*/
|
|
export const add = (a, b) => {
|
|
validateSameShape(a, b);
|
|
const result = zeros(shape(a));
|
|
|
|
walk(a, (cellIndices, cellValue) => {
|
|
updateCellAtIndex(result, cellIndices, cellValue);
|
|
});
|
|
|
|
walk(b, (cellIndices, cellValue) => {
|
|
const currentCellValue = getCellAtIndex(result, cellIndices);
|
|
updateCellAtIndex(result, cellIndices, currentCellValue + cellValue);
|
|
});
|
|
|
|
return result;
|
|
};
|
|
|
|
/**
|
|
* Multiplies two matrices element-wise.
|
|
*
|
|
* @param {Matrix} a
|
|
* @param {Matrix} b
|
|
* @return {Matrix}
|
|
*/
|
|
export const mul = (a, b) => {
|
|
validateSameShape(a, b);
|
|
const result = zeros(shape(a));
|
|
|
|
walk(a, (cellIndices, cellValue) => {
|
|
updateCellAtIndex(result, cellIndices, cellValue);
|
|
});
|
|
|
|
walk(b, (cellIndices, cellValue) => {
|
|
const currentCellValue = getCellAtIndex(result, cellIndices);
|
|
updateCellAtIndex(result, cellIndices, currentCellValue * cellValue);
|
|
});
|
|
|
|
return result;
|
|
};
|
|
|
|
/**
|
|
* Subtract two matrices element-wise.
|
|
*
|
|
* @param {Matrix} a
|
|
* @param {Matrix} b
|
|
* @return {Matrix}
|
|
*/
|
|
export const sub = (a, b) => {
|
|
validateSameShape(a, b);
|
|
const result = zeros(shape(a));
|
|
|
|
walk(a, (cellIndices, cellValue) => {
|
|
updateCellAtIndex(result, cellIndices, cellValue);
|
|
});
|
|
|
|
walk(b, (cellIndices, cellValue) => {
|
|
const currentCellValue = getCellAtIndex(result, cellIndices);
|
|
updateCellAtIndex(result, cellIndices, currentCellValue - cellValue);
|
|
});
|
|
|
|
return result;
|
|
};
|