148 lines
3.5 KiB
JavaScript

'use strict';
var clone = require('../../utils/object').clone;
var format = require('../../utils/string').format;
function factory (type, config, load, typed) {
var matrix = load(require('../../type/matrix/function/matrix'));
var add = load(require('../arithmetic/add'));
/**
* Calculate the trace of a matrix: the sum of the elements on the main
* diagonal of a square matrix.
*
* Syntax:
*
* math.trace(x)
*
* Examples:
*
* math.trace([[1, 2], [3, 4]]); // returns 5
*
* var A = [
* [1, 2, 3],
* [-1, 2, 3],
* [2, 0, 3]
* ]
* math.trace(A); // returns 6
*
* See also:
*
* diag
*
* @param {Array | Matrix} x A matrix
*
* @return {number} The trace of `x`
*/
var trace = typed('trace', {
'Array': function (x) {
// use dense matrix implementation
return trace(matrix(x));
},
'Matrix': function (x) {
// result
var c;
// process storage format
switch (x.storage()) {
case 'dense':
c = _denseTrace(x);
break;
case 'sparse':
c = _sparseTrace(x);
break;
}
return c;
},
'any': clone
});
var _denseTrace = function (m) {
// matrix size & data
var size = m._size;
var data = m._data;
// process dimensions
switch (size.length) {
case 1:
// vector
if (size[0] == 1) {
// return data[0]
return clone(data[0]);
}
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
case 2:
// two dimensional
var rows = size[0];
var cols = size[1];
if (rows === cols) {
// calulate sum
var sum = 0;
// loop diagonal
for (var i = 0; i < rows; i++)
sum = add(sum, data[i][i]);
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
default:
// multi dimensional
throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')');
}
};
var _sparseTrace = function (m) {
// matrix arrays
var values = m._values;
var index = m._index;
var ptr = m._ptr;
var size = m._size;
// check dimensions
var rows = size[0];
var columns = size[1];
// matrix must be square
if (rows === columns) {
// calulate sum
var sum = 0;
// check we have data (avoid looping columns)
if (values.length > 0) {
// loop columns
for (var j = 0; j < columns; j++) {
// k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
var k0 = ptr[j];
var k1 = ptr[j + 1];
// loop k within [k0, k1[
for (var k = k0; k < k1; k++) {
// row index
var i = index[k];
// check row
if (i === j) {
// accumulate value
sum = add(sum, values[k]);
// exit loop
break;
}
if (i > j) {
// exit loop, no value on the diagonal for column j
break;
}
}
}
}
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
};
trace.toTex = {1: '\\mathrm{tr}\\left(${args[0]}\\right)'};
return trace;
}
exports.name = 'trace';
exports.factory = factory;