-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
sparse_segment_mean.ts
90 lines (84 loc) · 3.49 KB
/
sparse_segment_mean.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/**
* @license
* Copyright 2021 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from '../../engine';
import {SparseSegmentMean, SparseSegmentMeanInputs} from '../../kernel_names';
import {Tensor, Tensor1D} from '../../tensor';
import {convertToTensor} from '../../tensor_util_env';
import {TensorLike} from '../../types';
import {op} from '../operation';
/**
* Computes the mean along sparse segments of a tensor.
*
* ```js
* const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
* // Select two rows, one segment.
* const result1 = tf.sparse.sparseSegmentMean(c,
* tf.tensor1d([0, 1], 'int32'),
* tf.tensor1d([0, 0], 'int32'));
* result1.print(); // [[0, 0, 0, 0]]
*
* // Select two rows, two segments.
* const result2 = tf.sparse.sparseSegmentMean(c,
* tf.tensor1d([0, 1], 'int32'),
* tf.tensor1d([0, 1], 'int32'));
* result2.print(); // [[1, 2, 3, 4], [-1, -2, -3, -4]]
*
* // Select all rows, two segments.
* const result3 = tf.sparse.sparseSegmentMean(c,
* tf.tensor1d([0, 1, 2], 'int32'),
* tf.tensor1d([0, 1, 1], 'int32'));
* result3.print(); // [[1.0, 2.0, 3.0, 4.0], [2.5, 2.5, 2.5, 2.5]]
* ```
* @param data: A Tensor of at least one dimension with data that will be
* assembled in the output.
* @param indices: A 1-D Tensor with indices into data. Has same rank as
* segmentIds.
* @param segmentIds: A 1-D Tensor with indices into the output Tensor. Values
* should be sorted and can be repeated.
* @return Has same shape as data, except for dimension 0 which has equal to
* the number of segments.
*
* @doc {heading: 'Operations', subheading: 'Sparse'}
*/
function sparseSegmentMean_(
data: Tensor|TensorLike, indices: Tensor1D|TensorLike,
segmentIds: Tensor1D|TensorLike): Tensor {
const $data = convertToTensor(data, 'data', 'sparseSegmentMean');
const $indices =
convertToTensor(indices, 'indices', 'sparseSegmentMean', 'int32');
const $segmentIds =
convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentMean', 'int32');
if ($data.rank < 1) {
throw new Error(
`Data should be at least 1 dimensional but received scalar`);
}
if ($indices.rank !== 1) {
throw new Error(`Indices should be Tensor1D but received shape
${$indices.shape}`);
}
if ($segmentIds.rank !== 1) {
throw new Error(`Segment ids should be Tensor1D but received shape
${$segmentIds.shape}`);
}
const inputs: SparseSegmentMeanInputs = {
data: $data,
indices: $indices,
segmentIds: $segmentIds
};
return ENGINE.runKernel(SparseSegmentMean, inputs as {});
}
export const sparseSegmentMean = /* @__PURE__ */ op({sparseSegmentMean_});