-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
one_hot.ts
68 lines (62 loc) · 2.63 KB
/
one_hot.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {ENGINE} from '../engine';
import {OneHot, OneHotAttrs, OneHotInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {DataType, TensorLike} from '../types';
import {op} from './operation';
/**
* Creates a one-hot `tf.Tensor`. The locations represented by `indices` take
* value `onValue` (defaults to 1), while all other locations take value
* `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank
* `R+1` with the last axis of size `depth`.
* `indices` used to encode prediction class must start from 0. For example,
* if you have 3 classes of data, class 1 should be encoded as 0, class 2
* should be 1, and class 3 should be 2.
*
* ```js
* tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print();
* ```
*
* @param indices `tf.Tensor` of indices with dtype `int32`. Indices must
* start from 0.
* @param depth The depth of the one hot dimension.
* @param onValue A number used to fill in the output when the index matches
* the location.
* @param offValue A number used to fill in the output when the index does
* not match the location.
* @param dtype The dtype of the output tensor, default to 'int32'.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
function oneHot_(
indices: Tensor|TensorLike, depth: number, onValue = 1, offValue = 0,
dtype: DataType = 'int32'): Tensor {
if (depth < 2) {
throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
}
const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
const inputs: OneHotInputs = {indices: $indices};
const attrs: OneHotAttrs = {dtype, depth, onValue, offValue};
return ENGINE.runKernel(
OneHot, inputs as unknown as NamedTensorMap,
attrs as unknown as NamedAttrMap);
}
export const oneHot = /* @__PURE__ */ op({oneHot_});