Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions animated-transformer/src/lib/gtensor/gtensor.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -905,4 +905,125 @@ describe('gtensor', () => {
[0, 0, 0],
]);
});

fit('where', async () => {
const g1 = new gtensor.GTensor(
tf.tensor([
[
[1, 2],
[3, 4],
[5, 6],
],
[
[1, 2],
[3, 4],
[5, 6],
],
]),
['example', 'pos', 'repSize'],
);

const g2 = new gtensor.GTensor(
tf.tensor([
[0, 0],
[0, 0],
[0, 0],
]),
['pos', 'repSize'],
);

const condition = new gtensor.GTensor(
tf.tensor2d(
[
[1, 0],
[0, 1],
[1, 0],
],
[3, 2],
'bool',
),
['pos', 'repSize'],
);

const g1WhereCondition = g1.where(condition, g2);

expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']);
tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [
[
[1, 0],
[0, 4],
[5, 0],
], // example = 1
[
[1, 0],
[0, 4],
[5, 0],
],
]);
});

fit('where no broadcast over g2', async () => {
const g1 = new gtensor.GTensor(
tf.tensor([
[
[1, 2],
[3, 4],
[5, 6],
],
[
[1, 2],
[3, 4],
[5, 6],
],
]),
['example', 'pos', 'repSize'],
);

const g2 = new gtensor.GTensor(
tf.tensor(
[
[
[0, 0],
[0, 0],
[0, 0],
], // example = 1
[
[0, 0],
[0, 0],
[0, 0],
],
], // example = 2
),
['example', 'pos', 'repSize'],
);

const condition = new gtensor.GTensor(
tf.tensor2d(
[
[1, 0],
[0, 1],
[1, 0],
],
[3, 2],
'bool',
),
['pos', 'repSize'],
);

const g1WhereCondition = g1.where(condition, g2);

expect(g1WhereCondition.dimNames).toEqual(['example', 'pos', 'repSize']);
tf.test_util.expectArraysEqual(g1WhereCondition.tensor.arraySync(), [
[
[1, 0],
[0, 4],
[5, 0],
], // example = 1
[
[1, 0],
[0, 4],
[5, 0],
],
]);
});
});
39 changes: 37 additions & 2 deletions animated-transformer/src/lib/gtensor/gtensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,37 @@ export class GTensor<G extends DName> {
this.dimNames,
);
}

/* Returns the elements, from this of the gtensor or g2 depending on the condition.
If the condition is true, select from the gtensor, otherwise select from g2.
if gtensor.dims != g2.dims g2 is broadcasted to this.dimensions!
if gtensor.dims != cond.dims condition is broadcasted to this.dimensions! */
public where<D extends DName, G2 extends DName>(
condition: GTensor<D>,
g2: GTensor<G2>,
): GTensor<G> {
// Verify that D and G2 are smaller than G or return an error
if (condition.dimNames.length > this.dimNames.length) {
throw new ValueError('The rank of condition cannot be higher than the rank of this tensor');
}
if (g2.dimNames.length > this.dimNames.length) {
throw new ValueError('The rank of g2 cannot be higher than the rank of this tensor');
}
// Broadcast G2 to this tensor's dims
const g2big = g2.broadcastToCombinedShape(this);
const g1big = this.broadcastToCombinedShape(g2);
const g2bigLikeG1 = g2big.transposeLike(g1big);

// Broadcast condition to this tensor's dims
const conditionBig = condition.broadcastToCombinedShape(this);
const g1bigC = this.broadcastToCombinedShape(condition);
const conditionBigLikeG1 = conditionBig.transposeLike(g1bigC);

return new GTensor(
this.tensor.where(conditionBigLikeG1.tensor, g2bigLikeG1.tensor),
this.dimNames,
);
}
}

export class GVariable<G extends DName> extends GTensor<G> {
Expand Down Expand Up @@ -964,13 +995,17 @@ export function makeRange<T extends DName>(
* - dtype : The type of an element in the resulting tensor. Defaults to 'float32'
* // TODO add optianal broadcastTo dimensions/GTensor
* */
export function makeTriangularMatrix<N1 extends string, N2 extends string, T extends string | number>(
export function makeTriangularMatrix<
N1 extends string,
N2 extends string,
T extends string | number,
>(
size: number,
d1Name: N1,
d2Name: N2,
lowerLeftValue: T,
upperRightValue: T,
dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32'
dtype: 'float32' | 'int32' | 'bool' | 'complex64' | 'string' = 'float32',
): GTensor<N1 | N2> {
// Create a range tensor for row indices
const rowIndices = tf.range(0, size, 1, 'int32');
Expand Down