mirror of
https://github.com/josdejong/mathjs.git
synced 2026-01-18 14:59:29 +00:00
trace() refactoring
This commit is contained in:
parent
c5007d4bf2
commit
ae0b2cf4e9
2
index.js
2
index.js
@ -142,7 +142,7 @@ function create (config) {
|
||||
math.import(require('./lib/function/matrix/size'));
|
||||
math.import(require('./lib/function/matrix/squeeze'));
|
||||
require('./lib/function/matrix/subset')(math, _config);
|
||||
require('./lib/function/matrix/trace')(math, _config);
|
||||
math.import(require('./lib/function/matrix/trace'));
|
||||
math.import(require('./lib/function/matrix/transpose'));
|
||||
require('./lib/function/matrix/zeros')(math, _config);
|
||||
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
'use strict';
|
||||
|
||||
module.exports = function (math) {
|
||||
var util = require('../../util/index'),
|
||||
var clone = require('../../util/object').clone;
|
||||
var format = require('../../util/string').format;
|
||||
|
||||
Matrix = math.type.Matrix,
|
||||
|
||||
object = util.object,
|
||||
array = util.array,
|
||||
string = util.string;
|
||||
function factory (type, config, load, typed) {
|
||||
|
||||
var matrix = load(require('../construction/matrix'));
|
||||
var add = load(require('../arithmetic/add'));
|
||||
|
||||
/**
|
||||
* Calculate the trace of a matrix: the sum of the elements on the main
|
||||
@ -36,61 +35,160 @@ module.exports = function (math) {
|
||||
*
|
||||
* @return {Number} The trace of `x`
|
||||
*/
|
||||
math.trace = function trace (x) {
|
||||
if (arguments.length != 1) {
|
||||
throw new math.error.ArgumentsError('trace', arguments.length, 1);
|
||||
}
|
||||
|
||||
// check x is a matrix
|
||||
if (x instanceof Matrix) {
|
||||
// use optimized operation for the matrix storage format
|
||||
return x.trace();
|
||||
}
|
||||
var trace = typed('trace', {
|
||||
|
||||
// size
|
||||
var size;
|
||||
if (x instanceof Array) {
|
||||
// calculate sixe
|
||||
size = array.size(x);
|
||||
}
|
||||
else {
|
||||
// a scalar
|
||||
size = [];
|
||||
}
|
||||
'Array': function (x) {
|
||||
// use dense matrix implementation
|
||||
return trace(matrix(x));
|
||||
},
|
||||
|
||||
'Matrix': function (x) {
|
||||
// result
|
||||
var c;
|
||||
// process storage format
|
||||
switch (x.storage()) {
|
||||
case 'dense':
|
||||
c = _denseTrace(x);
|
||||
break;
|
||||
case 'ccs':
|
||||
c = _ccsTrace(x);
|
||||
break;
|
||||
case 'crs':
|
||||
c = _crsTrace(x);
|
||||
break;
|
||||
}
|
||||
return c;
|
||||
},
|
||||
|
||||
'any': function (x) {
|
||||
return clone(x);
|
||||
}
|
||||
});
|
||||
|
||||
var _denseTrace = function (m) {
|
||||
// matrix size & data
|
||||
var size = m._size;
|
||||
var data = m._data;
|
||||
|
||||
// process dimensions
|
||||
switch (size.length) {
|
||||
case 0:
|
||||
// scalar
|
||||
return object.clone(x);
|
||||
|
||||
case 1:
|
||||
// vector
|
||||
if (size[0] == 1) {
|
||||
// clone value
|
||||
return object.clone(x[0]);
|
||||
// return data[0]
|
||||
return clone(data[0]);
|
||||
}
|
||||
throw new RangeError('Array must be square (size: ' + string.format(size) + ')');
|
||||
|
||||
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
|
||||
case 2:
|
||||
// two dimensional array
|
||||
// two dimensional
|
||||
var rows = size[0];
|
||||
var cols = size[1];
|
||||
// check array is square
|
||||
if (rows == cols) {
|
||||
// diagonal sum
|
||||
if (rows === cols) {
|
||||
// calulate sum
|
||||
var sum = 0;
|
||||
// loop diagonal
|
||||
for (var i = 0; i < x.length; i++) {
|
||||
// sum
|
||||
sum = math.add(sum, x[i][i]);
|
||||
}
|
||||
for (var i = 0; i < rows; i++)
|
||||
sum = add(sum, data[i][i]);
|
||||
// return trace
|
||||
return sum;
|
||||
}
|
||||
throw new RangeError('Array must be square (size: ' + string.format(size) + ')');
|
||||
|
||||
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
|
||||
default:
|
||||
// multi dimensional array
|
||||
throw new RangeError('Matrix must be two dimensional (size: ' + string.format(size) + ')');
|
||||
// multi dimensional
|
||||
throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')');
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
var _ccsTrace = function (m) {
|
||||
// matrix arrays
|
||||
var values = m._values;
|
||||
var index = m._index;
|
||||
var ptr = m._ptr;
|
||||
var size = m._size;
|
||||
// check dimensions
|
||||
var rows = size[0];
|
||||
var columns = size[1];
|
||||
// matrix must be square
|
||||
if (rows === columns) {
|
||||
// calulate sum
|
||||
var sum = 0;
|
||||
// check we have data (avoid looping columns)
|
||||
if (values.length > 0) {
|
||||
// loop columns
|
||||
for (var j = 0; j < columns; j++) {
|
||||
// k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
|
||||
var k0 = ptr[j];
|
||||
var k1 = ptr[j + 1];
|
||||
// loop k within [k0, k1[
|
||||
for (var k = k0; k < k1; k++) {
|
||||
// row index
|
||||
var i = index[k];
|
||||
// check row
|
||||
if (i === j) {
|
||||
// accumulate value
|
||||
sum = add(sum, values[k]);
|
||||
// exit loop
|
||||
break;
|
||||
}
|
||||
if (i > j) {
|
||||
// exit loop, no value on the diagonal for column j
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// return trace
|
||||
return sum;
|
||||
}
|
||||
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
|
||||
};
|
||||
|
||||
var _crsTrace = function (m) {
|
||||
// matrix arrays
|
||||
var values = m._values;
|
||||
var index = m._index;
|
||||
var ptr = m._ptr;
|
||||
var size = m._size;
|
||||
// check dimensions
|
||||
var rows = size[0];
|
||||
var columns = size[1];
|
||||
// matrix must be square
|
||||
if (rows === columns) {
|
||||
// calulate sum
|
||||
var sum = 0;
|
||||
// check we have data (avoid looping rows)
|
||||
if (values.length > 0) {
|
||||
// loop rows
|
||||
for (var i = 0; i < rows; i++) {
|
||||
// k0 <= k < k1 where k0 = _ptr[i] && k1 = _ptr[i+1]
|
||||
var k0 = ptr[i];
|
||||
var k1 = ptr[i + 1];
|
||||
// loop k within [k0, k1[
|
||||
for (var k = k0; k < k1; k++) {
|
||||
// column index
|
||||
var j = index[k];
|
||||
// check row
|
||||
if (i === j) {
|
||||
// accumulate value
|
||||
sum = add(sum, values[k]);
|
||||
// exit loop
|
||||
break;
|
||||
}
|
||||
if (j > i) {
|
||||
// exit loop, no value on the diagonal for column j
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// return trace
|
||||
return sum;
|
||||
}
|
||||
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
|
||||
};
|
||||
|
||||
return trace;
|
||||
}
|
||||
|
||||
exports.name = 'trace';
|
||||
exports.factory = factory;
|
||||
|
||||
@ -14,9 +14,8 @@ var isInteger = util.number.isInteger;
|
||||
|
||||
var validateIndex = array.validateIndex;
|
||||
|
||||
function factory (type, config, load, typed) {
|
||||
function factory (type, config, load) {
|
||||
|
||||
var add = load(require('../../function/arithmetic/add'));
|
||||
var multiply = load(require('../../function/arithmetic/multiply'));
|
||||
var equal = load(require('../../function/relational/equal'));
|
||||
|
||||
@ -1092,57 +1091,6 @@ function factory (type, config, load, typed) {
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculate the trace of a matrix: the sum of the elements on the main
|
||||
* diagonal of a square matrix.
|
||||
*
|
||||
* See also:
|
||||
*
|
||||
* diagonal
|
||||
*
|
||||
* @returns {Number} The matrix trace
|
||||
*/
|
||||
CcsMatrix.prototype.trace = function () {
|
||||
// size
|
||||
var size = this._size;
|
||||
// check dimensions
|
||||
var rows = size[0];
|
||||
var columns = size[1];
|
||||
// matrix must be square
|
||||
if (rows === columns) {
|
||||
// calulate sum
|
||||
var sum = 0;
|
||||
// check we have data (avoid looping columns)
|
||||
if (this._values.length > 0) {
|
||||
// loop columns
|
||||
for (var j = 0; j < columns; j++) {
|
||||
// k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
|
||||
var k0 = this._ptr[j];
|
||||
var k1 = this._ptr[j + 1];
|
||||
// loop k within [k0, k1[
|
||||
for (var k = k0; k < k1; k++) {
|
||||
// row index
|
||||
var i = this._index[k];
|
||||
// check row
|
||||
if (i === j) {
|
||||
// accumulate value
|
||||
sum = add(sum, this._values[k]);
|
||||
// exit loop
|
||||
break;
|
||||
}
|
||||
if (i > j) {
|
||||
// exit loop, no value on the diagonal for column j
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// return trace
|
||||
return sum;
|
||||
}
|
||||
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
|
||||
};
|
||||
|
||||
/**
|
||||
* Multiply the matrix values times the argument.
|
||||
*
|
||||
|
||||
@ -14,7 +14,7 @@ var isInteger = util.number.isInteger;
|
||||
|
||||
var validateIndex = array.validateIndex;
|
||||
|
||||
function factory (type, config, load, typed) {
|
||||
function factory (type, config, load) {
|
||||
|
||||
var add = load(require('../../function/arithmetic/add'));
|
||||
var multiply = load(require('../../function/arithmetic/multiply'));
|
||||
@ -1085,57 +1085,6 @@ function factory (type, config, load, typed) {
|
||||
size: [rows, columns]
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculate the trace of a matrix: the sum of the elements on the main
|
||||
* diagonal of a square matrix.
|
||||
*
|
||||
* See also:
|
||||
*
|
||||
* diagonal
|
||||
*
|
||||
* @returns {Number} The matrix trace
|
||||
*/
|
||||
CrsMatrix.prototype.trace = function () {
|
||||
// size
|
||||
var size = this._size;
|
||||
// check dimensions
|
||||
var rows = size[0];
|
||||
var columns = size[1];
|
||||
// matrix must be square
|
||||
if (rows === columns) {
|
||||
// calulate sum
|
||||
var sum = 0;
|
||||
// check we have data (avoid looping rows)
|
||||
if (this._values.length > 0) {
|
||||
// loop rows
|
||||
for (var i = 0; i < rows; i++) {
|
||||
// k0 <= k < k1 where k0 = _ptr[i] && k1 = _ptr[i+1]
|
||||
var k0 = this._ptr[i];
|
||||
var k1 = this._ptr[i + 1];
|
||||
// loop k within [k0, k1[
|
||||
for (var k = k0; k < k1; k++) {
|
||||
// column index
|
||||
var j = this._index[k];
|
||||
// check row
|
||||
if (i === j) {
|
||||
// accumulate value
|
||||
sum = add(sum, this._values[k]);
|
||||
// exit loop
|
||||
break;
|
||||
}
|
||||
if (j > i) {
|
||||
// exit loop, no value on the diagonal for column j
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// return trace
|
||||
return sum;
|
||||
}
|
||||
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
|
||||
};
|
||||
|
||||
/**
|
||||
* Multiply the matrix values times the argument.
|
||||
|
||||
@ -13,13 +13,10 @@ var isInteger = util.number.isInteger;
|
||||
|
||||
var validateIndex = array.validateIndex;
|
||||
|
||||
function factory (type, config, load, typed) {
|
||||
function factory (type, config, load) {
|
||||
|
||||
var add = load(require('../../function/arithmetic/add'));
|
||||
var divideScalar = load(require('../../function/arithmetic/divideScalar'));
|
||||
var multiply = load(require('../../function/arithmetic/multiply'));
|
||||
var subtract = load(require('../../function/arithmetic/subtract'));
|
||||
var equal = load(require('../../function/relational/equal'));
|
||||
|
||||
var Index = type.Index;
|
||||
var BigNumber = type.BigNumber;
|
||||
@ -748,49 +745,6 @@ function factory (type, config, load, typed) {
|
||||
size: [rows, columns]
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Calculate the trace of a matrix: the sum of the elements on the main
|
||||
* diagonal of a square matrix.
|
||||
*
|
||||
* See also:
|
||||
*
|
||||
* diagonal
|
||||
*
|
||||
* @returns {Number} The matrix trace
|
||||
*/
|
||||
DenseMatrix.prototype.trace = function () {
|
||||
// size & data
|
||||
var size = this._size;
|
||||
var data = this._data;
|
||||
// check dimensions
|
||||
switch (size.length) {
|
||||
case 1:
|
||||
// vector
|
||||
if (size[0] == 1) {
|
||||
// return data[0]
|
||||
return object.clone(data[0]);
|
||||
}
|
||||
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
|
||||
case 2:
|
||||
// two dimensional array
|
||||
var rows = size[0];
|
||||
var cols = size[1];
|
||||
if (rows === cols) {
|
||||
// calulate sum
|
||||
var sum = 0;
|
||||
// loop diagonal
|
||||
for (var i = 0; i < rows; i++)
|
||||
sum = add(sum, data[i][i]);
|
||||
// return trace
|
||||
return sum;
|
||||
}
|
||||
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
|
||||
default:
|
||||
// multi dimensional array
|
||||
throw new RangeError('Matrix must be two dimensional (size: ' + string.format(size) + ')');
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Generate a matrix from a JSON object
|
||||
|
||||
@ -210,4 +210,108 @@ describe('trace', function() {
|
||||
assert.equal(expression.toTex(), '\\mathrm{tr}\\left({\\begin{bmatrix}1&2\\\\3&4\\\\\\end{bmatrix}}\\right)');
|
||||
});
|
||||
|
||||
describe('DenseMatrix', function () {
|
||||
|
||||
it('should calculate trace on a square matrix', function() {
|
||||
var m = math.matrix([
|
||||
[1, 2],
|
||||
[4, -2]
|
||||
]);
|
||||
assert.equal(math.trace(m), -1);
|
||||
|
||||
m = math.matrix([
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]
|
||||
]);
|
||||
assert.equal(math.trace(m), 0);
|
||||
|
||||
m = math.matrix([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 2, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 9]
|
||||
]);
|
||||
assert.equal(math.trace(m), 10);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid matrix', function() {
|
||||
var m = math.matrix([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
]);
|
||||
assert.throws(function () { math.trace(m); });
|
||||
});
|
||||
});
|
||||
|
||||
describe('CcsMatrix', function () {
|
||||
|
||||
it('should calculate trace on a square matrix', function() {
|
||||
var m = math.matrix([
|
||||
[1, 2],
|
||||
[4, -2]
|
||||
], 'ccs');
|
||||
assert.equal(math.trace(m), -1);
|
||||
|
||||
m = math.matrix([
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]
|
||||
], 'ccs');
|
||||
assert.equal(math.trace(m), 0);
|
||||
|
||||
m = math.matrix([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 2, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 9]
|
||||
], 'ccs');
|
||||
assert.equal(math.trace(m), 10);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid matrix', function() {
|
||||
var m = math.matrix([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
], 'ccs');
|
||||
assert.throws(function () { math.trace(m); });
|
||||
});
|
||||
});
|
||||
|
||||
describe('CrsMatrix', function () {
|
||||
|
||||
it('should calculate trace on a square matrix', function() {
|
||||
var m = math.matrix([
|
||||
[1, 2],
|
||||
[4, -2]
|
||||
], 'crs');
|
||||
assert.equal(math.trace(m), -1);
|
||||
|
||||
m = math.matrix([
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]
|
||||
], 'crs');
|
||||
assert.equal(math.trace(m), 0);
|
||||
|
||||
m = math.matrix([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 2, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 9]
|
||||
], 'crs');
|
||||
assert.equal(math.trace(m), 10);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid matrix', function() {
|
||||
var m = math.matrix([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
], 'crs');
|
||||
assert.throws(function () { math.trace(m); });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -1758,41 +1758,6 @@ describe('CcsMatrix', function() {
|
||||
});
|
||||
});
|
||||
|
||||
describe('trace', function () {
|
||||
|
||||
it('should calculate trace on a square matrix', function() {
|
||||
var m = new CcsMatrix([
|
||||
[1, 2],
|
||||
[4, -2]
|
||||
]);
|
||||
assert.equal(m.trace(), -1);
|
||||
|
||||
m = new CcsMatrix([
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]
|
||||
]);
|
||||
assert.equal(m.trace(), 0);
|
||||
|
||||
m = new CcsMatrix([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 2, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 9]
|
||||
]);
|
||||
assert.equal(m.trace(), 10);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid matrix', function() {
|
||||
var m = new CcsMatrix([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
]);
|
||||
assert.throws(function () { m.trace(); });
|
||||
});
|
||||
});
|
||||
|
||||
describe('multiply', function () {
|
||||
|
||||
it('should multiply matrix x scalar', function() {
|
||||
|
||||
@ -1740,41 +1740,6 @@ describe('CrsMatrix', function() {
|
||||
});
|
||||
});
|
||||
|
||||
describe('trace', function () {
|
||||
|
||||
it('should calculate trace on a square matrix', function() {
|
||||
var m = new CrsMatrix([
|
||||
[1, 2],
|
||||
[4, -2]
|
||||
]);
|
||||
assert.equal(m.trace(), -1);
|
||||
|
||||
m = new CrsMatrix([
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]
|
||||
]);
|
||||
assert.equal(m.trace(), 0);
|
||||
|
||||
m = new CrsMatrix([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 2, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 9]
|
||||
]);
|
||||
assert.equal(m.trace(), 10);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid matrix', function() {
|
||||
var m = new CrsMatrix([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
]);
|
||||
assert.throws(function () { m.trace(); });
|
||||
});
|
||||
});
|
||||
|
||||
describe('multiply', function () {
|
||||
|
||||
it('should multiply matrix x scalar', function() {
|
||||
|
||||
@ -1086,41 +1086,6 @@ describe('DenseMatrix', function() {
|
||||
});
|
||||
});
|
||||
|
||||
describe('trace', function () {
|
||||
|
||||
it('should calculate trace on a square matrix', function() {
|
||||
var m = new DenseMatrix([
|
||||
[1, 2],
|
||||
[4, -2]
|
||||
]);
|
||||
assert.equal(m.trace(), -1);
|
||||
|
||||
m = new DenseMatrix([
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]
|
||||
]);
|
||||
assert.equal(m.trace(), 0);
|
||||
|
||||
m = new DenseMatrix([
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 2, 0],
|
||||
[1, 0, 0, 0],
|
||||
[0, 0, 1, 9]
|
||||
]);
|
||||
assert.equal(m.trace(), 10);
|
||||
});
|
||||
|
||||
it('should throw an error for invalid matrix', function() {
|
||||
var m = new DenseMatrix([
|
||||
[1, 2, 3],
|
||||
[4, 5, 6]
|
||||
]);
|
||||
assert.throws(function () { m.trace(); });
|
||||
});
|
||||
});
|
||||
|
||||
describe('multiply', function () {
|
||||
|
||||
it('should multiply matrix x scalar', function() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user