trace() refactoring

This commit is contained in:
Rogelio J. Baucells 2015-04-22 12:41:02 -04:00
parent c5007d4bf2
commit ae0b2cf4e9
9 changed files with 254 additions and 306 deletions

View File

@ -142,7 +142,7 @@ function create (config) {
math.import(require('./lib/function/matrix/size'));
math.import(require('./lib/function/matrix/squeeze'));
require('./lib/function/matrix/subset')(math, _config);
require('./lib/function/matrix/trace')(math, _config);
math.import(require('./lib/function/matrix/trace'));
math.import(require('./lib/function/matrix/transpose'));
require('./lib/function/matrix/zeros')(math, _config);

View File

@ -1,13 +1,12 @@
'use strict';
module.exports = function (math) {
var util = require('../../util/index'),
var clone = require('../../util/object').clone;
var format = require('../../util/string').format;
Matrix = math.type.Matrix,
object = util.object,
array = util.array,
string = util.string;
function factory (type, config, load, typed) {
var matrix = load(require('../construction/matrix'));
var add = load(require('../arithmetic/add'));
/**
* Calculate the trace of a matrix: the sum of the elements on the main
@ -36,61 +35,160 @@ module.exports = function (math) {
*
* @return {Number} The trace of `x`
*/
math.trace = function trace (x) {
if (arguments.length != 1) {
throw new math.error.ArgumentsError('trace', arguments.length, 1);
}
// check x is a matrix
if (x instanceof Matrix) {
// use optimized operation for the matrix storage format
return x.trace();
}
var trace = typed('trace', {
// size
var size;
if (x instanceof Array) {
// calculate sixe
size = array.size(x);
}
else {
// a scalar
size = [];
}
'Array': function (x) {
// use dense matrix implementation
return trace(matrix(x));
},
'Matrix': function (x) {
// result
var c;
// process storage format
switch (x.storage()) {
case 'dense':
c = _denseTrace(x);
break;
case 'ccs':
c = _ccsTrace(x);
break;
case 'crs':
c = _crsTrace(x);
break;
}
return c;
},
'any': function (x) {
return clone(x);
}
});
var _denseTrace = function (m) {
// matrix size & data
var size = m._size;
var data = m._data;
// process dimensions
switch (size.length) {
case 0:
// scalar
return object.clone(x);
case 1:
// vector
if (size[0] == 1) {
// clone value
return object.clone(x[0]);
// return data[0]
return clone(data[0]);
}
throw new RangeError('Array must be square (size: ' + string.format(size) + ')');
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
case 2:
// two dimensional array
// two dimensional
var rows = size[0];
var cols = size[1];
// check array is square
if (rows == cols) {
// diagonal sum
if (rows === cols) {
// calulate sum
var sum = 0;
// loop diagonal
for (var i = 0; i < x.length; i++) {
// sum
sum = math.add(sum, x[i][i]);
}
for (var i = 0; i < rows; i++)
sum = add(sum, data[i][i]);
// return trace
return sum;
}
throw new RangeError('Array must be square (size: ' + string.format(size) + ')');
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
default:
// multi dimensional array
throw new RangeError('Matrix must be two dimensional (size: ' + string.format(size) + ')');
// multi dimensional
throw new RangeError('Matrix must be two dimensional (size: ' + format(size) + ')');
}
};
};
var _ccsTrace = function (m) {
// matrix arrays
var values = m._values;
var index = m._index;
var ptr = m._ptr;
var size = m._size;
// check dimensions
var rows = size[0];
var columns = size[1];
// matrix must be square
if (rows === columns) {
// calulate sum
var sum = 0;
// check we have data (avoid looping columns)
if (values.length > 0) {
// loop columns
for (var j = 0; j < columns; j++) {
// k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
var k0 = ptr[j];
var k1 = ptr[j + 1];
// loop k within [k0, k1[
for (var k = k0; k < k1; k++) {
// row index
var i = index[k];
// check row
if (i === j) {
// accumulate value
sum = add(sum, values[k]);
// exit loop
break;
}
if (i > j) {
// exit loop, no value on the diagonal for column j
break;
}
}
}
}
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
};
var _crsTrace = function (m) {
// matrix arrays
var values = m._values;
var index = m._index;
var ptr = m._ptr;
var size = m._size;
// check dimensions
var rows = size[0];
var columns = size[1];
// matrix must be square
if (rows === columns) {
// calulate sum
var sum = 0;
// check we have data (avoid looping rows)
if (values.length > 0) {
// loop rows
for (var i = 0; i < rows; i++) {
// k0 <= k < k1 where k0 = _ptr[i] && k1 = _ptr[i+1]
var k0 = ptr[i];
var k1 = ptr[i + 1];
// loop k within [k0, k1[
for (var k = k0; k < k1; k++) {
// column index
var j = index[k];
// check row
if (i === j) {
// accumulate value
sum = add(sum, values[k]);
// exit loop
break;
}
if (j > i) {
// exit loop, no value on the diagonal for column j
break;
}
}
}
}
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + format(size) + ')');
};
return trace;
}
exports.name = 'trace';
exports.factory = factory;

View File

@ -14,9 +14,8 @@ var isInteger = util.number.isInteger;
var validateIndex = array.validateIndex;
function factory (type, config, load, typed) {
function factory (type, config, load) {
var add = load(require('../../function/arithmetic/add'));
var multiply = load(require('../../function/arithmetic/multiply'));
var equal = load(require('../../function/relational/equal'));
@ -1092,57 +1091,6 @@ function factory (type, config, load, typed) {
});
};
/**
* Calculate the trace of a matrix: the sum of the elements on the main
* diagonal of a square matrix.
*
* See also:
*
* diagonal
*
* @returns {Number} The matrix trace
*/
CcsMatrix.prototype.trace = function () {
// size
var size = this._size;
// check dimensions
var rows = size[0];
var columns = size[1];
// matrix must be square
if (rows === columns) {
// calulate sum
var sum = 0;
// check we have data (avoid looping columns)
if (this._values.length > 0) {
// loop columns
for (var j = 0; j < columns; j++) {
// k0 <= k < k1 where k0 = _ptr[j] && k1 = _ptr[j+1]
var k0 = this._ptr[j];
var k1 = this._ptr[j + 1];
// loop k within [k0, k1[
for (var k = k0; k < k1; k++) {
// row index
var i = this._index[k];
// check row
if (i === j) {
// accumulate value
sum = add(sum, this._values[k]);
// exit loop
break;
}
if (i > j) {
// exit loop, no value on the diagonal for column j
break;
}
}
}
}
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
};
/**
* Multiply the matrix values times the argument.
*

View File

@ -14,7 +14,7 @@ var isInteger = util.number.isInteger;
var validateIndex = array.validateIndex;
function factory (type, config, load, typed) {
function factory (type, config, load) {
var add = load(require('../../function/arithmetic/add'));
var multiply = load(require('../../function/arithmetic/multiply'));
@ -1085,57 +1085,6 @@ function factory (type, config, load, typed) {
size: [rows, columns]
});
};
/**
* Calculate the trace of a matrix: the sum of the elements on the main
* diagonal of a square matrix.
*
* See also:
*
* diagonal
*
* @returns {Number} The matrix trace
*/
CrsMatrix.prototype.trace = function () {
// size
var size = this._size;
// check dimensions
var rows = size[0];
var columns = size[1];
// matrix must be square
if (rows === columns) {
// calulate sum
var sum = 0;
// check we have data (avoid looping rows)
if (this._values.length > 0) {
// loop rows
for (var i = 0; i < rows; i++) {
// k0 <= k < k1 where k0 = _ptr[i] && k1 = _ptr[i+1]
var k0 = this._ptr[i];
var k1 = this._ptr[i + 1];
// loop k within [k0, k1[
for (var k = k0; k < k1; k++) {
// column index
var j = this._index[k];
// check row
if (i === j) {
// accumulate value
sum = add(sum, this._values[k]);
// exit loop
break;
}
if (j > i) {
// exit loop, no value on the diagonal for column j
break;
}
}
}
}
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
};
/**
* Multiply the matrix values times the argument.

View File

@ -13,13 +13,10 @@ var isInteger = util.number.isInteger;
var validateIndex = array.validateIndex;
function factory (type, config, load, typed) {
function factory (type, config, load) {
var add = load(require('../../function/arithmetic/add'));
var divideScalar = load(require('../../function/arithmetic/divideScalar'));
var multiply = load(require('../../function/arithmetic/multiply'));
var subtract = load(require('../../function/arithmetic/subtract'));
var equal = load(require('../../function/relational/equal'));
var Index = type.Index;
var BigNumber = type.BigNumber;
@ -748,49 +745,6 @@ function factory (type, config, load, typed) {
size: [rows, columns]
});
};
/**
* Calculate the trace of a matrix: the sum of the elements on the main
* diagonal of a square matrix.
*
* See also:
*
* diagonal
*
* @returns {Number} The matrix trace
*/
DenseMatrix.prototype.trace = function () {
// size & data
var size = this._size;
var data = this._data;
// check dimensions
switch (size.length) {
case 1:
// vector
if (size[0] == 1) {
// return data[0]
return object.clone(data[0]);
}
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
case 2:
// two dimensional array
var rows = size[0];
var cols = size[1];
if (rows === cols) {
// calulate sum
var sum = 0;
// loop diagonal
for (var i = 0; i < rows; i++)
sum = add(sum, data[i][i]);
// return trace
return sum;
}
throw new RangeError('Matrix must be square (size: ' + string.format(size) + ')');
default:
// multi dimensional array
throw new RangeError('Matrix must be two dimensional (size: ' + string.format(size) + ')');
}
};
/**
* Generate a matrix from a JSON object

View File

@ -210,4 +210,108 @@ describe('trace', function() {
assert.equal(expression.toTex(), '\\mathrm{tr}\\left({\\begin{bmatrix}1&2\\\\3&4\\\\\\end{bmatrix}}\\right)');
});
describe('DenseMatrix', function () {
it('should calculate trace on a square matrix', function() {
var m = math.matrix([
[1, 2],
[4, -2]
]);
assert.equal(math.trace(m), -1);
m = math.matrix([
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
]);
assert.equal(math.trace(m), 0);
m = math.matrix([
[1, 0, 0, 0],
[0, 0, 2, 0],
[1, 0, 0, 0],
[0, 0, 1, 9]
]);
assert.equal(math.trace(m), 10);
});
it('should throw an error for invalid matrix', function() {
var m = math.matrix([
[1, 2, 3],
[4, 5, 6]
]);
assert.throws(function () { math.trace(m); });
});
});
describe('CcsMatrix', function () {
it('should calculate trace on a square matrix', function() {
var m = math.matrix([
[1, 2],
[4, -2]
], 'ccs');
assert.equal(math.trace(m), -1);
m = math.matrix([
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
], 'ccs');
assert.equal(math.trace(m), 0);
m = math.matrix([
[1, 0, 0, 0],
[0, 0, 2, 0],
[1, 0, 0, 0],
[0, 0, 1, 9]
], 'ccs');
assert.equal(math.trace(m), 10);
});
it('should throw an error for invalid matrix', function() {
var m = math.matrix([
[1, 2, 3],
[4, 5, 6]
], 'ccs');
assert.throws(function () { math.trace(m); });
});
});
describe('CrsMatrix', function () {
it('should calculate trace on a square matrix', function() {
var m = math.matrix([
[1, 2],
[4, -2]
], 'crs');
assert.equal(math.trace(m), -1);
m = math.matrix([
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
], 'crs');
assert.equal(math.trace(m), 0);
m = math.matrix([
[1, 0, 0, 0],
[0, 0, 2, 0],
[1, 0, 0, 0],
[0, 0, 1, 9]
], 'crs');
assert.equal(math.trace(m), 10);
});
it('should throw an error for invalid matrix', function() {
var m = math.matrix([
[1, 2, 3],
[4, 5, 6]
], 'crs');
assert.throws(function () { math.trace(m); });
});
});
});

View File

@ -1758,41 +1758,6 @@ describe('CcsMatrix', function() {
});
});
describe('trace', function () {
it('should calculate trace on a square matrix', function() {
var m = new CcsMatrix([
[1, 2],
[4, -2]
]);
assert.equal(m.trace(), -1);
m = new CcsMatrix([
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
]);
assert.equal(m.trace(), 0);
m = new CcsMatrix([
[1, 0, 0, 0],
[0, 0, 2, 0],
[1, 0, 0, 0],
[0, 0, 1, 9]
]);
assert.equal(m.trace(), 10);
});
it('should throw an error for invalid matrix', function() {
var m = new CcsMatrix([
[1, 2, 3],
[4, 5, 6]
]);
assert.throws(function () { m.trace(); });
});
});
describe('multiply', function () {
it('should multiply matrix x scalar', function() {

View File

@ -1740,41 +1740,6 @@ describe('CrsMatrix', function() {
});
});
describe('trace', function () {
it('should calculate trace on a square matrix', function() {
var m = new CrsMatrix([
[1, 2],
[4, -2]
]);
assert.equal(m.trace(), -1);
m = new CrsMatrix([
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
]);
assert.equal(m.trace(), 0);
m = new CrsMatrix([
[1, 0, 0, 0],
[0, 0, 2, 0],
[1, 0, 0, 0],
[0, 0, 1, 9]
]);
assert.equal(m.trace(), 10);
});
it('should throw an error for invalid matrix', function() {
var m = new CrsMatrix([
[1, 2, 3],
[4, 5, 6]
]);
assert.throws(function () { m.trace(); });
});
});
describe('multiply', function () {
it('should multiply matrix x scalar', function() {

View File

@ -1086,41 +1086,6 @@ describe('DenseMatrix', function() {
});
});
describe('trace', function () {
it('should calculate trace on a square matrix', function() {
var m = new DenseMatrix([
[1, 2],
[4, -2]
]);
assert.equal(m.trace(), -1);
m = new DenseMatrix([
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]
]);
assert.equal(m.trace(), 0);
m = new DenseMatrix([
[1, 0, 0, 0],
[0, 0, 2, 0],
[1, 0, 0, 0],
[0, 0, 1, 9]
]);
assert.equal(m.trace(), 10);
});
it('should throw an error for invalid matrix', function() {
var m = new DenseMatrix([
[1, 2, 3],
[4, 5, 6]
]);
assert.throws(function () { m.trace(); });
});
});
describe('multiply', function () {
it('should multiply matrix x scalar', function() {