diff --git a/src/backend/function-node-base.js b/src/backend/function-node-base.js index 4d9e38cf..59a2ca4a 100644 --- a/src/backend/function-node-base.js +++ b/src/backend/function-node-base.js @@ -250,7 +250,8 @@ module.exports = class BaseFunctionNode { getParamType(paramName) { const paramIndex = this.paramNames.indexOf(paramName); - if (this.parent === null) return null; + if (paramIndex === -1) return null; + if (!this.parent) return null; if (this.paramTypes[paramIndex]) return this.paramTypes[paramIndex]; const calledFunctionArguments = this.parent.calledFunctionsArguments[this.functionName]; for (let i = 0; i < calledFunctionArguments.length; i++) { @@ -262,6 +263,20 @@ module.exports = class BaseFunctionNode { return null; } + getUserParamName(paramName) { + const paramIndex = this.paramNames.indexOf(paramName); + if (paramIndex === -1) return null; + if (!this.parent) return null; + const calledFunctionArguments = this.parent.calledFunctionsArguments[this.functionName]; + for (let i = 0; i < calledFunctionArguments.length; i++) { + const calledFunctionArgument = calledFunctionArguments[i]; + if (calledFunctionArgument[paramIndex] !== null) { + return calledFunctionArgument[paramIndex].name; + } + } + return null; + } + generate(options) { throw new Error('generate not defined on BaseFunctionNode'); } diff --git a/src/backend/web-gl/function-node.js b/src/backend/web-gl/function-node.js index 7ebde3a1..b57d3c63 100644 --- a/src/backend/web-gl/function-node.js +++ b/src/backend/web-gl/function-node.js @@ -400,7 +400,12 @@ module.exports = class WebGLFunctionNode extends FunctionNodeBase { if (this.constants && this.constants.hasOwnProperty(idtNode.name)) { retArr.push('constants_' + idtNode.name); } else { - retArr.push('user_' + idtNode.name); + const userParamName = funcParam.getUserParamName(idtNode.name); + if (userParamName !== null) { + retArr.push('user_' + userParamName); + } else { + retArr.push('user_' + idtNode.name); + } } } diff --git a/test/html/issues/96-param-names.html b/test/html/issues/96-param-names.html new file mode 100644 index 00000000..0554f0a7 --- /dev/null +++ b/test/html/issues/96-param-names.html @@ -0,0 +1,17 @@ + + + + + GPU.JS : FunctionBuilder unit testing + + + + + + +
+
+ + + + diff --git a/test/html/test-all.html b/test/html/test-all.html index 1b305264..effbbbb4 100644 --- a/test/html/test-all.html +++ b/test/html/test-all.html @@ -35,5 +35,6 @@ + diff --git a/test/src/issues/96-param-names.js b/test/src/issues/96-param-names.js new file mode 100644 index 00000000..5de3e02f --- /dev/null +++ b/test/src/issues/96-param-names.js @@ -0,0 +1,34 @@ +QUnit.test( "Issue #96 - param names (GPU only)", function() { + const A = [ + [1, 1], + [1, 1], + [1, 1] + ]; + + const B = [ + [1, 1, 1], + [1, 1, 1] + ]; + + const gpu = new GPU(); + + function multiply(m, n, y, x) { + var sum = 0; + for (var i = 0; i < 2; i++) { + sum += m[y][i] * n[i][x]; + } + return sum; + } + + var kernels = gpu.createKernels({ + multiplyResult: multiply + }, function (a, b) { + return multiply(b, a, this.thread.y, this.thread.x); + }) + .setDimensions([B.length, A.length]); + + const result = kernels(A, B).result; + QUnit.assert.deepEqual(QUnit.extend([], result[0]), [2,2]); + QUnit.assert.deepEqual(QUnit.extend([], result[1]), [2,2]); + QUnit.assert.deepEqual(QUnit.extend([], result[2]), [0,0]); +}); \ No newline at end of file