From 7c583911885e28f7c523099e256aed4dd1cdaaa5 Mon Sep 17 00:00:00 2001 From: LauBritoMedina Date: Sun, 19 Jan 2025 12:25:06 +0100 Subject: [PATCH 1/3] Adds GTensor.where function to gtensor and unit tests. --- .../src/lib/gtensor/gtensor.spec.ts | 105 ++++++++++++++++++ .../src/lib/gtensor/gtensor.ts | 22 +++- 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/animated-transformer/src/lib/gtensor/gtensor.spec.ts b/animated-transformer/src/lib/gtensor/gtensor.spec.ts index cf15b70..835ebaa 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.spec.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.spec.ts @@ -905,4 +905,109 @@ describe('gtensor', () => { [0, 0, 0], ]); }); + + it('where', async () => { + const g1 = new gtensor.GTensor( + tf.tensor( + [ + [ + [1, 2], + [3, 4], + [5, 6], + ], + [ + [1, 2], + [3, 4], + [5, 6], + ], + ], + ), + ['example', 'pos', 'repSize'], + ); + + const g2 = new gtensor.GTensor( + tf.tensor( + [ + [0, 0], + [0, 0], + [0, 0], + ], + ), + ['pos', 'repSize'], + ); + + const condition = tf.tensor([1, 0, 0, 1, 1, 0], [3, 2], 'bool'); + + const g1WhereCondition = g1.where(condition, g2); + + expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']); + tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [ + [ + [1, 0], + [0, 4], + [5, 0], + ], // example = 1 + [ + [1, 0], + [0, 4], + [5, 0], + ], + ]); + }); + + it('where no broadcast over g2', async () => { + const g1 = new gtensor.GTensor( + tf.tensor( + [ + [ + [1, 2], + [3, 4], + [5, 6], + ], + [ + [1, 2], + [3, 4], + [5, 6], + ], + ], + ), + ['example', 'pos', 'repSize'], + ); + + const g2 = new gtensor.GTensor( + tf.tensor( + [ + [ + [0, 0], + [0, 0], + [0, 0], + ], // example = 1 + [ + [0, 0], + [0, 0], + [0, 0], + ], + ], // example = 2 + ), + ['example', 'pos', 'repSize'], + ); + + const condition = tf.tensor([1, 0, 0, 1, 1, 0], [3, 2], 'bool'); + + const g1WhereCondition = g1.where(condition, g2); + + expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']); + tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [ + [ + [1, 0], + [0, 4], + [5, 0], + ], // example = 1 + [ + [1, 0], + [0, 4], + [5, 0], + ], + ]); + }); }); diff --git a/animated-transformer/src/lib/gtensor/gtensor.ts b/animated-transformer/src/lib/gtensor/gtensor.ts index e67e24e..389fde9 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.ts @@ -840,6 +840,20 @@ export class GTensor { this.dimNames, ); } + + /* Returns the elements, either of the gtensor or g2 depending on the condition. + If the condition is true, select from the gtensor, otherwise select from g2. + if gtensor.dims != g2.dims substracted dimentions are broadcasted */ + public where(condition: tf.Tensor, g2: GTensor): GTensor { + const g2big = g2.broadcastToCombinedShape(this); + const g1big = this.broadcastToCombinedShape(g2); + const g1bigLikeG2 = g1big.transposeLike(g2big); + + const shape = Object.values(g1bigLikeG2.gshape()) as number[]; + const conditionBig = condition.broadcastTo(shape); + + return new GTensor(g1bigLikeG2.tensor.where(conditionBig, g2big.tensor), this.dimNames); + } } export class GVariable extends GTensor { @@ -964,13 +978,17 @@ export function makeRange( * - dtype : The type of an element in the resulting tensor. Defaults to 'float32' * // TODO add optianal broadcastTo dimensions/GTensor * */ -export function makeTriangularMatrix( +export function makeTriangularMatrix< + N1 extends string, + N2 extends string, + T extends string | number, +>( size: number, d1Name: N1, d2Name: N2, lowerLeftValue: T, upperRightValue: T, - dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32' + dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32', ): GTensor { // Create a range tensor for row indices const rowIndices = tf.range(0, size, 1, 'int32'); From 0b6646f6f8fad73706272ef08f19fba5abbd3ef9 Mon Sep 17 00:00:00 2001 From: LauBritoMedina Date: Wed, 19 Feb 2025 22:34:27 +0100 Subject: [PATCH 2/3] Addressing comments. G2 and condition are now a named tensor --- .../src/lib/gtensor/gtensor.spec.ts | 84 +++++++++++-------- .../src/lib/gtensor/gtensor.ts | 31 +++++-- 2 files changed, 74 insertions(+), 41 deletions(-) diff --git a/animated-transformer/src/lib/gtensor/gtensor.spec.ts b/animated-transformer/src/lib/gtensor/gtensor.spec.ts index 835ebaa..8a9129d 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.spec.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.spec.ts @@ -906,38 +906,45 @@ describe('gtensor', () => { ]); }); - it('where', async () => { + fit('where', async () => { const g1 = new gtensor.GTensor( - tf.tensor( + tf.tensor([ [ - [ - [1, 2], - [3, 4], - [5, 6], - ], - [ - [1, 2], - [3, 4], - [5, 6], - ], + [1, 2], + [3, 4], + [5, 6], ], - ), + [ + [1, 2], + [3, 4], + [5, 6], + ], + ]), ['example', 'pos', 'repSize'], ); const g2 = new gtensor.GTensor( - tf.tensor( + tf.tensor([ + [0, 0], + [0, 0], + [0, 0], + ]), + ['pos', 'repSize'], + ); + + const condition = new gtensor.GTensor( + tf.tensor2d( [ - [0, 0], - [0, 0], - [0, 0], + [1, 0], + [0, 1], + [1, 0], ], + [3, 2], + 'bool', ), - ['pos', 'repSize'], + ['pos', 'repsize'], ); - const condition = tf.tensor([1, 0, 0, 1, 1, 0], [3, 2], 'bool'); - const g1WhereCondition = g1.where(condition, g2); expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']); @@ -955,22 +962,20 @@ describe('gtensor', () => { ]); }); - it('where no broadcast over g2', async () => { + fit('where no broadcast over g2', async () => { const g1 = new gtensor.GTensor( - tf.tensor( + tf.tensor([ [ - [ - [1, 2], - [3, 4], - [5, 6], - ], - [ - [1, 2], - [3, 4], - [5, 6], - ], + [1, 2], + [3, 4], + [5, 6], ], - ), + [ + [1, 2], + [3, 4], + [5, 6], + ], + ]), ['example', 'pos', 'repSize'], ); @@ -992,7 +997,18 @@ describe('gtensor', () => { ['example', 'pos', 'repSize'], ); - const condition = tf.tensor([1, 0, 0, 1, 1, 0], [3, 2], 'bool'); + const condition = new gtensor.GTensor( + tf.tensor2d( + [ + [1, 0], + [0, 1], + [1, 0], + ], + [3, 2], + 'bool', + ), + ['pos', 'repSize'], + ); const g1WhereCondition = g1.where(condition, g2); diff --git a/animated-transformer/src/lib/gtensor/gtensor.ts b/animated-transformer/src/lib/gtensor/gtensor.ts index 389fde9..2030075 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.ts @@ -841,18 +841,35 @@ export class GTensor { ); } - /* Returns the elements, either of the gtensor or g2 depending on the condition. + /* Returns the elements, from this of the gtensor or g2 depending on the condition. If the condition is true, select from the gtensor, otherwise select from g2. - if gtensor.dims != g2.dims substracted dimentions are broadcasted */ - public where(condition: tf.Tensor, g2: GTensor): GTensor { + if gtensor.dims != g2.dims g2 is broadcasted to this.dimensions! + if gtensor.dims != cond.dims condition is broadcasted to this.dimensions! */ + public where( + condition: GTensor, + g2: GTensor, + ): GTensor { + // Verify that D and G2 are smaller than G or return an error + if (condition.dimNames.length > this.dimNames.length) { + throw new ValueError('The rank of condition cannot be higher than the rank of this tensor'); + } + if (g2.dimNames.length > this.dimNames.length) { + throw new ValueError('The rank of g2 cannot be higher than the rank of this tensor'); + } + // Broadcast G2 to this tensor's dims const g2big = g2.broadcastToCombinedShape(this); const g1big = this.broadcastToCombinedShape(g2); - const g1bigLikeG2 = g1big.transposeLike(g2big); + const g2bigLikeG1 = g2big.transposeLike(g1big); - const shape = Object.values(g1bigLikeG2.gshape()) as number[]; - const conditionBig = condition.broadcastTo(shape); + // Broadcast condition to this tensor's dims + const conditionBig = condition.broadcastToCombinedShape(this); + const g1bigC = this.broadcastToCombinedShape(condition); + const conditionBigLikeG1 = conditionBig.transposeLike(g1bigC); - return new GTensor(g1bigLikeG2.tensor.where(conditionBig, g2big.tensor), this.dimNames); + return new GTensor( + this.tensor.where(conditionBigLikeG1.tensor, g2bigLikeG1.tensor), + this.dimNames, + ); } } From afbd001b511584abdae71b2074de917c11ada65e Mon Sep 17 00:00:00 2001 From: LauBritoMedina Date: Wed, 19 Feb 2025 22:38:38 +0100 Subject: [PATCH 3/3] fixing test typo --- animated-transformer/src/lib/gtensor/gtensor.spec.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/animated-transformer/src/lib/gtensor/gtensor.spec.ts b/animated-transformer/src/lib/gtensor/gtensor.spec.ts index 8a9129d..4e4bb39 100644 --- a/animated-transformer/src/lib/gtensor/gtensor.spec.ts +++ b/animated-transformer/src/lib/gtensor/gtensor.spec.ts @@ -942,7 +942,7 @@ describe('gtensor', () => { [3, 2], 'bool', ), - ['pos', 'repsize'], + ['pos', 'repSize'], ); const g1WhereCondition = g1.where(condition, g2);