diff --git a/index.js b/index.js index 3a343d655..d20e2d897 100644 --- a/index.js +++ b/index.js @@ -142,7 +142,7 @@ function create (config) { math.import(require('./lib/function/matrix/size')); math.import(require('./lib/function/matrix/squeeze')); require('./lib/function/matrix/subset')(math, _config); - require('./lib/function/matrix/trace')(math, _config); + math.import(require('./lib/function/matrix/trace')); math.import(require('./lib/function/matrix/transpose')); require('./lib/function/matrix/zeros')(math, _config); diff --git a/lib/function/matrix/trace.js b/lib/function/matrix/trace.js index 377e0feb6..82ed19b52 100644 --- a/lib/function/matrix/trace.js +++ b/lib/function/matrix/trace.js @@ -1,13 +1,12 @@ 'use strict'; -module.exports = function (math) { - var util = require('../../util/index'), +var clone = require('../../util/object').clone; +var format = require('../../util/string').format; - Matrix = math.type.Matrix, - - object = util.object, - array = util.array, - string = util.string; +function factory (type, config, load, typed) { + + var matrix = load(require('../construction/matrix')); + var add = load(require('../arithmetic/add')); /** * Calculate the trace of a matrix: the sum of the elements on the main @@ -36,61 +35,160 @@ module.exports = function (math) { * * @return {Number} The trace of `x` */ - math.trace = function trace (x) { - if (arguments.length != 1) { - throw new math.error.ArgumentsError('trace', arguments.length, 1); - } - - // check x is a matrix - if (x instanceof Matrix) { - // use optimized operation for the matrix storage format - return x.trace(); - } + var trace = typed('trace', { - // size - var size; - if (x instanceof Array) { - // calculate sixe - size = array.size(x); - } - else { - // a scalar - size = []; - } + 'Array': function (x) { + // use dense matrix implementation + return trace(matrix(x)); + }, + 'Matrix': function (x) { + // result + var c; + // process storage format + switch (x.storage()) { + case 'dense': + c = _denseTrace(x); + break; + case 'ccs': + c = _ccsTrace(x); + break; + case 'crs': + c = _crsTrace(x); + break; + } + return c; + }, + + 'any': function (x) { + return clone(x); + } + }); + + var _denseTrace = function (m) { + // matrix size & data + var size = m._size; + var data = m._data; + + // process dimensions switch (size.length) { - case 0: - // scalar - return object.clone(x); - case 1: // vector if (size[0] == 1) { - // clone value - return object.clone(x[0]); + // return data[0] + return clone(data[0]); } - throw new RangeError('Array must be square (size: ' + string.format(size) + ')'); - + throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); case 2: - // two dimensional array + // two dimensional var rows = size[0]; var cols = size[1]; - // check array is square - if (rows == cols) { - // diagonal sum + if (rows === cols) { + // calulate sum var sum = 0; // loop diagonal - for (var i = 0; i < x.length; i++) { - // sum - sum = math.add(sum, x[i][i]); - } + for (var i = 0; i < rows; i++) + sum = add(sum, data[i][i]); + // return trace return sum; } - throw new RangeError('Array must be square (size: ' + string.format(size) + ')'); - + throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); default: - // multi dimensional array - throw new RangeError('Matrix must be two dimensional (size: ' + string.format(size) + ')'); + // multi dimensional + throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')'); } }; -}; + + var _ccsTrace = function (m) { + // matrix arrays + var values = m._values; + var index = m._index; + var ptr = m._ptr; + var size = m._size; + // check dimensions + var rows = size[0]; + var columns = size[1]; + // matrix must be square + if (rows === columns) { + // calulate sum + var sum = 0; + // check we have data (avoid looping columns) + if (values.length > 0) { + // loop columns + for (var j = 0; j < columns; j++) { + // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1] + var k0 = ptr[j]; + var k1 = ptr[j + 1]; + // loop k within [k0, k1[ + for (var k = k0; k < k1; k++) { + // row index + var i = index[k]; + // check row + if (i === j) { + // accumulate value + sum = add(sum, values[k]); + // exit loop + break; + } + if (i > j) { + // exit loop, no value on the diagonal for column j + break; + } + } + } + } + // return trace + return sum; + } + throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); + }; + + var _crsTrace = function (m) { + // matrix arrays + var values = m._values; + var index = m._index; + var ptr = m._ptr; + var size = m._size; + // check dimensions + var rows = size[0]; + var columns = size[1]; + // matrix must be square + if (rows === columns) { + // calulate sum + var sum = 0; + // check we have data (avoid looping rows) + if (values.length > 0) { + // loop rows + for (var i = 0; i < rows; i++) { + // k0 <= k < k1 where k0 = _ptr[i] && k1 = _ptr[i+1] + var k0 = ptr[i]; + var k1 = ptr[i + 1]; + // loop k within [k0, k1[ + for (var k = k0; k < k1; k++) { + // column index + var j = index[k]; + // check row + if (i === j) { + // accumulate value + sum = add(sum, values[k]); + // exit loop + break; + } + if (j > i) { + // exit loop, no value on the diagonal for column j + break; + } + } + } + } + // return trace + return sum; + } + throw new RangeError('Matrix must be square (size: ' + format(size) + ')'); + }; + + return trace; +} + +exports.name = 'trace'; +exports.factory = factory; diff --git a/lib/type/matrix/CcsMatrix.js b/lib/type/matrix/CcsMatrix.js index 6fad5ddd3..d9371ad56 100644 --- a/lib/type/matrix/CcsMatrix.js +++ b/lib/type/matrix/CcsMatrix.js @@ -14,9 +14,8 @@ var isInteger = util.number.isInteger; var validateIndex = array.validateIndex; -function factory (type, config, load, typed) { +function factory (type, config, load) { - var add = load(require('../../function/arithmetic/add')); var multiply = load(require('../../function/arithmetic/multiply')); var equal = load(require('../../function/relational/equal')); @@ -1092,57 +1091,6 @@ function factory (type, config, load, typed) { }); }; - /** - * Calculate the trace of a matrix: the sum of the elements on the main - * diagonal of a square matrix. - * - * See also: - * - * diagonal - * - * @returns {Number} The matrix trace - */ - CcsMatrix.prototype.trace = function () { - // size - var size = this._size; - // check dimensions - var rows = size[0]; - var columns = size[1]; - // matrix must be square - if (rows === columns) { - // calulate sum - var sum = 0; - // check we have data (avoid looping columns) - if (this._values.length > 0) { - // loop columns - for (var j = 0; j < columns; j++) { - // k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1] - var k0 = this._ptr[j]; - var k1 = this._ptr[j + 1]; - // loop k within [k0, k1[ - for (var k = k0; k < k1; k++) { - // row index - var i = this._index[k]; - // check row - if (i === j) { - // accumulate value - sum = add(sum, this._values[k]); - // exit loop - break; - } - if (i > j) { - // exit loop, no value on the diagonal for column j - break; - } - } - } - } - // return trace - return sum; - } - throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')'); - }; - /** * Multiply the matrix values times the argument. * diff --git a/lib/type/matrix/CrsMatrix.js b/lib/type/matrix/CrsMatrix.js index cb5104d50..254d0b274 100644 --- a/lib/type/matrix/CrsMatrix.js +++ b/lib/type/matrix/CrsMatrix.js @@ -14,7 +14,7 @@ var isInteger = util.number.isInteger; var validateIndex = array.validateIndex; -function factory (type, config, load, typed) { +function factory (type, config, load) { var add = load(require('../../function/arithmetic/add')); var multiply = load(require('../../function/arithmetic/multiply')); @@ -1085,57 +1085,6 @@ function factory (type, config, load, typed) { size: [rows, columns] }); }; - - /** - * Calculate the trace of a matrix: the sum of the elements on the main - * diagonal of a square matrix. - * - * See also: - * - * diagonal - * - * @returns {Number} The matrix trace - */ - CrsMatrix.prototype.trace = function () { - // size - var size = this._size; - // check dimensions - var rows = size[0]; - var columns = size[1]; - // matrix must be square - if (rows === columns) { - // calulate sum - var sum = 0; - // check we have data (avoid looping rows) - if (this._values.length > 0) { - // loop rows - for (var i = 0; i < rows; i++) { - // k0 <= k < k1 where k0 = _ptr[i] && k1 = _ptr[i+1] - var k0 = this._ptr[i]; - var k1 = this._ptr[i + 1]; - // loop k within [k0, k1[ - for (var k = k0; k < k1; k++) { - // column index - var j = this._index[k]; - // check row - if (i === j) { - // accumulate value - sum = add(sum, this._values[k]); - // exit loop - break; - } - if (j > i) { - // exit loop, no value on the diagonal for column j - break; - } - } - } - } - // return trace - return sum; - } - throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')'); - }; /** * Multiply the matrix values times the argument. diff --git a/lib/type/matrix/DenseMatrix.js b/lib/type/matrix/DenseMatrix.js index de5838453..069dc1a48 100644 --- a/lib/type/matrix/DenseMatrix.js +++ b/lib/type/matrix/DenseMatrix.js @@ -13,13 +13,10 @@ var isInteger = util.number.isInteger; var validateIndex = array.validateIndex; -function factory (type, config, load, typed) { +function factory (type, config, load) { var add = load(require('../../function/arithmetic/add')); - var divideScalar = load(require('../../function/arithmetic/divideScalar')); var multiply = load(require('../../function/arithmetic/multiply')); - var subtract = load(require('../../function/arithmetic/subtract')); - var equal = load(require('../../function/relational/equal')); var Index = type.Index; var BigNumber = type.BigNumber; @@ -748,49 +745,6 @@ function factory (type, config, load, typed) { size: [rows, columns] }); }; - - /** - * Calculate the trace of a matrix: the sum of the elements on the main - * diagonal of a square matrix. - * - * See also: - * - * diagonal - * - * @returns {Number} The matrix trace - */ - DenseMatrix.prototype.trace = function () { - // size & data - var size = this._size; - var data = this._data; - // check dimensions - switch (size.length) { - case 1: - // vector - if (size[0] == 1) { - // return data[0] - return object.clone(data[0]); - } - throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')'); - case 2: - // two dimensional array - var rows = size[0]; - var cols = size[1]; - if (rows === cols) { - // calulate sum - var sum = 0; - // loop diagonal - for (var i = 0; i < rows; i++) - sum = add(sum, data[i][i]); - // return trace - return sum; - } - throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')'); - default: - // multi dimensional array - throw new RangeError('Matrix must be two dimensional (size: ' + string.format(size) + ')'); - } - }; /** * Generate a matrix from a JSON object diff --git a/test/function/matrix/trace.test.js b/test/function/matrix/trace.test.js index 83bd82315..395a24d37 100644 --- a/test/function/matrix/trace.test.js +++ b/test/function/matrix/trace.test.js @@ -210,4 +210,108 @@ describe('trace', function() { assert.equal(expression.toTex(), '\\mathrm{tr}\\left({\\begin{bmatrix}1&2\\\\3&4\\\\\\end{bmatrix}}\\right)'); }); + describe('DenseMatrix', function () { + + it('should calculate trace on a square matrix', function() { + var m = math.matrix([ + [1, 2], + [4, -2] + ]); + assert.equal(math.trace(m), -1); + + m = math.matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0] + ]); + assert.equal(math.trace(m), 0); + + m = math.matrix([ + [1, 0, 0, 0], + [0, 0, 2, 0], + [1, 0, 0, 0], + [0, 0, 1, 9] + ]); + assert.equal(math.trace(m), 10); + }); + + it('should throw an error for invalid matrix', function() { + var m = math.matrix([ + [1, 2, 3], + [4, 5, 6] + ]); + assert.throws(function () { math.trace(m); }); + }); + }); + + describe('CcsMatrix', function () { + + it('should calculate trace on a square matrix', function() { + var m = math.matrix([ + [1, 2], + [4, -2] + ], 'ccs'); + assert.equal(math.trace(m), -1); + + m = math.matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0] + ], 'ccs'); + assert.equal(math.trace(m), 0); + + m = math.matrix([ + [1, 0, 0, 0], + [0, 0, 2, 0], + [1, 0, 0, 0], + [0, 0, 1, 9] + ], 'ccs'); + assert.equal(math.trace(m), 10); + }); + + it('should throw an error for invalid matrix', function() { + var m = math.matrix([ + [1, 2, 3], + [4, 5, 6] + ], 'ccs'); + assert.throws(function () { math.trace(m); }); + }); + }); + + describe('CrsMatrix', function () { + + it('should calculate trace on a square matrix', function() { + var m = math.matrix([ + [1, 2], + [4, -2] + ], 'crs'); + assert.equal(math.trace(m), -1); + + m = math.matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0] + ], 'crs'); + assert.equal(math.trace(m), 0); + + m = math.matrix([ + [1, 0, 0, 0], + [0, 0, 2, 0], + [1, 0, 0, 0], + [0, 0, 1, 9] + ], 'crs'); + assert.equal(math.trace(m), 10); + }); + + it('should throw an error for invalid matrix', function() { + var m = math.matrix([ + [1, 2, 3], + [4, 5, 6] + ], 'crs'); + assert.throws(function () { math.trace(m); }); + }); + }); }); diff --git a/test/type/matrix/CcsMatrix.test.js b/test/type/matrix/CcsMatrix.test.js index c92c3e363..d5ac3c50d 100644 --- a/test/type/matrix/CcsMatrix.test.js +++ b/test/type/matrix/CcsMatrix.test.js @@ -1758,41 +1758,6 @@ describe('CcsMatrix', function() { }); }); - describe('trace', function () { - - it('should calculate trace on a square matrix', function() { - var m = new CcsMatrix([ - [1, 2], - [4, -2] - ]); - assert.equal(m.trace(), -1); - - m = new CcsMatrix([ - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0] - ]); - assert.equal(m.trace(), 0); - - m = new CcsMatrix([ - [1, 0, 0, 0], - [0, 0, 2, 0], - [1, 0, 0, 0], - [0, 0, 1, 9] - ]); - assert.equal(m.trace(), 10); - }); - - it('should throw an error for invalid matrix', function() { - var m = new CcsMatrix([ - [1, 2, 3], - [4, 5, 6] - ]); - assert.throws(function () { m.trace(); }); - }); - }); - describe('multiply', function () { it('should multiply matrix x scalar', function() { diff --git a/test/type/matrix/CrsMatrix.test.js b/test/type/matrix/CrsMatrix.test.js index 866a5b824..dde8d17e7 100644 --- a/test/type/matrix/CrsMatrix.test.js +++ b/test/type/matrix/CrsMatrix.test.js @@ -1740,41 +1740,6 @@ describe('CrsMatrix', function() { }); }); - describe('trace', function () { - - it('should calculate trace on a square matrix', function() { - var m = new CrsMatrix([ - [1, 2], - [4, -2] - ]); - assert.equal(m.trace(), -1); - - m = new CrsMatrix([ - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0] - ]); - assert.equal(m.trace(), 0); - - m = new CrsMatrix([ - [1, 0, 0, 0], - [0, 0, 2, 0], - [1, 0, 0, 0], - [0, 0, 1, 9] - ]); - assert.equal(m.trace(), 10); - }); - - it('should throw an error for invalid matrix', function() { - var m = new CrsMatrix([ - [1, 2, 3], - [4, 5, 6] - ]); - assert.throws(function () { m.trace(); }); - }); - }); - describe('multiply', function () { it('should multiply matrix x scalar', function() { diff --git a/test/type/matrix/DenseMatrix.test.js b/test/type/matrix/DenseMatrix.test.js index e10b7b8cd..b9f966b6f 100644 --- a/test/type/matrix/DenseMatrix.test.js +++ b/test/type/matrix/DenseMatrix.test.js @@ -1086,41 +1086,6 @@ describe('DenseMatrix', function() { }); }); - describe('trace', function () { - - it('should calculate trace on a square matrix', function() { - var m = new DenseMatrix([ - [1, 2], - [4, -2] - ]); - assert.equal(m.trace(), -1); - - m = new DenseMatrix([ - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0] - ]); - assert.equal(m.trace(), 0); - - m = new DenseMatrix([ - [1, 0, 0, 0], - [0, 0, 2, 0], - [1, 0, 0, 0], - [0, 0, 1, 9] - ]); - assert.equal(m.trace(), 10); - }); - - it('should throw an error for invalid matrix', function() { - var m = new DenseMatrix([ - [1, 2, 3], - [4, 5, 6] - ]); - assert.throws(function () { m.trace(); }); - }); - }); - describe('multiply', function () { it('should multiply matrix x scalar', function() {