tutanota/libs/tensorflow-stripped.js
abp e196f8f8e4
add tensorflow cpu backend as fallback for webgl backend
When the webgl backend is not available or unsupported,
we fall back to the tensorflow cpu backend.

Tensorflow cpu backend library review by abp and jhm.

Co-authored-by: jomapp <17314077+jomapp@users.noreply.github.com>
2025-11-20 16:06:18 +01:00

45262 lines
1.5 MiB
Vendored

function _mergeNamespaces(n, m) {
m.forEach(function (e) {
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
if (k !== 'default' && !(k in n)) {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
});
return Object.freeze(n);
}
const EPSILON_FLOAT32$1 = 1e-7;
const EPSILON_FLOAT16$1 = 1e-4;
class DataStorage {
constructor(backend, dataMover) {
this.backend = backend;
this.dataMover = dataMover;
this.data = new WeakMap();
this.dataIdsCount = 0;
}
get(dataId) {
if (!this.data.has(dataId)) {
this.dataMover.moveData(this.backend, dataId);
}
return this.data.get(dataId);
}
set(dataId, value) {
this.dataIdsCount++;
this.data.set(dataId, value);
}
has(dataId) {
return this.data.has(dataId);
}
delete(dataId) {
this.dataIdsCount--;
return this.data.delete(dataId);
}
numDataIds() {
return this.dataIdsCount;
}
}
class KernelBackend {
refCount(dataId) {
return notYetImplemented('refCount');
}
incRef(dataId) {
return notYetImplemented('incRef');
}
timerAvailable() {
return true;
}
time(f) {
return notYetImplemented('time');
}
read(dataId) {
return notYetImplemented('read');
}
readSync(dataId) {
return notYetImplemented('readSync');
}
readToGPU(dataId, options) {
return notYetImplemented('readToGPU');
}
numDataIds() {
return notYetImplemented('numDataIds');
}
disposeData(dataId, force) {
return notYetImplemented('disposeData');
}
write(values, shape, dtype) {
return notYetImplemented('write');
}
move(dataId, values, shape, dtype, refCount) {
return notYetImplemented('move');
}
createTensorFromGPUData(values, shape, dtype) {
return notYetImplemented('createTensorFromGPUData');
}
memory() {
return notYetImplemented('memory');
}
floatPrecision() {
return notYetImplemented('floatPrecision');
}
epsilon() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
}
dispose() {
return notYetImplemented('dispose');
}
}
function notYetImplemented(kernelName) {
throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
`This kernel may not be supported by the tfjs backend you have chosen`);
}
function shuffle(array) {
let counter = array.length;
let index = 0;
while (counter > 0) {
index = (Math.random() * counter) | 0;
counter--;
swap(array, counter, index);
}
}
function clamp(min, x, max) {
return Math.max(min, Math.min(x, max));
}
function nearestLargerEven(val) {
return val % 2 === 0 ? val : val + 1;
}
function swap(object, left, right) {
const temp = object[left];
object[left] = object[right];
object[right] = temp;
}
function sum$3(arr) {
let sum = 0;
for (let i = 0; i < arr.length; i++) {
sum += arr[i];
}
return sum;
}
function assert$1(expr, msg) {
if (!expr) {
throw new Error(typeof msg === 'string' ? msg : msg());
}
}
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
assert$1(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
}
function assertNonNull(a) {
assert$1(a != null, () => `The input to the tensor constructor must be a non-null value.`);
}
function sizeFromShape(shape) {
if (shape.length === 0) {
return 1;
}
let size = shape[0];
for (let i = 1; i < shape.length; i++) {
size *= shape[i];
}
return size;
}
function arraysEqual(n1, n2) {
if (n1 === n2) {
return true;
}
if (n1 == null || n2 == null) {
return false;
}
if (n1.length !== n2.length) {
return false;
}
for (let i = 0; i < n1.length; i++) {
if (n1[i] !== n2[i]) {
return false;
}
}
return true;
}
function isInt(a) {
return a % 1 === 0;
}
function sizeToSquarishShape(size) {
const width = Math.ceil(Math.sqrt(size));
return [width, Math.ceil(size / width)];
}
function rightPad(a, size) {
if (size <= a.length) {
return a;
}
return a + ' '.repeat(size - a.length);
}
function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) {
return new Promise((resolve, reject) => {
let tryCount = 0;
const tryFn = () => {
if (checkFn()) {
resolve();
return;
}
tryCount++;
const nextBackoff = delayFn(tryCount);
if (maxCounter != null && tryCount >= maxCounter) {
reject();
return;
}
if (scheduleFn != null) {
scheduleFn(tryFn, nextBackoff);
}
else {
setTimeout(tryFn, nextBackoff);
}
};
tryFn();
});
}
function inferFromImplicitShape(shape, size) {
let shapeProd = 1;
let implicitIdx = -1;
for (let i = 0; i < shape.length; ++i) {
if (shape[i] >= 0) {
shapeProd *= shape[i];
}
else if (shape[i] === -1) {
if (implicitIdx !== -1) {
throw Error(`Shapes can only have 1 implicit size. ` +
`Found -1 at dim ${implicitIdx} and dim ${i}`);
}
implicitIdx = i;
}
else if (shape[i] < 0) {
throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
}
}
if (implicitIdx === -1) {
if (size > 0 && size !== shapeProd) {
throw Error(`Size(${size}) must match the product of shape ${shape}`);
}
return shape;
}
if (shapeProd === 0) {
throw Error(`Cannot infer the missing size in [${shape}] when ` +
`there are 0 elements`);
}
if (size % shapeProd !== 0) {
throw Error(`The implicit shape can't be a fractional number. ` +
`Got ${size} / ${shapeProd}`);
}
const newShape = shape.slice();
newShape[implicitIdx] = size / shapeProd;
return newShape;
}
function parseAxisParam(axis, shape) {
const rank = shape.length;
axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
assert$1(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
`got axis ${axis}`);
assert$1(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
`got axis ${axis}`);
return axis.map(a => a < 0 ? rank + a : a);
}
function squeezeShape(shape, axis) {
const newShape = [];
const keptDims = [];
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
const axes = (axis == null || isEmptyArray) ?
null :
parseAxisParam(axis, shape).sort();
let j = 0;
for (let i = 0; i < shape.length; ++i) {
if (axes != null) {
if (axes[j] === i && shape[i] !== 1) {
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
}
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
if (axes[j] <= i) {
j++;
}
}
if (shape[i] !== 1) {
newShape.push(shape[i]);
keptDims.push(i);
}
}
return { newShape, keptDims };
}
function getTypedArrayFromDType(dtype, size) {
return getArrayFromDType(dtype, size);
}
function getArrayFromDType(dtype, size) {
let values = null;
if (dtype == null || dtype === 'float32') {
values = new Float32Array(size);
}
else if (dtype === 'int32') {
values = new Int32Array(size);
}
else if (dtype === 'bool') {
values = new Uint8Array(size);
}
else if (dtype === 'string') {
values = new Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
return values;
}
function checkConversionForErrors(vals, dtype) {
for (let i = 0; i < vals.length; i++) {
const num = vals[i];
if (isNaN(num) || !isFinite(num)) {
throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
}
}
}
function isValidDtype(dtype) {
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
dtype === 'int32' || dtype === 'string';
}
function hasEncodingLoss(oldType, newType) {
if (newType === 'complex64') {
return false;
}
if (newType === 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
return false;
}
if (newType === 'bool' && oldType === 'bool') {
return false;
}
return true;
}
function bytesPerElement(dtype) {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
}
else if (dtype === 'complex64') {
return 8;
}
else if (dtype === 'bool') {
return 1;
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
function bytesFromStringArray(arr) {
if (arr == null) {
return 0;
}
let bytes = 0;
arr.forEach(x => bytes += x.length);
return bytes;
}
function isString(value) {
return typeof value === 'string' || value instanceof String;
}
function isBoolean(value) {
return typeof value === 'boolean';
}
function isNumber(value) {
return typeof value === 'number';
}
function inferDtype(values) {
if (Array.isArray(values)) {
return inferDtype(values[0]);
}
if (values instanceof Float32Array) {
return 'float32';
}
else if (values instanceof Int32Array || values instanceof Uint8Array ||
values instanceof Uint8ClampedArray) {
return 'int32';
}
else if (isNumber(values)) {
return 'float32';
}
else if (isString(values)) {
return 'string';
}
else if (isBoolean(values)) {
return 'bool';
}
return 'float32';
}
function isFunction(f) {
return !!(f && f.constructor && f.call && f.apply);
}
function nearestDivisor(size, start) {
for (let i = start; i < size; ++i) {
if (size % i === 0) {
return i;
}
}
return size;
}
function computeStrides(shape) {
const rank = shape.length;
if (rank < 2) {
return [];
}
const strides = new Array(rank - 1);
strides[rank - 2] = shape[rank - 1];
for (let i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
}
function createNestedArray(offset, shape, a, isComplex = false) {
const ret = new Array();
if (shape.length === 1) {
const d = shape[0] * (isComplex ? 2 : 1);
for (let i = 0; i < d; i++) {
ret[i] = a[offset + i];
}
}
else {
const d = shape[0];
const rest = shape.slice(1);
const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
for (let i = 0; i < d; i++) {
ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
}
}
return ret;
}
function toNestedArray(shape, a, isComplex = false) {
if (shape.length === 0) {
return a[0];
}
const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
if (size === 0) {
return [];
}
if (size !== a.length) {
throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
}
return createNestedArray(0, shape, a, isComplex);
}
function convertBackendValuesAndArrayBuffer(data, dtype) {
if (Array.isArray(data)) {
return data;
}
if (dtype === 'float32') {
return data instanceof Float32Array ? data : new Float32Array(data);
}
else if (dtype === 'int32') {
return data instanceof Int32Array ? data : new Int32Array(data);
}
else if (dtype === 'bool' || dtype === 'string') {
return Uint8Array.from(new Int32Array(data));
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
function makeOnesTypedArray(size, dtype) {
const array = makeZerosTypedArray(size, dtype);
for (let i = 0; i < array.length; i++) {
array[i] = 1;
}
return array;
}
function makeZerosTypedArray(size, dtype) {
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(size);
}
else if (dtype === 'int32') {
return new Int32Array(size);
}
else if (dtype === 'bool') {
return new Uint8Array(size);
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
function makeZerosNestedTypedArray(shape, dtype) {
const size = shape.reduce((prev, curr) => prev * curr, 1);
if (dtype == null || dtype === 'float32') {
return toNestedArray(shape, new Float32Array(size));
}
else if (dtype === 'int32') {
return toNestedArray(shape, new Int32Array(size));
}
else if (dtype === 'bool') {
return toNestedArray(shape, new Uint8Array(size));
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
function assertNonNegativeIntegerDimensions(shape) {
shape.forEach(dimSize => {
assert$1(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
`shape [${shape}].`);
});
}
function locToIndex(locs, rank, strides) {
if (rank === 0) {
return 0;
}
else if (rank === 1) {
return locs[0];
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += strides[i] * locs[i];
}
return index;
}
function indexToLoc(index, rank, strides) {
if (rank === 0) {
return [];
}
else if (rank === 1) {
return [index];
}
const locs = new Array(rank);
for (let i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / strides[i]);
index -= locs[i] * strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
function isPromise(object) {
return object && object.then && typeof object.then === 'function';
}
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
class Environment {
constructor(global) {
this.global = global;
this.flags = {};
this.flagRegistry = {};
this.urlFlags = {};
this.getQueryParams = getQueryParams;
this.populateURLFlags();
}
setPlatform(platformName, platform) {
if (this.platform != null) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(`Platform ${this.platformName} has already been set. ` +
`Overwriting the platform with ${platformName}.`);
}
}
this.platformName = platformName;
this.platform = platform;
}
registerFlag(flagName, evaluationFn, setHook) {
this.flagRegistry[flagName] = { evaluationFn, setHook };
if (this.urlFlags[flagName] != null) {
const flagValue = this.urlFlags[flagName];
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
}
this.set(flagName, flagValue);
}
}
async getAsync(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
this.flags[flagName] = await this.evaluateFlag(flagName);
return this.flags[flagName];
}
get(flagName) {
if (flagName in this.flags) {
return this.flags[flagName];
}
const flagValue = this.evaluateFlag(flagName);
if (isPromise(flagValue)) {
throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
`Please use getAsync() instead.`);
}
this.flags[flagName] = flagValue;
return this.flags[flagName];
}
getNumber(flagName) {
return this.get(flagName);
}
getBool(flagName) {
return this.get(flagName);
}
getString(flagName) {
return this.get(flagName);
}
getFlags() {
return this.flags;
}
get features() {
return this.flags;
}
set(flagName, value) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
}
this.flags[flagName] = value;
if (this.flagRegistry[flagName].setHook != null) {
this.flagRegistry[flagName].setHook(value);
}
}
evaluateFlag(flagName) {
if (this.flagRegistry[flagName] == null) {
throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
}
return this.flagRegistry[flagName].evaluationFn();
}
setFlags(flags) {
this.flags = Object.assign({}, flags);
}
reset() {
this.flags = {};
this.urlFlags = {};
this.populateURLFlags();
}
populateURLFlags() {
if (typeof this.global === 'undefined' ||
typeof this.global.location === 'undefined' ||
typeof this.global.location.search === 'undefined') {
return;
}
const urlParams = this.getQueryParams(this.global.location.search);
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
keyValues.forEach(keyValue => {
const [key, value] = keyValue.split(':');
this.urlFlags[key] = parseValue(key, value);
});
}
}
}
function getQueryParams(queryString) {
const params = {};
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
decodeParam(params, t[0], t[1]);
return t.join('=');
});
return params;
}
function decodeParam(params, name, value) {
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
}
function parseValue(flagName, value) {
const lowerCaseValue = value.toLowerCase();
if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
return lowerCaseValue === 'true';
}
else if (`${+lowerCaseValue}` === lowerCaseValue) {
return +lowerCaseValue;
}
else {
return value;
}
}
function env() {
return ENV$2;
}
let ENV$2 = null;
function setEnvironmentGlobal(environment) {
ENV$2 = environment;
}
let globalNameSpace;
function getGlobalNamespace() {
if (globalNameSpace == null) {
let ns;
if (typeof (window) !== 'undefined') {
ns = window;
}
else if (typeof (global) !== 'undefined') {
ns = global;
}
else if (typeof (process) !== 'undefined') {
ns = process;
}
else if (typeof (self) !== 'undefined') {
ns = self;
}
else {
throw new Error('Could not find a global object');
}
globalNameSpace = ns;
}
return globalNameSpace;
}
function getGlobalMap() {
const ns = getGlobalNamespace();
if (ns._tfGlobals == null) {
ns._tfGlobals = new Map();
}
return ns._tfGlobals;
}
function getGlobal(key, init) {
const globalMap = getGlobalMap();
if (globalMap.has(key)) {
return globalMap.get(key);
}
else {
const singleton = init();
globalMap.set(key, singleton);
return globalMap.get(key);
}
}
const Abs = 'Abs';
const Acos = 'Acos';
const Acosh = 'Acosh';
const Add = 'Add';
const AddN = 'AddN';
const All = 'All';
const Any = 'Any';
const ArgMax = 'ArgMax';
const ArgMin = 'ArgMin';
const Asin = 'Asin';
const Asinh = 'Asinh';
const Atan = 'Atan';
const Atanh = 'Atanh';
const Atan2 = 'Atan2';
const AvgPool = 'AvgPool';
const AvgPoolGrad = 'AvgPoolGrad';
const AvgPool3D = 'AvgPool3D';
const AvgPool3DGrad = 'AvgPool3DGrad';
const BatchMatMul = 'BatchMatMul';
const BatchToSpaceND = 'BatchToSpaceND';
const Bincount = 'Bincount';
const BitwiseAnd = 'BitwiseAnd';
const BroadcastTo = 'BroadcastTo';
const BroadcastArgs = 'BroadcastArgs';
const Cast = 'Cast';
const Ceil = 'Ceil';
const ClipByValue = 'ClipByValue';
const Complex = 'Complex';
const ComplexAbs = 'ComplexAbs';
const Concat = 'Concat';
const Conv2D = 'Conv2D';
const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
const Conv2DBackpropInput = 'Conv2DBackpropInput';
const Conv3D = 'Conv3D';
const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
const Cos = 'Cos';
const Cosh = 'Cosh';
const Cumprod = 'Cumprod';
const Cumsum = 'Cumsum';
const CropAndResize = 'CropAndResize';
const DenseBincount = 'DenseBincount';
const DepthToSpace = 'DepthToSpace';
const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
const Diag = 'Diag';
const Dilation2D = 'Dilation2D';
const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
const Draw = 'Draw';
const RealDiv = 'RealDiv';
const Einsum = 'Einsum';
const Elu$1 = 'Elu';
const EluGrad = 'EluGrad';
const Erf = 'Erf';
const Equal = 'Equal';
const Exp = 'Exp';
const ExpandDims = 'ExpandDims';
const Expm1 = 'Expm1';
const FFT = 'FFT';
const Fill = 'Fill';
const FlipLeftRight = 'FlipLeftRight';
const Floor = 'Floor';
const FloorDiv = 'FloorDiv';
const FusedBatchNorm = 'FusedBatchNorm';
const GatherV2 = 'GatherV2';
const GatherNd = 'GatherNd';
const Greater = 'Greater';
const GreaterEqual = 'GreaterEqual';
const Identity$1 = 'Identity';
const IFFT = 'IFFT';
const Imag = 'Imag';
const IsFinite = 'IsFinite';
const IsInf = 'IsInf';
const IsNan = 'IsNan';
const LeakyRelu = 'LeakyRelu';
const Less = 'Less';
const LessEqual = 'LessEqual';
const LinSpace = 'LinSpace';
const Log = 'Log';
const Log1p = 'Log1p';
const LogicalAnd = 'LogicalAnd';
const LogicalNot = 'LogicalNot';
const LogicalOr = 'LogicalOr';
const LogSoftmax$1 = 'LogSoftmax';
const LRN = 'LRN';
const LRNGrad = 'LRNGrad';
const Max = 'Max';
const Maximum = 'Maximum';
const MaxPool = 'MaxPool';
const MaxPoolGrad = 'MaxPoolGrad';
const MaxPool3D = 'MaxPool3D';
const MaxPool3DGrad = 'MaxPool3DGrad';
const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
const Mean = 'Mean';
const Min = 'Min';
const Minimum = 'Minimum';
const MirrorPad = 'MirrorPad';
const Mod = 'Mod';
const Multinomial = 'Multinomial';
const Multiply = 'Multiply';
const Neg = 'Neg';
const NotEqual = 'NotEqual';
const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
const OnesLike = 'OnesLike';
const OneHot = 'OneHot';
const Pack = 'Pack';
const PadV2 = 'PadV2';
const Pow = 'Pow';
const Prelu = 'Prelu';
const Prod = 'Prod';
const RaggedGather = 'RaggedGather';
const RaggedRange = 'RaggedRange';
const RaggedTensorToTensor = 'RaggedTensorToTensor';
const Range = 'Range';
const Real = 'Real';
const Reciprocal = 'Reciprocal';
const Relu$1 = 'Relu';
const Reshape$1 = 'Reshape';
const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
const ResizeBilinear = 'ResizeBilinear';
const ResizeBilinearGrad = 'ResizeBilinearGrad';
const Relu6$1 = 'Relu6';
const Reverse = 'Reverse';
const Round = 'Round';
const Rsqrt = 'Rsqrt';
const ScatterNd = 'ScatterNd';
const TensorScatterUpdate = 'TensorScatterUpdate';
const SearchSorted = 'SearchSorted';
const Select = 'Select';
const Selu$1 = 'Selu';
const Slice = 'Slice';
const Sin = 'Sin';
const Sinh = 'Sinh';
const Sign = 'Sign';
const Sigmoid$1 = 'Sigmoid';
const Softplus$1 = 'Softplus';
const Sqrt = 'Sqrt';
const Sum = 'Sum';
const SpaceToBatchND = 'SpaceToBatchND';
const SplitV = 'SplitV';
const Softmax$1 = 'Softmax';
const SparseFillEmptyRows = 'SparseFillEmptyRows';
const SparseReshape = 'SparseReshape';
const SparseSegmentMean = 'SparseSegmentMean';
const SparseSegmentSum = 'SparseSegmentSum';
const SparseToDense = 'SparseToDense';
const SquaredDifference = 'SquaredDifference';
const Square = 'Square';
const StaticRegexReplace = 'StaticRegexReplace';
const StridedSlice = 'StridedSlice';
const StringNGrams = 'StringNGrams';
const StringSplit = 'StringSplit';
const StringToHashBucketFast = 'StringToHashBucketFast';
const Sub = 'Sub';
const Tan = 'Tan';
const Tanh$1 = 'Tanh';
const Tile = 'Tile';
const TopK = 'TopK';
const Transform = 'Transform';
const Transpose = 'Transpose';
const Unique = 'Unique';
const Unpack = 'Unpack';
const UnsortedSegmentSum = 'UnsortedSegmentSum';
const ZerosLike = 'ZerosLike';
const Step = 'Step';
const FromPixels = 'FromPixels';
const RotateWithOffset = 'RotateWithOffset';
const _FusedMatMul = '_FusedMatMul';
const FusedConv2D = 'FusedConv2D';
const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
function warn(...msg) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.warn(...msg);
}
}
function log$3(...msg) {
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
console.log(...msg);
}
}
const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
const gradRegistry = getGlobal('gradRegistry', () => new Map());
function getKernel(kernelName, backendName) {
const key = makeKey(kernelName, backendName);
return kernelRegistry.get(key);
}
function getGradient(kernelName) {
return gradRegistry.get(kernelName);
}
function getKernelsForBackend(backendName) {
const it = kernelRegistry.entries();
const result = [];
while (true) {
const { done, value } = it.next();
if (done) {
break;
}
const [key, config] = value;
const [backend,] = key.split('_');
if (backend === backendName) {
result.push(config);
}
}
return result;
}
function registerKernel(config) {
const { kernelName, backendName } = config;
const key = makeKey(kernelName, backendName);
if (kernelRegistry.has(key)) {
warn(`The kernel '${kernelName}' for backend ` +
`'${backendName}' is already registered`);
}
kernelRegistry.set(key, config);
}
function registerGradient(config) {
const { kernelName } = config;
if (gradRegistry.has(kernelName)) {
if (env().getBool('DEBUG')) {
warn(`Overriding the gradient for '${kernelName}'`);
}
}
gradRegistry.set(kernelName, config);
}
function makeKey(kernelName, backendName) {
return `${backendName}_${kernelName}`;
}
setTimeout(() => env().setPlatform('browser', new PlatformStub()));
function isTypedArrayBrowser(a) {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
class PlatformStub {
constructor() {
}
fetch(path, init) {
throw new Error("fetch is not supported in this build.");
}
now() {
return performance.now();
}
encode(text, encoding) {
if (encoding !== 'utf-8' && encoding !== 'utf8') {
throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`);
}
if (this.textEncoder == null) {
this.textEncoder = new TextEncoder();
}
return this.textEncoder.encode(text);
}
decode(bytes, encoding) {
return new TextDecoder(encoding).decode(bytes);
}
setTimeoutCustom(functionRef, delay) {
if (typeof window === 'undefined' ||
!env().getBool('USE_SETTIMEOUTCUSTOM')) {
setTimeout(functionRef, delay);
return;
}
this.functionRefs.push(functionRef);
setTimeout(() => {
window.postMessage({name: this.messageName, index: this.functionRefs.length - 1}, location.origin);
}, delay);
if (!this.hasEventListener) {
this.hasEventListener = true;
window.addEventListener('message', (event) => {
if (event.source === window && event.data.name === this.messageName) {
event.stopPropagation();
const functionRef = this.functionRefs[event.data.index];
functionRef();
this.handledMessageCount++;
if (this.handledMessageCount === this.functionRefs.length) {
this.functionRefs = [];
this.handledMessageCount = 0;
}
}
}, true);
}
}
isTypedArray(a) {
return isTypedArrayBrowser(a)
}
}
var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
function getDefaultExportFromCjs (x) {
return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
}
var long = Long$1;
var wasm = null;
try {
wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([
0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11
])), {}).exports;
} catch (e) {
}
function Long$1(low, high, unsigned) {
this.low = low | 0;
this.high = high | 0;
this.unsigned = !!unsigned;
}
Object.defineProperty(Long$1.prototype, "__isLong__", { value: true });
function isLong(obj) {
return (obj && obj["__isLong__"]) === true;
}
Long$1.isLong = isLong;
var INT_CACHE = {};
var UINT_CACHE = {};
function fromInt(value, unsigned) {
var obj, cachedObj, cache;
if (unsigned) {
value >>>= 0;
if (cache = (0 <= value && value < 256)) {
cachedObj = UINT_CACHE[value];
if (cachedObj)
return cachedObj;
}
obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
if (cache)
UINT_CACHE[value] = obj;
return obj;
} else {
value |= 0;
if (cache = (-128 <= value && value < 128)) {
cachedObj = INT_CACHE[value];
if (cachedObj)
return cachedObj;
}
obj = fromBits(value, value < 0 ? -1 : 0, false);
if (cache)
INT_CACHE[value] = obj;
return obj;
}
}
Long$1.fromInt = fromInt;
function fromNumber(value, unsigned) {
if (isNaN(value))
return unsigned ? UZERO : ZERO;
if (unsigned) {
if (value < 0)
return UZERO;
if (value >= TWO_PWR_64_DBL)
return MAX_UNSIGNED_VALUE;
} else {
if (value <= -TWO_PWR_63_DBL)
return MIN_VALUE;
if (value + 1 >= TWO_PWR_63_DBL)
return MAX_VALUE;
}
if (value < 0)
return fromNumber(-value, unsigned).neg();
return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned);
}
Long$1.fromNumber = fromNumber;
function fromBits(lowBits, highBits, unsigned) {
return new Long$1(lowBits, highBits, unsigned);
}
Long$1.fromBits = fromBits;
var pow_dbl = Math.pow;
function fromString(str, unsigned, radix) {
if (str.length === 0)
throw Error('empty string');
if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity")
return ZERO;
if (typeof unsigned === 'number') {
radix = unsigned,
unsigned = false;
} else {
unsigned = !! unsigned;
}
radix = radix || 10;
if (radix < 2 || 36 < radix)
throw RangeError('radix');
var p;
if ((p = str.indexOf('-')) > 0)
throw Error('interior hyphen');
else if (p === 0) {
return fromString(str.substring(1), unsigned, radix).neg();
}
var radixToPower = fromNumber(pow_dbl(radix, 8));
var result = ZERO;
for (var i = 0; i < str.length; i += 8) {
var size = Math.min(8, str.length - i),
value = parseInt(str.substring(i, i + size), radix);
if (size < 8) {
var power = fromNumber(pow_dbl(radix, size));
result = result.mul(power).add(fromNumber(value));
} else {
result = result.mul(radixToPower);
result = result.add(fromNumber(value));
}
}
result.unsigned = unsigned;
return result;
}
Long$1.fromString = fromString;
function fromValue(val, unsigned) {
if (typeof val === 'number')
return fromNumber(val, unsigned);
if (typeof val === 'string')
return fromString(val, unsigned);
return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
}
Long$1.fromValue = fromValue;
var TWO_PWR_16_DBL = 1 << 16;
var TWO_PWR_24_DBL = 1 << 24;
var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
var ZERO = fromInt(0);
Long$1.ZERO = ZERO;
var UZERO = fromInt(0, true);
Long$1.UZERO = UZERO;
var ONE = fromInt(1);
Long$1.ONE = ONE;
var UONE = fromInt(1, true);
Long$1.UONE = UONE;
var NEG_ONE = fromInt(-1);
Long$1.NEG_ONE = NEG_ONE;
var MAX_VALUE = fromBits(0xFFFFFFFF|0, 0x7FFFFFFF|0, false);
Long$1.MAX_VALUE = MAX_VALUE;
var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF|0, 0xFFFFFFFF|0, true);
Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
var MIN_VALUE = fromBits(0, 0x80000000|0, false);
Long$1.MIN_VALUE = MIN_VALUE;
var LongPrototype = Long$1.prototype;
LongPrototype.toInt = function toInt() {
return this.unsigned ? this.low >>> 0 : this.low;
};
LongPrototype.toNumber = function toNumber() {
if (this.unsigned)
return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0);
return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
};
LongPrototype.toString = function toString(radix) {
radix = radix || 10;
if (radix < 2 || 36 < radix)
throw RangeError('radix');
if (this.isZero())
return '0';
if (this.isNegative()) {
if (this.eq(MIN_VALUE)) {
var radixLong = fromNumber(radix),
div = this.div(radixLong),
rem1 = div.mul(radixLong).sub(this);
return div.toString(radix) + rem1.toInt().toString(radix);
} else
return '-' + this.neg().toString(radix);
}
var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned),
rem = this;
var result = '';
while (true) {
var remDiv = rem.div(radixToPower),
intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0,
digits = intval.toString(radix);
rem = remDiv;
if (rem.isZero())
return digits + result;
else {
while (digits.length < 6)
digits = '0' + digits;
result = '' + digits + result;
}
}
};
LongPrototype.getHighBits = function getHighBits() {
return this.high;
};
LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
return this.high >>> 0;
};
LongPrototype.getLowBits = function getLowBits() {
return this.low;
};
LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
return this.low >>> 0;
};
LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
if (this.isNegative())
return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
var val = this.high != 0 ? this.high : this.low;
for (var bit = 31; bit > 0; bit--)
if ((val & (1 << bit)) != 0)
break;
return this.high != 0 ? bit + 33 : bit + 1;
};
LongPrototype.isZero = function isZero() {
return this.high === 0 && this.low === 0;
};
LongPrototype.eqz = LongPrototype.isZero;
LongPrototype.isNegative = function isNegative() {
return !this.unsigned && this.high < 0;
};
LongPrototype.isPositive = function isPositive() {
return this.unsigned || this.high >= 0;
};
LongPrototype.isOdd = function isOdd() {
return (this.low & 1) === 1;
};
LongPrototype.isEven = function isEven() {
return (this.low & 1) === 0;
};
LongPrototype.equals = function equals(other) {
if (!isLong(other))
other = fromValue(other);
if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1)
return false;
return this.high === other.high && this.low === other.low;
};
LongPrototype.eq = LongPrototype.equals;
LongPrototype.notEquals = function notEquals(other) {
return !this.eq( other);
};
LongPrototype.neq = LongPrototype.notEquals;
LongPrototype.ne = LongPrototype.notEquals;
LongPrototype.lessThan = function lessThan(other) {
return this.comp( other) < 0;
};
LongPrototype.lt = LongPrototype.lessThan;
LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
return this.comp( other) <= 0;
};
LongPrototype.lte = LongPrototype.lessThanOrEqual;
LongPrototype.le = LongPrototype.lessThanOrEqual;
LongPrototype.greaterThan = function greaterThan(other) {
return this.comp( other) > 0;
};
LongPrototype.gt = LongPrototype.greaterThan;
LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
return this.comp( other) >= 0;
};
LongPrototype.gte = LongPrototype.greaterThanOrEqual;
LongPrototype.ge = LongPrototype.greaterThanOrEqual;
LongPrototype.compare = function compare(other) {
if (!isLong(other))
other = fromValue(other);
if (this.eq(other))
return 0;
var thisNeg = this.isNegative(),
otherNeg = other.isNegative();
if (thisNeg && !otherNeg)
return -1;
if (!thisNeg && otherNeg)
return 1;
if (!this.unsigned)
return this.sub(other).isNegative() ? -1 : 1;
return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1;
};
LongPrototype.comp = LongPrototype.compare;
LongPrototype.negate = function negate() {
if (!this.unsigned && this.eq(MIN_VALUE))
return MIN_VALUE;
return this.not().add(ONE);
};
LongPrototype.neg = LongPrototype.negate;
LongPrototype.add = function add(addend) {
if (!isLong(addend))
addend = fromValue(addend);
var a48 = this.high >>> 16;
var a32 = this.high & 0xFFFF;
var a16 = this.low >>> 16;
var a00 = this.low & 0xFFFF;
var b48 = addend.high >>> 16;
var b32 = addend.high & 0xFFFF;
var b16 = addend.low >>> 16;
var b00 = addend.low & 0xFFFF;
var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
c00 += a00 + b00;
c16 += c00 >>> 16;
c00 &= 0xFFFF;
c16 += a16 + b16;
c32 += c16 >>> 16;
c16 &= 0xFFFF;
c32 += a32 + b32;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c48 += a48 + b48;
c48 &= 0xFFFF;
return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
};
LongPrototype.subtract = function subtract(subtrahend) {
if (!isLong(subtrahend))
subtrahend = fromValue(subtrahend);
return this.add(subtrahend.neg());
};
LongPrototype.sub = LongPrototype.subtract;
LongPrototype.multiply = function multiply(multiplier) {
if (this.isZero())
return ZERO;
if (!isLong(multiplier))
multiplier = fromValue(multiplier);
if (wasm) {
var low = wasm.mul(this.low,
this.high,
multiplier.low,
multiplier.high);
return fromBits(low, wasm.get_high(), this.unsigned);
}
if (multiplier.isZero())
return ZERO;
if (this.eq(MIN_VALUE))
return multiplier.isOdd() ? MIN_VALUE : ZERO;
if (multiplier.eq(MIN_VALUE))
return this.isOdd() ? MIN_VALUE : ZERO;
if (this.isNegative()) {
if (multiplier.isNegative())
return this.neg().mul(multiplier.neg());
else
return this.neg().mul(multiplier).neg();
} else if (multiplier.isNegative())
return this.mul(multiplier.neg()).neg();
if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24))
return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned);
var a48 = this.high >>> 16;
var a32 = this.high & 0xFFFF;
var a16 = this.low >>> 16;
var a00 = this.low & 0xFFFF;
var b48 = multiplier.high >>> 16;
var b32 = multiplier.high & 0xFFFF;
var b16 = multiplier.low >>> 16;
var b00 = multiplier.low & 0xFFFF;
var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
c00 += a00 * b00;
c16 += c00 >>> 16;
c00 &= 0xFFFF;
c16 += a16 * b00;
c32 += c16 >>> 16;
c16 &= 0xFFFF;
c16 += a00 * b16;
c32 += c16 >>> 16;
c16 &= 0xFFFF;
c32 += a32 * b00;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c32 += a16 * b16;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c32 += a00 * b32;
c48 += c32 >>> 16;
c32 &= 0xFFFF;
c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
c48 &= 0xFFFF;
return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
};
LongPrototype.mul = LongPrototype.multiply;
LongPrototype.divide = function divide(divisor) {
if (!isLong(divisor))
divisor = fromValue(divisor);
if (divisor.isZero())
throw Error('division by zero');
if (wasm) {
if (!this.unsigned &&
this.high === -2147483648 &&
divisor.low === -1 && divisor.high === -1) {
return this;
}
var low = (this.unsigned ? wasm.div_u : wasm.div_s)(
this.low,
this.high,
divisor.low,
divisor.high
);
return fromBits(low, wasm.get_high(), this.unsigned);
}
if (this.isZero())
return this.unsigned ? UZERO : ZERO;
var approx, rem, res;
if (!this.unsigned) {
if (this.eq(MIN_VALUE)) {
if (divisor.eq(ONE) || divisor.eq(NEG_ONE))
return MIN_VALUE;
else if (divisor.eq(MIN_VALUE))
return ONE;
else {
var halfThis = this.shr(1);
approx = halfThis.div(divisor).shl(1);
if (approx.eq(ZERO)) {
return divisor.isNegative() ? ONE : NEG_ONE;
} else {
rem = this.sub(divisor.mul(approx));
res = approx.add(rem.div(divisor));
return res;
}
}
} else if (divisor.eq(MIN_VALUE))
return this.unsigned ? UZERO : ZERO;
if (this.isNegative()) {
if (divisor.isNegative())
return this.neg().div(divisor.neg());
return this.neg().div(divisor).neg();
} else if (divisor.isNegative())
return this.div(divisor.neg()).neg();
res = ZERO;
} else {
if (!divisor.unsigned)
divisor = divisor.toUnsigned();
if (divisor.gt(this))
return UZERO;
if (divisor.gt(this.shru(1)))
return UONE;
res = UZERO;
}
rem = this;
while (rem.gte(divisor)) {
approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber()));
var log2 = Math.ceil(Math.log(approx) / Math.LN2),
delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48),
approxRes = fromNumber(approx),
approxRem = approxRes.mul(divisor);
while (approxRem.isNegative() || approxRem.gt(rem)) {
approx -= delta;
approxRes = fromNumber(approx, this.unsigned);
approxRem = approxRes.mul(divisor);
}
if (approxRes.isZero())
approxRes = ONE;
res = res.add(approxRes);
rem = rem.sub(approxRem);
}
return res;
};
LongPrototype.div = LongPrototype.divide;
LongPrototype.modulo = function modulo(divisor) {
if (!isLong(divisor))
divisor = fromValue(divisor);
if (wasm) {
var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(
this.low,
this.high,
divisor.low,
divisor.high
);
return fromBits(low, wasm.get_high(), this.unsigned);
}
return this.sub(this.div(divisor).mul(divisor));
};
LongPrototype.mod = LongPrototype.modulo;
LongPrototype.rem = LongPrototype.modulo;
LongPrototype.not = function not() {
return fromBits(~this.low, ~this.high, this.unsigned);
};
LongPrototype.and = function and(other) {
if (!isLong(other))
other = fromValue(other);
return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
};
LongPrototype.or = function or(other) {
if (!isLong(other))
other = fromValue(other);
return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
};
LongPrototype.xor = function xor(other) {
if (!isLong(other))
other = fromValue(other);
return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
};
LongPrototype.shiftLeft = function shiftLeft(numBits) {
if (isLong(numBits))
numBits = numBits.toInt();
if ((numBits &= 63) === 0)
return this;
else if (numBits < 32)
return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned);
else
return fromBits(0, this.low << (numBits - 32), this.unsigned);
};
LongPrototype.shl = LongPrototype.shiftLeft;
LongPrototype.shiftRight = function shiftRight(numBits) {
if (isLong(numBits))
numBits = numBits.toInt();
if ((numBits &= 63) === 0)
return this;
else if (numBits < 32)
return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned);
else
return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned);
};
LongPrototype.shr = LongPrototype.shiftRight;
LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
if (isLong(numBits))
numBits = numBits.toInt();
numBits &= 63;
if (numBits === 0)
return this;
else {
var high = this.high;
if (numBits < 32) {
var low = this.low;
return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned);
} else if (numBits === 32)
return fromBits(high, 0, this.unsigned);
else
return fromBits(high >>> (numBits - 32), 0, this.unsigned);
}
};
LongPrototype.shru = LongPrototype.shiftRightUnsigned;
LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
LongPrototype.toSigned = function toSigned() {
if (!this.unsigned)
return this;
return fromBits(this.low, this.high, false);
};
LongPrototype.toUnsigned = function toUnsigned() {
if (this.unsigned)
return this;
return fromBits(this.low, this.high, true);
};
LongPrototype.toBytes = function toBytes(le) {
return le ? this.toBytesLE() : this.toBytesBE();
};
LongPrototype.toBytesLE = function toBytesLE() {
var hi = this.high,
lo = this.low;
return [
lo & 0xff,
lo >>> 8 & 0xff,
lo >>> 16 & 0xff,
lo >>> 24 ,
hi & 0xff,
hi >>> 8 & 0xff,
hi >>> 16 & 0xff,
hi >>> 24
];
};
LongPrototype.toBytesBE = function toBytesBE() {
var hi = this.high,
lo = this.low;
return [
hi >>> 24 ,
hi >>> 16 & 0xff,
hi >>> 8 & 0xff,
hi & 0xff,
lo >>> 24 ,
lo >>> 16 & 0xff,
lo >>> 8 & 0xff,
lo & 0xff
];
};
Long$1.fromBytes = function fromBytes(bytes, unsigned, le) {
return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned);
};
Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) {
return new Long$1(
bytes[0] |
bytes[1] << 8 |
bytes[2] << 16 |
bytes[3] << 24,
bytes[4] |
bytes[5] << 8 |
bytes[6] << 16 |
bytes[7] << 24,
unsigned
);
};
Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) {
return new Long$1(
bytes[4] << 24 |
bytes[5] << 16 |
bytes[6] << 8 |
bytes[7],
bytes[0] << 24 |
bytes[1] << 16 |
bytes[2] << 8 |
bytes[3],
unsigned
);
};
var long$1 = getDefaultExportFromCjs(long);
var LongExports = _mergeNamespaces({
__proto__: null,
default: long$1
}, [long]);
const Long =
long$1 || LongExports;
function hexToLong(hex) {
return Long.fromString(hex, true, 16);
}
const k0 = hexToLong('c3a5c85c97cb3127');
const k1 = hexToLong('b492b66fbe98f273');
const k2 = hexToLong('9ae16a3b2f90404f');
function shiftMix(val) {
return val.xor(val.shru(47));
}
function fetch(s, offset, numBytes) {
const bytes = s.slice(offset, offset + numBytes);
return Long.fromBytes(Array.from(bytes), true, true);
}
function fetch64(s, offset) {
return fetch(s, offset, 8);
}
function fetch32(s, offset) {
return fetch(s, offset, 4);
}
function rotate64(val, shift) {
return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
}
function hashLen16(u, v, mul = hexToLong('9ddfea08eb382d69')) {
let a = u.xor(v).mul(mul);
a = a.xor(a.shru(47));
let b = v.xor(a).mul(mul);
b = b.xor(b.shru(47));
b = b.mul(mul);
return b;
}
function weakHashLen32WithSeeds(w, x, y, z, a, b) {
a = a.add(w);
b = rotate64(b.add(a).add(z), 21);
const c = a;
a = a.add(x);
a = a.add(y);
b = b.add(rotate64(a, 44));
return [a.add(z), b.add(c)];
}
function weakHashLen32WithSeedsStr(s, offset, a, b) {
return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
}
function hashLen0to16(s, len = s.length) {
if (len >= 8) {
const mul = k2.add(len * 2);
const a = fetch64(s, 0).add(k2);
const b = fetch64(s, len - 8);
const c = rotate64(b, 37).mul(mul).add(a);
const d = rotate64(a, 25).add(b).mul(mul);
return hashLen16(c, d, mul);
}
if (len >= 4) {
const mul = k2.add(len * 2);
const a = fetch32(s, 0);
return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul);
}
if (len > 0) {
const a = s[0];
const b = s[len >> 1];
const c = s[len - 1];
const y = a + (b << 8);
const z = len + (c << 2);
return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
}
return k2;
}
function hashLen17to32(s, len = s.length) {
const mul = k2.add(len * 2);
const a = fetch64(s, 0).mul(k1);
const b = fetch64(s, 8);
const c = fetch64(s, len - 8).mul(mul);
const d = fetch64(s, len - 16).mul(k2);
return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
}
function hashLen33to64(s, len = s.length) {
const mul = k2.add(len * 2);
const a = fetch64(s, 0).mul(k2);
const b = fetch64(s, 8);
const c = fetch64(s, len - 8).mul(mul);
const d = fetch64(s, len - 16).mul(k2);
const y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
const z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
const e = fetch64(s, 16).mul(mul);
const f = fetch64(s, 24);
const g = y.add(fetch64(s, len - 32)).mul(mul);
const h = z.add(fetch64(s, len - 24)).mul(mul);
return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
}
function fingerPrint64(s, len = s.length) {
const seed = Long.fromNumber(81, true);
if (len <= 32) {
if (len <= 16) {
return hashLen0to16(s, len);
}
else {
return hashLen17to32(s, len);
}
}
else if (len <= 64) {
return hashLen33to64(s, len);
}
let x = seed;
let y = seed.mul(k1).add(113);
let z = shiftMix(y.mul(k2).add(113)).mul(k2);
let v = [Long.UZERO, Long.UZERO];
let w = [Long.UZERO, Long.UZERO];
x = x.mul(k2).add(fetch64(s, 0));
let offset = 0;
const end = ((len - 1) >> 6) * 64;
const last64 = end + ((len - 1) & 63) - 63;
do {
x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
x = x.xor(w[1]);
y = y.add(v[0]).add(fetch64(s, offset + 40));
z = rotate64(z.add(w[0]), 33).mul(k1);
v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
[z, x] = [x, z];
offset += 64;
} while (offset !== end);
const mul = k1.add(z.and(0xff).shl(1));
offset = last64;
w[0] = w[0].add((len - 1) & 63);
v[0] = v[0].add(w[0]);
w[0] = w[0].add(v[0]);
x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
x = x.xor(w[1].mul(9));
y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
z = rotate64(z.add(w[0]), 33).mul(mul);
v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
[z, x] = [x, z];
return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
}
function createScalarValue(value, dtype) {
if (dtype === 'string') {
return encodeString(value);
}
return toTypedArray([value], dtype);
}
function noConversionNeeded(a, dtype) {
return (a instanceof Float32Array && dtype === 'float32') ||
(a instanceof Int32Array && dtype === 'int32') ||
(a instanceof Uint8Array && dtype === 'bool');
}
function toTypedArray(a, dtype) {
if (dtype === 'string') {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = flatten$1(a);
}
if (env().getBool('DEBUG')) {
checkConversionForErrors(a, dtype);
}
if (noConversionNeeded(a, dtype)) {
return a;
}
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
return new Float32Array(a);
}
else if (dtype === 'int32') {
return new Int32Array(a);
}
else if (dtype === 'bool') {
const bool = new Uint8Array(a.length);
for (let i = 0; i < bool.length; ++i) {
if (Math.round(a[i]) !== 0) {
bool[i] = 1;
}
}
return bool;
}
else {
throw new Error(`Unknown data type ${dtype}`);
}
}
function now() {
return env().platform.now();
}
function encodeString(s, encoding = 'utf-8') {
encoding = encoding || 'utf-8';
return env().platform.encode(s, encoding);
}
function decodeString(bytes, encoding = 'utf-8') {
encoding = encoding || 'utf-8';
return env().platform.decode(bytes, encoding);
}
function isTypedArray(a) {
if (env().platform.isTypedArray != null) {
return env().platform.isTypedArray(a);
}
else {
return isTypedArrayBrowser(a);
}
}
function flatten$1(arr, result = [], skipTypedArray = false) {
if (result == null) {
result = [];
}
if (typeof arr === 'boolean' || typeof arr === 'number' ||
typeof arr === 'string' || isPromise(arr) || arr == null ||
isTypedArray(arr) && skipTypedArray) {
result.push(arr);
}
else if (Array.isArray(arr) || isTypedArray(arr)) {
for (let i = 0; i < arr.length; ++i) {
flatten$1(arr[i], result, skipTypedArray);
}
}
else {
let maxIndex = -1;
for (const key of Object.keys(arr)) {
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
maxIndex = Math.max(maxIndex, Number(key));
}
}
for (let i = 0; i <= maxIndex; i++) {
flatten$1(arr[i], result, skipTypedArray);
}
}
return result;
}
class Profiler {
constructor(backendTimer, logger) {
this.backendTimer = backendTimer;
this.logger = logger;
if (logger == null) {
this.logger = new Logger();
}
}
profileKernel(kernelName, inputs, f) {
let outputs;
const holdResultWrapperFn = () => {
outputs = f();
};
let timer;
const start = now();
if (this.backendTimer.timerAvailable()) {
timer = this.backendTimer.time(holdResultWrapperFn);
}
else {
holdResultWrapperFn();
for (const output of outputs) {
output.dataSync();
}
timer = Promise.resolve({ kernelMs: now() - start });
}
if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
for (let i = 0; i < outputs.length; i++) {
const output = outputs[i];
output.data().then(tensorVals => {
checkComputationForErrors(tensorVals, output.dtype, kernelName);
});
}
}
const kernelProfile = {
kernelName,
outputs,
inputs,
timeMs: timer.then(timing => timing.kernelMs),
extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ?
timing.getExtraProfileInfo() :
'')
};
return kernelProfile;
}
logKernelProfile(kernelProfile) {
const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile;
outputs.forEach(result => {
Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => {
this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
});
});
}
}
function checkComputationForErrors(vals, dtype, kernelName) {
if (dtype !== 'float32') {
return false;
}
for (let i = 0; i < vals.length; i++) {
const num = vals[i];
if (isNaN(num) || !isFinite(num)) {
console.warn(`Found ${num} in the result of '${kernelName}'`);
return true;
}
}
return false;
}
class Logger {
logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) :
timeMs['error'];
const paddedName = rightPad(name, 25);
const rank = result.rank;
const size = result.size;
const shape = rightPad(result.shape.toString(), 14);
let inputShapesDescription = '';
for (const name in inputs) {
const input = inputs[name];
if (input != null) {
const inputShape = input.shape || result.shape;
const inputRank = inputShape.length;
inputShapesDescription +=
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
}
}
console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
}
}
function getFilteredNodesXToY(tape, xs, y) {
const tensorsFromX = {};
const nodesFromX = {};
for (let i = 0; i < xs.length; i++) {
tensorsFromX[xs[i].id] = true;
}
for (let i = 0; i < tape.length; i++) {
const node = tape[i];
const nodeInputs = node.inputs;
for (const inputName in nodeInputs) {
const input = nodeInputs[inputName];
let anyInputFromX = false;
for (let j = 0; j < xs.length; j++) {
if (tensorsFromX[input.id]) {
node.outputs.forEach(output => tensorsFromX[output.id] = true);
anyInputFromX = true;
nodesFromX[node.id] = true;
break;
}
}
if (anyInputFromX) {
break;
}
}
}
const tensorsLeadToY = {};
tensorsLeadToY[y.id] = true;
const nodesToY = {};
for (let i = tape.length - 1; i >= 0; i--) {
const node = tape[i];
const nodeInputs = node.inputs;
for (let j = 0; j < node.outputs.length; j++) {
if (tensorsLeadToY[node.outputs[j].id]) {
for (const inputName in nodeInputs) {
tensorsLeadToY[nodeInputs[inputName].id] = true;
nodesToY[node.id] = true;
}
break;
}
}
}
const filteredTape = [];
for (let i = 0; i < tape.length; i++) {
const node = tape[i];
if (nodesFromX[node.id] && nodesToY[node.id]) {
const prunedInputs = {};
for (const inputName in node.inputs) {
const nodeInput = node.inputs[inputName];
if (tensorsFromX[nodeInput.id]) {
prunedInputs[inputName] = nodeInput;
}
}
const prunedNode = Object.assign({}, node);
prunedNode.inputs = prunedInputs;
prunedNode.outputs = node.outputs;
filteredTape.push(prunedNode);
}
}
return filteredTape;
}
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
for (let i = filteredTape.length - 1; i >= 0; i--) {
const node = filteredTape[i];
const dys = [];
node.outputs.forEach(o => {
const gradTensor = tensorAccumulatedGradientMap[o.id];
if (gradTensor != null) {
dys.push(gradTensor);
}
else {
dys.push(null);
}
});
if (node.gradient == null) {
throw new Error(`Cannot compute gradient: gradient function not found ` +
`for ${node.kernelName}.`);
}
const inputGradients = node.gradient(dys);
for (const inputName in node.inputs) {
if (!(inputName in inputGradients)) {
throw new Error(`Cannot backprop through input ${inputName}. ` +
`Available gradients found: ${Object.keys(inputGradients)}.`);
}
const dx = tidy(() => inputGradients[inputName]());
if (dx.dtype !== 'float32') {
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
}
const x = node.inputs[inputName];
if (!arraysEqual(dx.shape, x.shape)) {
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
`'${inputName}' has shape '${dx.shape}', which does not match ` +
`the shape of the input '${x.shape}'`);
}
if (tensorAccumulatedGradientMap[x.id] == null) {
tensorAccumulatedGradientMap[x.id] = dx;
}
else {
const curGradient = tensorAccumulatedGradientMap[x.id];
tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
curGradient.dispose();
}
}
}
}
const FORMAT_LIMIT_NUM_VALS = 20;
const FORMAT_NUM_FIRST_LAST_VALS = 3;
const FORMAT_NUM_SIG_DIGITS = 7;
function tensorToString(vals, shape, dtype, verbose) {
const strides = computeStrides(shape);
const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
const rank = shape.length;
const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
const lines = ['Tensor'];
if (verbose) {
lines.push(` dtype: ${dtype}`);
lines.push(` rank: ${rank}`);
lines.push(` shape: [${shape}]`);
lines.push(` values:`);
}
lines.push(valsLines.map(l => ' ' + l).join('\n'));
return lines.join('\n');
}
function computeMaxSizePerColumn(vals, shape, dtype, strides) {
const n = sizeFromShape(shape);
const numCols = strides[strides.length - 1];
const padPerCol = new Array(numCols).fill(0);
const rank = shape.length;
const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
if (rank > 1) {
for (let row = 0; row < n / numCols; row++) {
const offset = row * numCols;
for (let j = 0; j < numCols; j++) {
padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
}
}
}
return padPerCol;
}
function valToString(val, pad, dtype) {
let valStr;
if (Array.isArray(val)) {
valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +
`${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;
}
else if (isString(val)) {
valStr = `'${val}'`;
}
else if (dtype === 'bool') {
valStr = boolNumToString(val);
}
else {
valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
}
return rightPad(valStr, pad);
}
function boolNumToString(v) {
return v === 0 ? 'false' : 'true';
}
function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) {
const storagePerElement = dtype === 'complex64' ? 2 : 1;
const size = shape[0];
const rank = shape.length;
if (rank === 0) {
if (dtype === 'complex64') {
const complexTuple = createComplexTuples(vals);
return [valToString(complexTuple[0], 0, dtype)];
}
if (dtype === 'bool') {
return [boolNumToString(vals[0])];
}
return [vals[0].toString()];
}
if (rank === 1) {
if (size > FORMAT_LIMIT_NUM_VALS) {
const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
let firstVals = Array.from(vals.slice(0, firstValsSize));
let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
if (dtype === 'complex64') {
firstVals = createComplexTuples(firstVals);
lastVals = createComplexTuples(lastVals);
}
return [
'[' +
firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))
.join(', ') +
', ..., ' +
lastVals
.map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))
.join(', ') +
']'
];
}
const displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
Array.from(vals);
return [
'[' +
displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))
.join(', ') +
']'
];
}
const subshape = shape.slice(1);
const substrides = strides.slice(1);
const stride = strides[0] * storagePerElement;
const lines = [];
if (size > FORMAT_LIMIT_NUM_VALS) {
for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
const start = i * stride;
const end = start + stride;
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false ));
}
lines.push('...');
for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
const start = i * stride;
const end = start + stride;
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 ));
}
}
else {
for (let i = 0; i < size; i++) {
const start = i * stride;
const end = start + stride;
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 ));
}
}
const sep = rank === 2 ? ',' : '';
lines[0] = '[' + (size > 0 ? lines[0] + sep : '');
for (let i = 1; i < lines.length - 1; i++) {
lines[i] = ' ' + lines[i] + sep;
}
let newLineSep = ',\n';
for (let i = 2; i < rank; i++) {
newLineSep += '\n';
}
lines[lines.length - 1] =
' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
return lines;
}
function createComplexTuples(vals) {
const complexTuples = [];
for (let i = 0; i < vals.length; i += 2) {
complexTuples.push([vals[i], vals[i + 1]]);
}
return complexTuples;
}
class TensorBuffer {
constructor(shape, dtype, values) {
this.dtype = dtype;
this.shape = shape.slice();
this.size = sizeFromShape(shape);
if (values != null) {
const n = values.length;
assert$1(n === this.size, () => `Length of values '${n}' does not match the size ` +
`inferred by the shape '${this.size}'.`);
}
if (dtype === 'complex64') {
throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` +
`a TensorBuffer for the real and imaginary parts separately and ` +
`call tf.complex(real, imag).`);
}
this.values = values || getArrayFromDType(dtype, this.size);
this.strides = computeStrides(shape);
}
set(value, ...locs) {
if (locs.length === 0) {
locs = [0];
}
assert$1(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` +
`match the rank (${this.rank})`);
const index = this.locToIndex(locs);
this.values[index] = value;
}
get(...locs) {
if (locs.length === 0) {
locs = [0];
}
let i = 0;
for (const loc of locs) {
if (loc < 0 || loc >= this.shape[i]) {
const msg = `Requested out of range element at ${locs}. ` +
` Buffer shape=${this.shape}`;
throw new Error(msg);
}
i++;
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
}
return this.values[index];
}
locToIndex(locs) {
if (this.rank === 0) {
return 0;
}
else if (this.rank === 1) {
return locs[0];
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
}
return index;
}
indexToLoc(index) {
if (this.rank === 0) {
return [];
}
else if (this.rank === 1) {
return [index];
}
const locs = new Array(this.shape.length);
for (let i = 0; i < locs.length - 1; ++i) {
locs[i] = Math.floor(index / this.strides[i]);
index -= locs[i] * this.strides[i];
}
locs[locs.length - 1] = index;
return locs;
}
get rank() {
return this.shape.length;
}
toTensor() {
return trackerFn().makeTensor(this.values, this.shape, this.dtype);
}
}
let trackerFn = null;
let opHandler$1 = null;
function setTensorTracker(fn) {
trackerFn = fn;
}
function setOpHandler(handler) {
opHandler$1 = handler;
}
class Tensor {
constructor(shape, dtype, dataId, id) {
this.kept = false;
this.isDisposedInternal = false;
this.shape = shape.slice();
this.dtype = dtype || 'float32';
this.size = sizeFromShape(shape);
this.strides = computeStrides(shape);
this.dataId = dataId;
this.id = id;
this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
}
get rank() {
return this.shape.length;
}
async buffer() {
const vals = await this.data();
return opHandler$1.buffer(this.shape, this.dtype, vals);
}
bufferSync() {
return opHandler$1.buffer(this.shape, this.dtype, this.dataSync());
}
async array() {
const vals = await this.data();
return toNestedArray(this.shape, vals, this.dtype === 'complex64');
}
arraySync() {
return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
}
async data() {
this.throwIfDisposed();
const data = trackerFn().read(this.dataId);
if (this.dtype === 'string') {
const bytes = await data;
try {
return bytes.map(b => decodeString(b));
}
catch (_a) {
throw new Error('Failed to decode the string bytes into utf-8. ' +
'To get the original bytes, call tensor.bytes().');
}
}
return data;
}
dataToGPU(options) {
this.throwIfDisposed();
return trackerFn().readToGPU(this.dataId, options);
}
dataSync() {
this.throwIfDisposed();
const data = trackerFn().readSync(this.dataId);
if (this.dtype === 'string') {
try {
return data.map(b => decodeString(b));
}
catch (_a) {
throw new Error('Failed to decode the string bytes into utf-8. ' +
'To get the original bytes, call tensor.bytes().');
}
}
return data;
}
async bytes() {
this.throwIfDisposed();
const data = await trackerFn().read(this.dataId);
if (this.dtype === 'string') {
return data;
}
else {
return new Uint8Array(data.buffer);
}
}
dispose() {
if (this.isDisposed) {
return;
}
if (this.kerasMask) {
this.kerasMask.dispose();
}
trackerFn().disposeTensor(this);
this.isDisposedInternal = true;
}
get isDisposed() {
return this.isDisposedInternal;
}
throwIfDisposed() {
if (this.isDisposed) {
throw new Error(`Tensor is disposed.`);
}
}
print(verbose = false) {
return opHandler$1.print(this, verbose);
}
clone() {
this.throwIfDisposed();
return opHandler$1.clone(this);
}
toString(verbose = false) {
const vals = this.dataSync();
return tensorToString(vals, this.shape, this.dtype, verbose);
}
cast(dtype) {
this.throwIfDisposed();
return opHandler$1.cast(this, dtype);
}
variable(trainable = true, name, dtype) {
this.throwIfDisposed();
return trackerFn().makeVariable(this, trainable, name, dtype);
}
}
Object.defineProperty(Tensor, Symbol.hasInstance, {
value: (instance) => {
return !!instance && instance.data != null && instance.dataSync != null &&
instance.throwIfDisposed != null;
}
});
function getGlobalTensorClass() {
return getGlobal('Tensor', () => {
return Tensor;
});
}
getGlobalTensorClass();
class Variable extends Tensor {
constructor(initialValue, trainable, name, tensorId) {
super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
this.trainable = trainable;
this.name = name;
}
assign(newValue) {
if (newValue.dtype !== this.dtype) {
throw new Error(`dtype of the new value (${newValue.dtype}) and ` +
`previous value (${this.dtype}) must match`);
}
if (!arraysEqual(newValue.shape, this.shape)) {
throw new Error(`shape of the new value (${newValue.shape}) and ` +
`previous value (${this.shape}) must match`);
}
trackerFn().disposeTensor(this);
this.dataId = newValue.dataId;
trackerFn().incRef(this, null );
}
dispose() {
trackerFn().disposeVariable(this);
this.isDisposedInternal = true;
}
}
Object.defineProperty(Variable, Symbol.hasInstance, {
value: (instance) => {
return instance instanceof Tensor && instance.assign != null &&
instance.assign instanceof Function;
}
});
var Rank;
(function (Rank) {
Rank["R0"] = "R0";
Rank["R1"] = "R1";
Rank["R2"] = "R2";
Rank["R3"] = "R3";
Rank["R4"] = "R4";
Rank["R5"] = "R5";
Rank["R6"] = "R6";
})(Rank || (Rank = {}));
var UpcastInt32AndMap;
(function (UpcastInt32AndMap) {
UpcastInt32AndMap["float32"] = "float32";
UpcastInt32AndMap["int32"] = "int32";
UpcastInt32AndMap["bool"] = "int32";
UpcastInt32AndMap["complex64"] = "complex64";
})(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
var UpcastBoolAndMap;
(function (UpcastBoolAndMap) {
UpcastBoolAndMap["float32"] = "float32";
UpcastBoolAndMap["int32"] = "int32";
UpcastBoolAndMap["bool"] = "bool";
UpcastBoolAndMap["complex64"] = "complex64";
})(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
var UpcastFloat32AndMap;
(function (UpcastFloat32AndMap) {
UpcastFloat32AndMap["float32"] = "float32";
UpcastFloat32AndMap["int32"] = "float32";
UpcastFloat32AndMap["bool"] = "float32";
UpcastFloat32AndMap["complex64"] = "complex64";
})(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
var UpcastComplex64AndMap;
(function (UpcastComplex64AndMap) {
UpcastComplex64AndMap["float32"] = "complex64";
UpcastComplex64AndMap["int32"] = "complex64";
UpcastComplex64AndMap["bool"] = "complex64";
UpcastComplex64AndMap["complex64"] = "complex64";
})(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
const upcastTypeMap = {
'float32': UpcastFloat32AndMap,
'int32': UpcastInt32AndMap,
'bool': UpcastBoolAndMap,
'complex64': UpcastComplex64AndMap
};
function upcastType(typeA, typeB) {
if (typeA === 'string' || typeB === 'string') {
if (typeA === 'string' && typeB === 'string') {
return 'string';
}
throw new Error(`Can not upcast ${typeA} with ${typeB}`);
}
return upcastTypeMap[typeA][typeB];
}
function sumOutType(type) {
return upcastType(type, 'int32');
}
function isWebGLData(values) {
return values != null && typeof values === 'object' && 'texture' in values &&
values.texture instanceof WebGLTexture;
}
function isWebGPUData(values) {
return typeof GPUBuffer !== 'undefined' && values != null &&
typeof values === 'object' && 'buffer' in values &&
values.buffer instanceof GPUBuffer;
}
function makeTypesMatch(a, b) {
if (a.dtype === b.dtype) {
return [a, b];
}
const dtype = upcastType(a.dtype, b.dtype);
return [a.cast(dtype), b.cast(dtype)];
}
function getTensorsInContainer(result) {
const list = [];
const seen = new Set();
walkTensorContainer(result, list, seen);
return list;
}
function walkTensorContainer(container, list, seen) {
if (container == null) {
return;
}
if (container instanceof Tensor) {
list.push(container);
return;
}
if (!isIterable(container)) {
return;
}
const iterable = container;
for (const k in iterable) {
const val = iterable[k];
if (!seen.has(val)) {
seen.add(val);
walkTensorContainer(val, list, seen);
}
}
}
function isIterable(obj) {
return Array.isArray(obj) || typeof obj === 'object';
}
function isRegisteredKernelInvocation(kernelInvocation) {
return kernelInvocation.kernelName != null;
}
class EngineState {
constructor() {
this.registeredVariables = {};
this.nextTapeNodeId = 0;
this.numBytes = 0;
this.numTensors = 0;
this.numStringTensors = 0;
this.numDataBuffers = 0;
this.gradientDepth = 0;
this.kernelDepth = 0;
this.scopeStack = [];
this.numDataMovesStack = [];
this.nextScopeId = 0;
this.tensorInfo = new WeakMap();
this.profiling = false;
this.activeProfile = {
newBytes: 0,
newTensors: 0,
peakBytes: 0,
kernels: [],
result: null,
get kernelNames() {
return Array.from(new Set(this.kernels.map(k => k.name)));
}
};
}
dispose() {
for (const variableName in this.registeredVariables) {
this.registeredVariables[variableName].dispose();
}
}
}
class Engine {
constructor(ENV) {
this.ENV = ENV;
this.registry = {};
this.registryFactory = {};
this.pendingBackendInitId = 0;
this.state = new EngineState();
}
async ready() {
if (this.pendingBackendInit != null) {
return this.pendingBackendInit.then(() => { });
}
if (this.backendInstance != null) {
return;
}
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const success = await this.initializeBackend(backendName).success;
if (success) {
await this.setBackend(backendName);
return;
}
}
throw new Error(`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
get backend() {
if (this.pendingBackendInit != null) {
throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
`sure to await tf.ready() or await tf.setBackend() before calling ` +
`other methods`);
}
if (this.backendInstance == null) {
const { name, asyncInit } = this.initializeBackendsAndReturnBest();
if (asyncInit) {
throw new Error(`The highest priority backend '${name}' has not yet been ` +
`initialized. Make sure to await tf.ready() or ` +
`await tf.setBackend() before calling other methods`);
}
this.setBackend(name);
}
return this.backendInstance;
}
backendNames() {
return Object.keys(this.registryFactory);
}
findBackend(backendName) {
if (!(backendName in this.registry)) {
if (backendName in this.registryFactory) {
const { asyncInit } = this.initializeBackend(backendName);
if (asyncInit) {
return null;
}
}
else {
return null;
}
}
return this.registry[backendName];
}
findBackendFactory(backendName) {
if (!(backendName in this.registryFactory)) {
return null;
}
return this.registryFactory[backendName].factory;
}
registerBackend(backendName, factory, priority = 1) {
if (backendName in this.registryFactory) {
warn(`${backendName} backend was already registered. ` +
`Reusing existing backend factory.`);
return false;
}
this.registryFactory[backendName] = { factory, priority };
return true;
}
async setBackend(backendName) {
if (this.registryFactory[backendName] == null) {
throw new Error(`Backend name '${backendName}' not found in registry`);
}
this.backendName = backendName;
if (this.registry[backendName] == null) {
this.backendInstance = null;
const { success, asyncInit } = this.initializeBackend(backendName);
const result = asyncInit ? await success : success;
if (!result) {
return false;
}
}
this.backendInstance = this.registry[backendName];
this.setupRegisteredKernels();
this.profiler = new Profiler(this.backendInstance);
return true;
}
setupRegisteredKernels() {
const kernels = getKernelsForBackend(this.backendName);
kernels.forEach(kernel => {
if (kernel.setupFunc != null) {
kernel.setupFunc(this.backendInstance);
}
});
}
disposeRegisteredKernels(backendName) {
const kernels = getKernelsForBackend(backendName);
kernels.forEach(kernel => {
if (kernel.disposeFunc != null) {
kernel.disposeFunc(this.registry[backendName]);
}
});
}
initializeBackend(backendName) {
const registryFactoryEntry = this.registryFactory[backendName];
if (registryFactoryEntry == null) {
throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
}
try {
const backend = registryFactoryEntry.factory();
if (backend && !(backend instanceof KernelBackend) &&
typeof backend.then === 'function') {
const promiseId = ++this.pendingBackendInitId;
const success = backend
.then(backendInstance => {
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.registry[backendName] = backendInstance;
this.pendingBackendInit = null;
return true;
})
.catch(err => {
if (promiseId < this.pendingBackendInitId) {
return false;
}
this.pendingBackendInit = null;
warn(`Initialization of backend ${backendName} failed`);
warn(err.stack || err.message);
return false;
});
this.pendingBackendInit = success;
return { success, asyncInit: true };
}
else {
this.registry[backendName] = backend;
return { success: true, asyncInit: false };
}
}
catch (err) {
warn(`Initialization of backend ${backendName} failed`);
warn(err.stack || err.message);
return { success: false, asyncInit: false };
}
}
removeBackend(backendName) {
if (!(backendName in this.registryFactory)) {
throw new Error(`${backendName} backend not found in registry`);
}
if (this.backendName === backendName && this.pendingBackendInit != null) {
this.pendingBackendInitId++;
}
if (backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
delete this.registryFactory[backendName];
if (this.backendName === backendName) {
this.pendingBackendInit = null;
this.backendName = null;
this.backendInstance = null;
}
}
getSortedBackends() {
if (Object.keys(this.registryFactory).length === 0) {
throw new Error('No backend found in registry.');
}
return Object.keys(this.registryFactory).sort((a, b) => {
return this.registryFactory[b].priority -
this.registryFactory[a].priority;
});
}
initializeBackendsAndReturnBest() {
const sortedBackends = this.getSortedBackends();
for (let i = 0; i < sortedBackends.length; i++) {
const backendName = sortedBackends[i];
const { success, asyncInit } = this.initializeBackend(backendName);
if (asyncInit || success) {
return { name: backendName, asyncInit };
}
}
throw new Error(`Could not initialize any backends, all backend initializations ` +
`failed.`);
}
moveData(backend, dataId) {
const info = this.state.tensorInfo.get(dataId);
const srcBackend = info.backend;
const values = this.readSync(dataId);
const refCount = srcBackend.refCount(dataId);
srcBackend.disposeData(dataId, true);
info.backend = backend;
backend.move(dataId, values, info.shape, info.dtype, refCount);
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
}
}
tidy(nameOrFn, fn) {
let name = null;
if (fn == null) {
if (typeof nameOrFn !== 'function') {
throw new Error('Please provide a function to tidy()');
}
fn = nameOrFn;
}
else {
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
throw new Error('When calling with two arguments, the first argument ' +
'to tidy() must be a string');
}
if (typeof fn !== 'function') {
throw new Error('When calling with two arguments, the 2nd argument ' +
'to tidy() must be a function');
}
name = nameOrFn;
}
let result;
return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
result = fn();
if (result instanceof Promise) {
console.error('Cannot return a Promise inside of tidy.');
}
return result;
});
}
scopedRun(start, end, f) {
start();
try {
const res = f();
end();
return res;
}
catch (ex) {
end();
throw ex;
}
}
nextTensorId() {
return Engine.nextTensorId++;
}
nextVariableId() {
return Engine.nextVariableId++;
}
clone(x) {
const y = ENGINE.runKernel(Identity$1, { x });
const inputs = { x };
const grad = (dy) => ({
x: () => {
const dtype = 'float32';
const gradInputs = { x: dy };
const attrs = { dtype };
return ENGINE.runKernel(Cast, gradInputs,
attrs);
}
});
const saved = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}
runKernel(kernelName, inputs, attrs) {
const hasKernel = getKernel(kernelName, this.backendName) != null;
if (!hasKernel) {
throw new Error(`Kernel '${kernelName}' not registered for backend '${this.backendName}'`);
}
return this.runKernelFunc({ kernelName, inputs, attrs });
}
shouldCheckForMemLeaks() {
return this.ENV.getBool('IS_TEST');
}
checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
const numDataIdsAfter = this.backend.numDataIds();
let numOutputDataIds = 0;
outInfos.forEach(info => {
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
});
const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
if (dataIdsLeaked > 0) {
throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
`(${dataIdsLeaked} data ids) after running '${kernelName}'`);
}
}
runKernelFunc(kernelParams) {
let outputs;
let saved = [];
const isTapeOn = this.isTapeOn();
const startingBytecount = this.state.numBytes;
const startingNumTensors = this.state.numTensors;
if (this.shouldCheckForMemLeaks()) {
this.state.numDataMovesStack.push(0);
}
let kernelFunc;
let out;
const kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ?
kernelParams.kernelName :
this.state.activeScope != null ? this.state.activeScope.name : '';
if (isRegisteredKernelInvocation(kernelParams)) {
const { kernelName, inputs, attrs } = kernelParams;
const kernel = getKernel(kernelName, this.backendName);
assert$1(kernel != null, () => `Cannot find registered kernel '${kernelName}' for backend '${this.backendName}'`);
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
const outInfos = Array.isArray(out) ? out : [out];
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
}
const outTensors = outInfos.map((outInfo) => {
if (outInfo.rank != null) {
return outInfo;
}
return this.makeTensorFromTensorInfo(outInfo);
});
if (isTapeOn) {
const tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
saved = this.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
}
else {
const { forwardFunc } = kernelParams;
const saveFunc = (tensors) => {
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};
kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
const outs = (Array.isArray(out) ? out : [out]);
if (this.shouldCheckForMemLeaks()) {
this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
}
return outs;
};
}
const { inputs, attrs } = kernelParams;
const backwardsFunc = isRegisteredKernelInvocation(kernelParams) ?
null :
kernelParams.backwardsFunc;
let kernelProfile;
this.scopedRun(
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
outputs = kernelFunc();
}
else {
kernelProfile = this.profiler.profileKernel(kernelOrScopeName, inputs, () => kernelFunc());
if (this.ENV.getBool('DEBUG')) {
this.profiler.logKernelProfile(kernelProfile);
}
outputs = kernelProfile.outputs;
}
});
if (isTapeOn) {
this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
}
if (this.state.profiling) {
this.state.activeProfile.kernels.push({
name: kernelOrScopeName,
bytesAdded: this.state.numBytes - startingBytecount,
totalBytesSnapshot: this.state.numBytes,
tensorsAdded: this.state.numTensors - startingNumTensors,
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null),
outputShapes: outputs.map(item => item.shape),
kernelTimeMs: kernelProfile.timeMs,
extraInfo: kernelProfile.extraInfo
});
}
return (Array.isArray(out) ? outputs : outputs[0]);
}
saveTensorsForBackwardMode(tensors) {
const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
return saved;
}
getTensorsForGradient(kernelName, inputs, outputs) {
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
const inputsToSave = gradConfig.inputsToSave || [];
const outputsToSave = gradConfig.outputsToSave || [];
let inputTensorsToSave;
if (gradConfig.saveAllInputs) {
assert$1(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
}
else {
inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
}
const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
return inputTensorsToSave.concat(outputTensorsToSave);
}
return [];
}
makeTensor(values, shape, dtype, backend) {
if (values == null) {
throw new Error('Values passed to engine.makeTensor() are null');
}
dtype = dtype || 'float32';
backend = backend || this.backend;
let backendVals = values;
if (dtype === 'string' && isString(values[0])) {
backendVals = values.map(d => encodeString(d));
}
const dataId = backend.write(backendVals, shape, dtype);
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend);
if (dtype === 'string') {
const info = this.state.tensorInfo.get(dataId);
const newBytes = bytesFromStringArray(backendVals);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
return t;
}
makeTensorFromDataId(dataId, shape, dtype, backend) {
dtype = dtype || 'float32';
const tensorInfo = { dataId, shape, dtype };
return this.makeTensorFromTensorInfo(tensorInfo, backend);
}
makeTensorFromTensorInfo(tensorInfo, backend) {
const { dataId, shape, dtype } = tensorInfo;
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend);
return t;
}
makeVariable(initialValue, trainable = true, name, dtype) {
name = name || this.nextVariableId().toString();
if (dtype != null && dtype !== initialValue.dtype) {
initialValue = initialValue.cast(dtype);
}
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
if (this.state.registeredVariables[v.name] != null) {
throw new Error(`Variable with name ${v.name} was already registered`);
}
this.state.registeredVariables[v.name] = v;
this.incRef(v, this.backend);
return v;
}
trackTensor(a, backend) {
this.state.numTensors++;
if (a.dtype === 'string') {
this.state.numStringTensors++;
}
let bytes = 0;
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
bytes = a.size * bytesPerElement(a.dtype);
}
this.state.numBytes += bytes;
if (!this.state.tensorInfo.has(a.dataId)) {
this.state.numDataBuffers++;
this.state.tensorInfo.set(a.dataId, {
backend: backend || this.backend,
dtype: a.dtype,
shape: a.shape,
bytes
});
}
if (!(a instanceof Variable)) {
this.track(a);
}
}
incRef(a, backend) {
this.trackTensor(a, backend);
this.backend.incRef(a.dataId);
}
removeDataId(dataId, backend) {
if (this.state.tensorInfo.has(dataId) &&
this.state.tensorInfo.get(dataId).backend === backend) {
this.state.tensorInfo.delete(dataId);
this.state.numDataBuffers--;
}
}
disposeTensor(a) {
if (!this.state.tensorInfo.has(a.dataId)) {
return;
}
const info = this.state.tensorInfo.get(a.dataId);
this.state.numTensors--;
if (a.dtype === 'string') {
this.state.numStringTensors--;
this.state.numBytes -= info.bytes;
}
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
const bytes = a.size * bytesPerElement(a.dtype);
this.state.numBytes -= bytes;
}
if (info.backend.disposeData(a.dataId)) {
this.removeDataId(a.dataId, info.backend);
}
}
disposeVariables() {
for (const varName in this.state.registeredVariables) {
const v = this.state.registeredVariables[varName];
this.disposeVariable(v);
}
}
disposeVariable(v) {
this.disposeTensor(v);
if (this.state.registeredVariables[v.name] != null) {
delete this.state.registeredVariables[v.name];
}
}
memory() {
const info = this.backend.memory();
info.numTensors = this.state.numTensors;
info.numDataBuffers = this.state.numDataBuffers;
info.numBytes = this.state.numBytes;
if (this.state.numStringTensors > 0) {
info.unreliable = true;
if (info.reasons == null) {
info.reasons = [];
}
info.reasons.push('Memory usage by string tensors is approximate ' +
'(2 bytes per character)');
}
return info;
}
async profile(query) {
this.state.profiling = true;
const startBytes = this.state.numBytes;
const startNumTensors = this.state.numTensors;
this.state.activeProfile.kernels = [];
this.state.activeProfile.result = await query();
this.state.profiling = false;
this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors =
this.state.numTensors - startNumTensors;
for (const kernel of this.state.activeProfile.kernels) {
kernel.kernelTimeMs = await kernel.kernelTimeMs;
kernel.extraInfo = await kernel.extraInfo;
}
return this.state.activeProfile;
}
isTapeOn() {
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
}
addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved };
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
gradientsFunc = gradConfig.gradFunc;
}
if (gradientsFunc != null) {
tapeNode.gradient = (dys) => {
dys = dys.map((dy, i) => {
if (dy == null) {
const output = outputs[i];
const vals = makeZerosTypedArray(output.size, output.dtype);
return this.makeTensor(vals, output.shape, output.dtype);
}
return dy;
});
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
}
keep(result) {
result.kept = true;
return result;
}
startTape() {
if (this.state.gradientDepth === 0) {
this.state.activeTape = [];
}
this.state.gradientDepth++;
}
endTape() {
this.state.gradientDepth--;
}
startScope(name) {
const scopeInfo = {
track: [],
name: 'unnamed scope',
id: this.state.nextScopeId++
};
if (name) {
scopeInfo.name = name;
}
this.state.scopeStack.push(scopeInfo);
this.state.activeScope = scopeInfo;
}
endScope(result) {
const tensorsToTrackInParent = getTensorsInContainer(result);
const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id));
for (let i = 0; i < this.state.activeScope.track.length; i++) {
const tensor = this.state.activeScope.track[i];
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
tensor.dispose();
}
}
const oldScope = this.state.scopeStack.pop();
this.state.activeScope = this.state.scopeStack.length === 0 ?
null :
this.state.scopeStack[this.state.scopeStack.length - 1];
tensorsToTrackInParent.forEach(tensor => {
if (!tensor.kept && tensor.scopeId === oldScope.id) {
this.track(tensor);
}
});
}
gradients(f, xs, dy, allowNoGradients = false) {
assert$1(xs.length > 0, () => 'gradients() received an empty list of xs.');
if (dy != null && dy.dtype !== 'float32') {
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
}
const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f));
assert$1(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.');
const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
'that the f you passed encloses all operations that lead from x ' +
'to y.');
}
return this.tidy('backward', () => {
const accumulatedGradientMap = {};
accumulatedGradientMap[y.id] = (dy == null) ? ones$1(y.shape) : dy;
backpropagateGradients(accumulatedGradientMap, filteredTape,
f => this.tidy(f),
add$2);
const grads = xs.map(x => accumulatedGradientMap[x.id]);
if (this.state.gradientDepth === 0) {
this.state.activeTape.forEach(node => {
for (const tensor of node.saved) {
tensor.dispose();
}
});
this.state.activeTape = null;
}
return { value: y, grads };
});
}
customGrad(f) {
assert$1(isFunction(f), () => 'The f passed in customGrad(f) must be a function.');
return (...inputs) => {
assert$1(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
'tensors');
let res;
const inputMap = {};
inputs.forEach((input, i) => {
inputMap[i] = input;
});
const forwardFunc = (_, save) => {
res = f(...[...inputs, save]);
assert$1(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.value` is a tensor');
assert$1(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function.');
return res.value;
};
const backwardsFunc = (dy, saved) => {
const gradRes = res.gradFunc(dy, saved);
const grads = Array.isArray(gradRes) ? gradRes : [gradRes];
assert$1(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'the same number of tensors as inputs passed to f(...).');
assert$1(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' +
'object where `obj.gradFunc` is a function that returns ' +
'a list of only tensors.');
const gradMap = {};
grads.forEach((grad, i) => {
gradMap[i] = () => grad;
});
return gradMap;
};
return this.runKernelFunc({
forwardFunc,
backwardsFunc,
inputs: inputMap,
});
};
}
readSync(dataId) {
const info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
}
read(dataId) {
const info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
}
readToGPU(dataId, options) {
const info = this.state.tensorInfo.get(dataId);
return info.backend.readToGPU(dataId, options);
}
async time(query) {
const start = now();
const timingInfo = await this.backend.time(query);
timingInfo.wallMs = now() - start;
return timingInfo;
}
track(result) {
if (this.state.activeScope != null) {
result.scopeId = this.state.activeScope.id;
this.state.activeScope.track.push(result);
}
return result;
}
get registeredVariables() {
return this.state.registeredVariables;
}
reset() {
this.pendingBackendInitId++;
this.state.dispose();
this.ENV.reset();
this.state = new EngineState();
for (const backendName in this.registry) {
this.disposeRegisteredKernels(backendName);
this.registry[backendName].dispose();
delete this.registry[backendName];
}
this.backendName = null;
this.backendInstance = null;
this.pendingBackendInit = null;
}
}
Engine.nextTensorId = 0;
Engine.nextVariableId = 0;
function ones$1(shape) {
const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
return ENGINE.makeTensor(values, shape, 'float32');
}
function getOrMakeEngine() {
const ns = getGlobalNamespace();
if (ns._tfengine == null) {
const environment = new Environment(ns);
ns._tfengine = new Engine(environment);
}
setEnvironmentGlobal(ns._tfengine.ENV);
setTensorTracker(() => ns._tfengine);
return ns._tfengine;
}
const ENGINE = getOrMakeEngine();
function add$2(a, b) {
const inputs = { a, b };
return ENGINE.runKernel(Add, inputs);
}
function _isNavigatorDefined() {
return typeof navigator !== 'undefined' && navigator != null;
}
function isMobile(nav) {
if (nav || _isNavigatorDefined()) {
if (!nav) {
nav = navigator;
}
if (nav.product === 'ReactNative') {
return true;
}
const a = nav.userAgent || nav.vendor ||
(typeof window !== 'undefined' ? window.opera : '');
if (!a) {
const navAny = nav;
return navAny.userAgentData && navAny.userAgentData.mobile;
}
return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
.test(a) ||
/1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
.test(a.substr(0, 4));
}
return false;
}
function isBrowser() {
return (typeof window !== 'undefined' && window.document != null) ||
(typeof WorkerGlobalScope !== 'undefined');
}
const ENV$1 = env();
ENV$1.registerFlag('DEBUG', () => false, debugValue => {
if (debugValue) {
console.warn('Debugging mode is ON. The output of every math call will ' +
'be downloaded to CPU and checked for NaNs. ' +
'This significantly impacts performance.');
}
});
ENV$1.registerFlag('IS_BROWSER', () => isBrowser());
ENV$1.registerFlag('IS_NODE', () => (typeof process !== 'undefined') &&
(typeof process.versions !== 'undefined') &&
(typeof process.versions.node !== 'undefined'));
ENV$1.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null &&
navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
/Google Inc/.test(navigator.vendor));
ENV$1.registerFlag('IS_SAFARI', () => typeof navigator !== 'undefined' && navigator != null &&
navigator.userAgent != null && /Safari/.test(navigator.userAgent) &&
/Apple/.test(navigator.vendor));
ENV$1.registerFlag('PROD', () => false);
ENV$1.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV$1.getBool('DEBUG'));
ENV$1.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true);
ENV$1.registerFlag('IS_TEST', () => false);
ENV$1.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', () => ENV$1.getBool('DEBUG'));
ENV$1.registerFlag('WRAP_TO_IMAGEBITMAP', () => false);
ENV$1.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false);
ENV$1.registerFlag('USE_SETTIMEOUTCUSTOM', () => false);
function buffer(shape, dtype = 'float32', values) {
dtype = dtype || 'float32';
assertNonNegativeIntegerDimensions(shape);
return new TensorBuffer(shape, dtype, values);
}
function inferShape(val, dtype) {
let firstElem = val;
if (isTypedArray(val)) {
return dtype === 'string' ? [] : [val.length];
}
if (isWebGLData(val)) {
const usedChannels = val.channels || 'RGBA';
return [val.height, val.width * usedChannels.length];
}
else if (isWebGPUData(val)) {
return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))];
}
if (!Array.isArray(val)) {
return [];
}
const shape = [];
while (Array.isArray(firstElem) ||
isTypedArray(firstElem) && dtype !== 'string') {
shape.push(firstElem.length);
firstElem = firstElem[0];
}
if (Array.isArray(val) &&
env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
deepAssertShapeConsistency(val, shape, []);
}
return shape;
}
function deepAssertShapeConsistency(val, shape, indices) {
indices = indices || [];
if (!(Array.isArray(val)) && !isTypedArray(val)) {
assert$1(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` +
`but should be an array/TypedArray of ${shape[0]} elements`);
return;
}
assert$1(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` +
`but is an array of ${val.length} elements`);
assert$1(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +
`elements, but has ${val.length} elements`);
const subShape = shape.slice(1);
for (let i = 0; i < val.length; ++i) {
deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
}
}
function assertDtype(expectedDtype, actualDType, argName, functionName) {
if (expectedDtype === 'string_or_numeric') {
return;
}
if (expectedDtype == null) {
throw new Error(`Expected dtype cannot be null.`);
}
if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
expectedDtype === 'numeric' && actualDType === 'string') {
throw new Error(`Argument '${argName}' passed to '${functionName}' must ` +
`be ${expectedDtype} tensor, but got ${actualDType} tensor`);
}
}
function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') {
if (x instanceof getGlobalTensorClass()) {
assertDtype(parseAsDtype, x.dtype, argName, functionName);
return x;
}
let inferredDtype = inferDtype(x);
if (inferredDtype !== 'string' &&
['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
inferredDtype = parseAsDtype;
}
assertDtype(parseAsDtype, inferredDtype, argName, functionName);
if ((x == null) ||
(!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
typeof x !== 'boolean' && typeof x !== 'string')) {
const type = x == null ? 'null' : x.constructor.name;
throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` +
`Tensor or TensorLike, but got '${type}'`);
}
const inferredShape = inferShape(x, inferredDtype);
if (!isTypedArray(x) && !Array.isArray(x)) {
x = [x];
}
const skipTypedArray = true;
const values = inferredDtype !== 'string' ?
toTypedArray(x, inferredDtype) :
flatten$1(x, [], skipTypedArray);
return ENGINE.makeTensor(values, inferredShape, inferredDtype);
}
function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') {
if (!Array.isArray(arg)) {
throw new Error(`Argument ${argName} passed to ${functionName} must be a ` +
'`Tensor[]` or `TensorLike[]`');
}
const tensors = arg;
return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));
}
const OP_SCOPE_SUFFIX = '__op';
function op(f) {
const keys = Object.keys(f);
if (keys.length !== 1) {
throw new Error(`Please provide an object with a single key ` +
`(operation name) mapping to a function. Got an object with ` +
`${keys.length} keys.`);
}
let opName = keys[0];
const fn = f[opName];
if (opName.endsWith('_')) {
opName = opName.substring(0, opName.length - 1);
}
opName = opName + OP_SCOPE_SUFFIX;
const f2 = (...args) => {
ENGINE.startScope(opName);
try {
const result = fn(...args);
if (isPromise(result)) {
console.error('Cannot return a Promise inside of tidy.');
}
ENGINE.endScope(result);
return result;
}
catch (ex) {
ENGINE.endScope(null);
throw ex;
}
};
Object.defineProperty(f2, 'name', { value: opName, configurable: true });
return f2;
}
function cast_(x, dtype) {
const $x = convertToTensor(x, 'x', 'cast');
if (!isValidDtype(dtype)) {
throw new Error(`Failed to cast to unknown dtype ${dtype}`);
}
if (dtype === 'string' && $x.dtype !== 'string' ||
dtype !== 'string' && $x.dtype === 'string') {
throw new Error('Only strings can be casted to strings');
}
const inputs = { x: $x };
const attrs = { dtype };
return ENGINE.runKernel(Cast, inputs, attrs);
}
const cast$3 = op({ cast_ });
function clone_(x) {
const $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
const inputs = { x: $x };
return ENGINE.runKernel(Identity$1, inputs);
}
const clone = op({ clone_ });
function print(x, verbose = false) {
console.log(x.toString(verbose));
}
getOrMakeEngine();
const opHandler = {
buffer,
cast: cast$3,
clone,
print
};
setOpHandler(opHandler);
function enableProdMode() {
env().set('PROD', true);
}
function engine() {
return ENGINE;
}
function memory() {
return ENGINE.memory();
}
function tidy(nameOrFn, fn) {
return ENGINE.tidy(nameOrFn, fn);
}
function dispose(container) {
const tensors = getTensorsInContainer(container);
tensors.forEach(tensor => tensor.dispose());
}
function keep(result) {
return ENGINE.keep(result);
}
function registerBackend(name, factory, priority = 1) {
return ENGINE.registerBackend(name, factory, priority);
}
function backend() {
return ENGINE.backend;
}
function add_(a, b) {
let $a = convertToTensor(a, 'a', 'add');
let $b = convertToTensor(b, 'b', 'add');
[$a, $b] = makeTypesMatch($a, $b);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Add, inputs);
}
const add$1 = op({ add_ });
function floorDiv_(a, b) {
let $a = convertToTensor(a, 'a', 'floorDiv');
let $b = convertToTensor(b, 'b', 'floorDiv');
[$a, $b] = makeTypesMatch($a, $b);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(FloorDiv, inputs);
}
const floorDiv$2 = op({ floorDiv_ });
function div_(a, b) {
let $a = convertToTensor(a, 'a', 'div');
let $b = convertToTensor(b, 'b', 'div');
[$a, $b] = makeTypesMatch($a, $b);
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
return floorDiv$2($a, $b);
}
const inputs = { a: $a, b: $b };
const attrs = {};
return ENGINE.runKernel(RealDiv, inputs, attrs);
}
const div$1 = op({ div_ });
function mul_(a, b) {
let $a = convertToTensor(a, 'a', 'mul');
let $b = convertToTensor(b, 'b', 'mul');
[$a, $b] = makeTypesMatch($a, $b);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Multiply, inputs);
}
const mul = op({ mul_ });
function abs_(x) {
const $x = convertToTensor(x, 'x', 'abs');
if ($x.dtype === 'complex64') {
const inputs = { x: $x };
return ENGINE.runKernel(ComplexAbs, inputs);
}
else {
const inputs = { x: $x };
return ENGINE.runKernel(Abs, inputs);
}
}
const abs$2 = op({ abs_ });
function any_(x, axis = null, keepDims = false) {
const $x = convertToTensor(x, 'x', 'any', 'bool');
const inputs = { x: $x };
const attrs = { axis, keepDims };
return ENGINE.runKernel(Any, inputs, attrs);
}
const any$2 = op({ any_ });
function argMax_(x, axis = 0) {
const $x = convertToTensor(x, 'x', 'argMax');
const inputs = { x: $x };
const attrs = { axis };
return ENGINE.runKernel(ArgMax, inputs, attrs);
}
const argMax$2 = op({ argMax_ });
function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat = 'NHWC', dilations) {
const inputChannels = inputShape[3];
const $filterShape = [...filterShape, inputChannels];
const $dataFormat = convertConv2DDataFormat(dataFormat);
return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null , null , $dataFormat);
}
function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'channelsLast') {
const [filterHeight, filterWidth] = parseTupleParam(filterSize);
let filterShape;
if (dataFormat === 'channelsLast') {
filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
}
else if (dataFormat === 'channelsFirst') {
filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
}
else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
}
function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'NDHWC') {
const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize);
let filterShape;
let $dataFormat;
if (dataFormat === 'NDHWC') {
$dataFormat = 'channelsLast';
filterShape =
[filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
}
else if (dataFormat === 'NCDHW') {
$dataFormat = 'channelsFirst';
filterShape =
[filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
}
else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
}
function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise = false, dataFormat = 'channelsLast') {
let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
if (dataFormat === 'channelsLast') {
[batchSize, inHeight, inWidth, inChannels] = inShape;
}
else if (dataFormat === 'channelsFirst') {
[batchSize, inChannels, inHeight, inWidth] = inShape;
}
else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
const [filterHeight, filterWidth, , filterChannels] = filterShape;
const [strideHeight, strideWidth] = parseTupleParam(strides);
const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
const { padInfo, outHeight, outWidth } = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat);
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
let outShape;
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outHeight, outWidth];
}
else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inHeight,
inWidth,
inChannels,
outHeight,
outWidth,
outChannels,
padInfo,
strideHeight,
strideWidth,
filterHeight,
filterWidth,
effectiveFilterHeight,
effectiveFilterWidth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape
};
}
function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise = false, dataFormat = 'channelsLast', roundingMode) {
let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1];
if (dataFormat === 'channelsLast') {
[batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
}
else if (dataFormat === 'channelsFirst') {
[batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
}
else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape;
const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations);
const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode);
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
let outShape;
if (dataFormat === 'channelsFirst') {
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
}
else if (dataFormat === 'channelsLast') {
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
}
return {
batchSize,
dataFormat,
inDepth,
inHeight,
inWidth,
inChannels,
outDepth,
outHeight,
outWidth,
outChannels,
padInfo,
strideDepth,
strideHeight,
strideWidth,
filterDepth,
filterHeight,
filterWidth,
effectiveFilterDepth,
effectiveFilterHeight,
effectiveFilterWidth,
dilationDepth,
dilationHeight,
dilationWidth,
inShape,
outShape,
filterShape
};
}
function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
}
const inputRows = inShape[0];
const inputCols = inShape[1];
const outputRows = round$2((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
const outputCols = round$2((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
return [outputRows, outputCols];
}
function computeOutputShape4D(inShape, filterShape, outChannels, strides, zeroPad, roundingMode) {
if (zeroPad == null) {
zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]);
}
const outShape = [0, 0, 0, outChannels];
for (let index = 0; index < 3; index++) {
if (inShape[index] + 2 * zeroPad >= filterShape[index]) {
outShape[index] = round$2((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] +
1, roundingMode);
}
}
return outShape;
}
function computeDefaultPad(inputShape, fieldSize, stride, dilation = 1) {
const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
}
function parseTupleParam(param) {
if (typeof param === 'number') {
return [param, param, param];
}
if (param.length === 2) {
return [param[0], param[1], 1];
}
return param;
}
function parse3TupleParam(param) {
return typeof param === 'number' ? [param, param, param] : param;
}
function getEffectiveFilterSize(filterSize, dilation) {
if (dilation <= 1) {
return filterSize;
}
return filterSize + (filterSize - 1) * (dilation - 1);
}
function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
let padInfo;
let outHeight;
let outWidth;
if (typeof pad === 'number') {
const padType = (pad === 0) ? 'VALID' : 'NUMBER';
padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
const outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
outHeight = outShape[0];
outWidth = outShape[1];
}
else if (pad === 'same') {
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
const top = Math.floor(padAlongHeight / 2);
const bottom = padAlongHeight - top;
const left = Math.floor(padAlongWidth / 2);
const right = padAlongWidth - left;
padInfo = { top, bottom, left, right, type: 'SAME' };
}
else if (pad === 'valid') {
padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
}
else if (typeof pad === 'object') {
const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
'VALID' :
'EXPLICIT';
padInfo = { top, bottom, left, right, type: padType };
outHeight = round$2((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode);
outWidth = round$2((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
}
else {
throw Error(`Unknown padding parameter: ${pad}`);
}
return { padInfo, outHeight, outWidth };
}
function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
let padInfo;
let outDepth;
let outHeight;
let outWidth;
if (pad === 'valid') {
pad = 0;
}
if (typeof pad === 'number') {
const padType = (pad === 0) ? 'VALID' : 'NUMBER';
padInfo = {
top: pad,
bottom: pad,
left: pad,
right: pad,
front: pad,
back: pad,
type: padType
};
const outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, [strideDepth, strideHeight, strideWidth], pad, roundingMode);
outDepth = outShape[0];
outHeight = outShape[1];
outWidth = outShape[2];
}
else if (pad === 'same') {
outDepth = Math.ceil(inDepth / strideDepth);
outHeight = Math.ceil(inHeight / strideHeight);
outWidth = Math.ceil(inWidth / strideWidth);
const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
const front = Math.floor(padAlongDepth / 2);
const back = padAlongDepth - front;
const top = Math.floor(padAlongHeight / 2);
const bottom = padAlongHeight - top;
const left = Math.floor(padAlongWidth / 2);
const right = padAlongWidth - left;
padInfo = { top, bottom, left, right, front, back, type: 'SAME' };
}
else {
throw Error(`Unknown padding parameter: ${pad}`);
}
return { padInfo, outDepth, outHeight, outWidth };
}
function round$2(value, roundingMode) {
if (!roundingMode) {
return Math.trunc(value);
}
switch (roundingMode) {
case 'round':
return Math.round(value);
case 'ceil':
return Math.ceil(value);
case 'floor':
return Math.floor(value);
default:
throw new Error(`Unknown roundingMode ${roundingMode}`);
}
}
function tupleValuesAreOne(param) {
const [dimA, dimB, dimC] = parseTupleParam(param);
return dimA === 1 && dimB === 1 && dimC === 1;
}
function eitherStridesOrDilationsAreOne(strides, dilations) {
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
}
function stridesOrDilationsArePositive(values) {
return parseTupleParam(values).every(value => value > 0);
}
function convertConv2DDataFormat(dataFormat) {
if (dataFormat === 'NHWC') {
return 'channelsLast';
}
else if (dataFormat === 'NCHW') {
return 'channelsFirst';
}
else {
throw new Error(`Unknown dataFormat ${dataFormat}`);
}
}
function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) {
if (dimRoundingMode != null) {
if (typeof pad === 'string') {
throw Error(`Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}
else if (typeof pad === 'number') {
assert$1(isInt(pad), () => `Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
}
else if (typeof pad === 'object') {
pad.forEach(p => {
p.forEach(v => {
assert$1(isInt(v), () => `Error in ${opDesc}: pad must be an integer when using ` +
`dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
});
});
}
else {
throw Error(`Error in ${opDesc}: Unknown padding parameter: ${pad}`);
}
}
}
function reshape_(x, shape) {
const $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
const inputs = { x: $x };
const attrs = { shape };
return ENGINE.runKernel(Reshape$1, inputs, attrs);
}
const reshape$2 = op({ reshape_ });
function concat_(tensors, axis = 0) {
assert$1(tensors.length >= 1, () => 'Pass at least one tensor to concat');
const $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
if ($tensors[0].dtype === 'complex64') {
$tensors.forEach(tensor => {
if (tensor.dtype !== 'complex64') {
throw new Error(`Cannot concatenate complex64 tensors with a tensor
with dtype ${tensor.dtype}. `);
}
});
}
if ($tensors.length === 1) {
return clone($tensors[0]);
}
const inputs = $tensors;
const attr = { axis };
return ENGINE.runKernel(Concat, inputs, attr);
}
const concat$2 = op({ concat_ });
function matMul_(a, b, transposeA = false, transposeB = false) {
let $a = convertToTensor(a, 'a', 'matMul');
let $b = convertToTensor(b, 'b', 'matMul');
[$a, $b] = makeTypesMatch($a, $b);
const inputs = { a: $a, b: $b };
const attrs = { transposeA, transposeB };
return ENGINE.runKernel(BatchMatMul, inputs, attrs);
}
const matMul$1 = op({ matMul_ });
function sigmoid_(x) {
const $x = convertToTensor(x, 'x', 'sigmoid', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Sigmoid$1, inputs);
}
const sigmoid$2 = op({ sigmoid_ });
function slice_(x, begin, size) {
const $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
if ($x.rank === 0) {
throw new Error('Slicing scalar is not possible');
}
const inputs = { x: $x };
const attrs = { begin, size };
return ENGINE.runKernel(Slice, inputs, attrs);
}
const slice$2 = op({ slice_ });
function tanh_(x) {
const $x = convertToTensor(x, 'x', 'tanh', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Tanh$1, inputs);
}
const tanh$2 = op({ tanh_ });
function batchToSpaceND_(x, blockShape, crops) {
const $x = convertToTensor(x, 'x', 'batchToSpaceND');
const prod = blockShape.reduce((a, b) => a * b);
assert$1($x.rank >= 1 + blockShape.length, () => `input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`);
assert$1(crops.length === blockShape.length, () => `crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`);
assert$1($x.shape[0] % prod === 0, () => `input tensor batch is ${$x.shape[0]} but is not divisible by the product of ` +
`the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);
const inputs = { x: $x };
const attrs = { blockShape, crops };
return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
}
const batchToSpaceND$2 = op({ batchToSpaceND_ });
function broadcastTo_(x, shape) {
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;
assertNonNegativeIntegerDimensions(shape);
if (shape.length < input.rank) {
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
}
if (shape.length > input.rank) {
const newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = reshape$2(input, newShape);
}
const inputShape = input.shape;
const reps = Array.from(shape);
for (let i = shape.length - 1; i >= 0; i--) {
if (inputShape[i] === shape[i]) {
reps[i] = 1;
}
else if (input.shape[i] !== 1) {
throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
}
}
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
if (axes.length === 0) {
return clone(input);
}
const inputs = { x: input };
const attrs = { reps };
return ENGINE.runKernel(Tile, inputs, attrs);
}
const broadcastTo = op({ broadcastTo_ });
function fill$2(shape, value, dtype) {
assertNonNegativeIntegerDimensions(shape);
dtype = dtype || inferDtype(value);
const attrs = { shape, value, dtype };
return ENGINE.runKernel(Fill, {}, attrs);
}
function clipByValue_(x, clipValueMin, clipValueMax) {
const $x = convertToTensor(x, 'x', 'clipByValue');
assert$1((clipValueMin <= clipValueMax), () => `Error in clip: min (${clipValueMin}) must be ` +
`less than or equal to max (${clipValueMax}).`);
if (clipValueMin === clipValueMax) {
return fill$2($x.shape, clipValueMin, $x.dtype);
}
const inputs = { x: $x };
const attrs = { clipValueMin, clipValueMax };
return ENGINE.runKernel(ClipByValue, inputs, attrs);
}
const clipByValue$2 = op({ clipByValue_ });
function complex_(real, imag) {
const $real = convertToTensor(real, 'real', 'complex');
const $imag = convertToTensor(imag, 'imag', 'complex');
assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` +
`must match in call to tf.complex().`);
const inputs = { real: $real, imag: $imag };
return ENGINE.runKernel(Complex, inputs);
}
const complex$2 = op({ complex_ });
function conv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
let x4D = $x;
let reshapedTo4D = false;
if ($x.rank === 3) {
reshapedTo4D = true;
x4D = reshape$2($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
}
assert$1(x4D.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`);
assert$1($filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` +
`${$filter.rank}.`);
checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
assert$1(inDepth === $filter.shape[2], () => `Error in conv2d: depth of input (${inDepth}) must match ` +
`input depth for filter ${$filter.shape[2]}.`);
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
assert$1(stridesOrDilationsArePositive(dilations), () => 'Error in conv2D: Dilated rates should be larger than 0.');
assert$1(stridesOrDilationsArePositive(strides), () => 'Error in conv2D: Strides should be larger than 0.');
const inputs = { x: x4D, filter: $filter };
const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
const res = ENGINE.runKernel(Conv2D, inputs, attrs);
if (reshapedTo4D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
const conv2d$1 = op({ conv2d_ });
function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
assert$1(xShape.length === dy.rank, () => `Length of inShape ` +
`(${xShape.length}) and rank of dy (${dy.rank}) must match`);
let xShape4D = xShape;
let dy4D = dy;
let reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
xShape4D = [1, xShape[0], xShape[1], xShape[2]];
}
assert$1(xShape4D.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ` +
`${xShape4D.length}.`);
assert$1(dy4D.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got ` +
`rank ${dy4D.rank}`);
assert$1(filter.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got ` +
`rank ${filter.rank}`);
const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
assert$1(inDepth === filter.shape[2], () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` +
`match input depth for filter ${filter.shape[2]}.`);
assert$1(outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
`match output depth for filter ${filter.shape[3]}.`);
checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
const inputs = { dy: dy4D, filter };
const attrs = { strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D };
const res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
if (reshapedTo4D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
const conv2DBackpropInput$2 = op({ conv2DBackpropInput_ });
function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
assert$1(xShape.length === dy.rank, () => `Length of inShape ` +
`(${xShape.length}) and rank of dy (${dy.rank}) must match`);
let xShape5D = xShape;
let dy5D = dy;
let reshapedTo5D = false;
if (dy.rank === 4) {
reshapedTo5D = true;
dy5D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
}
const inDepth = xShape5D[4];
const outDepth = dy5D.shape[4];
assert$1(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
`${xShape5D.length}.`);
assert$1(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
`rank ${dy5D.rank}`);
assert$1(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
`rank ${filter.rank}`);
assert$1(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
`match input depth for filter ${filter.shape[3]}.`);
assert$1(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
`match output depth for filter ${filter.shape[4]}.`);
const inputs = { dy: dy5D, filter };
const attrs = { pad, strides, inputShape: xShape5D };
const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
if (reshapedTo5D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
const conv3DBackpropInput$1 = op({ conv3DBackpropInput_ });
function cos_(x) {
const $x = convertToTensor(x, 'x', 'cos', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Cos, inputs);
}
const cos$2 = op({ cos_ });
function cosh_(x) {
const $x = convertToTensor(x, 'x', 'cosh', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Cosh, inputs);
}
const cosh$2 = op({ cosh_ });
function cumprod_(x, axis = 0, exclusive = false, reverse = false) {
const $x = convertToTensor(x, 'x', 'cumprod');
const inputs = { x: $x };
const attrs = { axis, exclusive, reverse };
return ENGINE.runKernel(Cumprod, inputs, attrs);
}
const cumprod$2 = op({ cumprod_ });
function cumsum_(x, axis = 0, exclusive = false, reverse = false) {
const $x = convertToTensor(x, 'x', 'cumsum');
const inputs = { x: $x };
const attrs = { axis, exclusive, reverse };
return ENGINE.runKernel(Cumsum, inputs, attrs);
}
const cumsum$2 = op({ cumsum_ });
function getBroadcastDims$1(inShape, outShape) {
const inRank = inShape.length;
const dims = [];
for (let i = 0; i < inRank; i++) {
const dim = inRank - 1 - i;
const a = inShape[dim] || 1;
const b = outShape[outShape.length - 1 - i] || 1;
if (b > 1 && a === 1) {
dims.unshift(dim);
}
}
return dims;
}
function getReductionAxes(inShape, outShape) {
const result = [];
for (let i = 0; i < outShape.length; i++) {
const inDim = inShape[inShape.length - i - 1];
const outAxis = outShape.length - i - 1;
const outDim = outShape[outAxis];
if (inDim == null || (inDim === 1 && outDim > 1)) {
result.unshift(outAxis);
}
}
return result;
}
function assertAndGetBroadcastShape(shapeA, shapeB) {
const l = Math.max(shapeA.length, shapeB.length);
const result = new Array(l);
for (let i = 0; i < l; i++) {
let a = shapeA[shapeA.length - i - 1];
if (a == null) {
a = 1;
}
let b = shapeB[shapeB.length - i - 1];
if (b == null) {
b = 1;
}
if (a === 1) {
result[l - i - 1] = b;
}
else if (b === 1) {
result[l - i - 1] = a;
}
else if (a !== b) {
const errMsg = `Operands could not be broadcast together with shapes ` +
`${shapeA} and ${shapeB}.`;
throw Error(errMsg);
}
else {
result[l - i - 1] = a;
}
}
return result;
}
function equal_(a, b) {
let $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
let $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Equal, inputs);
}
const equal$2 = op({ equal_ });
function where_(condition, a, b) {
const $a = convertToTensor(a, 'a', 'where');
const $b = convertToTensor(b, 'b', 'where');
const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
const $broadcastedCondition = broadcastTo($condition, broadcastShape);
const $broadcastedA = broadcastTo($a, broadcastShape);
const $broadcastedB = broadcastTo($b, broadcastShape);
const inputs = {
condition: $broadcastedCondition,
t: $broadcastedA,
e: $broadcastedB
};
return ENGINE.runKernel(Select, inputs);
}
const where = op({ where_ });
function zerosLike_(x) {
const $x = convertToTensor(x, 'x', 'zerosLike');
const inputs = { x: $x };
return ENGINE.runKernel(ZerosLike, inputs);
}
const zerosLike$2 = op({ zerosLike_ });
function elu_(x) {
const $x = convertToTensor(x, 'x', 'elu', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Elu$1, inputs);
}
const elu$3 = op({ elu_ });
function erf_(x) {
let $x = convertToTensor(x, 'x', 'erf');
assert$1($x.dtype === 'int32' || $x.dtype === 'float32', () => 'Input dtype must be `int32` or `float32`.');
if ($x.dtype === 'int32') {
$x = cast$3($x, 'float32');
}
const inputs = { x: $x };
return ENGINE.runKernel(Erf, inputs);
}
const erf$2 = op({ erf_ });
function axesAreInnerMostDims(axes, rank) {
for (let i = 0; i < axes.length; ++i) {
if (axes[axes.length - i - 1] !== rank - 1 - i) {
return false;
}
}
return true;
}
function combineLocations(outputLoc, reduceLoc, axes) {
const rank = outputLoc.length + reduceLoc.length;
const loc = [];
let outIdx = 0;
let reduceIdx = 0;
for (let dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
loc.push(outputLoc[outIdx++]);
}
else {
loc.push(reduceLoc[reduceIdx++]);
}
}
return loc;
}
function computeOutAndReduceShapes(aShape, axes) {
const outShape = [];
const rank = aShape.length;
for (let dim = 0; dim < rank; dim++) {
if (axes.indexOf(dim) === -1) {
outShape.push(aShape[dim]);
}
}
const reduceShape = axes.map(dim => aShape[dim]);
return [outShape, reduceShape];
}
function expandShapeToKeepDim(shape, axes) {
const reduceSubShape = axes.map(x => 1);
return combineLocations(shape, reduceSubShape, axes);
}
function assertAxesAreInnerMostDims(msg, axes, rank) {
assert$1(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
`Got axes ${axes} and rank-${rank} input.`);
}
function getAxesPermutation(axes, rank) {
if (axesAreInnerMostDims(axes, rank)) {
return null;
}
const result = [];
for (let i = 0; i < rank; ++i) {
if (axes.indexOf(i) === -1) {
result.push(i);
}
}
axes.forEach(axis => result.push(axis));
return result;
}
function getUndoAxesPermutation(axes) {
return axes.map((axis, i) => [i, axis])
.sort((a, b) => a[1] - b[1])
.map(x => x[0]);
}
function getInnerMostAxes(numAxes, rank) {
const res = [];
for (let i = rank - numAxes; i < rank; ++i) {
res.push(i);
}
return res;
}
function max_(x, axis = null, keepDims = false) {
const $x = convertToTensor(x, 'x', 'max');
const inputs = { x: $x };
const attrs = { reductionIndices: axis, keepDims };
return ENGINE.runKernel(Max, inputs, attrs);
}
const max$2 = op({ max_ });
function min_(x, axis = null, keepDims = false) {
const $x = convertToTensor(x, 'x', 'min');
const inputs = { x: $x };
const attrs = { axis, keepDims };
return ENGINE.runKernel(Min, inputs, attrs);
}
const min$2 = op({ min_ });
function pow_(base, exp) {
let $base = convertToTensor(base, 'base', 'pow');
let $exp = convertToTensor(exp, 'exp', 'pow');
[$base, $exp] = makeTypesMatch($base, $exp);
const inputs = { a: $base, b: $exp };
return ENGINE.runKernel(Pow, inputs);
}
const pow$2 = op({ pow_ });
function makeTensor(values, shape, inferredShape, dtype) {
if (dtype == null) {
dtype = inferDtype(values);
}
else if (dtype === 'complex64') {
throw new Error(`Cannot construct a complex64 tensor directly. ` +
`Please use tf.complex(real, imag).`);
}
if (isWebGPUData(values) || isWebGLData(values)) {
if (dtype !== 'float32' && dtype !== 'int32') {
throw new Error(`Creating tensor from GPU data only supports ` +
`'float32'|'int32' dtype, while the dtype is ${dtype}.`);
}
return ENGINE.backend.createTensorFromGPUData(values, shape || inferredShape, dtype);
}
if (!isTypedArray(values) && !Array.isArray(values) &&
typeof values !== 'number' && typeof values !== 'boolean' &&
typeof values !== 'string') {
throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
'an array of numbers/booleans/strings, or a TypedArray');
}
if (shape != null) {
assertNonNegativeIntegerDimensions(shape);
const providedSize = sizeFromShape(shape);
const inferredSize = sizeFromShape(inferredShape);
assert$1(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` +
`${providedSize} values but has ${inferredSize}`);
for (let i = 0; i < inferredShape.length; ++i) {
const inferred = inferredShape[i];
const flatDimsDontMatch = i === inferredShape.length - 1 ?
inferred !== sizeFromShape(shape.slice(i)) :
true;
assert$1(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` +
`(${inferredShape}) does not match the provided ` +
`shape (${shape}). `);
}
}
if (!isTypedArray(values) && !Array.isArray(values)) {
values = [values];
}
shape = shape || inferredShape;
values = dtype !== 'string' ?
toTypedArray(values, dtype) :
flatten$1(values, [], true);
return ENGINE.makeTensor(values, shape, dtype);
}
function scalar(value, dtype) {
if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
dtype !== 'complex64') {
throw new Error('Error creating a new Scalar: value must be a primitive ' +
'(number|boolean|string)');
}
if (dtype === 'string' && isTypedArray(value) &&
!(value instanceof Uint8Array)) {
throw new Error('When making a scalar from encoded string, ' +
'the value must be `Uint8Array`.');
}
const shape = [];
const inferredShape = [];
return makeTensor(value, shape, inferredShape, dtype);
}
function sqrt_(x) {
const $x = convertToTensor(x, 'x', 'sqrt', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Sqrt, inputs);
}
const sqrt$2 = op({ sqrt_ });
function square_(x) {
const $x = convertToTensor(x, 'x', 'square');
const attrs = {};
return ENGINE.runKernel('Square', { x: $x }, attrs);
}
const square$2 = op({ square_ });
function sum_(x, axis = null, keepDims = false) {
let $x = convertToTensor(x, 'x', 'sum');
if ($x.dtype === 'bool') {
$x = cast$3($x, 'int32');
}
const inputs = { x: $x };
const attrs = { axis, keepDims };
return ENGINE.runKernel(Sum, inputs, attrs);
}
const sum$2 = op({ sum_ });
function norm_(x, ord = 'euclidean', axis = null, keepDims = false) {
x = convertToTensor(x, 'x', 'norm');
const norm = normImpl(x, ord, axis);
let keepDimsShape = norm.shape;
if (keepDims) {
const axes = parseAxisParam(axis, x.shape);
keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
}
return reshape$2(norm, keepDimsShape);
}
function normImpl(x, p, axis = null) {
if (x.rank === 0) {
return abs$2(x);
}
if (x.rank !== 1 && axis === null) {
return normImpl(reshape$2(x, [-1]), p, axis);
}
if (x.rank === 1 || typeof axis === 'number' ||
Array.isArray(axis) && axis.length === 1) {
if (p === 1) {
return sum$2(abs$2(x), axis);
}
if (p === Infinity) {
return max$2(abs$2(x), axis);
}
if (p === -Infinity) {
return min$2(abs$2(x), axis);
}
if (p === 'euclidean' || p === 2) {
return sqrt$2(sum$2(pow$2(abs$2(x), scalar(2, 'int32')), axis));
}
throw new Error(`Error in norm: invalid ord value: ${p}`);
}
if (Array.isArray(axis) && axis.length === 2) {
if (p === 1) {
return max$2(sum$2(abs$2(x), axis[0]), axis[1] - 1);
}
if (p === Infinity) {
return max$2(sum$2(abs$2(x), axis[1]), axis[0]);
}
if (p === -Infinity) {
return min$2(sum$2(abs$2(x), axis[1]), axis[0]);
}
if (p === 'fro' || p === 'euclidean') {
return sqrt$2(sum$2(square$2(x), axis));
}
throw new Error(`Error in norm: invalid ord value: ${p}`);
}
throw new Error(`Error in norm: invalid axis: ${axis}`);
}
const norm = op({ norm_ });
function exp_(x) {
const $x = convertToTensor(x, 'x', 'exp');
const inputs = { x: $x };
return ENGINE.runKernel(Exp, inputs);
}
const exp$2 = op({ exp_ });
function expandDims_(x, axis = 0) {
const $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
assert$1(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');
const inputs = { input: $x };
const attrs = { dim: axis };
return ENGINE.runKernel(ExpandDims, inputs, attrs);
}
const expandDims$3 = op({ expandDims_ });
function tile_(x, reps) {
const $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
assert$1($x.rank === reps.length, () => `Error in transpose: rank of input ${$x.rank} ` +
`must match length of reps ${reps}.`);
const inputs = { x: $x };
const attrs = { reps };
return ENGINE.runKernel(Tile, inputs, attrs);
}
const tile$3 = op({ tile_ });
function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
if (numColumns == null) {
numColumns = numRows;
}
const buff = buffer([numRows, numColumns], dtype);
const n = numRows <= numColumns ? numRows : numColumns;
for (let i = 0; i < n; ++i) {
buff.set(1, i, i);
}
const out = reshape$2(buff.toTensor(), [numRows, numColumns]);
if (batchShape == null) {
return out;
}
else {
if (batchShape.length === 1) {
return tile$3(expandDims$3(out, 0), [batchShape[0], 1, 1]);
}
else if (batchShape.length === 2) {
return tile$3(expandDims$3(expandDims$3(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
}
else if (batchShape.length === 3) {
return tile$3(expandDims$3(expandDims$3(expandDims$3(out, 0), 0), 0), [
batchShape[0], batchShape[1], batchShape[2], 1, 1
]);
}
else {
throw new Error(`eye() currently supports only 1D and 2D ` +
`batchShapes, but received ${batchShape.length}D.`);
}
}
}
const eye = op({ eye_ });
function floor_(x) {
const $x = convertToTensor(x, 'x', 'floor', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Floor, inputs);
}
const floor$2 = op({ floor_ });
function gather_(x, indices, axis = 0, batchDims = 0) {
const $x = convertToTensor(x, 'x', 'gather');
const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
const inputs = { x: $x, indices: $indices };
const attrs = { axis, batchDims };
return ENGINE.runKernel(GatherV2, inputs, attrs);
}
const gather$1 = op({ gather_ });
function greater_(a, b) {
let $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
let $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Greater, inputs);
}
const greater$2 = op({ greater_ });
function greaterEqual_(a, b) {
let $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
let $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(GreaterEqual, inputs);
}
const greaterEqual$2 = op({ greaterEqual_ });
function imag_(input) {
const $input = convertToTensor(input, 'input', 'imag');
const inputs = { input: $input };
return ENGINE.runKernel(Imag, inputs);
}
const imag$2 = op({ imag_ });
function leakyRelu_(x, alpha = 0.2) {
const $x = convertToTensor(x, 'x', 'leakyRelu');
const inputs = { x: $x };
const attrs = { alpha };
return ENGINE.runKernel(LeakyRelu, inputs, attrs);
}
const leakyRelu$2 = op({ leakyRelu_ });
function less_(a, b) {
let $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
let $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Less, inputs);
}
const less$2 = op({ less_ });
function lessEqual_(a, b) {
let $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
let $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(LessEqual, inputs);
}
const lessEqual$2 = op({ lessEqual_ });
function log_(x) {
const $x = convertToTensor(x, 'x', 'log', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Log, inputs);
}
const log$2 = op({ log_ });
function log1p_(x) {
const $x = convertToTensor(x, 'x', 'log1p');
const inputs = { x: $x };
return ENGINE.runKernel(Log1p, inputs);
}
const log1p$2 = op({ log1p_ });
function variableGrads(f, varList) {
assert$1(isFunction(f), () => 'The f passed in variableGrads(f) must be a function');
assert$1(varList == null ||
Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' +
'of variables');
const specifiedVarList = varList != null;
if (!specifiedVarList) {
varList = [];
for (const varName in ENGINE.registeredVariables) {
varList.push(ENGINE.registeredVariables[varName]);
}
}
const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null;
const originalVarCount = varList.length;
varList = varList.filter(variable => variable.trainable);
assert$1(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` +
`be trainable, but none of the ${originalVarCount} variables is ` +
`trainable.`);
const allowNoGradients = true;
const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients);
assert$1(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' +
'the loss function y=f(x). Please make sure the operations that ' +
'use variables are inside the function f passed to minimize().');
assert$1(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` +
`returned a rank-${value.rank} tensor`);
const namedGrads = {};
varList.forEach((v, i) => {
if (grads[i] != null) {
namedGrads[v.name] = grads[i];
}
});
if (specifiedNonTrainable != null) {
specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
}
return { value, grads: namedGrads };
}
function customGrad(f) {
return ENGINE.customGrad(f);
}
function neg_(x) {
const $x = convertToTensor(x, 'x', 'neg');
const inputs = { x: $x };
return ENGINE.runKernel(Neg, inputs);
}
const neg$2 = op({ neg_ });
function softplus_(x) {
const $x = convertToTensor(x, 'x', 'softplus');
const inputs = { x: $x };
return ENGINE.runKernel(Softplus$1, inputs);
}
const softplus$2 = op({ softplus_ });
function sub_(a, b) {
let $a = convertToTensor(a, 'a', 'sub');
let $b = convertToTensor(b, 'b', 'sub');
[$a, $b] = makeTypesMatch($a, $b);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Sub, inputs);
}
const sub$2 = op({ sub_ });
function logSoftmax_(logits, axis = -1) {
const $logits = convertToTensor(logits, 'logits', 'logSoftmax');
if (axis === -1) {
axis = $logits.rank - 1;
}
if (axis !== $logits.rank - 1) {
throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
`Logits was rank ${$logits.rank} and axis was ${axis}`);
}
const customOp = customGrad((logits, save) => {
const keepDims = true;
const xMax = max$2(logits, axis, true);
const shifted = sub$2(logits, xMax);
const value = sub$2(cast$3(shifted, 'float32'), log$2(sum$2(exp$2(shifted), axis, keepDims)));
save([value]);
const gradFunc = (dy, saved) => {
const [value] = saved;
const keepDims = true;
const softmax = exp$2(value);
return sub$2(dy, mul(sum$2(dy, axis, keepDims), softmax));
};
return { value, gradFunc };
});
return customOp($logits);
}
const logSoftmax = op({ logSoftmax_ });
function logicalAnd_(a, b) {
const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(LogicalAnd, inputs);
}
const logicalAnd$2 = op({ logicalAnd_ });
function logicalNot_(x) {
const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
const inputs = { x: $x };
return ENGINE.runKernel(LogicalNot, inputs);
}
const logicalNot$2 = op({ logicalNot_ });
function maximum_(a, b) {
let $a = convertToTensor(a, 'a', 'maximum');
let $b = convertToTensor(b, 'b', 'maximum');
[$a, $b] = makeTypesMatch($a, $b);
if ($a.dtype === 'bool') {
$a = cast$3($a, 'int32');
$b = cast$3($b, 'int32');
}
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Maximum, inputs);
}
const maximum$2 = op({ maximum_ });
function mean_(x, axis = null, keepDims = false) {
const $x = convertToTensor(x, 'x', 'mean');
const inputs = { x: $x };
const attrs = { axis, keepDims };
return ENGINE.runKernel(Mean, inputs, attrs);
}
const mean$1 = op({ mean_ });
function zeros$1(shape, dtype = 'float32') {
assertNonNegativeIntegerDimensions(shape);
if (dtype === 'complex64') {
const real = zeros$1(shape, 'float32');
const imag = zeros$1(shape, 'float32');
return complex$2(real, imag);
}
const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
return ENGINE.makeTensor(values, shape, dtype);
}
function ones(shape, dtype = 'float32') {
assertNonNegativeIntegerDimensions(shape);
if (dtype === 'complex64') {
const real = ones(shape, 'float32');
const imag = zeros$1(shape, 'float32');
return complex$2(real, imag);
}
const values = makeOnesTypedArray(sizeFromShape(shape), dtype);
return ENGINE.makeTensor(values, shape, dtype);
}
function minimum_(a, b) {
let $a = convertToTensor(a, 'a', 'minimum');
let $b = convertToTensor(b, 'b', 'minimum');
[$a, $b] = makeTypesMatch($a, $b);
if ($a.dtype === 'bool') {
$a = cast$3($a, 'int32');
$b = cast$3($b, 'int32');
}
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(Minimum, inputs);
}
const minimum$2 = op({ minimum_ });
function notEqual_(a, b) {
let $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
let $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
[$a, $b] = makeTypesMatch($a, $b);
assertAndGetBroadcastShape($a.shape, $b.shape);
const inputs = { a: $a, b: $b };
return ENGINE.runKernel(NotEqual, inputs);
}
const notEqual$2 = op({ notEqual_ });
function oneHot_(indices, depth, onValue = 1, offValue = 0, dtype = 'int32') {
if (depth < 2) {
throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
}
const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
const inputs = { indices: $indices };
const attrs = { dtype, depth, onValue, offValue };
return ENGINE.runKernel(OneHot, inputs, attrs);
}
const oneHot$2 = op({ oneHot_ });
function onesLike_(x) {
const $x = convertToTensor(x, 'x', 'onesLike');
const inputs = { x: $x };
return ENGINE.runKernel(OnesLike, inputs);
}
const onesLike$2 = op({ onesLike_ });
function pad_(x, paddings, constantValue = 0) {
const $x = convertToTensor(x, 'x', 'pad');
if ($x.rank === 0) {
throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
}
const attrs = { paddings, constantValue };
const inputs = { x: $x };
return ENGINE.runKernel(PadV2, inputs, attrs);
}
const pad = op({ pad_ });
function spaceToBatchND_(x, blockShape, paddings) {
const $x = convertToTensor(x, 'x', 'spaceToBatchND');
assert$1($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
assert$1(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
assert$1($x.shape.reduce((a, b, i) => {
if (i > 0 && i <= blockShape.length) {
return a &&
((b + paddings[i - 1][0] + paddings[i - 1][1]) %
blockShape[i - 1] ===
0);
}
return a;
}, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
const inputs = { x: $x };
const attrs = { blockShape, paddings };
return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
}
const spaceToBatchND$2 = op({ spaceToBatchND_ });
function prelu_(x, alpha) {
const $x = convertToTensor(x, 'x', 'prelu');
const $alpha = convertToTensor(alpha, 'alpha', 'prelu');
const inputs = { x: $x, alpha: $alpha };
return ENGINE.runKernel(Prelu, inputs);
}
const prelu$2 = op({ prelu_ });
var alea$1 = {exports: {}};
(function (module) {
(function(global, module, define) {
function Alea(seed) {
var me = this, mash = Mash();
me.next = function() {
var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10;
me.s0 = me.s1;
me.s1 = me.s2;
return me.s2 = t - (me.c = t | 0);
};
me.c = 1;
me.s0 = mash(' ');
me.s1 = mash(' ');
me.s2 = mash(' ');
me.s0 -= mash(seed);
if (me.s0 < 0) { me.s0 += 1; }
me.s1 -= mash(seed);
if (me.s1 < 0) { me.s1 += 1; }
me.s2 -= mash(seed);
if (me.s2 < 0) { me.s2 += 1; }
mash = null;
}
function copy(f, t) {
t.c = f.c;
t.s0 = f.s0;
t.s1 = f.s1;
t.s2 = f.s2;
return t;
}
function impl(seed, opts) {
var xg = new Alea(seed),
state = opts && opts.state,
prng = xg.next;
prng.int32 = function() { return (xg.next() * 0x100000000) | 0; };
prng.double = function() {
return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16;
};
prng.quick = prng;
if (state) {
if (typeof(state) == 'object') copy(state, xg);
prng.state = function() { return copy(xg, {}); };
}
return prng;
}
function Mash() {
var n = 0xefc8249d;
var mash = function(data) {
data = String(data);
for (var i = 0; i < data.length; i++) {
n += data.charCodeAt(i);
var h = 0.02519603282416938 * n;
n = h >>> 0;
h -= n;
h *= n;
n = h >>> 0;
h -= n;
n += h * 0x100000000;
}
return (n >>> 0) * 2.3283064365386963e-10;
};
return mash;
}
if (module && module.exports) {
module.exports = impl;
} else {
this.alea = impl;
}
})(
commonjsGlobal,
module);
} (alea$1));
var aleaExports = alea$1.exports;
var xor128$1 = {exports: {}};
(function (module) {
(function(global, module, define) {
function XorGen(seed) {
var me = this, strseed = '';
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.next = function() {
var t = me.x ^ (me.x << 11);
me.x = me.y;
me.y = me.z;
me.z = me.w;
return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
};
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
prng.double = function() {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof(state) == 'object') copy(state, xg);
prng.state = function() { return copy(xg, {}); };
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else {
this.xor128 = impl;
}
})(
commonjsGlobal,
module);
} (xor128$1));
var xor128Exports = xor128$1.exports;
var xorwow$1 = {exports: {}};
(function (module) {
(function(global, module, define) {
function XorGen(seed) {
var me = this, strseed = '';
me.next = function() {
var t = (me.x ^ (me.x >>> 2));
me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;
return (me.d = (me.d + 362437 | 0)) +
(me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
};
me.x = 0;
me.y = 0;
me.z = 0;
me.w = 0;
me.v = 0;
if (seed === (seed | 0)) {
me.x = seed;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 64; k++) {
me.x ^= strseed.charCodeAt(k) | 0;
if (k == strseed.length) {
me.d = me.x << 10 ^ me.x >>> 4;
}
me.next();
}
}
function copy(f, t) {
t.x = f.x;
t.y = f.y;
t.z = f.z;
t.w = f.w;
t.v = f.v;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
prng.double = function() {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof(state) == 'object') copy(state, xg);
prng.state = function() { return copy(xg, {}); };
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else {
this.xorwow = impl;
}
})(
commonjsGlobal,
module);
} (xorwow$1));
var xorwowExports = xorwow$1.exports;
var xorshift7$1 = {exports: {}};
(function (module) {
(function(global, module, define) {
function XorGen(seed) {
var me = this;
me.next = function() {
var X = me.x, i = me.i, t, v;
t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);
t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);
t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);
t = X[(i + 4) & 7]; v ^= t ^ (t << 7);
t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);
X[i] = v;
me.i = (i + 1) & 7;
return v;
};
function init(me, seed) {
var j, X = [];
if (seed === (seed | 0)) {
X[0] = seed;
} else {
seed = '' + seed;
for (j = 0; j < seed.length; ++j) {
X[j & 7] = (X[j & 7] << 15) ^
(seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
}
}
while (X.length < 8) X.push(0);
for (j = 0; j < 8 && X[j] === 0; ++j);
if (j == 8) X[7] = -1;
me.x = X;
me.i = 0;
for (j = 256; j > 0; --j) {
me.next();
}
}
init(me, seed);
}
function copy(f, t) {
t.x = f.x.slice();
t.i = f.i;
return t;
}
function impl(seed, opts) {
if (seed == null) seed = +(new Date);
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
prng.double = function() {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.x) copy(state, xg);
prng.state = function() { return copy(xg, {}); };
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else {
this.xorshift7 = impl;
}
})(
commonjsGlobal,
module);
} (xorshift7$1));
var xorshift7Exports = xorshift7$1.exports;
var xor4096$1 = {exports: {}};
(function (module) {
(function(global, module, define) {
function XorGen(seed) {
var me = this;
me.next = function() {
var w = me.w,
X = me.X, i = me.i, t, v;
me.w = w = (w + 0x61c88647) | 0;
v = X[(i + 34) & 127];
t = X[i = ((i + 1) & 127)];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
v = X[i] = v ^ t;
me.i = i;
return (v + (w ^ (w >>> 16))) | 0;
};
function init(me, seed) {
var t, v, i, j, w, X = [], limit = 128;
if (seed === (seed | 0)) {
v = seed;
seed = null;
} else {
seed = seed + '\0';
v = 0;
limit = Math.max(limit, seed.length);
}
for (i = 0, j = -32; j < limit; ++j) {
if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
if (j === 0) w = v;
v ^= v << 10;
v ^= v >>> 15;
v ^= v << 4;
v ^= v >>> 13;
if (j >= 0) {
w = (w + 0x61c88647) | 0;
t = (X[j & 127] ^= (v + w));
i = (0 == t) ? i + 1 : 0;
}
}
if (i >= 128) {
X[(seed && seed.length || 0) & 127] = -1;
}
i = 127;
for (j = 4 * 128; j > 0; --j) {
v = X[(i + 34) & 127];
t = X[i = ((i + 1) & 127)];
v ^= v << 13;
t ^= t << 17;
v ^= v >>> 15;
t ^= t >>> 12;
X[i] = v ^ t;
}
me.w = w;
me.X = X;
me.i = i;
}
init(me, seed);
}
function copy(f, t) {
t.i = f.i;
t.w = f.w;
t.X = f.X.slice();
return t;
}
function impl(seed, opts) {
if (seed == null) seed = +(new Date);
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
prng.double = function() {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (state.X) copy(state, xg);
prng.state = function() { return copy(xg, {}); };
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else {
this.xor4096 = impl;
}
})(
commonjsGlobal,
module);
} (xor4096$1));
var xor4096Exports = xor4096$1.exports;
var tychei$1 = {exports: {}};
(function (module) {
(function(global, module, define) {
function XorGen(seed) {
var me = this, strseed = '';
me.next = function() {
var b = me.b, c = me.c, d = me.d, a = me.a;
b = (b << 25) ^ (b >>> 7) ^ c;
c = (c - d) | 0;
d = (d << 24) ^ (d >>> 8) ^ a;
a = (a - b) | 0;
me.b = b = (b << 20) ^ (b >>> 12) ^ c;
me.c = c = (c - d) | 0;
me.d = (d << 16) ^ (c >>> 16) ^ a;
return me.a = (a - b) | 0;
};
me.a = 0;
me.b = 0;
me.c = 2654435769 | 0;
me.d = 1367130551;
if (seed === Math.floor(seed)) {
me.a = (seed / 0x100000000) | 0;
me.b = seed | 0;
} else {
strseed += seed;
}
for (var k = 0; k < strseed.length + 20; k++) {
me.b ^= strseed.charCodeAt(k) | 0;
me.next();
}
}
function copy(f, t) {
t.a = f.a;
t.b = f.b;
t.c = f.c;
t.d = f.d;
return t;
}
function impl(seed, opts) {
var xg = new XorGen(seed),
state = opts && opts.state,
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
prng.double = function() {
do {
var top = xg.next() >>> 11,
bot = (xg.next() >>> 0) / 0x100000000,
result = (top + bot) / (1 << 21);
} while (result === 0);
return result;
};
prng.int32 = xg.next;
prng.quick = prng;
if (state) {
if (typeof(state) == 'object') copy(state, xg);
prng.state = function() { return copy(xg, {}); };
}
return prng;
}
if (module && module.exports) {
module.exports = impl;
} else {
this.tychei = impl;
}
})(
commonjsGlobal,
module);
} (tychei$1));
var tycheiExports = tychei$1.exports;
var seedrandom$1 = {exports: {}};
(function (module) {
(function (global, pool, math) {
var width = 256,
chunks = 6,
digits = 52,
rngname = 'random',
startdenom = math.pow(width, chunks),
significance = math.pow(2, digits),
overflow = significance * 2,
mask = width - 1,
nodecrypto;
function seedrandom(seed, options, callback) {
var key = [];
options = (options == true) ? { entropy: true } : (options || {});
var shortseed = mixkey(flatten(
options.entropy ? [seed, tostring(pool)] :
(seed == null) ? autoseed() : seed, 3), key);
var arc4 = new ARC4(key);
var prng = function() {
var n = arc4.g(chunks),
d = startdenom,
x = 0;
while (n < significance) {
n = (n + x) * width;
d *= width;
x = arc4.g(1);
}
while (n >= overflow) {
n /= 2;
d /= 2;
x >>>= 1;
}
return (n + x) / d;
};
prng.int32 = function() { return arc4.g(4) | 0; };
prng.quick = function() { return arc4.g(4) / 0x100000000; };
prng.double = prng;
mixkey(tostring(arc4.S), pool);
return (options.pass || callback ||
function(prng, seed, is_math_call, state) {
if (state) {
if (state.S) { copy(state, arc4); }
prng.state = function() { return copy(arc4, {}); };
}
if (is_math_call) { math[rngname] = prng; return seed; }
else return prng;
})(
prng,
shortseed,
'global' in options ? options.global : (this == math),
options.state);
}
function ARC4(key) {
var t, keylen = key.length,
me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
if (!keylen) { key = [keylen++]; }
while (i < width) {
s[i] = i++;
}
for (i = 0; i < width; i++) {
s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
s[j] = t;
}
(me.g = function(count) {
var t, r = 0,
i = me.i, j = me.j, s = me.S;
while (count--) {
t = s[i = mask & (i + 1)];
r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
}
me.i = i; me.j = j;
return r;
})(width);
}
function copy(f, t) {
t.i = f.i;
t.j = f.j;
t.S = f.S.slice();
return t;
}
function flatten(obj, depth) {
var result = [], typ = (typeof obj), prop;
if (depth && typ == 'object') {
for (prop in obj) {
try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}
}
}
return (result.length ? result : typ == 'string' ? obj : obj + '\0');
}
function mixkey(seed, key) {
var stringseed = seed + '', smear, j = 0;
while (j < stringseed.length) {
key[mask & j] =
mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
}
return tostring(key);
}
function autoseed() {
try {
var out;
if (nodecrypto && (out = nodecrypto.randomBytes)) {
out = out(width);
} else {
out = new Uint8Array(width);
(global.crypto || global.msCrypto).getRandomValues(out);
}
return tostring(out);
} catch (e) {
var browser = global.navigator,
plugins = browser && browser.plugins;
return [+new Date, global, plugins, global.screen, tostring(pool)];
}
}
function tostring(a) {
return String.fromCharCode.apply(0, a);
}
mixkey(math.random(), pool);
if (module.exports) {
module.exports = seedrandom;
try {
nodecrypto = require('crypto');
} catch (ex) {}
} else {
math['seed' + rngname] = seedrandom;
}
})(
(typeof self !== 'undefined') ? self : commonjsGlobal,
[],
Math
);
} (seedrandom$1));
var seedrandomExports = seedrandom$1.exports;
var alea = aleaExports;
var xor128 = xor128Exports;
var xorwow = xorwowExports;
var xorshift7 = xorshift7Exports;
var xor4096 = xor4096Exports;
var tychei = tycheiExports;
var sr = seedrandomExports;
sr.alea = alea;
sr.xor128 = xor128;
sr.xorwow = xorwow;
sr.xorshift7 = xorshift7;
sr.xor4096 = xor4096;
sr.tychei = tychei;
var seedrandom = sr;
class MPRandGauss {
constructor(mean, stdDeviation, dtype, truncated, seed) {
this.mean = mean;
this.stdDev = stdDeviation;
this.dtype = dtype;
this.nextVal = NaN;
this.truncated = truncated;
if (this.truncated) {
this.upper = this.mean + this.stdDev * 2;
this.lower = this.mean - this.stdDev * 2;
}
const seedValue = seed ? seed : Math.random();
this.random = seedrandom.alea(seedValue.toString());
}
nextValue() {
if (!isNaN(this.nextVal)) {
const value = this.nextVal;
this.nextVal = NaN;
return value;
}
let resultX, resultY;
let isValid = false;
while (!isValid) {
let v1, v2, s;
do {
v1 = 2 * this.random() - 1;
v2 = 2 * this.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s >= 1 || s === 0);
const mul = Math.sqrt(-2 * Math.log(s) / s);
resultX = this.mean + this.stdDev * v1 * mul;
resultY = this.mean + this.stdDev * v2 * mul;
if (!this.truncated || this.isValidTruncated(resultX)) {
isValid = true;
}
}
if (!this.truncated || this.isValidTruncated(resultY)) {
this.nextVal = this.convertValue(resultY);
}
return this.convertValue(resultX);
}
convertValue(value) {
if (this.dtype == null || this.dtype === 'float32') {
return value;
}
return Math.round(value);
}
isValidTruncated(value) {
return value <= this.upper && value >= this.lower;
}
}
class UniformRandom {
constructor(min = 0, max = 1, dtype, seed) {
this.canReturnFloat = () => (this.dtype == null || this.dtype === 'float32');
this.min = min;
this.range = max - min;
this.dtype = dtype;
if (seed == null) {
seed = Math.random();
}
if (typeof seed === 'number') {
seed = seed.toString();
}
if (!this.canReturnFloat() && this.range <= 1) {
throw new Error(`The difference between ${min} - ${max} <= 1 and dtype is not float`);
}
this.random = seedrandom.alea(seed);
}
convertValue(value) {
if (this.canReturnFloat()) {
return value;
}
return Math.round(value);
}
nextValue() {
return this.convertValue(this.min + this.range * this.random());
}
}
function randomNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
assertNonNegativeIntegerDimensions(shape);
if (dtype != null && dtype === 'bool') {
throw new Error(`Unsupported data type ${dtype}`);
}
const randGauss = new MPRandGauss(mean, stdDev, dtype, false , seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}
const randomNormal$1 = op({ randomNormal_ });
function randomUniform_(shape, minval = 0, maxval = 1, dtype = 'float32', seed) {
assertNonNegativeIntegerDimensions(shape);
const res = buffer(shape, dtype);
const random = new UniformRandom(minval, maxval, null, seed);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = random.nextValue();
}
return res.toTensor();
}
const randomUniform = op({ randomUniform_ });
function range$3(start, stop, step = 1, dtype = 'float32') {
if (step === 0) {
throw new Error('Cannot have a step of zero');
}
const attrs = { start, stop, step, dtype };
return ENGINE.runKernel(Range, {} , attrs);
}
function real_(input) {
const $input = convertToTensor(input, 'input', 'real');
const inputs = { input: $input };
return ENGINE.runKernel(Real, inputs);
}
const real$2 = op({ real_ });
function relu_(x) {
const $x = convertToTensor(x, 'x', 'relu');
const inputs = { x: $x };
return ENGINE.runKernel(Relu$1, inputs);
}
const relu$2 = op({ relu_ });
function relu6_(x) {
const $x = convertToTensor(x, 'x', 'relu6');
const inputs = { x: $x };
return ENGINE.runKernel(Relu6$1, inputs);
}
const relu6$2 = op({ relu6_ });
function reverse_(x, axis) {
const $x = convertToTensor(x, 'x', 'reverse');
const inputs = { x: $x };
const attrs = { dims: axis };
return ENGINE.runKernel(Reverse, inputs, attrs);
}
const reverse$2 = op({ reverse_ });
function rsqrt_(x) {
const $x = convertToTensor(x, 'x', 'rsqrt', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Rsqrt, inputs);
}
const rsqrt$2 = op({ rsqrt_ });
function selu_(x) {
const $x = convertToTensor(x, 'x', 'selu');
const inputs = { x: $x };
return ENGINE.runKernel(Selu$1, inputs);
}
const selu$2 = op({ selu_ });
function sin_(x) {
const $x = convertToTensor(x, 'x', 'sin', 'float32');
const inputs = { x: $x };
return ENGINE.runKernel(Sin, inputs);
}
const sin$2 = op({ sin_ });
function sinh_(x) {
const $x = convertToTensor(x, 'x', 'sinh');
const inputs = { x: $x };
return ENGINE.runKernel(Sinh, inputs);
}
const sinh$2 = op({ sinh_ });
function slice1d_(x, begin, size) {
const $x = convertToTensor(x, 'x', 'slice1d');
assert$1($x.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
return slice$2($x, [begin], [size]);
}
const slice1d = op({ slice1d_ });
function slice2d_(x, begin, size) {
const $x = convertToTensor(x, 'x', 'slice2d');
assert$1($x.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
return slice$2($x, begin, size);
}
const slice2d = op({ slice2d_ });
function slice3d_(x, begin, size) {
const $x = convertToTensor(x, 'x', 'slice3d');
assert$1($x.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
return slice$2($x, begin, size);
}
const slice3d = op({ slice3d_ });
function slice4d_(x, begin, size) {
const $x = convertToTensor(x, 'x', 'slice4d');
assert$1($x.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
return slice$2($x, begin, size);
}
const slice4d = op({ slice4d_ });
function softmax_(logits, dim = -1) {
const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
if (dim === -1) {
dim = $logits.rank - 1;
}
if (dim !== $logits.rank - 1) {
throw Error('Softmax along a non-last dimension is not yet supported. ' +
`Logits was rank ${$logits.rank} and dim was ${dim}`);
}
const inputs = { logits: $logits };
const attrs = { dim };
return ENGINE.runKernel(Softmax$1, inputs, attrs);
}
const softmax$2 = op({ softmax_ });
function split_(x, numOrSizeSplits, axis = 0) {
const $x = convertToTensor(x, 'x', 'split');
const inputs = { x: $x };
const attr = { numOrSizeSplits, axis };
return ENGINE.runKernel(SplitV, inputs, attr);
}
const split$1 = op({ split_ });
function squeeze_(x, axis) {
const $x = convertToTensor(x, 'x', 'squeeze', 'string_or_numeric');
return reshape$2($x, squeezeShape($x.shape, axis).newShape);
}
const squeeze = op({ squeeze_ });
function stack_(tensors, axis = 0) {
const $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
assert$1($tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');
if ($tensors.length > 0) {
assert$1(axis <= $tensors[0].rank, () => 'Axis must be <= rank of the tensor');
}
const inputs = $tensors;
const attrs = { axis };
return ENGINE.runKernel(Pack, inputs, attrs);
}
const stack = op({ stack_ });
function step_(x, alpha = 0.0) {
const $x = convertToTensor(x, 'x', 'step');
const inputs = { x: $x };
const attrs = { alpha };
return ENGINE.runKernel(Step, inputs, attrs);
}
const step$2 = op({ step_ });
function tensor(values, shape, dtype) {
const inferredShape = inferShape(values, dtype);
return makeTensor(values, shape, inferredShape, dtype);
}
function tensor1d(values, dtype) {
assertNonNull(values);
const inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 1) {
throw new Error('tensor1d() requires values to be a flat/TypedArray');
}
const shape = null;
return makeTensor(values, shape, inferredShape, dtype);
}
function tensor2d(values, shape, dtype) {
assertNonNull(values);
if (shape != null && shape.length !== 2) {
throw new Error('tensor2d() requires shape to have two numbers');
}
const inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 2 && inferredShape.length !== 1) {
throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
}
if (inferredShape.length === 1 && shape == null) {
throw new Error('tensor2d() requires shape to be provided when `values` ' +
'are a flat/TypedArray');
}
return makeTensor(values, shape, inferredShape, dtype);
}
function validateUpdateShape(shape, indices, updates) {
const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
`shape[sliceDim:], got updates.shape: ${updates.shape}` +
`, indices.shape: ${indices.shape}, shape: ${shape}` +
`, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;
if (updates.rank < batchDim) {
throw new Error(shapeError + ` update.rank < ${batchDim}. `);
}
if (shape.length < sliceDim + (updates.rank - batchDim)) {
throw new Error(shapeError +
` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);
}
if (updates.rank !== batchDim + shape.length - sliceDim) {
throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);
}
for (let d = 0; d < batchDim; ++d) {
if (updates.shape[d] !== indices.shape[d]) {
throw new Error(shapeError +
` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`);
}
}
for (let d = 0; d < updates.rank - batchDim; ++d) {
if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
throw new Error(shapeError +
` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`);
}
}
}
function validateInput(updates, indices, shape) {
if (indices.rank < 1) {
throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
` but the rank was ${indices.rank}.`);
}
if (updates.rank < 1) {
throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
` but the rank was ${updates.rank}.`);
}
if (indices.dtype !== 'int32') {
throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`);
}
if (shape.length < 1) {
throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`);
}
if (shape.length === 0) {
if (indices.size === 0) {
throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`);
}
if (updates.size === 0) {
throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`);
}
}
validateUpdateShape(shape, indices, updates);
}
function calculateShapes(updates, indices, shape) {
const indicesRank = indices.shape.length;
const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
const totalNd = shape.length;
let sliceSize = 1;
for (let i = sliceRank; i < totalNd; ++i) {
sliceSize *= shape[i];
}
const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];
const outputSize = sizeFromShape(shape);
return { sliceRank, numUpdates, sliceSize, strides, outputSize };
}
function truncatedNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
assertNonNegativeIntegerDimensions(shape);
if (dtype != null && dtype === 'bool') {
throw new Error(`Unsupported data type $ { dtype }`);
}
const randGauss = new MPRandGauss(mean, stdDev, dtype, true , seed);
const res = buffer(shape, dtype);
for (let i = 0; i < res.values.length; i++) {
res.values[i] = randGauss.nextValue();
}
return res.toTensor();
}
const truncatedNormal = op({ truncatedNormal_ });
function unsortedSegmentSum_(x, segmentIds, numSegments) {
const $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
assert$1(isInt(numSegments), () => 'numSegments must be of dtype int');
const inputs = { x: $x, segmentIds: $segmentIds };
const attrs = { numSegments };
return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
}
const unsortedSegmentSum$2 = op({ unsortedSegmentSum_ });
function unstack_(x, axis = 0) {
const $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
assert$1(axis >= -$x.shape.length && axis < $x.shape.length, () => `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);
const inputs = { value: $x };
const attrs = { axis };
return ENGINE.runKernel(Unpack, inputs, attrs);
}
const unstack = op({ unstack_ });
function variable(initialValue, trainable = true, name, dtype) {
return ENGINE.makeVariable(initialValue, trainable, name, dtype);
}
function whereImpl$2(condShape, condVals) {
const indices = [];
for (let i = 0; i < condVals.length; i++) {
if (condVals[i]) {
indices.push(i);
}
}
const inBuffer = buffer(condShape, 'int32');
const out = buffer([indices.length, condShape.length], 'int32');
for (let i = 0; i < indices.length; i++) {
const loc = inBuffer.indexToLoc(indices[i]);
const offset = i * condShape.length;
out.values.set(loc, offset);
}
return out.toTensor();
}
function transpose_(x, perm, conjugate) {
const $x = convertToTensor(x, 'x', 'transpose');
if (perm == null) {
perm = $x.shape.map((s, i) => i).reverse();
}
assert$1($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` +
`must match length of perm ${perm}.`);
perm.forEach(axis => {
assert$1(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +
` but got ${perm}`);
});
if ($x.rank <= 1) {
return $x.clone();
}
const inputs = { x: $x };
const attrs = { perm };
if ($x.dtype === 'complex64') {
return tidy(() => {
let $real = real$2($x);
let $imag = imag$2($x);
$real = ENGINE.runKernel(Transpose, { x: $real }, attrs);
$imag = ENGINE.runKernel(Transpose, { x: $imag }, attrs);
if (conjugate) {
$imag = neg$2($imag);
}
return complex$2($real, $imag);
});
}
return ENGINE.runKernel(Transpose, inputs, attrs);
}
const transpose$2 = op({ transpose_ });
function getNoiseShape(x, noiseShape) {
if (noiseShape == null) {
return x.shape.slice();
}
if (arraysEqual(x.shape, noiseShape)) {
return noiseShape;
}
if (x.shape.length === noiseShape.length) {
const newDimension = [];
for (let i = 0; i < x.shape.length; i++) {
if (noiseShape[i] == null && x.shape[i] != null) {
newDimension.push(x.shape[i]);
}
else {
newDimension.push(noiseShape[i]);
}
}
return newDimension;
}
return noiseShape;
}
function dropout_(x, rate, noiseShape, seed) {
const $x = convertToTensor(x, 'x', 'dropout');
assert$1($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
`scaled, but got a ${$x.dtype} tensor instead.`);
assert$1(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
if (rate === 0) {
return x instanceof Tensor ? $x.clone() : $x;
}
const $noiseShape = getNoiseShape($x, noiseShape);
const keepProb = 1 - rate;
const multiplier = div$1(floor$2(add$1(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
return mul($x, multiplier);
}
const dropout$2 = op({ dropout_ });
function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
let x4D = x;
if (x.rank === 3) {
x4D = reshape$2(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
let dy4D = dy;
if (dy4D.rank === 3) {
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
assert$1(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +
`${x4D.shape}.`);
assert$1(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +
`${dy4D.shape}.`);
assert$1(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +
`${filterShape}.`);
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
assert$1(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +
`match input depth in filter (${filterShape[2]}.`);
assert$1(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
`match output depth for filter (${filterShape[3]}).`);
checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
const inputs = { x: x4D, dy: dy4D };
const attrs = { strides, pad, dataFormat, dimRoundingMode, filterShape };
return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
}
const conv2DBackpropFilter$2 = op({ conv2DBackpropFilter_ });
function getFusedDyActivation(dy, y, activation) {
if (activation == null || activation === 'linear') {
return dy;
}
if (activation === 'relu') {
return mul(dy, step$2(y));
}
throw new Error(`Cannot compute gradient for fused activation ${activation}.`);
}
function getFusedBiasGradient(bias, dyActivation) {
let res = dyActivation;
const reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, bias.shape);
}
function applyActivation$1(x, activation, preluActivationWeights, leakyreluAlpha) {
if (activation === 'linear') {
return x;
}
else if (activation === 'relu') {
return relu$2(x);
}
else if (activation === 'elu') {
return elu$3(x);
}
else if (activation === 'relu6') {
return relu6$2(x);
}
else if (activation === 'prelu') {
return prelu$2(x, preluActivationWeights);
}
else if (activation === 'leakyrelu') {
return leakyRelu$2(x, leakyreluAlpha);
}
else if (activation === 'sigmoid') {
return sigmoid$2(x);
}
throw new Error(`Unknown fused activation ${activation}.`);
}
const shouldFuse = (gradientDepth, activation) => {
const gradientMode = gradientDepth > 0;
return !gradientMode || activation === 'linear';
};
function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations = [1, 1], dimRoundingMode) {
let x4D = x;
if (x.rank === 3) {
x4D = reshape$2(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
}
let dy4D = dy;
if (dy4D.rank === 3) {
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
const inputs = { x: x4D, dy: dy4D };
const attrs = { strides, pad, dimRoundingMode, dilations, filterShape };
return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
}
const depthwiseConv2dNativeBackpropFilter$2 = op({ depthwiseConv2dNativeBackpropFilter_ });
function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations = [1, 1], dimRoundingMode) {
let dy4D = dy;
let reshapedTo4D = false;
if (dy.rank === 3) {
reshapedTo4D = true;
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
}
const inputs = { dy: dy4D, filter };
const attrs = { strides, pad, dimRoundingMode, dilations, inputShape: xShape };
const res =
ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
if (reshapedTo4D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
const depthwiseConv2dNativeBackpropInput$2 = op({ depthwiseConv2dNativeBackpropInput_ });
function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha = 0.2, }) {
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
let result = matMul$1(a, b, transposeA, transposeB);
if (bias != null) {
result = add$1(result, bias);
}
return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
}
let $a = convertToTensor(a, 'a', 'fused matMul');
let $b = convertToTensor(b, 'b', 'fused matMul');
[$a, $b] = makeTypesMatch($a, $b);
const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
const outerDimsA = $a.shape.slice(0, -2);
const outerDimsB = $b.shape.slice(0, -2);
const batchDimA = sizeFromShape(outerDimsA);
const batchDimB = sizeFromShape(outerDimsB);
assert$1(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
`${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
`${$b.shape} and transposeA=${transposeA}` +
` and transposeB=${transposeB} must match.`);
const outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2));
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
const a3D = transposeA ?
reshape$2($a, [batchDimA, innerShapeA, outerShapeA]) :
reshape$2($a, [batchDimA, outerShapeA, innerShapeA]);
const b3D = transposeB ?
reshape$2($b, [batchDimB, outerShapeB, innerShapeB]) :
reshape$2($b, [batchDimB, innerShapeB, outerShapeB]);
let $bias;
if (bias != null) {
$bias = convertToTensor(bias, 'bias', 'fused matMul');
[$bias] = makeTypesMatch($bias, $a);
assertAndGetBroadcastShape(outShape, $bias.shape);
}
let $preluActivationWeights;
if (preluActivationWeights != null) {
$preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
}
const grad = (dy, saved) => {
const [a3D, b3D, y, $bias] = saved;
const dyActivation = getFusedDyActivation(reshape$2(dy, y.shape), y, activation);
let aDer;
let bDer;
if (!transposeA && !transposeB) {
aDer = matMul$1(dyActivation, b3D, false, true);
bDer = matMul$1(a3D, dyActivation, true, false);
}
else if (!transposeA && transposeB) {
aDer = matMul$1(dyActivation, b3D, false, false);
bDer = matMul$1(dyActivation, a3D, true, false);
}
else if (transposeA && !transposeB) {
aDer = matMul$1(b3D, dyActivation, false, true);
bDer = matMul$1(a3D, dyActivation, false, false);
}
else {
aDer = matMul$1(b3D, dyActivation, true, true);
bDer = matMul$1(dyActivation, a3D, true, true);
}
if (bias != null) {
const biasDer = getFusedBiasGradient($bias, dyActivation);
return [aDer, bDer, biasDer];
}
else {
return [aDer, bDer];
}
};
const inputs = {
a: a3D,
b: b3D,
bias: $bias,
preluActivationWeights: $preluActivationWeights
};
const attrs = { transposeA, transposeB, activation, leakyreluAlpha };
if (bias == null) {
const customOp = customGrad((a3D, b3D, save) => {
const res =
ENGINE.runKernel(_FusedMatMul, inputs, attrs);
save([a3D, b3D, res]);
return { value: reshape$2(res, outShape), gradFunc: grad };
});
return customOp(a3D, b3D);
}
else {
const customOpWithBias = customGrad((a3D, b3D, $bias, save) => {
const res =
ENGINE.runKernel(_FusedMatMul, inputs, attrs);
save([a3D, b3D, res, $bias]);
return { value: reshape$2(res, outShape), gradFunc: grad };
});
return customOpWithBias(a3D, b3D, $bias);
}
}
const matMul = op({ fusedMatMul_ });
function binaryInsert(arr, element, comparator) {
const index = binarySearch(arr, element, comparator);
const insertionPoint = index < 0 ? -(index + 1) : index;
arr.splice(insertionPoint, 0, element);
}
function binarySearch(arr, target, comparator) {
return binarySearch_(arr, target, comparator || defaultComparator);
}
function defaultComparator(a, b) {
return a > b ? 1 : a < b ? -1 : 0;
}
function binarySearch_(arr, target, comparator) {
let left = 0;
let right = arr.length;
let middle = 0;
let found = false;
while (left < right) {
middle = left + ((right - left) >>> 1);
const compareResult = comparator(target, arr[middle]);
if (compareResult > 0) {
left = middle + 1;
}
else {
right = middle;
found = !compareResult;
}
}
return found ? left : -left - 1;
}
function nonMaxSuppressionV3Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 );
}
function nonMaxSuppressionV4Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 , false , padToMaxOutputSize , true
);
}
function nonMaxSuppressionV5Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true );
}
function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor = false, padToMaxOutputSize = false, returnValidOutputs = false) {
const candidates = [];
for (let i = 0; i < scores.length; i++) {
if (scores[i] > scoreThreshold) {
candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 });
}
}
candidates.sort(ascendingComparator);
const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
const selectedIndices = [];
const selectedScores = [];
while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
const candidate = candidates.pop();
const { score: originalScore, boxIndex, suppressBeginIndex } = candidate;
if (originalScore < scoreThreshold) {
break;
}
let ignoreCandidate = false;
for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
if (iou >= iouThreshold) {
ignoreCandidate = true;
break;
}
candidate.score =
candidate.score * suppressWeight(iouThreshold, scale, iou);
if (candidate.score <= scoreThreshold) {
break;
}
}
candidate.suppressBeginIndex = selectedIndices.length;
if (!ignoreCandidate) {
if (candidate.score === originalScore) {
selectedIndices.push(boxIndex);
selectedScores.push(candidate.score);
}
else if (candidate.score > scoreThreshold) {
binaryInsert(candidates, candidate, ascendingComparator);
}
}
}
const validOutputs = selectedIndices.length;
const elemsToPad = maxOutputSize - validOutputs;
if (padToMaxOutputSize && elemsToPad > 0) {
selectedIndices.push(...new Array(elemsToPad).fill(0));
selectedScores.push(...new Array(elemsToPad).fill(0.0));
}
const result = { selectedIndices };
if (returnScoresTensor) {
result['selectedScores'] = selectedScores;
}
if (returnValidOutputs) {
result['validOutputs'] = validOutputs;
}
return result;
}
function intersectionOverUnion(boxes, i, j) {
const iCoord = boxes.subarray(i * 4, i * 4 + 4);
const jCoord = boxes.subarray(j * 4, j * 4 + 4);
const yminI = Math.min(iCoord[0], iCoord[2]);
const xminI = Math.min(iCoord[1], iCoord[3]);
const ymaxI = Math.max(iCoord[0], iCoord[2]);
const xmaxI = Math.max(iCoord[1], iCoord[3]);
const yminJ = Math.min(jCoord[0], jCoord[2]);
const xminJ = Math.min(jCoord[1], jCoord[3]);
const ymaxJ = Math.max(jCoord[0], jCoord[2]);
const xmaxJ = Math.max(jCoord[1], jCoord[3]);
const areaI = (ymaxI - yminI) * (xmaxI - xminI);
const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
if (areaI <= 0 || areaJ <= 0) {
return 0.0;
}
const intersectionYmin = Math.max(yminI, yminJ);
const intersectionXmin = Math.max(xminI, xminJ);
const intersectionYmax = Math.min(ymaxI, ymaxJ);
const intersectionXmax = Math.min(xmaxI, xmaxJ);
const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
Math.max(intersectionXmax - intersectionXmin, 0.0);
return intersectionArea / (areaI + areaJ - intersectionArea);
}
function suppressWeight(iouThreshold, scale, iou) {
const weight = Math.exp(scale * iou * iou);
return iou <= iouThreshold ? weight : 0.0;
}
function ascendingComparator(c1, c2) {
return (c1.score - c2.score) ||
((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
}
function bandPart_(a, numLower, numUpper) {
const $a = convertToTensor(a, 'a', 'bandPart');
assert$1($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);
const shape = $a.shape;
const [M, N] = $a.shape.slice(-2);
let $numLower;
let $numUpper;
if (typeof numLower === 'number') {
assert$1(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`);
assert$1(numLower <= M, () => `bandPart(): numLower (${numLower})` +
` must not be greater than the number of rows (${M}).`);
$numLower =
convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart');
}
else {
assert$1(numLower.dtype === 'int32', () => `bandPart(): numLower's dtype must be an int32.`);
$numLower = where(less$2(numLower, 0), M, minimum$2(numLower, M));
}
if (typeof numUpper === 'number') {
assert$1(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);
assert$1(numUpper <= N, () => `bandPart(): numUpper (${numUpper})` +
` must not be greater than the number of columns (${N}).`);
$numUpper =
convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart');
}
else {
assert$1(numUpper.dtype === 'int32', () => `bandPart(): numUpper's dtype must be an int32.`);
$numUpper = where(less$2(numUpper, 0), N, minimum$2(numUpper, N));
}
const i = reshape$2(range$3(0, M, 1, 'int32'), [-1, 1]);
const j = range$3(0, N, 1, 'int32');
const ij = sub$2(i, j);
const inBand = logicalAnd$2(lessEqual$2(ij, $numLower), greaterEqual$2(ij, neg$2($numUpper)));
const zero = zeros$1([M, N], $a.dtype);
return reshape$2(stack(unstack(reshape$2($a, [-1, M, N]))
.map(mat => where(inBand, mat, zero))), shape);
}
const bandPart = op({ bandPart_ });
function gramSchmidt_(xs) {
let inputIsTensor2D;
if (Array.isArray(xs)) {
inputIsTensor2D = false;
assert$1(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' +
'empty');
const dim = xs[0].shape[0];
for (let i = 1; i < xs.length; ++i) {
assert$1(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
`(${xs[i].shape[0]} vs. ${dim})`);
}
}
else {
inputIsTensor2D = true;
xs = split$1(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
}
assert$1(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` +
`number of dimensions (${xs[0].shape[0]}).`);
const ys = [];
const xs1d = xs;
for (let i = 0; i < xs.length; ++i) {
ys.push(ENGINE.tidy(() => {
let x = xs1d[i];
if (i > 0) {
for (let j = 0; j < i; ++j) {
const proj = mul(sum$2(mul(ys[j], x)), ys[j]);
x = sub$2(x, proj);
}
}
return div$1(x, norm(x, 'euclidean'));
}));
}
if (inputIsTensor2D) {
return stack(ys, 0);
}
else {
return ys;
}
}
const gramSchmidt = op({ gramSchmidt_ });
function qr_(x, fullMatrices = false) {
assert$1(x.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${x.rank}`);
if (x.rank === 2) {
return qr2d(x, fullMatrices);
}
else {
const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
.reduce((value, prev) => value * prev);
const x2ds = unstack(reshape$2(x, [
outerDimsProd, x.shape[x.shape.length - 2],
x.shape[x.shape.length - 1]
]), 0);
const q2ds = [];
const r2ds = [];
x2ds.forEach(x2d => {
const [q2d, r2d] = qr2d(x2d, fullMatrices);
q2ds.push(q2d);
r2ds.push(r2d);
});
const q = reshape$2(stack(q2ds, 0), x.shape);
const r = reshape$2(stack(r2ds, 0), x.shape);
return [q, r];
}
}
function qr2d(x, fullMatrices = false) {
return ENGINE.tidy(() => {
assert$1(x.shape.length === 2, () => `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
const m = x.shape[0];
const n = x.shape[1];
let q = eye(m);
let r = clone(x);
const one2D = tensor2d([[1]], [1, 1]);
let w = clone(one2D);
const iters = m >= n ? n : m;
for (let j = 0; j < iters; ++j) {
const rTemp = r;
const wTemp = w;
const qTemp = q;
[w, r, q] = ENGINE.tidy(() => {
const rjEnd1 = slice$2(r, [j, j], [m - j, 1]);
const normX = norm(rjEnd1);
const rjj = slice$2(r, [j, j], [1, 1]);
const s = where(greater$2(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
const u1 = sub$2(rjj, mul(s, normX));
const wPre = div$1(rjEnd1, u1);
if (wPre.shape[0] === 1) {
w = clone(one2D);
}
else {
w = concat$2([
one2D,
slice$2(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
], 0);
}
const tau = neg$2(div$1(matMul$1(s, u1), normX));
const rjEndAll = slice$2(r, [j, 0], [m - j, n]);
const tauTimesW = mul(tau, w);
const wT = transpose$2(w);
if (j === 0) {
r = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
}
else {
const rTimesTau = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
r = concat$2([slice$2(r, [0, 0], [j, n]), rTimesTau], 0);
}
const tawTimesWT = transpose$2(tauTimesW);
const qAllJEnd = slice$2(q, [0, j], [m, q.shape[1] - j]);
if (j === 0) {
q = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
}
else {
const qTimesTau = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
q = concat$2([slice$2(q, [0, 0], [m, j]), qTimesTau], 1);
}
return [w, r, q];
});
dispose([rTemp, wTemp, qTemp]);
}
if (!fullMatrices && m > n) {
q = slice$2(q, [0, 0], [m, n]);
r = slice$2(r, [0, 0], [n, n]);
}
return [q, r];
});
}
const qr = op({ qr_ });
function stringToHashBucketFast_(input, numBuckets) {
const $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
const attrs = { numBuckets };
if (numBuckets <= 0) {
throw new Error(`Number of buckets must be at least 1`);
}
const inputs = { input: $input };
return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
}
const stringToHashBucketFast$2 = op({ stringToHashBucketFast_ });
const linalg = {
bandPart,
gramSchmidt,
qr
};
const GLOBAL_CUSTOM_OBJECT = new Map();
const GLOBAL_CUSTOM_NAMES = new Map();
class Serializable {
getClassName() {
return this.constructor
.className;
}
static fromConfig(cls, config) {
return new cls(config);
}
}
class SerializationMap {
constructor() {
this.classNameMap = {};
}
static getMap() {
if (SerializationMap.instance == null) {
SerializationMap.instance = new SerializationMap();
}
return SerializationMap.instance;
}
static register(cls) {
SerializationMap.getMap().classNameMap[cls.className] =
[cls, cls.fromConfig];
}
}
function registerClass(cls, pkg, name) {
assert$1(cls.className != null, () => `Class being registered does not have the static className ` +
`property defined.`);
assert$1(typeof cls.className === 'string', () => `className is required to be a string, but got type ` +
typeof cls.className);
assert$1(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` +
`which is disallowed.`);
if (typeof pkg === 'undefined') {
pkg = 'Custom';
}
if (typeof name === 'undefined') {
name = cls.className;
}
const className = name;
const registerName = pkg + '>' + className;
SerializationMap.register(cls);
GLOBAL_CUSTOM_OBJECT.set(registerName, cls);
GLOBAL_CUSTOM_NAMES.set(cls, registerName);
return cls;
}
class Optimizer extends Serializable {
minimize(f, returnCost = false, varList) {
const { value, grads } = this.computeGradients(f, varList);
if (varList != null) {
const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] }));
this.applyGradients(gradArray);
}
else {
this.applyGradients(grads);
}
dispose(grads);
if (returnCost) {
return value;
}
else {
value.dispose();
return null;
}
}
get iterations() {
if (this.iterations_ == null) {
this.iterations_ = 0;
}
return this.iterations_;
}
incrementIterations() {
this.iterations_ = this.iterations + 1;
}
computeGradients(f, varList) {
return variableGrads(f, varList);
}
dispose() {
if (this.iterations_ != null) {
dispose(this.iterations_);
}
}
async saveIterations() {
if (this.iterations_ == null) {
this.iterations_ = 0;
}
return {
name: 'iter',
tensor: scalar(this.iterations_, 'int32')
};
}
async getWeights() {
throw new Error('getWeights() is not implemented for this optimizer yet.');
}
async setWeights(weightValues) {
throw new Error(`setWeights() is not implemented for this optimizer class ` +
`${this.getClassName()}`);
}
async extractIterations(weightValues) {
this.iterations_ = (await weightValues[0].tensor.data())[0];
return weightValues.slice(1);
}
}
Object.defineProperty(Optimizer, Symbol.hasInstance, {
value: (instance) => {
return instance.minimize != null && instance.computeGradients != null &&
instance.applyGradients != null;
}
});
class AdadeltaOptimizer extends Optimizer {
static get className() {
return 'Adadelta';
}
constructor(learningRate, rho, epsilon = null) {
super();
this.learningRate = learningRate;
this.rho = rho;
this.epsilon = epsilon;
this.accumulatedGrads = [];
this.accumulatedUpdates = [];
if (epsilon == null) {
this.epsilon = ENGINE.backend.epsilon();
}
}
applyGradients(variableGradients) {
const variableNames = Array.isArray(variableGradients) ?
variableGradients.map(item => item.name) :
Object.keys(variableGradients);
variableNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
const trainable = false;
if (this.accumulatedGrads[i] == null) {
this.accumulatedGrads[i] = {
originalName: `${name}/accum_grad`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
if (this.accumulatedUpdates[i] == null) {
this.accumulatedUpdates[i] = {
originalName: `${name}/accum_var`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const accumulatedGrad = this.accumulatedGrads[i].variable;
const accumulatedUpdate = this.accumulatedUpdates[i].variable;
tidy(() => {
const newAccumulatedGrad = add$1(mul(accumulatedGrad, this.rho), mul(square$2(gradient), 1 - this.rho));
const updates = mul(div$1(sqrt$2(add$1(accumulatedUpdate, this.epsilon)), sqrt$2(add$1(accumulatedGrad, this.epsilon))), gradient);
const newAccumulatedUpdate = add$1(mul(accumulatedUpdate, this.rho), mul(square$2(updates), 1 - this.rho));
accumulatedGrad.assign(newAccumulatedGrad);
accumulatedUpdate.assign(newAccumulatedUpdate);
const newValue = add$1(mul(updates, -this.learningRate), value);
value.assign(newValue);
});
});
this.incrementIterations();
}
dispose() {
if (this.accumulatedUpdates != null) {
dispose(this.accumulatedGrads.map(v => v.variable));
dispose(this.accumulatedUpdates.map(v => v.variable));
}
}
async getWeights() {
const variables = [...this.accumulatedGrads, ...this.accumulatedUpdates];
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
}
async setWeights(weightValues) {
weightValues = await this.extractIterations(weightValues);
const variableCount = weightValues.length / 2;
const trainable = false;
this.accumulatedGrads =
weightValues.slice(0, variableCount).map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
this.accumulatedUpdates =
weightValues.slice(variableCount, variableCount * 2)
.map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
}
getConfig() {
return {
'learningRate': this.learningRate,
'rho': this.rho,
'epsilon': this.epsilon
};
}
static fromConfig(cls, config) {
return new cls(config['learningRate'], config['rho'], config['epsilon']);
}
}
class AdagradOptimizer extends Optimizer {
static get className() {
return 'Adagrad';
}
constructor(learningRate, initialAccumulatorValue = 0.1) {
super();
this.learningRate = learningRate;
this.initialAccumulatorValue = initialAccumulatorValue;
this.accumulatedGrads = [];
}
applyGradients(variableGradients) {
const variableNames = Array.isArray(variableGradients) ?
variableGradients.map(item => item.name) :
Object.keys(variableGradients);
variableNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
if (this.accumulatedGrads[i] == null) {
const trainable = false;
this.accumulatedGrads[i] = {
originalName: `${name}/accumulator`,
variable: tidy(() => fill$2(value.shape, this.initialAccumulatorValue)
.variable(trainable))
};
}
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const accumulatedGrad = this.accumulatedGrads[i].variable;
tidy(() => {
const newAccumulatedGrad = add$1(accumulatedGrad, square$2(gradient));
accumulatedGrad.assign(newAccumulatedGrad);
const newValue = add$1(mul(div$1(gradient, sqrt$2(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -this.learningRate), value);
value.assign(newValue);
});
});
this.incrementIterations();
}
dispose() {
if (this.accumulatedGrads != null) {
dispose(this.accumulatedGrads.map(v => v.variable));
}
}
async getWeights() {
return [await this.saveIterations()].concat(this.accumulatedGrads.map(v => ({ name: v.originalName, tensor: v.variable })));
}
async setWeights(weightValues) {
weightValues = await this.extractIterations(weightValues);
const trainable = false;
this.accumulatedGrads = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
}
getConfig() {
return {
'learningRate': this.learningRate,
'initialAccumulatorValue': this.initialAccumulatorValue,
};
}
static fromConfig(cls, config) {
return new cls(config['learningRate'], config['initialAccumulatorValue']);
}
}
class AdamOptimizer extends Optimizer {
static get className() {
return 'Adam';
}
constructor(learningRate, beta1, beta2, epsilon = null) {
super();
this.learningRate = learningRate;
this.beta1 = beta1;
this.beta2 = beta2;
this.epsilon = epsilon;
this.accumulatedFirstMoment = [];
this.accumulatedSecondMoment = [];
tidy(() => {
this.accBeta1 = scalar(beta1).variable();
this.accBeta2 = scalar(beta2).variable();
});
if (epsilon == null) {
this.epsilon = ENGINE.backend.epsilon();
}
}
applyGradients(variableGradients) {
const varNames = Array.isArray(variableGradients) ?
variableGradients.map(v => v.name) :
Object.keys(variableGradients);
tidy(() => {
const oneMinusAccBeta1 = sub$2(1, this.accBeta1);
const oneMinusAccBeta2 = sub$2(1, this.accBeta2);
varNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
const trainable = false;
if (this.accumulatedFirstMoment[i] == null) {
this.accumulatedFirstMoment[i] = {
originalName: `${name}/m`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
if (this.accumulatedSecondMoment[i] == null) {
this.accumulatedSecondMoment[i] = {
originalName: `${name}/v`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const firstMoment = this.accumulatedFirstMoment[i].variable;
const secondMoment = this.accumulatedSecondMoment[i].variable;
const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
const newSecondMoment = add$1(mul(secondMoment, this.beta2), mul(square$2(gradient), 1 - this.beta2));
const biasCorrectedFirstMoment = div$1(newFirstMoment, oneMinusAccBeta1);
const biasCorrectedSecondMoment = div$1(newSecondMoment, oneMinusAccBeta2);
firstMoment.assign(newFirstMoment);
secondMoment.assign(newSecondMoment);
const newValue = add$1(mul(div$1(biasCorrectedFirstMoment, add$1(sqrt$2(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value);
value.assign(newValue);
});
this.accBeta1.assign(mul(this.accBeta1, this.beta1));
this.accBeta2.assign(mul(this.accBeta2, this.beta2));
});
this.incrementIterations();
}
dispose() {
this.accBeta1.dispose();
this.accBeta2.dispose();
if (this.accumulatedFirstMoment != null) {
dispose(this.accumulatedFirstMoment.map(v => v.variable));
}
if (this.accumulatedSecondMoment != null) {
dispose(this.accumulatedSecondMoment.map(v => v.variable));
}
}
async getWeights() {
const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
}
async setWeights(weightValues) {
weightValues = await this.extractIterations(weightValues);
tidy(() => {
this.accBeta1.assign(pow$2(this.beta1, this.iterations_ + 1));
this.accBeta2.assign(pow$2(this.beta2, this.iterations_ + 1));
});
const variableCount = weightValues.length / 2;
const trainable = false;
this.accumulatedFirstMoment =
weightValues.slice(0, variableCount).map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
this.accumulatedSecondMoment =
weightValues.slice(variableCount, variableCount * 2)
.map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
}
getConfig() {
return {
'learningRate': this.learningRate,
'beta1': this.beta1,
'beta2': this.beta2,
'epsilon': this.epsilon,
};
}
static fromConfig(cls, config) {
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
}
}
class AdamaxOptimizer extends Optimizer {
static get className() {
return 'Adamax';
}
constructor(learningRate, beta1, beta2, epsilon = null, decay = 0.0) {
super();
this.learningRate = learningRate;
this.beta1 = beta1;
this.beta2 = beta2;
this.epsilon = epsilon;
this.decay = decay;
this.accumulatedFirstMoment = [];
this.accumulatedWeightedInfNorm = [];
tidy(() => {
this.iteration = scalar(0).variable();
this.accBeta1 = scalar(beta1).variable();
});
if (epsilon == null) {
this.epsilon = ENGINE.backend.epsilon();
}
}
applyGradients(variableGradients) {
const variableNames = Array.isArray(variableGradients) ?
variableGradients.map(item => item.name) :
Object.keys(variableGradients);
tidy(() => {
const oneMinusAccBeta1 = sub$2(1, this.accBeta1);
const lr = div$1(-this.learningRate, add$1(mul(this.iteration, this.decay), 1));
variableNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
const trainable = false;
if (this.accumulatedFirstMoment[i] == null) {
this.accumulatedFirstMoment[i] = {
originalName: `${name}/m`,
variable: zerosLike$2(value).variable(trainable)
};
}
if (this.accumulatedWeightedInfNorm[i] == null) {
this.accumulatedWeightedInfNorm[i] = {
originalName: `${name}/v`,
variable: zerosLike$2(value).variable(trainable)
};
}
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const firstMoment = this.accumulatedFirstMoment[i].variable;
const weightedInfNorm = this.accumulatedWeightedInfNorm[i].variable;
const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
const ut0 = mul(weightedInfNorm, this.beta2);
const ut1 = abs$2(gradient);
const newWeightedInfNorm = maximum$2(ut0, ut1);
firstMoment.assign(newFirstMoment);
weightedInfNorm.assign(newWeightedInfNorm);
const newValue = add$1(mul(div$1(lr, oneMinusAccBeta1), div$1(newFirstMoment, add$1(newWeightedInfNorm, this.epsilon))), value);
value.assign(newValue);
});
this.iteration.assign(add$1(this.iteration, 1));
this.accBeta1.assign(mul(this.accBeta1, this.beta1));
});
this.incrementIterations();
}
dispose() {
this.accBeta1.dispose();
this.iteration.dispose();
if (this.accumulatedFirstMoment != null) {
dispose(this.accumulatedFirstMoment.map(v => v.variable));
}
if (this.accumulatedWeightedInfNorm != null) {
dispose(this.accumulatedWeightedInfNorm.map(v => v.variable));
}
}
async getWeights() {
throw new Error('getWeights() is not implemented for Adamax yet.');
}
async setWeights(weightValues) {
throw new Error('setWeights() is not implemented for Adamax yet.');
}
getConfig() {
return {
'learningRate': this.learningRate,
'beta1': this.beta1,
'beta2': this.beta2,
'epsilon': this.epsilon,
'decay': this.decay
};
}
static fromConfig(cls, config) {
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
}
}
class SGDOptimizer extends Optimizer {
static get className() {
return 'SGD';
}
constructor(learningRate) {
super();
this.learningRate = learningRate;
this.setLearningRate(learningRate);
}
applyGradients(variableGradients) {
const varNames = Array.isArray(variableGradients) ?
variableGradients.map(v => v.name) :
Object.keys(variableGradients);
varNames.forEach((name, i) => {
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const value = ENGINE.registeredVariables[name];
tidy(() => {
const newValue = add$1(mul(this.c, gradient), value);
value.assign(newValue);
});
});
this.incrementIterations();
}
setLearningRate(learningRate) {
this.learningRate = learningRate;
if (this.c != null) {
this.c.dispose();
}
this.c = keep(scalar(-learningRate));
}
dispose() {
this.c.dispose();
}
async getWeights() {
return [await this.saveIterations()];
}
async setWeights(weightValues) {
weightValues = await this.extractIterations(weightValues);
if (weightValues.length !== 0) {
throw new Error('SGD optimizer does not have settable weights.');
}
}
getConfig() {
return { 'learningRate': this.learningRate };
}
static fromConfig(cls, config) {
return new cls(config['learningRate']);
}
}
class MomentumOptimizer extends SGDOptimizer {
static get className() {
return 'Momentum';
}
constructor(learningRate, momentum, useNesterov = false) {
super(learningRate);
this.learningRate = learningRate;
this.momentum = momentum;
this.useNesterov = useNesterov;
this.accumulations = [];
this.m = scalar(this.momentum);
}
applyGradients(variableGradients) {
const variableNames = Array.isArray(variableGradients) ?
variableGradients.map(item => item.name) :
Object.keys(variableGradients);
variableNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
if (this.accumulations[i] == null) {
const trainable = false;
this.accumulations[i] = {
originalName: `${name}/momentum`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
const accumulation = this.accumulations[i].variable;
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
tidy(() => {
let newValue;
const newAccumulation = add$1(mul(this.m, accumulation), gradient);
if (this.useNesterov) {
newValue = add$1(mul(this.c, add$1(gradient, mul(newAccumulation, this.m))), value);
}
else {
newValue = add$1(mul(this.c, newAccumulation), value);
}
accumulation.assign(newAccumulation);
value.assign(newValue);
});
});
this.incrementIterations();
}
dispose() {
this.m.dispose();
if (this.accumulations != null) {
dispose(this.accumulations.map(v => v.variable));
}
}
setMomentum(momentum) {
this.momentum = momentum;
}
async getWeights() {
return [await this.saveIterations()].concat(this.accumulations.map(v => ({ name: v.originalName, tensor: v.variable })));
}
async setWeights(weightValues) {
weightValues = await this.extractIterations(weightValues);
const trainable = false;
this.accumulations = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
}
getConfig() {
return {
'learningRate': this.learningRate,
'momentum': this.momentum,
'useNesterov': this.useNesterov
};
}
static fromConfig(cls, config) {
return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
}
}
class RMSPropOptimizer extends Optimizer {
static get className() {
return 'RMSProp';
}
constructor(learningRate, decay = 0.9, momentum = 0.0, epsilon = null, centered = false) {
super();
this.learningRate = learningRate;
this.decay = decay;
this.momentum = momentum;
this.epsilon = epsilon;
this.accumulatedMeanSquares = [];
this.accumulatedMoments = [];
this.accumulatedMeanGrads = [];
this.centered = centered;
if (epsilon == null) {
this.epsilon = ENGINE.backend.epsilon();
}
if (learningRate == null) {
throw new Error(`learningRate for RMSPropOptimizer must be defined.`);
}
}
applyGradients(variableGradients) {
const variableNames = Array.isArray(variableGradients) ?
variableGradients.map(item => item.name) :
Object.keys(variableGradients);
variableNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
const trainable = false;
if (this.accumulatedMeanSquares[i] == null) {
this.accumulatedMeanSquares[i] = {
originalName: `${name}/rms`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
if (this.accumulatedMoments[i] == null) {
this.accumulatedMoments[i] = {
originalName: `${name}/momentum`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
if (this.accumulatedMeanGrads[i] == null && this.centered) {
this.accumulatedMeanGrads[i] = {
originalName: `${name}/mg`,
variable: tidy(() => zerosLike$2(value).variable(trainable))
};
}
const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}
const accumulatedMeanSquare = this.accumulatedMeanSquares[i].variable;
const accumulatedMoments = this.accumulatedMoments[i].variable;
tidy(() => {
const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay));
if (this.centered) {
const accumulatedMeanGrad = this.accumulatedMeanGrads[i].variable;
const newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, this.decay), mul(gradient, 1 - this.decay));
const gradContribution = div$1(mul(gradient, this.learningRate), sqrt$2(sub$2(newAccumulatedMeanSquare, add$1(square$2(newAccumulatedMeanGrad), this.epsilon))));
const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), gradContribution);
accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
accumulatedMoments.assign(newAccumulatedMoments);
const newValue = sub$2(value, newAccumulatedMoments);
value.assign(newValue);
}
else {
const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay));
const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), div$1(mul(gradient, this.learningRate), sqrt$2(add$1(newAccumulatedMeanSquare, this.epsilon))));
accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
accumulatedMoments.assign(newAccumulatedMoments);
const newValue = sub$2(value, newAccumulatedMoments);
value.assign(newValue);
}
});
});
this.incrementIterations();
}
dispose() {
if (this.accumulatedMeanSquares != null) {
dispose(this.accumulatedMeanSquares.map(v => v.variable));
}
if (this.accumulatedMeanGrads != null && this.centered) {
dispose(this.accumulatedMeanGrads.map(v => v.variable));
}
if (this.accumulatedMoments != null) {
dispose(this.accumulatedMoments.map(v => v.variable));
}
}
async getWeights() {
const variables = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
if (this.centered) {
variables.push(...this.accumulatedMeanGrads);
}
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
}
async setWeights(weightValues) {
weightValues = await this.extractIterations(weightValues);
const variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
const trainable = false;
this.accumulatedMeanSquares =
weightValues.slice(0, variableCount).map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
this.accumulatedMoments =
weightValues.slice(variableCount, variableCount * 2)
.map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
if (this.centered) {
this.accumulatedMeanGrads =
weightValues.slice(variableCount * 2, variableCount * 3)
.map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
}
}
getConfig() {
return {
'learningRate': this.learningRate,
'decay': this.decay,
'momentum': this.momentum,
'epsilon': this.epsilon,
'centered': this.centered
};
}
static fromConfig(cls, config) {
return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
}
}
const OPTIMIZERS = [
AdadeltaOptimizer,
AdagradOptimizer,
AdamOptimizer,
AdamaxOptimizer,
MomentumOptimizer,
RMSPropOptimizer,
SGDOptimizer,
];
function registerOptimizers() {
for (const optimizer of OPTIMIZERS) {
registerClass(optimizer);
}
}
const DTYPE_VALUE_SIZE_MAP = {
'float32': 4,
'float16': 2,
'int32': 4,
'uint16': 2,
'uint8': 1,
'bool': 1,
'complex64': 8
};
class CompositeArrayBuffer {
static join(buffers) {
return new CompositeArrayBuffer(buffers).slice();
}
constructor(buffers) {
this.shards = [];
this.previousShardIndex = 0;
if (buffers == null) {
return;
}
if (!(buffers instanceof Array)) {
buffers = [buffers];
}
buffers = buffers.map((bufferOrTypedArray) => {
if (isTypedArray(bufferOrTypedArray)) {
return bufferOrTypedArray.buffer;
}
return bufferOrTypedArray;
});
if (buffers.length === 0) {
return;
}
this.bufferUniformSize = buffers[0].byteLength;
let start = 0;
for (let i = 0; i < buffers.length; i++) {
const buffer = buffers[i];
if (i !== buffers.length - 1 &&
buffer.byteLength !== this.bufferUniformSize) {
this.bufferUniformSize = undefined;
}
const end = start + buffer.byteLength;
this.shards.push({ buffer, start, end });
start = end;
}
if (this.shards.length === 0) {
this.byteLength = 0;
}
this.byteLength = this.shards[this.shards.length - 1].end;
}
slice(start = 0, end = this.byteLength) {
if (this.shards.length === 0) {
return new ArrayBuffer(0);
}
start = isNaN(Number(start)) ? 0 : start;
end = isNaN(Number(end)) ? 0 : end;
start = Math.max(0, start);
end = Math.min(this.byteLength, end);
if (end <= start) {
return new ArrayBuffer(0);
}
const startShardIndex = this.findShardForByte(start);
if (startShardIndex === -1) {
throw new Error(`Could not find start shard for byte ${start}`);
}
const size = end - start;
const outputBuffer = new ArrayBuffer(size);
const outputArray = new Uint8Array(outputBuffer);
let sliced = 0;
for (let i = startShardIndex; i < this.shards.length; i++) {
const shard = this.shards[i];
const globalStart = start + sliced;
const localStart = globalStart - shard.start;
const outputStart = sliced;
const globalEnd = Math.min(end, shard.end);
const localEnd = globalEnd - shard.start;
const outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart);
outputArray.set(outputSlice, outputStart);
sliced += outputSlice.length;
if (end < shard.end) {
break;
}
}
return outputBuffer;
}
findShardForByte(byteIndex) {
if (this.shards.length === 0 || byteIndex < 0 ||
byteIndex >= this.byteLength) {
return -1;
}
if (this.bufferUniformSize != null) {
this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);
return this.previousShardIndex;
}
function check(shard) {
if (byteIndex < shard.start) {
return -1;
}
if (byteIndex >= shard.end) {
return 1;
}
return 0;
}
if (check(this.shards[this.previousShardIndex]) === 0) {
return this.previousShardIndex;
}
const index = search(this.shards, check);
if (index === -1) {
return -1;
}
this.previousShardIndex = index;
return this.previousShardIndex;
}
}
function search(sortedArray, compare) {
let min = 0;
let max = sortedArray.length;
while (min <= max) {
const middle = Math.floor((max - min) / 2) + min;
const side = compare(sortedArray[middle]);
if (side === 0) {
return middle;
}
else if (side < 0) {
max = middle;
}
else {
min = middle + 1;
}
}
return -1;
}
const NUM_BYTES_STRING_LENGTH = 4;
async function encodeWeights(tensors, group) {
const specs = [];
const dataPromises = [];
const names = Array.isArray(tensors) ?
tensors.map(tensor => tensor.name) :
Object.keys(tensors);
for (let i = 0; i < names.length; ++i) {
const name = names[i];
const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
t.dtype !== 'string' && t.dtype !== 'complex64') {
throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
}
const spec = { name, shape: t.shape, dtype: t.dtype };
if (t.dtype === 'string') {
const utf8bytes = new Promise(async (resolve) => {
const vals = await t.bytes();
const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +
NUM_BYTES_STRING_LENGTH * vals.length;
const bytes = new Uint8Array(totalNumBytes);
let offset = 0;
for (let i = 0; i < vals.length; i++) {
const val = vals[i];
const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
bytes.set(bytesOfLength, offset);
offset += NUM_BYTES_STRING_LENGTH;
bytes.set(val, offset);
offset += val.length;
}
resolve(bytes);
});
dataPromises.push(utf8bytes);
}
else {
dataPromises.push(t.data());
}
if (group != null) {
spec.group = group;
}
specs.push(spec);
}
const tensorValues = await Promise.all(dataPromises);
return { data: concatenateTypedArrays(tensorValues), specs };
}
function decodeWeights(weightData, specs) {
const compositeBuffer = new CompositeArrayBuffer(weightData);
const out = {};
let offset = 0;
for (const spec of specs) {
const byteLength = getWeightBytelength(spec, (start, end) => {
return compositeBuffer.slice(offset + start, offset + end);
});
out[spec.name] = decodeWeight(spec, compositeBuffer
.slice(offset, offset + byteLength));
offset += byteLength;
}
return out;
}
function getWeightBytelength(spec, slice) {
const size = sizeFromShape(spec.shape);
let bytesPerValue;
if ('quantization' in spec) {
const quantization = spec.quantization;
bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
}
else if (spec.dtype === 'string') {
let byteLength = 0;
for (let i = 0; i < size; i++) {
byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
}
return byteLength;
}
else {
bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
}
return size * bytesPerValue;
}
function decodeWeight(spec, byteBuffer) {
const name = spec.name;
const dtype = spec.dtype;
const shape = spec.shape;
const size = sizeFromShape(shape);
let values;
let offset = 0;
if ('quantization' in spec) {
const quantization = spec.quantization;
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
if (!('min' in quantization && 'scale' in quantization)) {
throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` +
`doesn't have corresponding metadata min and scale.`);
}
}
else if (quantization.dtype === 'float16') {
if (dtype !== 'float32') {
throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
`which only supports weights of type float32 not ${dtype}.`);
}
}
else {
throw new Error(`Weight ${spec.name} has unknown ` +
`quantization dtype ${quantization.dtype}. ` +
`Supported quantization dtypes are: ` +
`'uint8', 'uint16', and 'float16'.`);
}
const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
const quantizedArray = (quantization.dtype === 'uint8') ?
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
values = new Float32Array(quantizedArray.length);
for (let i = 0; i < quantizedArray.length; i++) {
const v = quantizedArray[i];
values[i] = v * quantization.scale + quantization.min;
}
}
else if (quantization.dtype === 'float16') {
const float16Decode = getFloat16Decoder();
values = float16Decode(quantizedArray);
}
else {
throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
`for weight type float32.`);
}
}
else if (dtype === 'int32') {
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
`for weight type int32.`);
}
values = new Int32Array(quantizedArray.length);
for (let i = 0; i < quantizedArray.length; i++) {
const v = quantizedArray[i];
values[i] = Math.round(v * quantization.scale + quantization.min);
}
}
else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
offset += size * quantizationSizeFactor;
}
else if (dtype === 'string') {
const size = sizeFromShape(spec.shape);
values = [];
for (let i = 0; i < size; i++) {
const byteLength = new Uint32Array(byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
offset += NUM_BYTES_STRING_LENGTH;
const bytes = new Uint8Array(byteBuffer.slice(offset, offset + byteLength));
values.push(bytes);
offset += byteLength;
}
}
else {
const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
if (dtype === 'float32') {
values = new Float32Array(byteBuffer);
}
else if (dtype === 'int32') {
values = new Int32Array(byteBuffer);
}
else if (dtype === 'bool') {
values = new Uint8Array(byteBuffer);
}
else if (dtype === 'complex64') {
values = new Float32Array(byteBuffer);
const real = new Float32Array(values.length / 2);
const image = new Float32Array(values.length / 2);
for (let i = 0; i < real.length; i++) {
real[i] = values[i * 2];
image[i] = values[i * 2 + 1];
}
const realTensor = tensor(real, shape, 'float32');
const imageTensor = tensor(image, shape, 'float32');
const complexTensor = complex$2(realTensor, imageTensor);
realTensor.dispose();
imageTensor.dispose();
return complexTensor;
}
else {
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
}
offset += size * dtypeFactor;
}
return tensor(values, shape, dtype);
}
function concatenateTypedArrays(xs) {
if (xs === null) {
throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);
}
let totalByteLength = 0;
const normalizedXs = [];
xs.forEach((x) => {
totalByteLength += x.byteLength;
normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
new x.constructor(x));
if (!(x instanceof Float32Array || x instanceof Int32Array ||
x instanceof Uint8Array)) {
throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
}
});
const y = new Uint8Array(totalByteLength);
let offset = 0;
normalizedXs.forEach((x) => {
y.set(new Uint8Array(x.buffer), offset);
offset += x.byteLength;
});
return y.buffer;
}
const useNodeBuffer = typeof Buffer !== 'undefined' &&
(typeof Blob === 'undefined' || typeof atob === 'undefined' ||
typeof btoa === 'undefined');
function stringByteLength(str) {
if (useNodeBuffer) {
return Buffer.byteLength(str, 'utf8');
}
return new Blob([str]).size;
}
function arrayBufferToBase64String(buffer) {
if (useNodeBuffer) {
return Buffer.from(buffer).toString('base64');
}
const buf = new Uint8Array(buffer);
let s = '';
for (let i = 0, l = buf.length; i < l; i++) {
s += String.fromCharCode(buf[i]);
}
return btoa(s);
}
function base64StringToArrayBuffer(str) {
if (useNodeBuffer) {
const buf = Buffer.from(str, 'base64');
return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
}
const s = atob(str);
const buffer = new Uint8Array(s.length);
for (let i = 0; i < s.length; ++i) {
buffer.set([s.charCodeAt(i)], i);
}
return buffer.buffer;
}
function concatenateArrayBuffers(buffers) {
return CompositeArrayBuffer.join(buffers);
}
function getModelJSONForModelArtifacts(artifacts, manifest) {
const result = {
modelTopology: artifacts.modelTopology,
format: artifacts.format,
generatedBy: artifacts.generatedBy,
convertedBy: artifacts.convertedBy,
weightsManifest: manifest
};
if (artifacts.signature != null) {
result.signature = artifacts.signature;
}
if (artifacts.userDefinedMetadata != null) {
result.userDefinedMetadata = artifacts.userDefinedMetadata;
}
if (artifacts.modelInitializer != null) {
result.modelInitializer = artifacts.modelInitializer;
}
if (artifacts.initializerSignature != null) {
result.initializerSignature = artifacts.initializerSignature;
}
if (artifacts.trainingConfig != null) {
result.trainingConfig = artifacts.trainingConfig;
}
return result;
}
function getModelArtifactsInfoForJSON(modelArtifacts) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error('Expected JSON model topology, received ArrayBuffer.');
}
return {
dateSaved: new Date(),
modelTopologyType: 'JSON',
modelTopologyBytes: modelArtifacts.modelTopology == null ?
0 :
stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
weightSpecsBytes: modelArtifacts.weightSpecs == null ?
0 :
stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
weightDataBytes: modelArtifacts.weightData == null ?
0 :
new CompositeArrayBuffer(modelArtifacts.weightData).byteLength,
};
}
function computeFloat16MantisaTable() {
const convertMantissa = (i) => {
let m = i << 13;
let e = 0;
while ((m & 0x00800000) === 0) {
e -= 0x00800000;
m <<= 1;
}
m &= -8388609;
e += 0x38800000;
return m | e;
};
const mantisaTable = new Uint32Array(2048);
mantisaTable[0] = 0;
for (let i = 1; i < 1024; i++) {
mantisaTable[i] = convertMantissa(i);
}
for (let i = 1024; i < 2048; i++) {
mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
}
return mantisaTable;
}
function computeFloat16ExponentTable() {
const exponentTable = new Uint32Array(64);
exponentTable[0] = 0;
exponentTable[31] = 0x47800000;
exponentTable[32] = 0x80000000;
exponentTable[63] = 0xc7800000;
for (let i = 1; i < 31; i++) {
exponentTable[i] = i << 23;
}
for (let i = 33; i < 63; i++) {
exponentTable[i] = 0x80000000 + ((i - 32) << 23);
}
return exponentTable;
}
function computeFloat16OffsetTable() {
const offsetTable = new Uint32Array(64);
for (let i = 0; i < 64; i++) {
offsetTable[i] = 1024;
}
offsetTable[0] = offsetTable[32] = 0;
return offsetTable;
}
function getFloat16Decoder() {
const mantisaTable = computeFloat16MantisaTable();
const exponentTable = computeFloat16ExponentTable();
const offsetTable = computeFloat16OffsetTable();
return (quantizedArray) => {
const buffer = new ArrayBuffer(4 * quantizedArray.length);
const bufferUint32View = new Uint32Array(buffer);
for (let index = 0; index < quantizedArray.length; index++) {
const float16Bits = quantizedArray[index];
const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
exponentTable[float16Bits >> 10];
bufferUint32View[index] = float32Bits;
}
return new Float32Array(buffer);
};
}
class IORouterRegistry {
constructor() {
this.saveRouters = [];
this.loadRouters = [];
}
static getInstance() {
if (IORouterRegistry.instance == null) {
IORouterRegistry.instance = new IORouterRegistry();
}
return IORouterRegistry.instance;
}
static registerSaveRouter(saveRouter) {
IORouterRegistry.getInstance().saveRouters.push(saveRouter);
}
static registerLoadRouter(loadRouter) {
IORouterRegistry.getInstance().loadRouters.push(loadRouter);
}
static getSaveHandlers(url) {
return IORouterRegistry.getHandlers(url, 'save');
}
static getLoadHandlers(url, loadOptions) {
return IORouterRegistry.getHandlers(url, 'load', loadOptions);
}
static getHandlers(url, handlerType, loadOptions) {
const validHandlers = [];
const routers = handlerType === 'load' ?
IORouterRegistry.getInstance().loadRouters :
IORouterRegistry.getInstance().saveRouters;
routers.forEach(router => {
const handler = router(url, loadOptions);
if (handler !== null) {
validHandlers.push(handler);
}
});
return validHandlers;
}
}
const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url);
const DATABASE_NAME = 'tensorflowjs';
const DATABASE_VERSION = 1;
const MODEL_STORE_NAME = 'models_store';
const INFO_STORE_NAME = 'model_info_store';
function getIndexedDBFactory() {
if (!env().getBool('IS_BROWSER')) {
throw new Error('Failed to obtain IndexedDB factory because the current environment' +
'is not a web browser.');
}
const theWindow = typeof window === 'undefined' ? self : window;
const factory = theWindow.indexedDB || theWindow.mozIndexedDB ||
theWindow.webkitIndexedDB || theWindow.msIndexedDB ||
theWindow.shimIndexedDB;
if (factory == null) {
throw new Error('The current browser does not appear to support IndexedDB.');
}
return factory;
}
function setUpDatabase(openRequest) {
const db = openRequest.result;
db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' });
db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' });
}
class BrowserIndexedDB {
constructor(modelPath) {
this.indexedDB = getIndexedDBFactory();
if (modelPath == null || !modelPath) {
throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
}
this.modelPath = modelPath;
}
async save(modelArtifacts) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
'in binary formats yet.');
}
return this.databaseAction(this.modelPath, modelArtifacts);
}
async load() {
return this.databaseAction(this.modelPath);
}
databaseAction(modelPath, modelArtifacts) {
return new Promise((resolve, reject) => {
const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
openRequest.onsuccess = () => {
const db = openRequest.result;
if (modelArtifacts == null) {
const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
const getRequest = modelStore.get(this.modelPath);
getRequest.onsuccess = () => {
if (getRequest.result == null) {
db.close();
return reject(new Error(`Cannot find model with path '${this.modelPath}' ` +
`in IndexedDB.`));
}
else {
resolve(getRequest.result.modelArtifacts);
}
};
getRequest.onerror = error => {
db.close();
return reject(getRequest.error);
};
modelTx.oncomplete = () => db.close();
}
else {
modelArtifacts.weightData = CompositeArrayBuffer.join(modelArtifacts.weightData);
const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
let infoStore = infoTx.objectStore(INFO_STORE_NAME);
let putInfoRequest;
try {
putInfoRequest =
infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo });
}
catch (error) {
return reject(error);
}
let modelTx;
putInfoRequest.onsuccess = () => {
modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
let putModelRequest;
try {
putModelRequest = modelStore.put({
modelPath: this.modelPath,
modelArtifacts,
modelArtifactsInfo
});
}
catch (error) {
return reject(error);
}
putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo });
putModelRequest.onerror = error => {
infoStore = infoTx.objectStore(INFO_STORE_NAME);
const deleteInfoRequest = infoStore.delete(this.modelPath);
deleteInfoRequest.onsuccess = () => {
db.close();
return reject(putModelRequest.error);
};
deleteInfoRequest.onerror = error => {
db.close();
return reject(putModelRequest.error);
};
};
};
putInfoRequest.onerror = error => {
db.close();
return reject(putInfoRequest.error);
};
infoTx.oncomplete = () => {
if (modelTx == null) {
db.close();
}
else {
modelTx.oncomplete = () => db.close();
}
};
}
};
openRequest.onerror = error => reject(openRequest.error);
});
}
}
BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
const indexedDBRouter = (url) => {
if (!env().getBool('IS_BROWSER')) {
return null;
}
else {
if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
}
else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(indexedDBRouter);
IORouterRegistry.registerLoadRouter(indexedDBRouter);
function browserIndexedDB(modelPath) {
return new BrowserIndexedDB(modelPath);
}
const PATH_SEPARATOR = '/';
const PATH_PREFIX = 'tensorflowjs_models';
const INFO_SUFFIX = 'info';
const MODEL_TOPOLOGY_SUFFIX = 'model_topology';
const WEIGHT_SPECS_SUFFIX = 'weight_specs';
const WEIGHT_DATA_SUFFIX = 'weight_data';
const MODEL_METADATA_SUFFIX = 'model_metadata';
function getModelKeys(path) {
return {
info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
};
}
function removeItems(keys) {
for (const key of Object.values(keys)) {
window.localStorage.removeItem(key);
}
}
class BrowserLocalStorage {
constructor(modelPath) {
if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
typeof window.localStorage === 'undefined') {
throw new Error('The current environment does not support local storage.');
}
this.LS = window.localStorage;
if (modelPath == null || !modelPath) {
throw new Error('For local storage, modelPath must not be null, undefined or empty.');
}
this.modelPath = modelPath;
this.keys = getModelKeys(this.modelPath);
}
async save(modelArtifacts) {
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
'in binary formats yet.');
}
else {
const topology = JSON.stringify(modelArtifacts.modelTopology);
const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
try {
this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
this.LS.setItem(this.keys.topology, topology);
this.LS.setItem(this.keys.weightSpecs, weightSpecs);
this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(weightBuffer));
const metadata = {
format: modelArtifacts.format,
generatedBy: modelArtifacts.generatedBy,
convertedBy: modelArtifacts.convertedBy,
signature: modelArtifacts.signature != null ?
modelArtifacts.signature :
undefined,
userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ?
modelArtifacts.userDefinedMetadata :
undefined,
modelInitializer: modelArtifacts.modelInitializer != null ?
modelArtifacts.modelInitializer :
undefined,
initializerSignature: modelArtifacts.initializerSignature != null ?
modelArtifacts.initializerSignature :
undefined,
trainingConfig: modelArtifacts.trainingConfig != null ?
modelArtifacts.trainingConfig :
undefined
};
this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
return { modelArtifactsInfo };
}
catch (err) {
removeItems(this.keys);
throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` +
`size quota being exceeded is a possible cause of this failure: ` +
`modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` +
`weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` +
`weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`);
}
}
}
async load() {
const info = JSON.parse(this.LS.getItem(this.keys.info));
if (info == null) {
throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
}
if (info.modelTopologyType !== 'JSON') {
throw new Error('BrowserLocalStorage does not support loading non-JSON model ' +
'topology yet.');
}
const out = {};
const topology = JSON.parse(this.LS.getItem(this.keys.topology));
if (topology == null) {
throw new Error(`In local storage, the topology of model '${this.modelPath}' ` +
`is missing.`);
}
out.modelTopology = topology;
const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
if (weightSpecs == null) {
throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` +
`are missing.`);
}
out.weightSpecs = weightSpecs;
const metadataString = this.LS.getItem(this.keys.modelMetadata);
if (metadataString != null) {
const metadata = JSON.parse(metadataString);
out.format = metadata.format;
out.generatedBy = metadata.generatedBy;
out.convertedBy = metadata.convertedBy;
if (metadata.signature != null) {
out.signature = metadata.signature;
}
if (metadata.userDefinedMetadata != null) {
out.userDefinedMetadata = metadata.userDefinedMetadata;
}
if (metadata.modelInitializer != null) {
out.modelInitializer = metadata.modelInitializer;
}
if (metadata.initializerSignature != null) {
out.initializerSignature = metadata.initializerSignature;
}
if (metadata.trainingConfig != null) {
out.trainingConfig = metadata.trainingConfig;
}
}
const weightDataBase64 = this.LS.getItem(this.keys.weightData);
if (weightDataBase64 == null) {
throw new Error(`In local storage, the binary weight values of model ` +
`'${this.modelPath}' are missing.`);
}
out.weightData = base64StringToArrayBuffer(weightDataBase64);
return out;
}
}
BrowserLocalStorage.URL_SCHEME = 'localstorage://';
const localStorageRouter = (url) => {
if (!env().getBool('IS_BROWSER')) {
return null;
}
else {
if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
}
else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(localStorageRouter);
IORouterRegistry.registerLoadRouter(localStorageRouter);
function browserLocalStorage(modelPath) {
return new BrowserLocalStorage(modelPath);
}
const DEFAULT_FILE_NAME_PREFIX = 'model';
const DEFAULT_JSON_EXTENSION_NAME = '.json';
const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
function defer(f) {
return new Promise(resolve => setTimeout(resolve)).then(f);
}
class BrowserDownloads {
constructor(fileNamePrefix) {
if (!env().getBool('IS_BROWSER')) {
throw new Error('browserDownloads() cannot proceed because the current environment ' +
'is not a browser.');
}
if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
}
if (fileNamePrefix == null || fileNamePrefix.length === 0) {
fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
}
this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
this.weightDataFileName =
fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
}
async save(modelArtifacts) {
if (typeof (document) === 'undefined') {
throw new Error('Browser downloads are not supported in ' +
'this environment since `document` is not present');
}
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
const weightsURL = window.URL.createObjectURL(new Blob([weightBuffer], { type: 'application/octet-stream' }));
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
throw new Error('BrowserDownloads.save() does not support saving model topology ' +
'in binary formats yet.');
}
else {
const weightsManifest = [{
paths: ['./' + this.weightDataFileName],
weights: modelArtifacts.weightSpecs
}];
const modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
const modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' }));
const jsonAnchor = this.modelJsonAnchor == null ?
document.createElement('a') :
this.modelJsonAnchor;
jsonAnchor.download = this.modelJsonFileName;
jsonAnchor.href = modelJsonURL;
await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));
if (modelArtifacts.weightData != null) {
const weightDataAnchor = this.weightDataAnchor == null ?
document.createElement('a') :
this.weightDataAnchor;
weightDataAnchor.download = this.weightDataFileName;
weightDataAnchor.href = weightsURL;
await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click')));
}
return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) };
}
}
}
BrowserDownloads.URL_SCHEME = 'downloads://';
const browserDownloadsRouter = (url) => {
if (!env().getBool('IS_BROWSER')) {
return null;
}
else {
if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
}
else {
return null;
}
}
};
IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
function browserDownloads(fileNamePrefix = 'model') {
return new BrowserDownloads(fileNamePrefix);
}
class PassthroughLoader {
constructor(modelArtifacts) {
this.modelArtifacts = modelArtifacts;
}
load() {
return this.modelArtifacts;
}
}
class PassthroughSaver {
constructor(saveHandler) {
this.saveHandler = saveHandler;
}
save(modelArtifacts) {
return this.saveHandler(modelArtifacts);
}
}
class PassthroughAsync {
constructor(handler) {
if (handler.load) {
this.load = () => Promise.resolve(handler.load());
}
if (handler.save) {
this.save = (modelArtifacts) => Promise.resolve(handler.save(modelArtifacts));
}
}
}
function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
const args = arguments;
return new PassthroughAsync(fromMemorySync(...args));
}
function fromMemorySync(modelArtifacts, weightSpecs, weightData, trainingConfig) {
if (arguments.length === 1) {
const isModelArtifacts = modelArtifacts.modelTopology != null ||
modelArtifacts.weightSpecs != null;
if (isModelArtifacts) {
return new PassthroughLoader(modelArtifacts);
}
else {
console.warn('Please call tf.io.fromMemory() with only one argument. ' +
'The argument should be of type ModelArtifacts. ' +
'The multi-argument signature of tf.io.fromMemory() has been ' +
'deprecated and will be removed in a future release.');
return new PassthroughLoader({ modelTopology: modelArtifacts });
}
}
else {
console.warn('Please call tf.io.fromMemory() with only one argument. ' +
'The argument should be of type ModelArtifacts. ' +
'The multi-argument signature of tf.io.fromMemory() has been ' +
'deprecated and will be removed in a future release.');
return new PassthroughLoader({
modelTopology: modelArtifacts,
weightSpecs,
weightData,
trainingConfig
});
}
}
function withSaveHandler(saveHandler) {
return new PassthroughSaver(saveHandler);
}
function prepareAndValidate(tensor, indices) {
const tensorRank = tensor.shape.length;
const indicesRank = indices.shape.length;
if (tensorRank < 1) {
throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
` but the rank was ${tensorRank}.`);
}
if (indicesRank < 1) {
throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
` but the rank was ${indicesRank}.`);
}
if (indices.dtype !== 'int32') {
throw new Error('tf.gatherND() expects the indices to be int32 type,' +
` but the dtype was ${indices.dtype}.`);
}
if (indices.shape[indicesRank - 1] > tensorRank) {
throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
`${indices.shape[indicesRank - 1]} vs. ${tensorRank}`);
}
if (sizeFromShape(tensor.shape) === 0) {
throw new Error('Requested more than 0 entries, but input is empty.' +
` Input shape: ${tensor.shape}.`);
}
const indicesShape = indices.shape;
const sliceRank = indicesShape[indicesShape.length - 1];
let nResult = 1;
for (let i = 0; i < indicesShape.length - 1; ++i) {
nResult *= indicesShape[i];
}
const inputShape = tensor.shape;
const resultShape = indicesShape.slice();
resultShape.pop();
let sliceSize = 1;
for (let i = sliceRank; i < tensorRank; ++i) {
sliceSize *= inputShape[i];
resultShape.push(inputShape[i]);
}
const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
1].slice(0, sliceRank);
return [resultShape, nResult, sliceSize, strides];
}
const NEW_AXIS = -2;
const SHRINK_AXIS = -1;
function assertParamsValid(input, begin, size) {
const inputRank = input.shape.length;
assert$1(inputRank === begin.length, () => `Error in slice${inputRank}D: Length of begin ${begin} must ` +
`match the rank of the array (${inputRank}).`);
assert$1(inputRank === size.length, () => `Error in slice${inputRank}D: Length of size ${size} must ` +
`match the rank of the array (${inputRank}).`);
for (let i = 0; i < inputRank; ++i) {
assert$1(begin[i] + size[i] <= input.shape[i], () => `Error in slice${inputRank}D: begin[${i}] + size[${i}] ` +
`(${begin[i] + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`);
}
}
function maskToAxes(mask) {
const axes = [];
let axis = 0;
while (mask > 0) {
if (mask & 1) {
axes.push(axis);
}
mask /= 2;
axis++;
}
return axes;
}
function computeOutShape$2(begin, end, strides) {
const size = [];
for (let axis = 0; axis < begin.length; axis++) {
size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
}
return size;
}
function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
const newStrides = [...strides];
for (let i = newStrides.length; i < inputShape.length; i++) {
newStrides.push(1);
}
for (let i = 0; i < numElidedAxes; i++) {
if (i === 0) {
newStrides[ellipsisInsertionIndex] = 1;
}
else {
newStrides.splice(ellipsisInsertionIndex, 0 , 1 );
newStrides.pop();
}
}
return newStrides;
}
function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
if (normalizedAxis <= ellipsisInsertionIndex) {
return normalizedAxis;
}
return normalizedAxis - (numElidedAxes - 1);
}
function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
const elidedAxes = [];
for (let i = 0; i < numElidedAxes; i++) {
elidedAxes.push(ellipsisInsertionIndex + i);
}
return elidedAxes;
}
function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
const inputRank = inputShape.length;
let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
if (ellipsisAxes.length && numInterpolatedAxes > 0) {
const fullIndex = ellipsisAxes[0];
const numElidedAxes = numInterpolatedAxes + 1;
normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
normalizedStrides =
stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
}
else {
for (let axis = 0; axis < inputRank; axis++) {
normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
normalizedEnd[axis] =
stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
}
}
return {
begin: normalizedBegin,
end: normalizedEnd,
strides: normalizedStrides
};
}
function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
const newIndices = [...inputShape];
const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
for (let axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = 0;
}
else {
const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
let originalValue = originalBegin[originalAxis];
if (beginMask & 1 << originalAxis) {
originalValue = 0;
}
newIndices[axis] = originalValue;
}
}
return newIndices;
}
function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
const newIndices = [...inputShape];
const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
for (let axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = Number.MAX_SAFE_INTEGER;
}
else {
const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
let originalValue = originalEnd[originalAxis];
if (endMask & 1 << originalAxis) {
originalValue = Number.MAX_SAFE_INTEGER;
}
newIndices[axis] = originalValue;
}
}
for (let i = 0; i < newIndices.length; i++) {
const axisSize = inputShape[i];
if (newIndices[i] < 0) {
newIndices[i] += axisSize;
}
newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
}
return newIndices;
}
function stridesForAxis(strides, axis, ellipsisMask) {
let stride = strides[axis];
if (ellipsisMask & (1 << axis) || stride == null) {
stride = 1;
}
return stride;
}
function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
let start = startIndices[axis];
const stride = strides[axis] || 1;
if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
if (stride > 0) {
start = Number.MIN_SAFE_INTEGER;
}
else {
start = Number.MAX_SAFE_INTEGER;
}
}
const axisSize = inputShape[axis];
if (start < 0) {
start += axisSize;
}
start = clamp(0, start, axisSize - 1);
return start;
}
function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
let stop = stopIndices[axis];
const stride = strides[axis] || 1;
if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) {
if (stride > 0) {
stop = Number.MAX_SAFE_INTEGER;
}
else {
stop = Number.MIN_SAFE_INTEGER;
}
}
const axisSize = inputShape[axis];
if (stop < 0) {
stop += axisSize;
}
if (stride > 0) {
stop = clamp(0, stop, axisSize);
}
else {
stop = clamp(-1, stop, axisSize - 1);
}
return stop;
}
function isSliceContinous(shape, begin, size) {
let firstNonOneAxis = size.length;
for (let i = 0; i < size.length; i++) {
if (size[i] > 1) {
firstNonOneAxis = i;
break;
}
}
for (let i = firstNonOneAxis + 1; i < size.length; i++) {
if (begin[i] > 0 || size[i] !== shape[i]) {
return false;
}
}
return true;
}
function computeFlatOffset(begin, strides) {
let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
for (let i = 0; i < begin.length - 1; i++) {
flatOffset += begin[i] * strides[i];
}
return flatOffset;
}
function parseSliceParams(x, begin, size) {
let begin_;
const xRank = x.shape.length;
if (typeof begin === 'number') {
begin_ = [begin, ...new Array(xRank - 1).fill(0)];
}
else if (begin.length < xRank) {
begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
}
else {
begin_ = begin.slice();
}
begin_.forEach(d => {
assert$1(d !== -1, () => 'slice() does not support negative begin indexing.');
});
let size_;
if (size == null) {
size_ = new Array(xRank).fill(-1);
}
else if (typeof size === 'number') {
size_ = [size, ...new Array(xRank - 1).fill(-1)];
}
else if (size.length < xRank) {
size_ = size.concat(new Array(xRank - size.length).fill(-1));
}
else {
size_ = size;
}
size_ = size_.map((d, i) => {
if (d >= 0) {
return d;
}
else {
assert$1(d === -1, () => `Negative size values should be exactly -1 but got ` +
`${d} for the slice() size at index ${i}.`);
return x.shape[i] - begin_[i];
}
});
return [begin_, size_];
}
function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
let stridesNonNull;
if (strides == null) {
stridesNonNull = new Array(begin.length);
stridesNonNull.fill(1);
}
else {
stridesNonNull = strides;
}
if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) {
throw new Error('Multiple ellipses in slice is not allowed.');
}
let ellipsisSeen = false;
const sparseSpec = {
dims: stridesNonNull.length,
numAddAxisAfterEllipsis: 0,
begin: begin.slice(),
end: end.slice(),
strides: stridesNonNull.slice(),
beginMask,
endMask,
ellipsisMask,
newAxisMask,
shrinkAxisMask
};
for (let i = 0; i < sparseSpec.dims; i++) {
if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) {
sparseSpec.numAddAxisAfterEllipsis++;
}
if ((1 << i) & ellipsisMask) {
ellipsisSeen = true;
}
}
if (!ellipsisSeen) {
sparseSpec.ellipsisMask |= (1 << sparseSpec.dims);
sparseSpec.dims++;
}
const denseSpec = {
dims: xShape.length,
beginMask: 0,
endMask: 0,
beginValid: false,
endValid: false
};
buildDenseSpec(sparseSpec, denseSpec);
let isIdentity = true;
let sliceDim0 = true;
let isSimpleSlice = true;
const processingShape = [];
const finalShape = [];
for (let i = 0; i < xShape.length; ++i) {
if (denseSpec.strides[i] === 0) {
throw Error(`strides[${i}] must be non-zero`);
}
const shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i));
const dimI = xShape[i];
if (dimI === -1) {
processingShape.push(shrinkI ? 1 : -1);
continue;
}
const masks = [denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)];
const validRange = [
denseSpec.strides[i] > 0 ? 0 : -1,
denseSpec.strides[i] > 0 ? dimI : dimI - 1
];
if (shrinkI && denseSpec.strides[i] <= 0) {
throw Error('only stride 1 allowed on non-range indexing.');
}
isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1);
const beginAndEndMasked = !!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i)));
if (denseSpec.beginValid && denseSpec.endValid) {
if (shrinkI) {
const xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] :
denseSpec.begin[i];
denseSpec.begin[i] = xFwd;
denseSpec.end[i] = denseSpec.begin[i] + 1;
if (xFwd < 0 || xFwd >= dimI) {
throw Error(`slice index ${denseSpec.begin[i]} of dimension ${i} out of bounds.`);
}
}
else {
denseSpec.begin[i] = canonical(denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks, validRange);
denseSpec.end[i] = canonical(denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange);
}
const takeAllInDimension = denseSpec.strides[i] === 1 &&
denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI;
isIdentity = isIdentity && takeAllInDimension;
sliceDim0 = sliceDim0 &&
((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension);
}
else {
isIdentity =
isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked);
sliceDim0 = sliceDim0 &&
((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked);
}
let intervalLength;
let knownInterval = false;
if (denseSpec.beginValid && denseSpec.endValid) {
intervalLength = denseSpec.end[i] - denseSpec.begin[i];
knownInterval = true;
}
else if (shrinkI) {
intervalLength = 1;
knownInterval = true;
}
else if (beginAndEndMasked) {
if (dimI >= 0) {
if (denseSpec.strides[i] < 0) {
intervalLength = -dimI;
}
else {
intervalLength = dimI;
}
knownInterval = true;
}
}
if (knownInterval) {
let sizeI;
if (intervalLength === 0 ||
((intervalLength < 0) !== (denseSpec.strides[i] < 0))) {
sizeI = 0;
}
else {
sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) +
(intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0);
}
processingShape.push(sizeI);
}
else {
processingShape.push(-1);
}
}
for (let denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) {
const gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
if (gatherIndex >= 0) {
finalShape.push(processingShape[gatherIndex]);
}
else if (gatherIndex === NEW_AXIS) {
finalShape.push(1);
}
}
const finalShapeSparse = finalShape.filter((dim, i) => denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS);
return {
finalShapeSparse,
finalShape,
isIdentity,
sliceDim0,
isSimpleSlice,
begin: denseSpec.begin,
end: denseSpec.end,
strides: denseSpec.strides
};
}
function buildDenseSpec(sparse, dense) {
dense.beginMask = 0;
dense.endMask = 0;
dense.shrinkAxisMask = 0;
let fullIndex = 0;
dense.beginValid = sparse.begin != null;
dense.endValid = sparse.end != null;
dense.begin = new Array(dense.dims);
dense.end = new Array(dense.dims);
dense.strides = new Array(dense.dims);
dense.finalShapeGatherIndices = [];
dense.finalShapeGatherIndicesSparse = [];
dense.inputShapeGatherIndicesSparse = new Array(dense.dims);
for (let i = 0; i < sparse.dims; i++) {
if ((1 << i) & sparse.ellipsisMask) {
const nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims);
for (; fullIndex < nextIndex; fullIndex++) {
dense.begin[fullIndex] = 0;
dense.end[fullIndex] = 0;
dense.strides[fullIndex] = 1;
dense.beginMask |= (1 << fullIndex);
dense.endMask |= (1 << fullIndex);
dense.finalShapeGatherIndices.push(fullIndex);
dense.finalShapeGatherIndicesSparse.push(-1);
dense.inputShapeGatherIndicesSparse[fullIndex] = i;
}
}
else if ((1 << i) & sparse.newAxisMask) {
dense.finalShapeGatherIndices.push(NEW_AXIS);
dense.finalShapeGatherIndicesSparse.push(-1);
}
else {
if (fullIndex === dense.begin.length) {
throw Error(`Index out of range using input dim ${fullIndex}; input ` +
`has only ${dense.dims} dims, ${dense.begin.length}.`);
}
if (sparse.begin != null) {
dense.begin[fullIndex] = sparse.begin[i];
}
if (sparse.end != null) {
dense.end[fullIndex] = sparse.end[i];
}
dense.strides[fullIndex] = sparse.strides[i];
if (sparse.beginMask & (1 << i)) {
dense.beginMask |= (1 << fullIndex);
}
if (sparse.endMask & (1 << i)) {
dense.endMask |= (1 << fullIndex);
}
if (sparse.shrinkAxisMask & (1 << i)) {
dense.finalShapeGatherIndices.push(SHRINK_AXIS);
dense.finalShapeGatherIndicesSparse.push(-1);
dense.shrinkAxisMask |= (1 << fullIndex);
}
else {
dense.finalShapeGatherIndices.push(fullIndex);
dense.finalShapeGatherIndicesSparse.push(i);
}
dense.inputShapeGatherIndicesSparse[fullIndex] = i;
fullIndex++;
}
}
}
function canonical(x, c, strideI, dimI, masks, validRange) {
if (masks[c]) {
return strideI > 0 ? validRange[c] : validRange[(c + 1) & 1];
}
else {
const xFwd = x < 0 ? dimI + x : x;
return xFwd < validRange[0] ? validRange[0] :
xFwd > validRange[1] ? validRange[1] : xFwd;
}
}
var slice_util = Object.freeze({
__proto__: null,
assertParamsValid: assertParamsValid,
computeFlatOffset: computeFlatOffset,
computeOutShape: computeOutShape$2,
getNormalizedAxes: getNormalizedAxes,
isSliceContinous: isSliceContinous,
maskToAxes: maskToAxes,
parseSliceParams: parseSliceParams,
sliceInfo: sliceInfo,
startForAxis: startForAxis,
startIndicesWithElidedDims: startIndicesWithElidedDims,
stopForAxis: stopForAxis,
stopIndicesWithElidedDims: stopIndicesWithElidedDims,
stridesForAxis: stridesForAxis,
stridesWithElidedDims: stridesWithElidedDims
});
class OptimizerConstructors {
static sgd(learningRate) {
return new SGDOptimizer(learningRate);
}
static momentum(learningRate, momentum, useNesterov = false) {
return new MomentumOptimizer(learningRate, momentum, useNesterov);
}
static rmsprop(learningRate, decay = .9, momentum = 0.0, epsilon = null, centered = false) {
return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
}
static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = null) {
return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
}
static adadelta(learningRate = .001, rho = .95, epsilon = null) {
return new AdadeltaOptimizer(learningRate, rho, epsilon);
}
static adamax(learningRate = 0.002, beta1 = 0.9, beta2 = 0.999, epsilon = null, decay = 0.0) {
return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
}
static adagrad(learningRate, initialAccumulatorValue = 0.1) {
return new AdagradOptimizer(learningRate, initialAccumulatorValue);
}
}
const train = OptimizerConstructors;
const delayCallback = (() => {
if (typeof requestAnimationFrame !== 'undefined') {
return requestAnimationFrame;
}
else if (typeof setImmediate !== 'undefined') {
return setImmediate;
}
return (f) => f();
})();
function nextFrame() {
return new Promise(resolve => delayCallback(() => resolve()));
}
function assertParamsConsistent(shapes, axis) {
const rank = shapes[0].length;
shapes.forEach((shape, i) => {
assert$1(shape.length === rank, () => `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +
`as the rank of the rest (${rank})`);
});
assert$1(axis >= 0 && axis < rank, () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);
const firstShape = shapes[0];
shapes.forEach((shape, i) => {
for (let r = 0; r < rank; r++) {
assert$1((r === axis) || (shape[r] === firstShape[r]), () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +
`does not match the shape of the rest (${firstShape}) ` +
`along the non-concatenated axis ${i}.`);
}
});
}
function computeOutShape$1(shapes, axis) {
const outputShape = shapes[0].slice();
for (let i = 1; i < shapes.length; i++) {
outputShape[axis] += shapes[i][axis];
}
return outputShape;
}
var RowPartitionType$1;
(function (RowPartitionType) {
RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE";
RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS";
RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS";
RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS";
RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS";
RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS";
})(RowPartitionType$1 || (RowPartitionType$1 = {}));
function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) {
let outputShape = new Array();
if (valueShape == null && shape == null) {
return outputShape;
}
if (shape == null) {
while (outputShape.length < raggedRank + valueShape.length) {
outputShape.push(-1);
}
}
else {
outputShape = shape.slice();
}
if (valueShape == null) {
return outputShape;
}
if (raggedRank + valueShape.length !== outputShape.length) {
throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.rank = ${raggedRank +
valueShape.length}, but shape.rank = ${outputShape.length}`);
}
for (let i = 1; i < valueShape.length; ++i) {
const valueDim = valueShape[i];
const outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i];
const outputShapeDim = outputShape[outputShapeDimIndex];
if (valueDim >= 0) {
if (outputShapeDim >= 0) {
if (outputShapeDim !== valueDim) {
throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.shape[${i + raggedRank}] = ${valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);
}
}
else {
outputShape[outputShapeDimIndex] = valueDim;
}
}
}
return outputShape;
}
function getRowPartitionTypesHelper(rowPartitionTypeStrings) {
const stringToType = {
'FIRST_DIM_SIZE': RowPartitionType$1.FIRST_DIM_SIZE,
'VALUE_ROWIDS': RowPartitionType$1.VALUE_ROWIDS,
'ROW_LENGTHS': RowPartitionType$1.ROW_LENGTHS,
'ROW_SPLITS': RowPartitionType$1.ROW_SPLITS,
'ROW_LIMITS': RowPartitionType$1.ROW_LIMITS,
'ROW_STARTS': RowPartitionType$1.ROW_STARTS
};
const result = [];
for (const typeStr of rowPartitionTypeStrings) {
if (typeStr in stringToType) {
result.push(stringToType[typeStr]);
}
else {
break;
}
}
return result;
}
function getRaggedRank(rowPartitionTypes) {
if (rowPartitionTypes.length === 0) {
return 0;
}
if (rowPartitionTypes[0] === RowPartitionType$1.FIRST_DIM_SIZE) {
return rowPartitionTypes.length - 1;
}
return rowPartitionTypes.length;
}
function validateDefaultValueShape(defaultValueShape, valueShape) {
if (defaultValueShape == null || valueShape == null) {
return;
}
const defaultNDims = defaultValueShape.length;
const valuesNDims = valueShape.length;
if (defaultNDims >= valuesNDims) {
throw new Error(`defaultValue.shape=${defaultValueShape} and ragged tensor flatValues.shape=${valueShape}, are incompatible: defaultValue.rank = ${defaultNDims} must be less than ragged tensor input flatValues.rank = ${valuesNDims})`);
}
for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {
const defaultDim = defaultValueShape[i];
const valueDim = valueShape[i + 1];
if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&
defaultDim !== valueDim) {
throw new Error(`defaultValue.shape=${defaultValueShape}, and ragged tensor input flatValues.shape=${valueShape} are incompatible: defaultValue.shape[${i - defaultValueShape.length}] = ${defaultDim} but ragged tensor input.flatValues.shape[${i - defaultValueShape.length}] = ${valueDim}`);
}
}
}
const PARALLELIZE_THRESHOLD = 30;
function computeOptimalWindowSize(inSize) {
if (inSize <= PARALLELIZE_THRESHOLD) {
return inSize;
}
return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
}
function getImageCenter(center, imageHeight, imageWidth) {
const centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
const centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
return [centerX, centerY];
}
function getReshaped(inputShape, blockShape, prod, batchToSpace = true) {
let reshaped = [];
if (batchToSpace) {
reshaped = reshaped.concat(blockShape.slice(0));
reshaped.push(inputShape[0] / prod);
reshaped = reshaped.concat(inputShape.slice(1));
}
else {
reshaped = reshaped.concat(inputShape[0]);
const spatialLength = blockShape.length;
for (let i = 0; i < spatialLength; ++i) {
reshaped =
reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
}
reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
}
return reshaped;
}
function getPermuted(reshapedRank, blockShapeRank, batchToSpace = true) {
const permuted = [];
if (batchToSpace) {
permuted.push(blockShapeRank);
for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {
if (i <= 2 * blockShapeRank) {
permuted.push(i);
permuted.push(i - (blockShapeRank + 1));
}
else {
permuted.push(i);
}
}
}
else {
const permutedBeforeBatch = [];
const permutedAfterBatch = [];
for (let i = 1; i < reshapedRank; ++i) {
if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
permutedAfterBatch.push(i);
}
else {
permutedBeforeBatch.push(i);
}
}
permuted.push(...permutedBeforeBatch);
permuted.push(0);
permuted.push(...permutedAfterBatch);
}
return permuted;
}
function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace = true) {
const reshapedPermuted = [];
if (batchToSpace) {
reshapedPermuted.push(inputShape[0] / prod);
}
else {
reshapedPermuted.push(inputShape[0] * prod);
}
for (let i = 1; i < inputShape.length; ++i) {
if (i <= blockShape.length) {
if (batchToSpace) {
reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
}
else {
reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
}
}
else {
reshapedPermuted.push(inputShape[i]);
}
}
return reshapedPermuted;
}
function getSliceBeginCoords(crops, blockShape) {
const sliceBeginCoords = [0];
for (let i = 0; i < blockShape; ++i) {
sliceBeginCoords.push(crops[i][0]);
}
return sliceBeginCoords;
}
function getSliceSize(uncroppedShape, crops, blockShape) {
const sliceSize = uncroppedShape.slice(0, 1);
for (let i = 0; i < blockShape; ++i) {
sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
}
return sliceSize;
}
const SELU_SCALEALPHA = 1.7580993408473768599402175208123;
const SELU_SCALE = 1.0507009873554804934193349852946;
const ERF_P = 0.3275911;
const ERF_A1 = 0.254829592;
const ERF_A2 = -0.284496736;
const ERF_A3 = 1.421413741;
const ERF_A4 = -1.453152027;
const ERF_A5 = 1.061405429;
function mergeRealAndImagArrays(real, imag) {
if (real.length !== imag.length) {
throw new Error(`Cannot merge real and imag arrays of different lengths. real:` +
`${real.length}, imag: ${imag.length}.`);
}
const result = new Float32Array(real.length * 2);
for (let i = 0; i < result.length; i += 2) {
result[i] = real[i / 2];
result[i + 1] = imag[i / 2];
}
return result;
}
function splitRealAndImagArrays(complex) {
const real = new Float32Array(complex.length / 2);
const imag = new Float32Array(complex.length / 2);
for (let i = 0; i < complex.length; i += 2) {
real[i / 2] = complex[i];
imag[i / 2] = complex[i + 1];
}
return { real, imag };
}
function complexWithEvenIndex(complex) {
const len = Math.ceil(complex.length / 4);
const real = new Float32Array(len);
const imag = new Float32Array(len);
for (let i = 0; i < complex.length; i += 4) {
real[Math.floor(i / 4)] = complex[i];
imag[Math.floor(i / 4)] = complex[i + 1];
}
return { real, imag };
}
function complexWithOddIndex(complex) {
const len = Math.floor(complex.length / 4);
const real = new Float32Array(len);
const imag = new Float32Array(len);
for (let i = 2; i < complex.length; i += 4) {
real[Math.floor(i / 4)] = complex[i];
imag[Math.floor(i / 4)] = complex[i + 1];
}
return { real, imag };
}
function getComplexWithIndex(complex, index) {
const real = complex[index * 2];
const imag = complex[index * 2 + 1];
return { real, imag };
}
function assignToTypedArray(data, real, imag, index) {
data[index * 2] = real;
data[index * 2 + 1] = imag;
}
function exponents(n, inverse) {
const real = new Float32Array(n / 2);
const imag = new Float32Array(n / 2);
for (let i = 0; i < Math.ceil(n / 2); i++) {
const x = (inverse ? 2 : -2) * Math.PI * (i / n);
real[i] = Math.cos(x);
imag[i] = Math.sin(x);
}
return { real, imag };
}
function exponent(k, n, inverse) {
const x = (inverse ? 2 : -2) * Math.PI * (k / n);
const real = Math.cos(x);
const imag = Math.sin(x);
return { real, imag };
}
const ARROW = '->';
const ARROW_REGEX = /->/g;
const COMMA = ',';
const ELLIPSIS = '...';
function decodeEinsumEquation(equation, numTensors) {
equation = equation.replace(/\s/g, '');
const numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) /
ARROW.length;
if (numArrows < 1) {
throw new Error('Equations without an arrow are not supported.');
}
else if (numArrows > 1) {
throw new Error(`Equation must contain exactly one arrow ("${ARROW}").`);
}
const [inputString, outputString] = equation.split(ARROW);
assert$1(inputString.indexOf(ELLIPSIS) === -1, () => `The ellipsis notation ("${ELLIPSIS}") is not supported yet.`);
const inputTerms = inputString.split(COMMA);
const numInputs = inputTerms.length;
if (numTensors !== numInputs) {
throw new Error(`Expected ${numInputs} input tensors, received ${numTensors}`);
}
if (numInputs > 2) {
throw new Error('Support for more than 2 input tensors is not implemented yet.');
}
const allDims = [];
for (let i = 0; i < outputString.length; ++i) {
const dimName = outputString[i];
if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) {
throw new Error(`Output subscripts contain the label ${dimName} ` +
`not present in the input subscripts.`);
}
if (allDims.indexOf(dimName) === -1) {
allDims.push(dimName);
}
}
for (let i = 0; i < inputString.length; ++i) {
const dimName = inputString[i];
if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
allDims.push(dimName);
}
}
const idDims = new Array(inputTerms.length);
for (let i = 0; i < numInputs; ++i) {
if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {
throw new Error(`Found duplicate axes in input component ${inputTerms[i]}. ` +
`Support for duplicate axes in input is not implemented yet.`);
}
idDims[i] = [];
for (let j = 0; j < inputTerms[i].length; ++j) {
idDims[i].push(allDims.indexOf(inputTerms[i][j]));
}
}
const numDims = allDims.length;
const numOutDims = outputString.length;
const summedDims = [];
for (let i = numOutDims; i < numDims; ++i) {
summedDims.push(i);
}
return { allDims, summedDims, idDims };
}
function getEinsumPermutation(nDims, idDims) {
let permutationIndices = new Array(nDims);
permutationIndices.fill(-1);
for (let i = 0; i < idDims.length; ++i) {
permutationIndices[idDims[i]] = i;
}
const expandDims = [];
for (let i = 0; i < nDims; ++i) {
if (permutationIndices[i] === -1) {
expandDims.push(i);
}
}
permutationIndices = permutationIndices.filter(d => d !== -1);
return { permutationIndices, expandDims };
}
function checkEinsumDimSizes(nDims, idDims, tensors) {
const dimSizes = new Array(nDims);
for (let i = 0; i < tensors.length; ++i) {
const shape = tensors[i].shape;
for (let j = 0; j < idDims[i].length; ++j) {
if (dimSizes[idDims[i][j]] === undefined) {
dimSizes[idDims[i][j]] = shape[j];
}
else {
assert$1(dimSizes[idDims[i][j]] === shape[j], () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` +
`of input shaped ${JSON.stringify(shape)}, ` +
`but got dimension ${shape[j]}`);
}
}
}
}
function getEinsumComputePath(summedDims, idDims) {
const path = summedDims;
const steps = [];
let nSteps = 0;
if (summedDims.length === 0) {
path.push(-1);
}
nSteps = summedDims.length + 1;
for (let i = 0; i < nSteps; ++i) {
steps.push([]);
}
const computedTermIndices = [];
for (let i = 0; i < path.length; ++i) {
const summedDim = path[i];
const termIndices = findTermsWithDim(idDims, summedDim);
for (const termIndex of termIndices) {
if (computedTermIndices.indexOf(termIndex) === -1) {
steps[i].push(termIndex);
computedTermIndices.push(termIndex);
}
}
}
return { path, steps };
}
function isIdentityPermutation(perm) {
return perm.every((dim, index) => dim === index);
}
function findTermsWithDim(idDims, dim) {
const termIndices = [];
for (let i = 0; i < idDims.length; ++i) {
if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
termIndices.push(i);
}
}
return termIndices;
}
function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
let splitSizes = [];
if (typeof (numOrSizeSplits) === 'number') {
assert$1(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
splitSizes =
new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
}
else {
const numOfNegs = numOrSizeSplits.reduce((count, value) => {
if (value === -1) {
count += 1;
}
return count;
}, 0);
assert$1(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
const negIndex = numOrSizeSplits.indexOf(-1);
if (negIndex !== -1) {
const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
numOrSizeSplits[negIndex] = x.shape[axis] - total;
}
assert$1(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
splitSizes = numOrSizeSplits;
}
return splitSizes;
}
function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) {
return `Received SparseTensor with denseShape[0] = 0 but
indices.shape[0] = ${indicesLength}`;
}
function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) {
return `indices(${index}, 0) is invalid: ${value} < 0`;
}
function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) {
return `indices(${index}, 0) is invalid: ${value} >= ${limit}`;
}
function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) {
return `only one output dimension may be -1, not both ${dim1} and ${dim2}`;
}
function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) {
return `size ${dim} must be non-negative, not ${value}`;
}
function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() {
return 'reshape cannot infer the missing input size for an empty tensor ' +
'unless all specified input sizes are non-zero';
}
function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) {
const inputSize = sizeFromShape(inputShape);
const outputSize = sizeFromShape(outputShape);
return `Input to reshape is a SparseTensor with ${inputSize}
dense values, but the requested shape requires a multiple of ${outputSize}. inputShape=${inputShape} outputShape= ${outputShape}`;
}
function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) {
const inputSize = sizeFromShape(inputShape);
const outputSize = sizeFromShape(outputShape);
return `Input to reshape is a tensor with ${inputSize} dense values, but the requested shape has ${outputSize}. inputShape=${inputShape} outputShape=${outputShape}`;
}
function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() {
return `segment ids must be >= 0`;
}
function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() {
return `segment ids are not increasing`;
}
function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) {
return `Segment id ${segmentId} out of range [0, ${outputRows}), possibly because segmentIds input is not sorted.`;
}
function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) {
return `Bad: indices[${index}] == ${indexValue} out of range [0, ${inputRows})`;
}
function segOpComputeOptimalWindowSize(inSize, numSegments) {
let done = false;
let res;
if (inSize <= PARALLELIZE_THRESHOLD) {
res = inSize;
done = true;
}
else {
res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
}
while (!done) {
if (res > numSegments || res === inSize) {
done = true;
}
else {
res = nearestDivisor(inSize, res + 1);
}
}
return res;
}
function computeOutShape(aShape, axis, numSegments) {
const outShape = [];
const rank = aShape.length;
for (let dim = 0; dim < rank; dim++) {
if (dim !== axis) {
outShape.push(aShape[dim]);
}
else {
outShape.push(numSegments);
}
}
return outShape;
}
function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
const indicesRank = indices.shape.length;
const xRank = x.shape.length;
if (batchDims !== 0) {
if (batchDims < -indicesRank || batchDims > indicesRank) {
throw new Error(`Expect batchDims in the range of [-${indicesRank}, ${indicesRank}], but got ${batchDims}`);
}
}
if (batchDims < 0) {
batchDims += indicesRank;
}
if (batchDims > xRank) {
throw new Error(`batchDims (${batchDims}) must be less than rank(x) (
${xRank}).`);
}
if (axis < batchDims) {
throw new Error(`batchDims (${batchDims}) must be less than or equal to axis (${axis}).`);
}
for (let i = 0; i < batchDims; ++i) {
if (x.shape[i] !== indices.shape[i]) {
throw new Error(`x.shape[${i}]: ${x.shape[i]} should be equal to indices.shape[${i}]: ${indices.shape[i]}.`);
}
}
const dimSize = x.shape[axis];
const outputShape = [];
let batchSize = 1;
let outerSize = 1;
let sliceSize = 1;
for (let i = 0; i < batchDims; ++i) {
outputShape.push(x.shape[i]);
batchSize *= x.shape[i];
}
for (let i = batchDims; i < axis; i++) {
outputShape.push(x.shape[i]);
outerSize *= x.shape[i];
}
for (let i = batchDims; i < indicesRank; i++) {
outputShape.push(indices.shape[i]);
}
for (let i = axis + 1; i < xRank; i++) {
outputShape.push(x.shape[i]);
sliceSize *= x.shape[i];
}
return { batchSize, sliceSize, outerSize, dimSize, outputShape };
}
var segment_util = Object.freeze({
__proto__: null,
collectGatherOpShapeInfo: collectGatherOpShapeInfo,
computeOutShape: computeOutShape,
segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize
});
function fromUint8ToStringArray(vals) {
try {
return vals.map(val => decodeString(val));
}
catch (err) {
throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${err}`);
}
}
function fromStringArrayToUint8(strings) {
return strings.map(s => encodeString(s));
}
var backend_util = Object.freeze({
__proto__: null,
ERF_A1: ERF_A1,
ERF_A2: ERF_A2,
ERF_A3: ERF_A3,
ERF_A4: ERF_A4,
ERF_A5: ERF_A5,
ERF_P: ERF_P,
PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
get RowPartitionType () { return RowPartitionType$1; },
SELU_SCALE: SELU_SCALE,
SELU_SCALEALPHA: SELU_SCALEALPHA,
applyActivation: applyActivation$1,
assertAndGetBroadcastShape: assertAndGetBroadcastShape,
assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
assertParamsConsistent: assertParamsConsistent,
assignToTypedArray: assignToTypedArray,
axesAreInnerMostDims: axesAreInnerMostDims,
calculateShapes: calculateShapes,
checkEinsumDimSizes: checkEinsumDimSizes,
checkPadOnDimRoundingMode: checkPadOnDimRoundingMode,
combineLocations: combineLocations,
combineRaggedTensorToTensorShapes: combineRaggedTensorToTensorShapes,
complexWithEvenIndex: complexWithEvenIndex,
complexWithOddIndex: complexWithOddIndex,
computeConv2DInfo: computeConv2DInfo,
computeConv3DInfo: computeConv3DInfo,
computeDefaultPad: computeDefaultPad,
computeDilation2DInfo: computeDilation2DInfo,
computeOptimalWindowSize: computeOptimalWindowSize,
computeOutAndReduceShapes: computeOutAndReduceShapes,
computeOutShape: computeOutShape$1,
computePool2DInfo: computePool2DInfo,
computePool3DInfo: computePool3DInfo,
convertConv2DDataFormat: convertConv2DDataFormat,
decodeEinsumEquation: decodeEinsumEquation,
eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
expandShapeToKeepDim: expandShapeToKeepDim,
exponent: exponent,
exponents: exponents,
fromStringArrayToUint8: fromStringArrayToUint8,
fromUint8ToStringArray: fromUint8ToStringArray,
getAxesPermutation: getAxesPermutation,
getBroadcastDims: getBroadcastDims$1,
getComplexWithIndex: getComplexWithIndex,
getEinsumComputePath: getEinsumComputePath,
getEinsumPermutation: getEinsumPermutation,
getFusedBiasGradient: getFusedBiasGradient,
getFusedDyActivation: getFusedDyActivation,
getImageCenter: getImageCenter,
getInnerMostAxes: getInnerMostAxes,
getPermuted: getPermuted,
getRaggedRank: getRaggedRank,
getReductionAxes: getReductionAxes,
getReshaped: getReshaped,
getReshapedPermuted: getReshapedPermuted,
getRowPartitionTypesHelper: getRowPartitionTypesHelper,
getSliceBeginCoords: getSliceBeginCoords,
getSliceSize: getSliceSize,
getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch,
getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage,
getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage,
getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage,
getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage,
getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage,
getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage,
getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage,
getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage,
getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage,
getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage,
getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage,
getUndoAxesPermutation: getUndoAxesPermutation,
isIdentityPermutation: isIdentityPermutation,
log: log$3,
mergeRealAndImagArrays: mergeRealAndImagArrays,
prepareAndValidate: prepareAndValidate,
prepareSplitSize: prepareSplitSize,
segment_util: segment_util,
shouldFuse: shouldFuse,
slice_util: slice_util,
splitRealAndImagArrays: splitRealAndImagArrays,
stridesOrDilationsArePositive: stridesOrDilationsArePositive,
tupleValuesAreOne: tupleValuesAreOne,
upcastType: upcastType,
validateDefaultValueShape: validateDefaultValueShape,
validateInput: validateInput,
validateUpdateShape: validateUpdateShape,
warn: warn
});
registerOptimizers();
const contexts = {};
const WEBGL_ATTRIBUTES = {
alpha: false,
antialias: false,
premultipliedAlpha: false,
preserveDrawingBuffer: false,
depth: false,
stencil: false,
failIfMajorPerformanceCaveat: true
};
function setWebGLContext(webGLVersion, gl) {
contexts[webGLVersion] = gl;
}
function getWebGLContext(webGLVersion, customCanvas) {
if (!(webGLVersion in contexts) || customCanvas != null) {
const newCtx = getWebGLRenderingContext(webGLVersion, customCanvas);
if (newCtx !== null) {
contexts[webGLVersion] = newCtx;
}
else {
console.log('Could not get context for WebGL version', webGLVersion);
return null;
}
}
const gl = contexts[webGLVersion];
if (gl == null || gl.isContextLost()) {
delete contexts[webGLVersion];
return getWebGLContext(webGLVersion);
}
gl.disable(gl.DEPTH_TEST);
gl.disable(gl.STENCIL_TEST);
gl.disable(gl.BLEND);
gl.disable(gl.DITHER);
gl.disable(gl.POLYGON_OFFSET_FILL);
gl.disable(gl.SAMPLE_COVERAGE);
gl.enable(gl.SCISSOR_TEST);
gl.enable(gl.CULL_FACE);
gl.cullFace(gl.BACK);
return contexts[webGLVersion];
}
function createCanvas(webGLVersion) {
if (!env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' &&
webGLVersion === 2) {
return new OffscreenCanvas(300, 150);
}
else if (typeof document !== 'undefined') {
return document.createElement('canvas');
}
else {
throw new Error('Cannot create a canvas in this context');
}
}
function getWebGLRenderingContext(webGLVersion, customCanvas) {
if (webGLVersion !== 1 && webGLVersion !== 2) {
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
}
const canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas;
canvas.addEventListener('webglcontextlost', (ev) => {
ev.preventDefault();
delete contexts[webGLVersion];
}, false);
if (env().getBool('SOFTWARE_WEBGL_ENABLED')) {
WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false;
}
if (webGLVersion === 1) {
return (
canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
canvas
.getContext('experimental-webgl', WEBGL_ATTRIBUTES));
}
return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
}
var PackingScheme;
(function (PackingScheme) {
PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
})(PackingScheme || (PackingScheme = {}));
var TextureUsage;
(function (TextureUsage) {
TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
})(TextureUsage || (TextureUsage = {}));
var PhysicalTextureType;
(function (PhysicalTextureType) {
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
})(PhysicalTextureType || (PhysicalTextureType = {}));
function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
return [columns, rows];
}
function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
return matrixSize * channelsPerTexture;
}
function getDenseTexShape(shape) {
const size = sizeFromShape(shape);
const texelsNeeded = Math.ceil(size / 4);
return sizeToSquarishShape(texelsNeeded);
}
function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
return [
Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
];
}
function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
return w * h * 4;
}
function getTextureConfig(
gl, textureHalfFloatExtension) {
const glany = gl;
let internalFormatFloat;
let internalFormatHalfFloat;
let internalFormatPackedHalfFloat;
let internalFormatPackedFloat;
let textureFormatFloat;
let downloadTextureFormat;
let downloadUnpackNumChannels;
let defaultNumChannels;
let textureTypeHalfFloat;
let textureTypeFloat;
if (env().getNumber('WEBGL_VERSION') === 2) {
internalFormatFloat = glany.R32F;
internalFormatHalfFloat = glany.R16F;
internalFormatPackedHalfFloat = glany.RGBA16F;
internalFormatPackedFloat = glany.RGBA32F;
textureFormatFloat = glany.RED;
downloadUnpackNumChannels = 4;
defaultNumChannels = 1;
textureTypeHalfFloat = glany.HALF_FLOAT;
textureTypeFloat = glany.FLOAT;
downloadTextureFormat = glany.RGBA8;
}
else {
internalFormatFloat = gl.RGBA;
internalFormatHalfFloat = gl.RGBA;
internalFormatPackedHalfFloat = gl.RGBA;
internalFormatPackedFloat = glany.RGBA;
textureFormatFloat = gl.RGBA;
downloadUnpackNumChannels = 4;
defaultNumChannels = 4;
textureTypeHalfFloat = textureHalfFloatExtension != null ?
textureHalfFloatExtension.HALF_FLOAT_OES :
null;
textureTypeFloat = gl.FLOAT;
downloadTextureFormat = gl.RGBA;
}
return {
internalFormatFloat,
internalFormatHalfFloat,
internalFormatPackedHalfFloat,
internalFormatPackedFloat,
textureFormatFloat,
downloadTextureFormat,
downloadUnpackNumChannels,
defaultNumChannels,
textureTypeHalfFloat,
textureTypeFloat
};
}
function callAndCheck(gl, func) {
const returnValue = func();
if (env().getBool('DEBUG')) {
checkWebGLError(gl);
}
return returnValue;
}
function checkWebGLError(gl) {
const error = gl.getError();
if (error !== gl.NO_ERROR) {
throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
}
}
const MIN_FLOAT16 = 5.96e-8;
const MAX_FLOAT16 = 65504;
function canBeRepresented(num) {
if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
(MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
return true;
}
return false;
}
function getWebGLErrorMessage(gl, status) {
switch (status) {
case gl.NO_ERROR:
return 'NO_ERROR';
case gl.INVALID_ENUM:
return 'INVALID_ENUM';
case gl.INVALID_VALUE:
return 'INVALID_VALUE';
case gl.INVALID_OPERATION:
return 'INVALID_OPERATION';
case gl.INVALID_FRAMEBUFFER_OPERATION:
return 'INVALID_FRAMEBUFFER_OPERATION';
case gl.OUT_OF_MEMORY:
return 'OUT_OF_MEMORY';
case gl.CONTEXT_LOST_WEBGL:
return 'CONTEXT_LOST_WEBGL';
default:
return `Unknown error code ${status}`;
}
}
function getExtensionOrThrow(gl, extensionName) {
return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.');
}
function createVertexShader$1(gl, vertexShaderSource) {
const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.');
callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource));
callAndCheck(gl, () => gl.compileShader(vertexShader));
if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
console.log(gl.getShaderInfoLog(vertexShader));
throw new Error('Failed to compile vertex shader.');
}
return vertexShader;
}
function createFragmentShader(gl, fragmentShaderSource) {
const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.');
callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource));
callAndCheck(gl, () => gl.compileShader(fragmentShader));
if (env().get('ENGINE_COMPILE_ONLY')) {
return fragmentShader;
}
if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
throw new Error('Failed to compile fragment shader.');
}
return fragmentShader;
}
const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
if (lineNumberRegexResult == null) {
console.log(`Couldn't parse line number in error: ${shaderInfoLog}`);
console.log(shaderSource);
return;
}
const lineNumber = +lineNumberRegexResult[1];
const shaderLines = shaderSource.split('\n');
const pad = shaderLines.length.toString().length + 2;
const linesWithLineNumbers = shaderLines.map((line, lineNumber) => rightPad((lineNumber + 1).toString(), pad) + line);
let maxLineLength = 0;
for (let i = 0; i < linesWithLineNumbers.length; i++) {
maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
}
const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
const afterErrorLines = linesWithLineNumbers.slice(lineNumber);
console.log(beforeErrorLines.join('\n'));
console.log(shaderInfoLog.split('\n')[0]);
console.log(`%c ${rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
console.log(afterErrorLines.join('\n'));
}
function createProgram(gl) {
return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.');
}
function linkProgram(gl, program) {
callAndCheck(gl, () => gl.linkProgram(program));
if (env().get('ENGINE_COMPILE_ONLY')) {
return;
}
if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
console.log(gl.getProgramInfoLog(program));
throw new Error('Failed to link vertex and fragment shaders.');
}
}
function validateProgram(gl, program) {
callAndCheck(gl, () => gl.validateProgram(program));
if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
console.log(gl.getProgramInfoLog(program));
throw new Error('Shader program validation failed.');
}
}
function createStaticVertexBuffer(gl, data) {
const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW));
return buffer;
}
function createStaticIndexBuffer(gl, data) {
const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer));
callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW));
return buffer;
}
function createTexture(gl) {
return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.');
}
function validateTextureSize(width, height) {
const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
if ((width <= 0) || (height <= 0)) {
const requested = `[${width}x${height}]`;
throw new Error('Requested texture size ' + requested + ' is invalid.');
}
if ((width > maxTextureSize) || (height > maxTextureSize)) {
const requested = `[${width}x${height}]`;
const max = `[${maxTextureSize}x${maxTextureSize}]`;
throw new Error('Requested texture size ' + requested +
' greater than WebGL maximum on this browser / GPU ' + max + '.');
}
}
function createFramebuffer(gl) {
return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.');
}
function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
const loc = gl.getAttribLocation(program, attribute);
if (loc === -1) {
return false;
}
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes));
callAndCheck(gl, () => gl.enableVertexAttribArray(loc));
return true;
}
function bindTextureUnit(gl, texture, textureUnit) {
validateTextureUnit(gl, textureUnit);
callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
}
function getProgramUniformLocationOrThrow(gl, program, uniformName) {
return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.');
}
function getProgramUniformLocation(gl, program, uniformName) {
return gl.getUniformLocation(program, uniformName);
}
function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit));
callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit));
}
function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0));
}
function unbindColorTextureFromFramebuffer(gl, framebuffer) {
callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0));
}
function validateFramebuffer(gl) {
const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
if (status !== gl.FRAMEBUFFER_COMPLETE) {
throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
}
}
function getFramebufferErrorMessage(gl, status) {
switch (status) {
case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
case gl.FRAMEBUFFER_UNSUPPORTED:
return 'FRAMEBUFFER_UNSUPPORTED';
default:
return `unknown error ${status}`;
}
}
function throwIfNull(gl, returnTOrNull, failureMessage) {
const tOrNull = callAndCheck(gl, () => returnTOrNull());
if (tOrNull == null) {
throw new Error(failureMessage);
}
return tOrNull;
}
function validateTextureUnit(gl, textureUnit) {
const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
const glTextureUnit = textureUnit + gl.TEXTURE0;
if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`;
throw new Error(`textureUnit must be in ${textureUnitRange}.`);
}
}
function getBatchDim(shape, dimsToSkip = 2) {
return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
}
function getRowsCols(shape) {
if (shape.length === 0) {
throw Error('Cannot get rows and columns of an empty shape array.');
}
return [
shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
];
}
function getShapeAs3D(shape) {
let shapeAs3D = [1, 1, 1];
const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
if (!isScalar) {
shapeAs3D =
[getBatchDim(shape), ...getRowsCols(shape)];
}
return shapeAs3D;
}
function getTextureShapeFromLogicalShape(logShape, isPacked = false) {
let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
let maxSizeForNarrowTex = env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');
if (maxSizeForNarrowTex === Infinity &&
env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) {
maxSizeForNarrowTex = maxTexSize / 2;
}
if (isPacked) {
maxTexSize = maxTexSize * 2;
maxSizeForNarrowTex = maxSizeForNarrowTex * 2;
logShape = logShape.map((d, i) => i >= logShape.length - 2 ?
nearestLargerEven(logShape[i]) :
logShape[i]);
if (logShape.length === 1) {
logShape = [2, logShape[0]];
}
}
if (logShape.length !== 2) {
const squeezeResult = squeezeShape(logShape);
logShape = squeezeResult.newShape;
}
let size = sizeFromShape(logShape);
let textureShape = null;
if (logShape.length <= 1 && size <= maxTexSize) {
textureShape = [1, size];
}
else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
logShape[1] <= maxTexSize) {
textureShape = logShape;
}
else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
logShape[2] <= maxTexSize) {
textureShape = [logShape[0] * logShape[1], logShape[2]];
}
else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
logShape[1] * logShape[2] <= maxTexSize) {
textureShape = [logShape[0], logShape[1] * logShape[2]];
}
else if (logShape.length === 4 &&
logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
logShape[3] <= maxTexSize) {
textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]];
}
else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]];
}
const isLongNarrowTex = textureShape != null &&
Math.max(...textureShape) > maxSizeForNarrowTex &&
Math.min(...textureShape) <= (isPacked ? 2 : 1) &&
Math.min(...textureShape) > 0;
if (textureShape == null || isLongNarrowTex) {
if (isPacked) {
const batchDim = getBatchDim(logShape);
let rows = 2, cols = 2;
if (logShape.length) {
[rows, cols] = getRowsCols(logShape);
}
size = batchDim * (rows / 2) * (cols / 2);
textureShape =
sizeToSquarishShape(size).map(d => d * 2);
}
else {
textureShape = sizeToSquarishShape(size);
}
}
return textureShape;
}
function isEven(n) {
return n % 2 === 0;
}
function isReshapeFree(shape1, shape2) {
shape1 = shape1.slice(-2);
shape2 = shape2.slice(-2);
if (arraysEqual(shape1, shape2)) {
return true;
}
if (!shape1.length || !shape2.length) {
return true;
}
if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
shape2[1] === 0) {
return true;
}
if (shape1.length !== shape2.length) {
const shape1Cols = shape1[shape1.length - 1];
const shape2Cols = shape2[shape2.length - 1];
if (shape1Cols === shape2Cols) {
return true;
}
if (isEven(shape1Cols) && isEven(shape2Cols) &&
(shape1[0] === 1 || shape2[0] === 1)) {
return true;
}
}
return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
}
let MAX_TEXTURE_SIZE;
let MAX_TEXTURES_IN_SHADER;
function getWebGLMaxTextureSize(webGLVersion) {
if (MAX_TEXTURE_SIZE == null) {
const gl = getWebGLContext(webGLVersion);
MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
}
return MAX_TEXTURE_SIZE;
}
function getMaxTexturesInShader(webGLVersion) {
if (MAX_TEXTURES_IN_SHADER == null) {
const gl = getWebGLContext(webGLVersion);
MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
}
return Math.min(16, MAX_TEXTURES_IN_SHADER);
}
function getWebGLDisjointQueryTimerVersion(webGLVersion) {
if (webGLVersion === 0) {
return 0;
}
let queryTimerVersion;
const gl = getWebGLContext(webGLVersion);
if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
webGLVersion === 2) {
queryTimerVersion = 2;
}
else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
queryTimerVersion = 1;
}
else {
queryTimerVersion = 0;
}
return queryTimerVersion;
}
function hasExtension(gl, extensionName) {
const ext = gl.getExtension(extensionName);
return ext != null;
}
function isWebGLVersionEnabled(webGLVersion) {
try {
const gl = getWebGLContext(webGLVersion);
if (gl != null) {
return true;
}
}
catch (e) {
console.log('Error when getting WebGL context: ', e);
return false;
}
return false;
}
function isCapableOfRenderingToFloatTexture(webGLVersion) {
if (webGLVersion === 0) {
return false;
}
const gl = getWebGLContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, 'OES_texture_float')) {
return false;
}
}
else {
if (!hasExtension(gl, 'EXT_color_buffer_float')) {
return false;
}
}
const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
return isFrameBufferComplete;
}
function isDownloadFloatTextureEnabled(webGLVersion) {
if (webGLVersion === 0) {
return false;
}
const gl = getWebGLContext(webGLVersion);
if (webGLVersion === 1) {
if (!hasExtension(gl, 'OES_texture_float')) {
return false;
}
if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
return false;
}
}
else {
if (hasExtension(gl, 'EXT_color_buffer_float')) {
return createFloatTextureAndBindToFramebuffer(gl);
}
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
}
return false;
}
const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
return isFrameBufferComplete;
}
function createFloatTextureAndBindToFramebuffer(gl) {
const texConfig = getTextureConfig(gl);
const texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
const width = 1;
const height = 1;
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
const frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isFrameBufferComplete;
}
function createHalfFloatTextureAndBindToFramebuffer(
gl, textureHalfFloatExtension) {
const texConfig = getTextureConfig(gl, textureHalfFloatExtension);
const texture = gl.createTexture();
gl.bindTexture(gl.TEXTURE_2D, texture);
const width = 1;
const height = 1;
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
const frameBuffer = gl.createFramebuffer();
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
gl.bindTexture(gl.TEXTURE_2D, null);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteTexture(texture);
gl.deleteFramebuffer(frameBuffer);
return isFrameBufferComplete;
}
function isWebGLFenceEnabled(webGLVersion) {
if (webGLVersion !== 2) {
return false;
}
const gl = getWebGLContext(webGLVersion);
const isEnabled = gl.fenceSync != null;
return isEnabled;
}
function assertNotComplex$1(tensor, opName) {
if (!Array.isArray(tensor)) {
tensor = [tensor];
}
tensor.forEach(t => {
if (t != null) {
assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` +
'in the WebGL backend.');
}
});
}
const ENV = env();
ENV.registerFlag('HAS_WEBGL', () => ENV.getNumber('WEBGL_VERSION') > 0);
ENV.registerFlag('WEBGL_VERSION', () => {
if (isWebGLVersionEnabled(2)) {
return 2;
}
else if (isWebGLVersionEnabled(1)) {
return 1;
}
return 0;
});
ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false);
ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV.get('WEBGL_VERSION') === 2);
ENV.registerFlag('WEBGL_CPU_FORWARD', () => true);
ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false);
ENV.registerFlag('WEBGL_PACK', () => ENV.getBool('HAS_WEBGL'));
ENV.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_CLIP', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_REDUCE', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_CONV_IM2COL', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', () => ENV.getBool('WEBGL_PACK'));
ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION')));
ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION')));
ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => {
const webGLVersion = ENV.getNumber('WEBGL_VERSION');
if (webGLVersion === 0) {
return 0;
}
return getWebGLDisjointQueryTimerVersion(webGLVersion);
});
ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
!isMobile());
ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION')));
ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => {
return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ?
false :
ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
});
ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION')));
ENV.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION')));
ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => {
const useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
return useUniforms ? 4 : 0;
});
ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => {
return -1;
}, threshold => {
if (!(typeof threshold === 'number')) {
throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' +
`got ${threshold}.`);
}
if (threshold < 0 && threshold !== -1) {
throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` +
`delete) or at least 0, but got ${threshold}.`);
}
});
ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', () => {
return isMobile() ? 1 : -1;
}, threshold => {
if (!(typeof threshold === 'number')) {
throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' +
`${threshold}.`);
}
if (threshold < 0 && threshold !== -1) {
throw new Error(`WEBGL_FLUSH_THRESHOLD must be -1 (indicating never ` +
`manual flush) or at least 0, but got ${threshold}.`);
}
});
ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128);
ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => false);
ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', () => 100000);
ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', () => 128);
ENV.registerFlag('WEBGL_EXP_CONV', () => false);
ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST'));
ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity);
ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false);
ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', () => false);
ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false);
function getGlslDifferences() {
let version;
let attribute;
let varyingVs;
let varyingFs;
let texture2D;
let output;
let defineOutput;
let defineSpecialNaN;
let defineSpecialInf;
let defineRound;
if (env().getNumber('WEBGL_VERSION') === 2) {
version = '#version 300 es';
attribute = 'in';
varyingVs = 'out';
varyingFs = 'in';
texture2D = 'texture';
output = 'outputColor';
defineOutput = 'out vec4 outputColor;';
defineSpecialNaN = env().getBool('WEBGL2_ISNAN_CUSTOM') ? `
bool isnan_custom(float val) {
uint floatToUint = floatBitsToUint(val);
return (floatToUint & 0x7fffffffu) > 0x7f800000u;
}
bvec4 isnan_custom(vec4 val) {
return bvec4(isnan_custom(val.x),
isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));
}
#define isnan(value) isnan_custom(value)
` :
'';
defineSpecialInf = ``;
defineRound = `
#define round(value) newRound(value)
int newRound(float value) {
return int(floor(value + 0.5));
}
ivec4 newRound(vec4 value) {
return ivec4(floor(value + vec4(0.5)));
}
`;
}
else {
version = '';
attribute = 'attribute';
varyingVs = 'varying';
varyingFs = 'varying';
texture2D = 'texture2D';
output = 'gl_FragColor';
defineOutput = '';
defineSpecialNaN = `
#define isnan(value) isnan_custom(value)
bool isnan_custom(float val) {
return (val > 0. || val < 1. || val == 0.) ? false : true;
}
bvec4 isnan_custom(vec4 val) {
return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));
}
`;
defineSpecialInf = `
uniform float INFINITY;
bool isinf(float val) {
return abs(val) == INFINITY;
}
bvec4 isinf(vec4 val) {
return equal(abs(val), vec4(INFINITY));
}
`;
defineRound = `
int round(float value) {
return int(floor(value + 0.5));
}
ivec4 round(vec4 value) {
return ivec4(floor(value + vec4(0.5)));
}
`;
}
return {
version,
attribute,
varyingVs,
varyingFs,
texture2D,
output,
defineOutput,
defineSpecialNaN,
defineSpecialInf,
defineRound
};
}
function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') {
const strides = computeStrides(shape);
return strides
.map((stride, i) => {
const line1 = `int ${coords[i]} = ${index} / ${stride}`;
const line2 = i === strides.length - 1 ?
`int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :
`index -= ${coords[i]} * ${stride}`;
return `${line1}; ${line2};`;
})
.join('');
}
function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index = 'index') {
const strides = computeStrides(shape);
return strides
.map((_, i) => {
const line1 = `int ${coords[i]} = ${index} / outShapeStrides[${i}]`;
const line2 = i === strides.length - 1 ?
`int ${coords[i + 1]} = ${index} - ${coords[i]} * outShapeStrides[${i}]` :
`index -= ${coords[i]} * outShapeStrides[${i}]`;
return `${line1}; ${line2};`;
})
.join('');
}
function symbolicallyComputeStrides(indicesArr, variableName) {
const numCoords = indicesArr.length;
const shape = indicesArr.map(d => `${variableName}[${d}]`);
const strides = new Array(numCoords - 1);
strides[numCoords - 2] = shape[numCoords - 1];
for (let i = numCoords - 3; i >= 0; --i) {
strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
}
return strides;
}
function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index = 'index') {
const indicesArray = coords.map((_, i) => i);
const strides = symbolicallyComputeStrides(indicesArray, variableName);
return strides
.map((_, i) => {
const line1 = `int ${coords[i]} = ${index} / ${strides[i]}`;
const line2 = i === strides.length - 1 ?
`int ${coords[i + 1]} = ${index} - ${coords[i]} * ${strides[i]}` :
`index -= ${coords[i]} * ${strides[i]}`;
return `${line1}; ${line2};`;
})
.join('');
}
function getFlatIndexFrom3D(shape) {
const strides = computeStrides(shape).map(d => d.toString());
return `
int getFlatIndex(ivec3 coords) {
return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
}
`;
}
function getFlatIndexFrom3DOutput() {
return `
int getFlatIndex(ivec3 coords) {
return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;
}
`;
}
const ENCODE_FLOAT_SNIPPET = `
const float FLOAT_MAX = 1.70141184e38;
const float FLOAT_MIN = 1.17549435e-38;
lowp vec4 encode_float(highp float v) {
if (isnan(v)) {
return vec4(255, 255, 255, 255);
}
highp float av = abs(v);
if(av < FLOAT_MIN) {
return vec4(0.0, 0.0, 0.0, 0.0);
} else if(v > FLOAT_MAX) {
return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
} else if(v < -FLOAT_MAX) {
return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
}
highp vec4 c = vec4(0,0,0,0);
highp float e = floor(log2(av));
highp float m = exp2(fract(log2(av))) - 1.0;
c[2] = floor(128.0 * m);
m -= c[2] / 128.0;
c[1] = floor(32768.0 * m);
m -= c[1] / 32768.0;
c[0] = floor(8388608.0 * m);
highp float ebias = e + 127.0;
c[3] = floor(ebias / 2.0);
ebias -= c[3] * 2.0;
c[2] += floor(ebias) * 128.0;
c[3] += 128.0 * step(0.0, -v);
return c / 255.0;
}
`;
const { getBroadcastDims } = backend_util;
function makeShader(inputsInfo, outputShape, program) {
const prefixSnippets = [];
inputsInfo.forEach(x => {
const size = sizeFromShape(x.shapeInfo.logicalShape);
if (x.shapeInfo.isUniform) {
prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`);
}
else {
prefixSnippets.push(`uniform sampler2D ${x.name};`);
prefixSnippets.push(`uniform int offset${x.name};`);
}
if (program.enableShapeUniforms) {
const { uniformShape } = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape);
switch (uniformShape.length) {
case 1:
prefixSnippets.push(`uniform int ${x.name}Shape;`);
break;
case 2:
prefixSnippets.push(`uniform ivec2 ${x.name}Shape;`);
break;
case 3:
prefixSnippets.push(`uniform ivec3 ${x.name}Shape;`);
break;
case 4:
prefixSnippets.push(`uniform ivec4 ${x.name}Shape;`);
break;
}
prefixSnippets.push(`uniform ivec2 ${x.name}TexShape;`);
}
});
if (program.enableShapeUniforms) {
switch (outputShape.logicalShape.length) {
case 1:
prefixSnippets.push(`uniform int outShape;`);
break;
case 2:
prefixSnippets.push(`uniform ivec2 outShape;`);
prefixSnippets.push(`uniform int outShapeStrides;`);
break;
case 3:
prefixSnippets.push(`uniform ivec3 outShape;`);
prefixSnippets.push(`uniform ivec2 outShapeStrides;`);
break;
case 4:
prefixSnippets.push(`uniform ivec4 outShape;`);
prefixSnippets.push(`uniform ivec3 outShapeStrides;`);
break;
}
prefixSnippets.push(`uniform ivec2 outTexShape;`);
}
if (program.customUniforms) {
program.customUniforms.forEach((d) => {
prefixSnippets.push(`uniform ${d.type} ${d.name}${d.arrayIndex ? `[${d.arrayIndex}]` : ''};`);
});
}
const inputPrefixSnippet = prefixSnippets.join('\n');
const inputSamplingSnippet = inputsInfo
.map(x => getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms))
.join('\n');
const outTexShape = outputShape.texShape;
const glsl = getGlslDifferences();
const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
let outputSamplingSnippet;
let floatTextureSetOutputSnippet;
let shaderPrefix = getShaderPrefix(glsl);
if (outputShape.isPacked) {
outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
}
else {
outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
}
if (program.packedInputs) {
shaderPrefix += SHADER_PACKED_PREFIX;
}
const source = [
shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet,
program.userCode
].join('\n');
return source;
}
function getSamplerFromInInfo(inInfo, enableShapeUniforms = false) {
const shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getSamplerScalar(inInfo, enableShapeUniforms);
case 1:
return getSampler1D(inInfo, enableShapeUniforms);
case 2:
return getSampler2D(inInfo, enableShapeUniforms);
case 3:
return getSampler3D(inInfo, enableShapeUniforms);
case 4:
return getSampler4D(inInfo, enableShapeUniforms);
case 5:
return getSampler5D(inInfo);
case 6:
return getSampler6D(inInfo);
default:
throw new Error(`${shape.length}-D input sampling` +
` is not yet supported`);
}
}
function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) {
const shape = inInfo.shapeInfo.logicalShape;
switch (shape.length) {
case 0:
return getPackedSamplerScalar(inInfo);
case 1:
return getPackedSampler1D(inInfo, enableShapeUniforms);
case 2:
return getPackedSampler2D(inInfo, enableShapeUniforms);
case 3:
return getPackedSampler3D(inInfo, enableShapeUniforms);
default:
return getPackedSamplerND(inInfo, enableShapeUniforms);
}
}
function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false, enableShapeUniforms) {
let res = '';
if (usesPackedTextures) {
res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms);
}
else {
res += getSamplerFromInInfo(inInfo, enableShapeUniforms);
}
const inShape = inInfo.shapeInfo.logicalShape;
const outShape = outShapeInfo.logicalShape;
if (inShape.length <= outShape.length) {
if (usesPackedTextures) {
res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
}
else {
res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
}
}
return res;
}
function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms);
case 2:
return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms);
case 3:
return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms);
default:
return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms);
}
}
function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
switch (outShape.length) {
case 0:
return getOutputScalarCoords();
case 1:
return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms);
case 2:
return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms);
case 3:
return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms);
case 4:
return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms);
case 5:
return getOutput5DCoords(outShape, outTexShape);
case 6:
return getOutput6DCoords(outShape, outTexShape);
default:
throw new Error(`${outShape.length}-D output sampling is not yet supported`);
}
}
function getFloatTextureSampleSnippet(glsl) {
return `
float sampleTexture(sampler2D textureSampler, vec2 uv) {
return ${glsl.texture2D}(textureSampler, uv).r;
}
`;
}
function getFloatTextureSetRSnippet(glsl) {
return `
void setOutput(float val) {
${glsl.output} = vec4(val, 0, 0, 0);
}
`;
}
function getFloatTextureSetRGBASnippet(glsl) {
return `
void setOutput(vec4 val) {
${glsl.output} = val;
}
`;
}
function getShaderPrefix(glsl) {
const SHADER_PREFIX = `${glsl.version}
precision highp float;
precision highp int;
precision highp sampler2D;
${glsl.varyingFs} vec2 resultUV;
${glsl.defineOutput}
const vec2 halfCR = vec2(0.5, 0.5);
struct ivec5
{
int x;
int y;
int z;
int w;
int u;
};
struct ivec6
{
int x;
int y;
int z;
int w;
int u;
int v;
};
uniform float NAN;
${glsl.defineSpecialNaN}
${glsl.defineSpecialInf}
${glsl.defineRound}
int imod(int x, int y) {
return x - y * (x / y);
}
int idiv(int a, int b, float sign) {
int res = a / b;
int mod = imod(a, b);
if (sign < 0. && mod != 0) {
res -= 1;
}
return res;
}
#define HASHSCALE1 443.8975
float random(float seed){
vec2 p = resultUV * seed;
vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);
p3 += dot(p3, p3.yzx + 19.19);
return fract((p3.x + p3.y) * p3.z);
}
${SAMPLE_1D_SNIPPET}
${SAMPLE_2D_SNIPPET}
${SAMPLE_3D_SNIPPET}
`;
return SHADER_PREFIX;
}
const SAMPLE_1D_SNIPPET = `
vec2 uvFromFlat(int texNumR, int texNumC, int index) {
int texR = index / texNumC;
int texC = index - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
vec2 packedUVfrom1D(int texNumR, int texNumC, int index) {
int texelIndex = index / 2;
int texR = texelIndex / texNumC;
int texC = texelIndex - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
`;
const SAMPLE_2D_SNIPPET = `
vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,
int texNumC, int row, int col) {
int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);
int texR = texelIndex / texNumC;
int texC = texelIndex - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
`;
const SAMPLE_3D_SNIPPET = `
vec2 packedUVfrom3D(int texNumR, int texNumC,
int texelsInBatch, int texelsInLogicalRow, int b,
int row, int col) {
int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);
int texR = index / texNumC;
int texC = index - texR * texNumC;
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
}
`;
const SHADER_PACKED_PREFIX = `
float getChannel(vec4 frag, vec2 innerDims) {
vec2 modCoord = mod(innerDims, 2.);
return modCoord.x == 0. ?
(modCoord.y == 0. ? frag.r : frag.g) :
(modCoord.y == 0. ? frag.b : frag.a);
}
float getChannel(vec4 frag, int dim) {
float modCoord = mod(float(dim), 2.);
return modCoord == 0. ? frag.r : frag.g;
}
`;
function getOutputScalarCoords() {
return `
int getOutputCoords() {
return 0;
}
`;
}
function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) {
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (packedTexShape[0] === 1) {
if (enableShapeUniforms) {
return `
int getOutputCoords() {
return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));
}
`;
}
return `
int getOutputCoords() {
return 2 * int(resultUV.x * ${packedTexShape[1]}.0);
}
`;
}
if (packedTexShape[1] === 1) {
if (enableShapeUniforms) {
return `
int getOutputCoords() {
return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));
}
`;
}
return `
int getOutputCoords() {
return 2 * int(resultUV.y * ${packedTexShape[0]}.0);
}
`;
}
if (enableShapeUniforms) {
return `
int getOutputCoords() {
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(packedTexShape[0], packedTexShape[1]));
return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);
}
`;
}
return `
int getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);
}
`;
}
function getOutput1DCoords(shape, texShape, enableShapeUniforms) {
if (texShape[0] === 1) {
if (enableShapeUniforms) {
return `
int getOutputCoords() {
return int(resultUV.x * float(outTexShape[1]));
}
`;
}
return `
int getOutputCoords() {
return int(resultUV.x * ${texShape[1]}.0);
}
`;
}
if (texShape[1] === 1) {
if (enableShapeUniforms) {
return `
int getOutputCoords() {
return int(resultUV.y * float(outTexShape[0]));
}
`;
}
return `
int getOutputCoords() {
return int(resultUV.y * ${texShape[0]}.0);
}
`;
}
if (enableShapeUniforms) {
return `
int getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(outTexShape[0], outTexShape[1]));
return resTexRC.x * outTexShape[1] + resTexRC.y;
}
`;
}
return `
int getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
return resTexRC.x * ${texShape[1]} + resTexRC.y;
}
`;
}
function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
return `
ivec3 getOutputCoords() {
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));
int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(packedTexShape[0], packedTexShape[1]));
int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
int b = index / texelsInBatch;
index -= b * texelsInBatch;
int r = 2 * (index / texelsInLogicalRow);
int c = imod(index, texelsInLogicalRow) * 2;
return ivec3(b, r, c);
}
`;
}
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
const texelsInLogicalRow = Math.ceil(shape[2] / 2);
const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
return `
ivec3 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
int b = index / ${texelsInBatch};
index -= b * ${texelsInBatch};
int r = 2 * (index / ${texelsInLogicalRow});
int c = imod(index, ${texelsInLogicalRow}) * 2;
return ivec3(b, r, c);
}
`;
}
function getOutput3DCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape);
return `
ivec3 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(outTexShape[0], outTexShape[1]));
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
${coordsFromIndexSnippet}
return ivec3(r, c, d);
}
`;
}
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
return `
ivec3 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
return ivec3(r, c, d);
}
`;
}
function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
return `
ivec4 getOutputCoords() {
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(packedTexShape[0], packedTexShape[1]));
int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));
int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));
int texelsInBatchN = texelsInBatch * outShape[1];
int b2 = index / texelsInBatchN;
index -= b2 * texelsInBatchN;
int b = index / texelsInBatch;
index -= b * texelsInBatch;
int r = 2 * (index / texelsInLogicalRow);
int c = imod(index, texelsInLogicalRow) * 2;
return ivec4(b2, b, r, c);
}
`;
}
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
let texelsInBatchN = texelsInBatch;
let batches = ``;
let coords = 'b, r, c';
for (let b = 2; b < shape.length - 1; b++) {
texelsInBatchN *= shape[shape.length - b - 1];
batches = `
int b${b} = index / ${texelsInBatchN};
index -= b${b} * ${texelsInBatchN};
` + batches;
coords = `b${b}, ` + coords;
}
return `
ivec${shape.length} getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
${batches}
int b = index / ${texelsInBatch};
index -= b * ${texelsInBatch};
int r = 2 * (index / ${texelsInLogicalRow});
int c = imod(index, ${texelsInLogicalRow}) * 2;
return ivec${shape.length}(${coords});
}
`;
}
function getOutput4DCoords(shape, texShape, enableShapeUniforms) {
if (enableShapeUniforms) {
const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape);
return `
ivec4 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(outTexShape[0], outTexShape[1]));
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
${coordsFromIndexSnippet}
return ivec4(r, c, d, d2);
}
`;
}
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
return `
ivec4 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
return ivec4(r, c, d, d2);
}
`;
}
function getOutput5DCoords(shape, texShape) {
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
return `
ivec5 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},
${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
ivec5 outShape = ivec5(r, c, d, d2, d3);
return outShape;
}
`;
}
function getOutput6DCoords(shape, texShape) {
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
return `
ivec6 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
${coordsFromIndexSnippet}
ivec6 result = ivec6(r, c, d, d2, d3, d4);
return result;
}
`;
}
function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) {
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return `
ivec2 getOutputCoords() {
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));
}
`;
}
return `
ivec2 getOutputCoords() {
return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
}
`;
}
const texelsInLogicalRow = Math.ceil(shape[1] / 2);
if (enableShapeUniforms) {
return `
ivec2 getOutputCoords() {
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(packedTexShape[0], packedTexShape[1]));
int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
int r = 2 * (index / texelsInLogicalRow);
int c = imod(index, texelsInLogicalRow) * 2;
return ivec2(r, c);
}
`;
}
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
int r = 2 * (index / ${texelsInLogicalRow});
int c = imod(index, ${texelsInLogicalRow}) * 2;
return ivec2(r, c);
}
`;
}
function getOutput2DCoords(shape, texShape, enableShapeUniforms) {
if (arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return `
ivec2 getOutputCoords() {
return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));
}
`;
}
return `
ivec2 getOutputCoords() {
return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));
}
`;
}
if (shape[1] === 1) {
if (enableShapeUniforms) {
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(outTexShape[0], outTexShape[1]));
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
return ivec2(index, 0);
}
`;
}
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
return ivec2(index, 0);
}
`;
}
if (shape[0] === 1) {
if (enableShapeUniforms) {
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(outTexShape[0], outTexShape[1]));
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
return ivec2(0, index);
}
`;
}
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
return ivec2(0, index);
}
`;
}
if (enableShapeUniforms) {
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(outTexShape[0], outTexShape[1]));
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
int r = index / outShape[1];
int c = index - r * outShape[1];
return ivec2(r, c);
}
`;
}
return `
ivec2 getOutputCoords() {
ivec2 resTexRC = ivec2(resultUV.yx *
vec2(${texShape[0]}, ${texShape[1]}));
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
int r = index / ${shape[1]};
int c = index - r * ${shape[1]};
return ivec2(r, c);
}
`;
}
function getFlatOffsetUniformName(texName) {
return `offset${texName}`;
}
function getPackedSamplerScalar(inputInfo) {
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const glsl = getGlslDifferences();
return `
vec4 ${funcName}() {
return ${glsl.texture2D}(${texName}, halfCR);
}
`;
}
function getSamplerScalar(inputInfo, enableShapeUniforms) {
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
return `float ${funcName}() {return ${texName};}`;
}
const [texNumR, texNumC] = inputInfo.shapeInfo.texShape;
if (texNumR === 1 && texNumC === 1) {
return `
float ${funcName}() {
return sampleTexture(${texName}, halfCR);
}
`;
}
const offset = getFlatOffsetUniformName(texName);
if (enableShapeUniforms) {
return `
float ${funcName}() {
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], ${offset});
return sampleTexture(${texName}, uv);
}
`;
}
const [tNumR, tNumC] = inputInfo.shapeInfo.texShape;
return `
float ${funcName}() {
vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});
return sampleTexture(${texName}, uv);
}
`;
}
function getPackedSampler1D(inputInfo, enableShapeUniforms) {
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const texShape = inputInfo.shapeInfo.texShape;
const glsl = getGlslDifferences();
if (enableShapeUniforms) {
return `
vec4 ${funcName}(int index) {
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
vec2 uv = packedUVfrom1D(
packedTexShape[0], packedTexShape[1], index);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
return `
vec4 ${funcName}(int index) {
vec2 uv = packedUVfrom1D(
${packedTexShape[0]}, ${packedTexShape[1]}, index);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
function getSampler1D(inputInfo, enableShapeUniforms) {
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
if (inputInfo.shapeInfo.isUniform) {
return `
float ${funcName}(int index) {
${getUniformSampler(inputInfo)}
}
`;
}
const texShape = inputInfo.shapeInfo.texShape;
const tNumR = texShape[0];
const tNumC = texShape[1];
if (tNumC === 1 && tNumR === 1) {
return `
float ${funcName}(int index) {
return sampleTexture(${texName}, halfCR);
}
`;
}
const offset = getFlatOffsetUniformName(texName);
if (tNumC === 1) {
if (enableShapeUniforms) {
return `
float ${funcName}(int index) {
vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / float(${texName}TexShape[0]));
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int index) {
vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
if (tNumR === 1) {
if (enableShapeUniforms) {
return `
float ${funcName}(int index) {
vec2 uv = vec2((float(index + ${offset}) + 0.5) / float(${texName}TexShape[1]), 0.5);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int index) {
vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);
return sampleTexture(${texName}, uv);
}
`;
}
if (enableShapeUniforms) {
return `
float ${funcName}(int index) {
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int index) {
vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});
return sampleTexture(${texName}, uv);
}
`;
}
function getPackedSampler2D(inputInfo, enableShapeUniforms) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const texShape = inputInfo.shapeInfo.texShape;
const texNumR = texShape[0];
const texNumC = texShape[1];
const glsl = getGlslDifferences();
if (texShape != null && arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return `
vec4 ${funcName}(int row, int col) {
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
return `
vec4 ${funcName}(int row, int col) {
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
if (enableShapeUniforms) {
return `
vec4 ${funcName}(int row, int col) {
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
int valuesPerRow = int(ceil(float(${texName}Shape[1]) / 2.0));
vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
const valuesPerRow = Math.ceil(shape[1] / 2);
return `
vec4 ${funcName}(int row, int col) {
vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
function getSampler2D(inputInfo, enableShapeUniforms) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const texShape = inputInfo.shapeInfo.texShape;
if (texShape != null && arraysEqual(shape, texShape)) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col) {
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
return sampleTexture(${texName}, uv);
}
`;
}
const texNumR = texShape[0];
const texNumC = texShape[1];
return `
float ${funcName}(int row, int col) {
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
const { newShape, keptDims } = squeezeShape(shape);
const squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
const params = ['row', 'col'];
return `
${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
float ${funcName}(int row, int col) {
return ${funcName}(${getSqueezedParams(params, keptDims)});
}
`;
}
if (inputInfo.shapeInfo.isUniform) {
return `
float ${funcName}(int row, int col) {
int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));
${getUniformSampler(inputInfo)}
}
`;
}
const texNumR = texShape[0];
const texNumC = texShape[1];
const offset = getFlatOffsetUniformName(texName);
if (texNumC === 1) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col) {
float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
vec2 uv = vec2(0.5, (index + 0.5) / float(${texName}TexShape[0]));
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col) {
float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
if (texNumR === 1) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col) {
float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
vec2 uv = vec2((index + 0.5) / float(${texName}TexShape[1]), 0.5);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col) {
float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);
return sampleTexture(${texName}, uv);
}
`;
}
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col) {
int index = row * ${texName}Shape[1] + col + ${offset};
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col) {
int index = row * ${shape[1]} + col + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
}
`;
}
function getPackedSampler3D(inputInfo, enableShapeUniforms) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const texShape = inputInfo.shapeInfo.texShape;
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
if (shape[0] === 1) {
const squeezedShape = shape.slice(1);
const keptDims = [1, 2];
const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
const params = ['b', 'row', 'col'];
return `
${getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
vec4 ${funcName}(int b, int row, int col) {
return ${funcName}(${getSqueezedParams(params, keptDims)});
}
`;
}
const glsl = getGlslDifferences();
if (enableShapeUniforms) {
return `
vec4 ${funcName}(int b, int row, int col) {
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
int valuesPerRow = int(ceil(float(${texName}Shape[2]) / 2.0));
int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[1]) / 2.0));
vec2 uv = packedUVfrom3D(
packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
const texNumR = packedTexShape[0];
const texNumC = packedTexShape[1];
const valuesPerRow = Math.ceil(shape[2] / 2);
const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
return `
vec4 ${funcName}(int b, int row, int col) {
vec2 uv = packedUVfrom3D(
${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
function getSampler3D(inputInfo, enableShapeUniforms) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const stride0 = shape[1] * shape[2];
const stride1 = shape[2];
const { newShape, keptDims } = squeezeShape(shape);
const squeezedShape = newShape;
if (squeezedShape.length < shape.length) {
const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
const params = ['row', 'col', 'depth'];
return `
${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
float ${funcName}(int row, int col, int depth) {
return ${funcName}(${getSqueezedParams(params, keptDims)});
}
`;
}
if (inputInfo.shapeInfo.isUniform) {
return `
float ${funcName}(int row, int col, int depth) {
int index = round(dot(vec3(row, col, depth),
vec3(${stride0}, ${stride1}, 1)));
${getUniformSampler(inputInfo)}
}
`;
}
const texShape = inputInfo.shapeInfo.texShape;
const texNumR = texShape[0];
const texNumC = texShape[1];
const flatOffset = inputInfo.shapeInfo.flatOffset;
if (texNumC === stride0 && flatOffset == null) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col, int depth) {
int stride1 = ${texName}Shape[2];
float texR = float(row);
float texC = dot(vec2(col, depth), vec2(stride1, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texName}TexShape[1], ${texName}TexShape[0]);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col, int depth) {
float texR = float(row);
float texC = dot(vec2(col, depth), vec2(${stride1}, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
if (texNumC === stride1 && flatOffset == null) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col, int depth) {
float texR = dot(vec2(row, col), vec2(${texName}Shape[1], 1));
float texC = float(depth);
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col, int depth) {
float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));
float texC = float(depth);
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
const offset = getFlatOffsetUniformName(texName);
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col, int depth) {
int stride0 = ${texName}Shape[1] * ${texName}Shape[2];
int stride1 = ${texName}Shape[2];
int index = row * stride0 + col * stride1 + depth + ${offset};
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col, int depth) {
int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
}
`;
}
function getPackedSamplerND(inputInfo, enableShapeUniforms) {
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const glsl = getGlslDifferences();
if (enableShapeUniforms) {
return `
vec4 ${funcName}(int b2, int b, int row, int col) {
int valuesPerRow = int(ceil(float(${texName}Shape[3]) / 2.0));
int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[2]) / 2.0));
int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);
texelsInBatch *= ${texName}Shape[1];
index = b2 * texelsInBatch + index;
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
int texR = index / packedTexShape[1];
int texC = index - texR * packedTexShape[1];
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ${glsl.texture2D}(${texName}, uv);
}
`;
}
const shape = inputInfo.shapeInfo.logicalShape;
const rank = shape.length;
const texShape = inputInfo.shapeInfo.texShape;
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
const texNumR = packedTexShape[0];
const texNumC = packedTexShape[1];
const valuesPerRow = Math.ceil(shape[rank - 1] / 2);
let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
let params = `int b, int row, int col`;
let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;
for (let b = 2; b < rank - 1; b++) {
params = `int b${b}, ` + params;
texelsInBatch *= shape[rank - b - 1];
index = `b${b} * ${texelsInBatch} + ` + index;
}
return `
vec4 ${funcName}(${params}) {
int index = ${index};
int texR = index / ${texNumC};
int texC = index - texR * ${texNumC};
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
return ${glsl.texture2D}(${texName}, uv);
}
`;
}
function getSampler4D(inputInfo, enableShapeUniforms) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const stride2 = shape[3];
const stride1 = shape[2] * stride2;
const stride0 = shape[1] * stride1;
const { newShape, keptDims } = squeezeShape(shape);
if (newShape.length < shape.length) {
const newInputInfo = squeezeInputInfo(inputInfo, newShape);
const params = ['row', 'col', 'depth', 'depth2'];
return `
${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
float ${funcName}(int row, int col, int depth, int depth2) {
return ${funcName}(${getSqueezedParams(params, keptDims)});
}
`;
}
if (inputInfo.shapeInfo.isUniform) {
return `
float ${funcName}(int row, int col, int depth, int depth2) {
int index = round(dot(vec4(row, col, depth, depth2),
vec4(${stride0}, ${stride1}, ${stride2}, 1)));
${getUniformSampler(inputInfo)}
}
`;
}
const flatOffset = inputInfo.shapeInfo.flatOffset;
const texShape = inputInfo.shapeInfo.texShape;
const texNumR = texShape[0];
const texNumC = texShape[1];
const stride2Str = `int stride2 = ${texName}Shape[3];`;
const stride1Str = `int stride1 = ${texName}Shape[2] * stride2;`;
const stride0Str = `int stride0 = ${texName}Shape[1] * stride1;`;
if (texNumC === stride0 && flatOffset == null) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col, int depth, int depth2) {
${stride2Str}
${stride1Str}
float texR = float(row);
float texC =
dot(vec3(col, depth, depth2),
vec3(stride1, stride2, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texName}TexShape[1], ${texName}TexShape[0]);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col, int depth, int depth2) {
float texR = float(row);
float texC =
dot(vec3(col, depth, depth2),
vec3(${stride1}, ${stride2}, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
if (texNumC === stride2 && flatOffset == null) {
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col, int depth, int depth2) {
float texR = dot(vec3(row, col, depth),
vec3(${texName}Shape[1] * ${texName}Shape[2], ${texName}Shape[2], 1));
float texC = float(depth2);
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texName}TexShape[1], ${texName}TexShape[0]);
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col, int depth, int depth2) {
float texR = dot(vec3(row, col, depth),
vec3(${shape[1] * shape[2]}, ${shape[2]}, 1));
float texC = float(depth2);
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
const offset = getFlatOffsetUniformName(texName);
if (enableShapeUniforms) {
return `
float ${funcName}(int row, int col, int depth, int depth2) {
${stride2Str}
${stride1Str}
${stride0Str}
int index = row * stride0 + col * stride1 +
depth * stride2 + depth2;
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
return sampleTexture(${texName}, uv);
}
`;
}
return `
float ${funcName}(int row, int col, int depth, int depth2) {
int index = row * ${stride0} + col * ${stride1} +
depth * ${stride2} + depth2;
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});
return sampleTexture(${texName}, uv);
}
`;
}
function getSampler5D(inputInfo) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const stride3 = shape[4];
const stride2 = shape[3] * stride3;
const stride1 = shape[2] * stride2;
const stride0 = shape[1] * stride1;
const { newShape, keptDims } = squeezeShape(shape);
if (newShape.length < shape.length) {
const newInputInfo = squeezeInputInfo(inputInfo, newShape);
const params = ['row', 'col', 'depth', 'depth2', 'depth3'];
return `
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
return ${funcName}(${getSqueezedParams(params, keptDims)});
}
`;
}
if (inputInfo.shapeInfo.isUniform) {
return `
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
float index = dot(
vec4(row, col, depth, depth2),
vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
depth3;
${getUniformSampler(inputInfo)}
}
`;
}
const flatOffset = inputInfo.shapeInfo.flatOffset;
const texShape = inputInfo.shapeInfo.texShape;
const texNumR = texShape[0];
const texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
return `
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
int texR = row;
float texC = dot(vec4(col, depth, depth2, depth3),
vec4(${stride1}, ${stride2}, ${stride3}, 1));
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
if (texNumC === stride3 && flatOffset == null) {
return `
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
float texR = dot(
vec4(row, col, depth, depth2),
vec4(${shape[1] * shape[2] * shape[3]},
${shape[2] * shape[3]}, ${shape[3]}, 1));
int texC = depth3;
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
const offset = getFlatOffsetUniformName(texName);
return `
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
depth2 * ${stride3} + depth3 + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
}
`;
}
function getSampler6D(inputInfo) {
const shape = inputInfo.shapeInfo.logicalShape;
const texName = inputInfo.name;
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
const { newShape, keptDims } = squeezeShape(shape);
if (newShape.length < shape.length) {
const newInputInfo = squeezeInputInfo(inputInfo, newShape);
const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
return `
${getSamplerFromInInfo(newInputInfo)}
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
return ${funcName}(${getSqueezedParams(params, keptDims)});
}
`;
}
const stride4 = shape[5];
const stride3 = shape[4] * stride4;
const stride2 = shape[3] * stride3;
const stride1 = shape[2] * stride2;
const stride0 = shape[1] * stride1;
if (inputInfo.shapeInfo.isUniform) {
return `
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
int index = round(dot(
vec4(row, col, depth, depth2),
vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
dot(
vec2(depth3, depth4),
vec2(${stride4}, 1)));
${getUniformSampler(inputInfo)}
}
`;
}
const flatOffset = inputInfo.shapeInfo.flatOffset;
const texShape = inputInfo.shapeInfo.texShape;
const texNumR = texShape[0];
const texNumC = texShape[1];
if (texNumC === stride0 && flatOffset == null) {
return `
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
int texR = row;
float texC = dot(vec4(col, depth, depth2, depth3),
vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +
float(depth4);
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
if (texNumC === stride4 && flatOffset == null) {
return `
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
float texR = dot(vec4(row, col, depth, depth2),
vec4(${shape[1] * shape[2] * shape[3] * shape[4]},
${shape[2] * shape[3] * shape[4]},
${shape[3] * shape[4]},
${shape[4]})) + float(depth3);
int texC = depth4;
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${texNumC}.0, ${texNumR}.0);
return sampleTexture(${texName}, uv);
}
`;
}
const offset = getFlatOffsetUniformName(texName);
return `
float ${funcName}(int row, int col, int depth,
int depth2, int depth3, int depth4) {
int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
return sampleTexture(${texName}, uv);
}
`;
}
function getUniformSampler(inputInfo) {
const texName = inputInfo.name;
const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
if (inSize < 2) {
return `return ${texName};`;
}
return `
for (int i = 0; i < ${inSize}; i++) {
if (i == index) {
return ${texName}[i];
}
}
`;
}
function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
const texName = inputInfo.name;
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
const inRank = inputInfo.shapeInfo.logicalShape.length;
const outRank = outShapeInfo.logicalShape.length;
const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
const type = getCoordsDataType(outRank);
const rankDiff = outRank - inRank;
let coordsSnippet;
const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
if (inRank === 0) {
coordsSnippet = '';
}
else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
}
else {
coordsSnippet =
broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
.join('\n');
}
let unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
}
else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
.map((s, i) => `coords.${fields[i + rankDiff]}`)
.join(', ');
}
let output = `return outputValue;`;
const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
const isInputScalar = inSize === 1;
const outSize = sizeFromShape(outShapeInfo.logicalShape);
const isOutputScalar = outSize === 1;
if (inRank === 1 && !isInputScalar && !isOutputScalar) {
output = `
return vec4(outputValue.xy, outputValue.xy);
`;
}
else if (isInputScalar && !isOutputScalar) {
if (outRank === 1) {
output = `
return vec4(outputValue.x, outputValue.x, 0., 0.);
`;
}
else {
output = `
return vec4(outputValue.x);
`;
}
}
else if (broadcastDims.length) {
const rows = inRank - 2;
const cols = inRank - 1;
if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
output = `return vec4(outputValue.x);`;
}
else if (broadcastDims.indexOf(rows) > -1) {
output = `return vec4(outputValue.x, outputValue.y, ` +
`outputValue.x, outputValue.y);`;
}
else if (broadcastDims.indexOf(cols) > -1) {
output = `return vec4(outputValue.xx, outputValue.zz);`;
}
}
return `
vec4 ${funcName}() {
${type} coords = getOutputCoords();
${coordsSnippet}
vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});
${output}
}
`;
}
function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
const texName = inputInfo.name;
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
const outTexShape = outShapeInfo.texShape;
const inTexShape = inputInfo.shapeInfo.texShape;
const inRank = inputInfo.shapeInfo.logicalShape.length;
const outRank = outShapeInfo.logicalShape.length;
if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
inputInfo.shapeInfo.flatOffset == null &&
arraysEqual(inTexShape, outTexShape)) {
return `
float ${funcName}() {
return sampleTexture(${texName}, resultUV);
}
`;
}
const type = getCoordsDataType(outRank);
const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
const rankDiff = outRank - inRank;
let coordsSnippet;
const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
if (inRank === 0) {
coordsSnippet = '';
}
else if (outRank < 2 && broadcastDims.length >= 1) {
coordsSnippet = 'coords = 0;';
}
else {
coordsSnippet =
broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
.join('\n');
}
let unpackedCoordsSnippet = '';
if (outRank < 2 && inRank > 0) {
unpackedCoordsSnippet = 'coords';
}
else {
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
.map((s, i) => `coords.${fields[i + rankDiff]}`)
.join(', ');
}
return `
float ${funcName}() {
${type} coords = getOutputCoords();
${coordsSnippet}
return get${texFuncSnippet}(${unpackedCoordsSnippet});
}
`;
}
function getCoordsDataType(rank) {
if (rank <= 1) {
return 'int';
}
else if (rank === 2) {
return 'ivec2';
}
else if (rank === 3) {
return 'ivec3';
}
else if (rank === 4) {
return 'ivec4';
}
else if (rank === 5) {
return 'ivec5';
}
else if (rank === 6) {
return 'ivec6';
}
else {
throw Error(`GPU for rank ${rank} is not yet supported`);
}
}
function getUniformInfoFromShape(isPacked, shape, texShape) {
const { newShape, keptDims } = squeezeShape(shape);
const rank = shape.length;
const useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1;
const squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape;
const useSqueezeShape = (!isPacked && rank > 1 && !arraysEqual(shape, texShape) &&
newShape.length < rank) ||
useSqueezePackedShape;
const uniformShape = useSqueezeShape ? squeezeShape$1 : shape;
return { useSqueezeShape, uniformShape, keptDims };
}
function squeezeInputInfo(inInfo, squeezedShape) {
const newInputInfo = JSON.parse(JSON.stringify(inInfo));
newInputInfo.shapeInfo.logicalShape = squeezedShape;
return newInputInfo;
}
function getSqueezedParams(params, keptDims) {
return keptDims.map(d => params[d]).join(', ');
}
function compileProgram(gpgpu, program, inputs, output) {
const inputInfos = inputs.map((input, i) => {
const shapeInfo = {
logicalShape: input.shape,
texShape: input.isUniform ? null : input.texData.texShape,
isUniform: input.isUniform,
isPacked: input.isUniform ? false : input.texData.isPacked,
flatOffset: null
};
if (input.texData != null && input.texData.slice != null &&
input.texData.slice.flatOffset > 0) {
shapeInfo.flatOffset = input.texData.slice.flatOffset;
}
return { name: program.variableNames[i], shapeInfo };
});
const inShapeInfos = inputInfos.map(x => x.shapeInfo);
const outShapeInfo = {
logicalShape: output.shape,
texShape: output.texData.texShape,
isUniform: false,
isPacked: output.texData.isPacked,
flatOffset: null
};
const source = makeShader(inputInfos, outShapeInfo, program);
const fragmentShader = createFragmentShader(gpgpu.gl, source);
const webGLProgram = gpgpu.createProgram(fragmentShader);
if (!env().get('ENGINE_COMPILE_ONLY')) {
gpgpu.buildVao(webGLProgram);
return Object.assign({ program,
fragmentShader,
source,
webGLProgram,
inShapeInfos,
outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram));
}
else {
return {
program,
fragmentShader,
source,
webGLProgram,
inShapeInfos,
outShapeInfo,
variablesLocations: null,
customUniformLocations: null,
infLoc: null,
nanLoc: null,
outShapeLocation: null,
outShapeStridesLocation: null,
outTexShapeLocation: null
};
}
}
function getUniformLocations(gpgpu, program, webGLProgram) {
const variablesLocations = [];
const customUniformLocations = [];
let outShapeLocation;
let outTexShapeLocation;
let outShapeStridesLocation;
let infLoc = null;
let nanLoc = null;
nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
if (env().getNumber('WEBGL_VERSION') === 1) {
infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
}
const shouldThrow = false;
for (const varName of program.variableNames) {
const varLocs = {
name: varName,
uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow),
offset: gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow),
};
if (program.enableShapeUniforms) {
varLocs.shape = gpgpu.getUniformLocation(webGLProgram, `${varName}Shape`, shouldThrow);
varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, `${varName}TexShape`, shouldThrow);
}
variablesLocations.push(varLocs);
}
if (program.enableShapeUniforms) {
outShapeLocation =
gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
outShapeStridesLocation =
gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
outTexShapeLocation =
gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
}
if (program.customUniforms) {
for (const d of program.customUniforms) {
customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow));
}
}
return {
variablesLocations,
customUniformLocations,
infLoc,
nanLoc,
outShapeLocation,
outShapeStridesLocation,
outTexShapeLocation
};
}
function validateBinaryAndProgram(shapeInfos, inputs) {
if (shapeInfos.length !== inputs.length) {
throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` +
`was executed with ${inputs.length} inputs`);
}
shapeInfos.forEach((s, i) => {
const shapeA = s.logicalShape;
const input = inputs[i];
const shapeB = input.shape;
if (!arraysEqual(shapeA, shapeB)) {
throw Error(`Binary was compiled with different shapes than ` +
`the current args. Shapes ${shapeA} and ${shapeB} must match`);
}
if (s.isUniform && input.isUniform) {
return;
}
const texShapeA = s.texShape;
const texShapeB = input.isUniform ? null : input.texData.texShape;
if (!arraysEqual(texShapeA, texShapeB)) {
throw Error(`Binary was compiled with different texture shapes than the` +
` current args. Shape ${texShapeA} and ${texShapeB} must match`);
}
});
}
function runProgram(gpgpu, binary, inputs, output, customUniformValues) {
if (!binary.program.enableShapeUniforms) {
validateBinaryAndProgram(binary.inShapeInfos, inputs);
validateBinaryAndProgram([binary.outShapeInfo], [output]);
}
const outTex = output.texData.texture;
const outTexShape = output.texData.texShape;
if (output.texData.isPacked) {
gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
}
else {
gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
}
gpgpu.setProgram(binary.webGLProgram);
gpgpu.bindVertexArray(binary.webGLProgram.vao);
if (env().getNumber('WEBGL_VERSION') === 1) {
if (binary.infLoc !== null) {
gpgpu.gl.uniform1f(binary.infLoc, Infinity);
}
}
if (binary.nanLoc !== null) {
gpgpu.gl.uniform1f(binary.nanLoc, NaN);
}
for (let i = 0; i < inputs.length; ++i) {
const input = inputs[i];
const { uniform: varLoc, offset: varOffsetLoc, shape: varShapeLoc, texShape: varTexShapeLoc, } = binary.variablesLocations[i];
if (varShapeLoc) {
const { uniformShape } = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape);
switch (uniformShape.length) {
case 1:
gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 2:
gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 3:
gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
break;
case 4:
gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
break;
}
}
if (varTexShapeLoc) {
gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
}
if (varLoc == null) {
continue;
}
if (input.isUniform) {
if (sizeFromShape(input.shape) < 2) {
gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
}
else {
let vals = input.uniformValues;
if (!(vals instanceof Float32Array)) {
vals = new Float32Array(vals);
}
gpgpu.gl.uniform1fv(varLoc, vals);
}
continue;
}
if (input.texData.slice != null && varOffsetLoc != null) {
gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
}
gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i);
}
const outShapeLoc = binary.outShapeLocation;
if (outShapeLoc) {
switch (output.shape.length) {
case 1:
gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
break;
case 2:
gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
break;
case 3:
gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
break;
case 4:
gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
break;
}
}
if (binary.outShapeStridesLocation) {
const strides = computeStrides(output.shape);
switch (output.shape.length) {
case 2:
gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides));
break;
case 3:
gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides));
break;
case 4:
gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides));
break;
}
}
if (binary.outTexShapeLocation) {
gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]);
}
if (binary.program.customUniforms && customUniformValues) {
for (let i = 0; i < binary.program.customUniforms.length; ++i) {
const d = binary.program.customUniforms[i];
const customLoc = binary.customUniformLocations[i];
const customValue = customUniformValues[i];
if (d.type === 'float') {
gpgpu.gl.uniform1fv(customLoc, customValue);
}
else if (d.type === 'vec2') {
gpgpu.gl.uniform2fv(customLoc, customValue);
}
else if (d.type === 'vec3') {
gpgpu.gl.uniform3fv(customLoc, customValue);
}
else if (d.type === 'vec4') {
gpgpu.gl.uniform4fv(customLoc, customValue);
}
else if (d.type === 'int') {
gpgpu.gl.uniform1iv(customLoc, customValue);
}
else if (d.type === 'ivec2') {
gpgpu.gl.uniform2iv(customLoc, customValue);
}
else if (d.type === 'ivec3') {
gpgpu.gl.uniform3iv(customLoc, customValue);
}
else if (d.type === 'ivec4') {
gpgpu.gl.uniform4iv(customLoc, customValue);
}
else {
throw Error(`uniform type ${d.type} is not supported yet.`);
}
}
}
gpgpu.executeProgram();
}
function makeShaderKey(program, inputs, output) {
let keyInputs = '';
inputs.concat(output).forEach(x => {
const hasOffset = x.texData != null && x.texData.slice != null &&
x.texData.slice.flatOffset > 0;
if (program.enableShapeUniforms && !x.isUniform) {
const xTexShape = x.texData.texShape;
const { useSqueezeShape, uniformShape, keptDims } = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape);
let rank1 = '', rank2 = '', rank34 = '';
if (uniformShape.length === 1 && program.packedInputs) {
const packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)];
rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`;
}
else if (uniformShape.length === 2 && !program.packedInputs) {
rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`;
}
else if (uniformShape.length > 2 && !program.packedInputs) {
const strides = computeStrides(uniformShape);
rank34 = `${strides[0] === xTexShape[1]}_${strides[strides.length - 1] === xTexShape[1]}`;
}
const xRank = x.shape.length;
const isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape);
const isScalar = sizeFromShape(x.shape) === 1;
const broadcastDims = getBroadcastDims$1(x.shape, output.shape);
const isInOutTexShapeEqual = !program.packedInputs &&
xRank === output.shape.length &&
arraysEqual(xTexShape, output.texData.texShape);
const isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ?
'' :
`${xTexShape[0] > 1}_${xTexShape[1] > 1}`;
keyInputs += `${xRank}_${isInOutTexShapeEqual}_${useSqueezeShape ? keptDims : ''}_${uniformShape.length}_${isScalar}_${broadcastDims}_${isLogicalShapTexShapeEqual}_${rank1}_${rank2}_${rank34}_${isTexShapeGreaterThanOne}_${hasOffset}`;
}
else {
const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
}
});
const keyUserCode = program.userCode;
let key = program.constructor.name;
key += '_' + keyInputs + '_' + keyUserCode +
`${env().getNumber('WEBGL_VERSION')}`;
return key;
}
function useShapeUniforms(rank) {
return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
}
class DecodeMatrixProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
this.outPackingScheme = PackingScheme.DENSE;
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
const glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = `
ivec3 outCoordsFromFlatIndex(int index) {
${this.enableShapeUniforms ?
getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
return ivec3(r, c, d);
}
void main() {
ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
vec4 result = vec4(0.);
for (int i=0; i<4; i++) {
int flatIndex = index + i;
ivec3 rc = outCoordsFromFlatIndex(flatIndex);
result[i] = getA(rc.x, rc.y, rc.z);
}
${glsl.output} = result;
}
`;
}
}
class DecodeMatrixPackedProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outPackingScheme = PackingScheme.DENSE;
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
const glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = `
ivec3 outCoordsFromFlatIndex(int index) {
${this.enableShapeUniforms ?
getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
return ivec3(r, c, d);
}
void main() {
ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
vec4 result = vec4(0.);
for (int i=0; i<4; i++) {
int flatIndex = index + i;
ivec3 rc = outCoordsFromFlatIndex(flatIndex);
result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
}
${glsl.output} = result;
}
`;
}
}
class EncodeFloatProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.outTexUsage = TextureUsage.DOWNLOAD;
const glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = `
${ENCODE_FLOAT_SNIPPET}
void main() {
float x = getAAtOutCoords();
${glsl.output} = encode_float(x);
}
`;
}
}
class EncodeFloatPackedProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = false;
this.outTexUsage = TextureUsage.DOWNLOAD;
const glsl = getGlslDifferences();
this.outputShape = outputShape;
this.userCode = `
${ENCODE_FLOAT_SNIPPET}
void main() {
ivec3 coords = getOutputCoords();
float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
${glsl.output} = encode_float(x);
}
`;
}
}
const CHANNEL_CHAR_TO_INDEX_MAP = {
'R': 0,
'G': 1,
'B': 2,
'A': 3
};
class EncodeMatrixProgram {
constructor(outputShape, inputIsUnsignedByte = false, usedChannels = 'RGBA') {
this.variableNames = ['A'];
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
const glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
let output = `result`;
if (inputIsUnsignedByte) {
output = `floor(result * 255. + 0.5)`;
}
let mainLoop = '';
for (let usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) {
const curChannel = usedChannels[usedChannelIndex];
mainLoop += `
if(offset == ${usedChannelIndex}) {
result = values[${CHANNEL_CHAR_TO_INDEX_MAP[curChannel]}];
}`;
}
this.userCode = `
${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
getFlatIndexFrom3D(outputShape)}
void main() {
ivec3 coords = getOutputCoords();
int flatIndex = getFlatIndex(coords);
float result = 0.;
int offset = imod(flatIndex, ${usedChannels.length});
flatIndex = idiv(flatIndex, ${usedChannels.length}, 1.);
int r = flatIndex / texShape[1];
if (r < texShape[0]) {
int c = imod(flatIndex, texShape[1]);
vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
vec4 values = ${glsl.texture2D}(A, uv);
${mainLoop}
}
${glsl.output} = vec4(${output}, 0., 0., 0.);
}
`;
}
}
class EncodeMatrixPackedProgram {
constructor(outputShape, inputIsUnsignedByte = false) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
const glsl = getGlslDifferences();
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
let mainLoop = '';
let output = 'result';
if (inputIsUnsignedByte) {
output = 'floor(result * 255. + 0.5)';
}
for (let row = 0; row <= 1; row++) {
for (let col = 0; col <= 1; col++) {
const channel = row * 2 + col;
mainLoop += `
localCoords = coords;
if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) {
localCoords[2] += ${col};
if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) {
localCoords[1] += ${row};
flatIndex = getFlatIndex(localCoords);
offset = imod(flatIndex, 4);
flatIndex = idiv(flatIndex, 4, 1.);
int r = flatIndex / texShape[1];
int c = imod(flatIndex, texShape[1]);
vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
values = ${glsl.texture2D}(A, uv);
if (offset == 0) {
result[${channel}] = values[0];
} else if (offset == 1) {
result[${channel}] = values[1];
} else if (offset == 2) {
result[${channel}] = values[2];
} else {
result[${channel}] = values[3];
}
}
}
`;
}
}
this.userCode = `
${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
getFlatIndexFrom3D(outputShape)}
void main() {
ivec3 coords = getOutputCoords();
vec4 result = vec4(0.);
int flatIndex, r, c, offset;
ivec3 localCoords;
vec2 uv;
vec4 values;
${mainLoop}
${glsl.output} = ${output};
}
`;
}
}
function createVertexShader(gl) {
const glsl = getGlslDifferences();
const vertexShaderSource = `${glsl.version}
precision highp float;
${glsl.attribute} vec3 clipSpacePos;
${glsl.attribute} vec2 uv;
${glsl.varyingVs} vec2 resultUV;
void main() {
gl_Position = vec4(clipSpacePos, 1);
resultUV = uv;
}`;
return createVertexShader$1(gl, vertexShaderSource);
}
function createVertexBuffer(gl) {
const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
return createStaticVertexBuffer(gl, vertexArray);
}
function createIndexBuffer(gl) {
const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
return createStaticIndexBuffer(gl, triangleVertexIndices);
}
function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
validateTextureSize(width, height);
const texture = createTexture(gl);
const tex2d = gl.TEXTURE_2D;
callAndCheck(gl, () => gl.bindTexture(tex2d, texture));
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE));
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE));
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST));
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST));
if (env().getNumber('WEBGL_VERSION') === 1) {
callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null));
}
else {
callAndCheck(gl, () => gl
.texStorage2D(tex2d, 1, internalFormat, width, height));
}
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
return { texture, texShape: [height, width] };
}
function getInternalFormatForFloat32MatrixTexture(textureConfig) {
return textureConfig.internalFormatFloat;
}
function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
}
function getInternalFormatForFloat16MatrixTexture(textureConfig) {
return textureConfig.internalFormatHalfFloat;
}
function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
}
function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
return textureConfig.downloadTextureFormat;
}
function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
}
function getInternalFormatForPackedMatrixTexture(textureConfig) {
return textureConfig.internalFormatPackedFloat;
}
function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
}
function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
return textureConfig.internalFormatPackedHalfFloat;
}
function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
}
function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
const posOffset = 0;
const uvOffset = 3 * 4;
const stride = (3 * 4) + (2 * 4);
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer));
const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
return success &&
bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
}
function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
let dataForUpload, texelDataType, internalFormat;
if (data instanceof Uint8Array) {
dataForUpload = new Uint8Array(width * height * 4);
texelDataType = gl.UNSIGNED_BYTE;
internalFormat = gl.RGBA;
}
else {
dataForUpload = new Float32Array(width * height * 4);
texelDataType = gl.FLOAT;
internalFormat = textureConfig.internalFormatPackedFloat;
}
dataForUpload.set(data);
if (env().getNumber('WEBGL_VERSION') === 2) {
callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload));
}
else {
callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload));
}
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
}
function uploadPixelDataToTexture(gl, texture, pixels) {
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
if (pixels.data instanceof Uint8Array) {
if (env().getNumber('WEBGL_VERSION') === 2) {
callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
}
else {
callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
}
}
else {
if (env().getNumber('WEBGL_VERSION') === 2) {
callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
}
else {
callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
}
}
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
}
function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
const buffer = gl2.createBuffer();
callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer));
const bytesPerFloat = 4;
const valuesPerTexel = 4;
const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ));
callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0));
callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null));
return buffer;
}
function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
const gl2 = gl;
const downloadTarget = new Float32Array(size);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
return downloadTarget;
}
function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
const numChannels = 4;
const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget));
return new Float32Array(downloadTarget.buffer);
}
function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
const gl2 = gl;
const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
return downloadTarget;
}
function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
const packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA));
return packedRGBA;
}
class GPGPUContext {
constructor(gl) {
this.outputTexture = null;
this.program = null;
this.disposed = false;
this.itemsToPoll = [];
const glVersion = env().getNumber('WEBGL_VERSION');
if (gl != null) {
this.gl = gl;
setWebGLContext(glVersion, gl);
}
else {
this.gl = getWebGLContext(glVersion);
}
gl = this.gl;
if (env().getNumber('WEBGL_VERSION') === 2) {
const gl2 = gl;
this.createVertexArray = () => {
return callAndCheck(gl2, () => gl2.createVertexArray());
};
this.bindVertexArray = (vao) => {
return callAndCheck(gl2, () => gl2.bindVertexArray(vao));
};
this.deleteVertexArray = (vao) => {
return callAndCheck(gl2, () => gl2.deleteVertexArray(vao));
};
this.getVertexArray = () => {
return callAndCheck(gl2, () => gl2.getParameter(gl2.VERTEX_ARRAY_BINDING));
};
}
else if (gl != null) {
const ext = gl.getExtension('OES_vertex_array_object');
if (ext == null) {
throw new Error('All WebGL1 implementations are expected to offer' +
' OES_vertex_array_object.');
}
this.createVertexArray = () => {
return callAndCheck(gl, () => ext.createVertexArrayOES());
};
this.bindVertexArray = (vao) => {
return callAndCheck(gl, () => ext.bindVertexArrayOES(vao));
};
this.deleteVertexArray = (vao) => {
return callAndCheck(gl, () => ext.deleteVertexArrayOES(vao));
};
this.getVertexArray = () => {
return callAndCheck(gl, () => gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES));
};
}
let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
this.parallelCompilationExtension =
this.gl.getExtension('KHR_parallel_shader_compile');
if (env().getNumber('WEBGL_VERSION') === 1) {
const TEXTURE_FLOAT = 'OES_texture_float';
const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
this.textureFloatExtension =
getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
this.textureHalfFloatExtension =
getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
}
else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
throw new Error('GL context does not support half float textures, yet the ' +
'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
}
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
this.colorBufferHalfFloatExtension =
getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
}
else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
throw new Error('GL context does not support color renderable half floats, yet ' +
'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
}
}
else {
COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
this.colorBufferFloatExtension =
this.gl.getExtension(COLOR_BUFFER_FLOAT);
}
else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
this.colorBufferHalfFloatExtension =
this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
}
else {
throw new Error('GL context does not support color renderable floats');
}
}
this.vertexBuffer = createVertexBuffer(this.gl);
this.indexBuffer = createIndexBuffer(this.gl);
this.framebuffer = createFramebuffer(this.gl);
this.textureConfig =
getTextureConfig(this.gl, this.textureHalfFloatExtension);
}
get debug() {
return env().getBool('DEBUG');
}
dispose() {
if (this.disposed) {
return;
}
if (this.program != null) {
console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
' This is probably a resource leak, delete the program with ' +
'GPGPUContext.deleteProgram before disposing.');
}
if (this.outputTexture != null) {
console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
'texture. This is probably a resource leak, delete the output ' +
'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
'disposing.');
}
const gl = this.gl;
callAndCheck(gl, () => gl.finish());
callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer));
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null));
callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null));
callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer));
this.disposed = true;
}
createFloat32MatrixTexture(rows, columns) {
this.throwIfDisposed();
return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
}
createFloat16MatrixTexture(rows, columns) {
this.throwIfDisposed();
return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
}
createUnsignedBytesMatrixTexture(rows, columns) {
this.throwIfDisposed();
return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
}
uploadPixelDataToTexture(texture, pixels) {
this.throwIfDisposed();
uploadPixelDataToTexture(this.gl, texture, pixels);
}
uploadDenseMatrixToTexture(texture, width, height, data) {
this.throwIfDisposed();
uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
}
createFloat16PackedMatrixTexture(rows, columns) {
this.throwIfDisposed();
return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
}
createPackedMatrixTexture(rows, columns) {
this.throwIfDisposed();
return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
}
deleteMatrixTexture(texture) {
this.throwIfDisposed();
if (this.outputTexture === texture) {
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
this.outputTexture = null;
}
callAndCheck(this.gl, () => this.gl.deleteTexture(texture));
}
downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) {
return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig));
}
downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) {
return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols);
}
downloadFloat32MatrixFromBuffer(buffer, size) {
return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
}
createBufferFromTexture(texture, rows, columns) {
this.bindTextureToFrameBuffer(texture);
const result = createBufferFromOutputTexture(this.gl, rows, columns);
this.unbindTextureToFrameBuffer();
return result;
}
createAndWaitForFence() {
const fenceContext = this.createFence(this.gl);
return this.pollFence(fenceContext);
}
createFence(gl) {
let query;
let isFencePassed;
if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
const gl2 = gl;
const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
gl.flush();
isFencePassed = () => {
const status = gl2.clientWaitSync(sync, 0, 0);
return status === gl2.ALREADY_SIGNALED ||
status === gl2.CONDITION_SATISFIED;
};
query = sync;
}
else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
query = this.beginQuery();
this.endQuery();
isFencePassed = () => this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
}
else {
isFencePassed = () => true;
}
return { query, isFencePassed };
}
downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols));
}
createProgram(fragmentShader) {
this.throwIfDisposed();
const gl = this.gl;
if (this.vertexShader == null) {
this.vertexShader = createVertexShader(gl);
}
const program = createProgram(gl);
callAndCheck(gl, () => gl.attachShader(program, this.vertexShader));
callAndCheck(gl, () => gl.attachShader(program, fragmentShader));
linkProgram(gl, program);
const program2 = Object.assign(program, { vao: this.createVertexArray() });
if (this.debug) {
validateProgram(gl, program2);
}
return program2;
}
buildVao(program) {
this.setProgram(program);
this.bindVertexArray(program.vao);
const gl = this.gl;
callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer));
bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer);
}
deleteProgram(program) {
this.throwIfDisposed();
if (program === this.program) {
this.program = null;
}
if (program != null) {
callAndCheck(this.gl, () => this.gl.deleteProgram(program));
this.deleteVertexArray(program.vao);
}
}
setProgram(program) {
this.throwIfDisposed();
this.program = program;
if (this.program != null) {
if (this.debug) {
validateProgram(this.gl, this.program);
}
}
callAndCheck(this.gl, () => this.gl.useProgram(program));
}
getUniformLocation(program, uniformName, shouldThrow = true) {
this.throwIfDisposed();
if (shouldThrow) {
return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
}
else {
return getProgramUniformLocation(this.gl, program, uniformName);
}
}
getAttributeLocation(program, attribute) {
this.throwIfDisposed();
return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute));
}
getUniformLocationNoThrow(program, uniformName) {
this.throwIfDisposed();
return this.gl.getUniformLocation(program, uniformName);
}
setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
this.throwIfDisposed();
this.throwIfNoProgram();
bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
}
setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
}
setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
this.throwIfDisposed();
const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
}
setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
}
setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
}
debugValidate() {
if (this.program != null) {
validateProgram(this.gl, this.program);
}
validateFramebuffer(this.gl);
}
executeProgram() {
this.throwIfDisposed();
this.throwIfNoProgram();
const gl = this.gl;
if (this.debug) {
const boundVao = this.getVertexArray();
console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!');
this.debugValidate();
}
callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0));
}
blockUntilAllProgramsCompleted() {
this.throwIfDisposed();
callAndCheck(this.gl, () => this.gl.finish());
}
getQueryTimerExtension() {
if (this.disjointQueryTimerExtension == null) {
this.disjointQueryTimerExtension =
getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
'EXT_disjoint_timer_query_webgl2' :
'EXT_disjoint_timer_query');
}
return this.disjointQueryTimerExtension;
}
getQueryTimerExtensionWebGL2() {
return this.getQueryTimerExtension();
}
getQueryTimerExtensionWebGL1() {
return this.getQueryTimerExtension();
}
beginQuery() {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
const gl2 = this.gl;
const ext = this.getQueryTimerExtensionWebGL2();
const query = gl2.createQuery();
gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);
return query;
}
const ext = this.getQueryTimerExtensionWebGL1();
const query = ext.createQueryEXT();
ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
return query;
}
endQuery() {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
const gl2 = this.gl;
const ext = this.getQueryTimerExtensionWebGL2();
gl2.endQuery(ext.TIME_ELAPSED_EXT);
return;
}
const ext = this.getQueryTimerExtensionWebGL1();
ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
}
async waitForQueryAndGetTime(query) {
await repeatedTry(() => this.disposed ||
this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
return this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
}
getQueryTime(query, queryTimerVersion) {
if (queryTimerVersion === 0) {
return null;
}
if (queryTimerVersion === 2) {
const gl2 = this.gl;
const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
return timeElapsedNanos / 1000000;
}
else {
const ext = this.getQueryTimerExtensionWebGL1();
const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
return timeElapsedNanos / 1000000;
}
}
isQueryAvailable(query, queryTimerVersion) {
if (queryTimerVersion === 0) {
return true;
}
if (queryTimerVersion === 2) {
const gl2 = this.gl;
const ext = this.getQueryTimerExtensionWebGL2();
const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
if (this.disjoint == null) {
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
}
return available && !this.disjoint;
}
else {
const ext = this.getQueryTimerExtensionWebGL1();
const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
if (this.disjoint == null) {
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
}
return available && !this.disjoint;
}
}
pollFence(fenceContext) {
return new Promise(resolve => {
this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve());
});
}
pollItems() {
const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn));
for (let i = 0; i <= index; ++i) {
const { resolveFn } = this.itemsToPoll[i];
resolveFn();
}
this.itemsToPoll = this.itemsToPoll.slice(index + 1);
}
addItemToPoll(isDoneFn, resolveFn) {
this.itemsToPoll.push({ isDoneFn, resolveFn });
if (this.itemsToPoll.length > 1) {
return;
}
let scheduleFn = undefined;
if ('setTimeoutCustom' in env().platform) {
scheduleFn = env().platform.setTimeoutCustom.bind(env().platform);
}
repeatedTry(() => {
this.pollItems();
return this.itemsToPoll.length === 0;
}, () => 0, null, scheduleFn);
}
bindTextureToFrameBuffer(texture) {
this.throwIfDisposed();
bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
if (this.debug) {
validateFramebuffer(this.gl);
}
}
unbindTextureToFrameBuffer() {
if (this.outputTexture != null) {
bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
if (this.debug) {
validateFramebuffer(this.gl);
}
}
else {
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
}
}
downloadMatrixDriver(texture, downloadAndDecode) {
this.bindTextureToFrameBuffer(texture);
const result = downloadAndDecode();
this.unbindTextureToFrameBuffer();
return result;
}
setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
this.throwIfDisposed();
const gl = this.gl;
bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
if (this.debug) {
validateFramebuffer(gl);
}
this.outputTexture = outputMatrixTextureMaybePacked;
callAndCheck(gl, () => gl.viewport(0, 0, width, height));
callAndCheck(gl, () => gl.scissor(0, 0, width, height));
}
setOutputMatrixWriteRegionDriver(x, y, width, height) {
this.throwIfDisposed();
callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height));
}
throwIfDisposed() {
if (this.disposed) {
throw new Error('Attempted to use disposed GPGPUContext.');
}
}
throwIfNoProgram() {
if (this.program == null) {
throw new Error('No GPU program is currently set.');
}
}
}
function linearSearchLastTrue(arr) {
let i = 0;
for (; i < arr.length; ++i) {
const isDone = arr[i]();
if (!isDone) {
break;
}
}
return i - 1;
}
function assertNotComplex(tensor, opName) {
if (!Array.isArray(tensor)) {
tensor = [tensor];
}
tensor.forEach(t => {
if (t != null) {
assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors in the CPU backend.`);
}
});
}
function simpleAbsImpl(vals) {
const resultValues = new Float32Array(vals.length);
for (let i = 0; i < vals.length; ++i) {
resultValues[i] = Math.abs(vals[i]);
}
return resultValues;
}
const abs$1 = (args) => {
const { x } = args.inputs;
const cpuBackend = args.backend;
assertNotComplex(x, 'abs');
let resultValues = new Float32Array(sizeFromShape(x.shape));
const values = cpuBackend.data.get(x.dataId).values;
resultValues = simpleAbsImpl(values);
return cpuBackend.makeOutput(resultValues, x.shape, x.dtype);
};
const absConfig$1 = {
kernelName: Abs,
backendName: 'cpu',
kernelFunc: abs$1,
};
function createSimpleBinaryKernelImpl(op) {
return (aShape, bShape, aVals, bVals, dtype) => {
const newShape = assertAndGetBroadcastShape(aShape, bShape);
const resultRank = newShape.length;
const resultStrides = computeStrides(newShape);
const resultSize = sizeFromShape(newShape);
const result = getTypedArrayFromDType(dtype, resultSize);
const aRank = aShape.length;
const bRank = bShape.length;
const aStrides = computeStrides(aShape);
const bStrides = computeStrides(bShape);
const aBroadcastDims = getBroadcastDims$1(aShape, newShape);
const bBroadcastDims = getBroadcastDims$1(bShape, newShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (let i = 0; i < result.length; ++i) {
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
}
}
else {
for (let i = 0; i < result.length; ++i) {
const loc = indexToLoc(i, resultRank, resultStrides);
const aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(d => aLoc[d] = 0);
const aIndex = locToIndex(aLoc, aRank, aStrides);
const bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(d => bLoc[d] = 0);
const bIndex = locToIndex(bLoc, bRank, bStrides);
result[i] = op(aVals[aIndex], bVals[bIndex]);
}
}
return [result, newShape];
};
}
function complex$1(args) {
const { inputs, backend } = args;
const { real, imag } = inputs;
const realVals = backend.data.get(real.dataId).values;
const imagVals = backend.data.get(imag.dataId).values;
const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
const complex = backend.data.get(complexInfo.dataId);
complex.complexTensorInfos = {
real: backend.makeTensorInfo(real.shape, 'float32', realVals),
imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
};
return complexInfo;
}
const complexConfig$1 = {
kernelName: Complex,
backendName: 'cpu',
kernelFunc: complex$1
};
function zeros(backend, shape, dtype = 'float32') {
if (dtype === 'complex64') {
const real = zeros(backend, shape, 'float32');
const imag = zeros(backend, shape, 'float32');
return complex$1({ inputs: { real, imag }, backend });
}
const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
return backend.makeTensorInfo(shape, dtype, values);
}
function identity$1(args) {
const { inputs, backend } = args;
const { x } = inputs;
backend.incRef(x.dataId);
return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
}
const identityConfig$1 = {
kernelName: Identity$1,
backendName: 'cpu',
kernelFunc: identity$1
};
function real$1(args) {
const { inputs, backend } = args;
const { input } = inputs;
const real = backend.data.get(input.dataId).complexTensorInfos.real;
const realVal = backend.data.get(real.dataId).values;
return backend.makeTensorInfo(real.shape, real.dtype, realVal);
}
const realConfig$1 = {
kernelName: Real,
backendName: 'cpu',
kernelFunc: real$1
};
function castImpl(values, shape, inputType, dtype) {
if (dtype === 'int32') {
const resultValues = Int32Array.from(values);
return [shape, 'int32', resultValues];
}
if (dtype === 'bool') {
const zero = toTypedArray([0], inputType);
const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool');
return [resultShape, 'bool', resultData];
}
throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`);
}
function cast$2(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { dtype } = attrs;
if (dtype === 'complex64') {
if (x.dtype === 'complex64') {
return identity$1({ inputs: { x }, backend });
}
const zerosTensorInfo = zeros(backend, x.shape, x.dtype);
const floatX = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
const result = complex$1({ inputs: { real: floatX, imag: zerosTensorInfo }, backend });
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
backend.disposeIntermediateTensorInfo(floatX);
return result;
}
if (x.dtype === 'complex64') {
const realPart = real$1({ inputs: { input: x }, backend });
const result = cast$2({ inputs: { x: realPart }, backend, attrs: { dtype } });
backend.disposeIntermediateTensorInfo(realPart);
return result;
}
if (!hasEncodingLoss(x.dtype, dtype)) {
const result = identity$1({ inputs: { x }, backend });
return { dataId: result.dataId, shape: result.shape, dtype };
}
const values = backend.data.get(x.dataId).values;
const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype);
return backend.makeTensorInfo(resultShape, resultType, resultData);
}
const castConfig$1 = {
kernelName: Cast,
backendName: 'cpu',
kernelFunc: cast$2
};
function binaryKernelFunc$1(name, simpleImpl, complexImpl, dtype) {
if (complexImpl == null) {
return ({ inputs, backend }) => {
const { a, b } = inputs;
const cpuBackend = backend;
assertNotComplex([a, b], name);
const aVals = cpuBackend.data.get(a.dataId).values;
const bVals = cpuBackend.data.get(b.dataId).values;
const decodedAVals = a.dtype === 'string' ?
fromUint8ToStringArray(aVals) :
aVals;
const decodedBVals = a.dtype === 'string' ?
fromUint8ToStringArray(bVals) :
bVals;
const $dtype = dtype || a.dtype;
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
};
}
return ({ inputs, backend }) => {
const { a, b } = inputs;
const cpuBackend = backend;
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
const $aComplex = cast$2({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
const aReal = $aComplexVals.complexTensorInfos.real;
const aImag = $aComplexVals.complexTensorInfos.imag;
const aRealVals = cpuBackend.data.get(aReal.dataId).values;
const aImagVals = cpuBackend.data.get(aImag.dataId).values;
const $bComplex = cast$2({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
const bReal = $bComplexVals.complexTensorInfos.real;
const bImag = $bComplexVals.complexTensorInfos.imag;
const bRealVals = cpuBackend.data.get(bReal.dataId).values;
const bImagVals = cpuBackend.data.get(bImag.dataId).values;
const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
const result = complex$1({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
cpuBackend.disposeIntermediateTensorInfo($aComplex);
cpuBackend.disposeIntermediateTensorInfo($bComplex);
cpuBackend.disposeIntermediateTensorInfo(resultReal);
cpuBackend.disposeIntermediateTensorInfo(resultImag);
return result;
}
else {
const aVals = cpuBackend.data.get(a.dataId).values;
const bVals = cpuBackend.data.get(b.dataId).values;
const $dtype = dtype || a.dtype;
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
}
};
}
function createComplexBinaryKernelImpl(op) {
return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
const resultShape = assertAndGetBroadcastShape(aShape, bShape);
const resultSize = sizeFromShape(resultShape);
const resultRank = resultShape.length;
const resultStrides = computeStrides(resultShape);
const resultRealVals = getTypedArrayFromDType('float32', resultSize);
const resultImagVals = getTypedArrayFromDType('float32', resultSize);
const aBroadcastDims = getBroadcastDims$1(aShape, resultShape);
const bBroadcastDims = getBroadcastDims$1(bShape, resultShape);
const aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
const bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
const aRank = aShape.length;
const aStrides = computeStrides(aShape);
const bRank = bShape.length;
const bStrides = computeStrides(bShape);
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
for (let i = 0; i < resultRealVals.length; i++) {
const aIdx = i % aVals.length;
const bIdx = i % bVals.length;
const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
resultRealVals[i] = result.real;
resultImagVals[i] = result.imag;
}
}
else {
for (let i = 0; i < resultRealVals.length; i++) {
const loc = indexToLoc(i, resultRank, resultStrides);
const aLoc = loc.slice(-aRank);
aBroadcastDims.forEach(d => aLoc[d] = 0);
const aIndex = locToIndex(aLoc, aRank, aStrides);
const bLoc = loc.slice(-bRank);
bBroadcastDims.forEach(d => bLoc[d] = 0);
const bIndex = locToIndex(bLoc, bRank, bStrides);
const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
resultRealVals[i] = opResult.real;
resultImagVals[i] = opResult.imag;
}
}
return [resultRealVals, resultImagVals, resultShape];
};
}
const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
const addComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
return { real: aReal + bReal, imag: aImag + bImag };
}));
const add = binaryKernelFunc$1(Add, addImpl, addComplexImpl);
const addConfig$1 = {
kernelName: Add,
backendName: 'cpu',
kernelFunc: add
};
function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
const weightsSize = sizeFromShape(weightsShape);
const outVals = makeZerosTypedArray(size, weightsDtype);
for (let i = 0; i < xVals.length; i++) {
const value = xVals[i];
if (value < 0) {
throw new Error('Input x must be non-negative!');
}
if (value >= size) {
continue;
}
if (weightsSize > 0) {
outVals[value] += weightsVals[i];
}
else {
outVals[value] += 1;
}
}
return outVals;
}
function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) {
const numRows = xBuf.shape[0];
const numCols = xBuf.shape[1];
const outBuf = buffer([numRows, size], weightsBuf.dtype);
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < numCols; j++) {
const value = xBuf.get(i, j);
if (value < 0) {
throw new Error('Input x must be non-negative!');
}
if (value >= size) {
continue;
}
if (binaryOutput) {
outBuf.set(1, i, value);
}
else {
if (weightsBuf.size > 0) {
outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
}
else {
outBuf.set(outBuf.get(i, value) + 1, i, value);
}
}
}
}
return outBuf;
}
const bitwiseAndImpl = createSimpleBinaryKernelImpl(((a, b) => a & b));
const bitwiseAnd$1 = binaryKernelFunc$1(BitwiseAnd, bitwiseAndImpl);
const bitwiseAndConfig$1 = {
kernelName: BitwiseAnd,
backendName: 'cpu',
kernelFunc: bitwiseAnd$1
};
function createSimpleUnaryImpl(op) {
return (values, dtype, attrs) => {
const newValues = getArrayFromDType(dtype, values.length);
for (let i = 0; i < values.length; ++i) {
newValues[i] = op(values[i], attrs);
}
return newValues;
};
}
function unaryKernelFunc$1(name, op, dtype) {
const impl = createSimpleUnaryImpl(op);
return unaryKernelFuncFromImpl(name, impl, dtype);
}
function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
return ({ inputs, attrs, backend }) => {
const { x } = inputs;
assertNotComplex(x, name);
const cpuBackend = backend;
const values = cpuBackend.data.get(x.dataId).values;
let decoded;
if (x.dtype === 'string') {
if (!Array.isArray(values)) {
throw new Error('String tensor\'s value was not an instance of Array');
}
decoded = fromUint8ToStringArray(values);
}
else {
decoded = values;
}
const $dtype = dtype || x.dtype;
const newValues = unaryImpl(decoded, $dtype, attrs);
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
};
}
const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
const ceil$1 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
const ceilConfig$1 = {
kernelName: Ceil,
backendName: 'cpu',
kernelFunc: ceil$1,
};
function concatImpl$1(inputs, outShape, dtype, simplyConcat) {
const outVals = getArrayFromDType(dtype, sizeFromShape(outShape));
if (simplyConcat && dtype !== 'string') {
let offset = 0;
inputs.forEach(input => {
const size = sizeFromShape(input.shape);
outVals.set(input.vals, offset);
offset += size;
});
}
else {
let colOffset = 0;
inputs.forEach(input => {
const decodedData = dtype === 'string' ?
fromUint8ToStringArray(input.vals) :
input.vals;
let tIdx = 0;
for (let row = 0; row < input.shape[0]; ++row) {
const resIdx = row * outShape[1] + colOffset;
for (let col = 0; col < input.shape[1]; ++col) {
outVals[resIdx + col] = decodedData[tIdx++];
}
}
colOffset += input.shape[1];
});
}
return outVals;
}
const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0);
const equal$1 = binaryKernelFunc$1(Equal, equalImpl, null , 'bool');
const equalConfig$1 = {
kernelName: Equal,
backendName: 'cpu',
kernelFunc: equal$1
};
const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
const exp$1 = unaryKernelFuncFromImpl(Exp, expImpl, 'float32');
const expConfig$1 = {
kernelName: Exp,
backendName: 'cpu',
kernelFunc: exp$1,
};
const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
const expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
const expm1Config$1 = {
kernelName: Expm1,
backendName: 'cpu',
kernelFunc: expm1$1,
};
const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
const floor$1 = unaryKernelFuncFromImpl(Floor, floorImpl);
const floorConfig$1 = {
kernelName: Floor,
backendName: 'cpu',
kernelFunc: floor$1,
};
const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b));
const floorDiv$1 = binaryKernelFunc$1(FloorDiv, floorDivImpl, null , 'int32');
const floorDivConfig$1 = {
kernelName: FloorDiv,
backendName: 'cpu',
kernelFunc: floorDiv$1
};
function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
const outBuf = buffer([numSlices, sliceSize], dtype);
for (let i = 0; i < numSlices; i++) {
const index = [];
let flattenIndex = 0;
for (let j = 0; j < sliceRank; j++) {
const dim = indicesData[i * sliceRank + j];
flattenIndex += dim * strides[j];
index.push(dim);
}
if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`);
}
for (let k = 0; k < sliceSize; k++) {
outBuf.values[i * sliceSize + k] =
paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
}
}
return outBuf;
}
function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
const outBuf = buffer(flattenOutputShape, xBuf.dtype);
for (let i = 0; i < outBuf.size; ++i) {
const newLoc = outBuf.indexToLoc(i);
const originalLoc = newLoc.slice();
const batchIdx = originalLoc[0];
const indicesIdx = originalLoc[2];
const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
originalLoc[2] = indicesBuf.values[indicesIndex];
const originalIndex = xBuf.locToIndex(originalLoc);
if (0 <= originalIndex && originalIndex < xBuf.values.length) {
outBuf.values[i] = xBuf.values[originalIndex];
}
}
return outBuf;
}
const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0);
const greater$1 = binaryKernelFunc$1(Greater, greaterImpl, null , 'bool');
const greaterConfig$1 = {
kernelName: Greater,
backendName: 'cpu',
kernelFunc: greater$1
};
const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0);
const greaterEqual$1 = binaryKernelFunc$1(GreaterEqual, greaterEqualImpl, null , 'bool');
const greaterEqualConfig$1 = {
kernelName: GreaterEqual,
backendName: 'cpu',
kernelFunc: greaterEqual$1
};
const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0);
const less$1 = binaryKernelFunc$1(Less, lessImpl, null , 'bool');
const lessConfig$1 = {
kernelName: Less,
backendName: 'cpu',
kernelFunc: less$1
};
const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0);
const lessEqual$1 = binaryKernelFunc$1(LessEqual, lessEqualImpl, null , 'bool');
const lessEqualConfig$1 = {
kernelName: LessEqual,
backendName: 'cpu',
kernelFunc: lessEqual$1
};
function linSpaceImpl(start, stop, num) {
const step = (stop - start) / (num - 1);
const values = makeZerosTypedArray(num, 'float32');
values[0] = start;
for (let i = 1; i < values.length; i++) {
values[i] = values[i - 1] + step;
}
return values;
}
const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
const log$1 = unaryKernelFuncFromImpl(Log, logImpl);
const logConfig$1 = {
kernelName: Log,
backendName: 'cpu',
kernelFunc: log$1,
};
function maxImpl$1(aVals, reduceSize, outShape, dtype) {
const vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (Number.isNaN(value) ||
value > max) {
max = value;
}
}
vals[i] = max;
}
return vals;
}
const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue)));
const maximum$1 = binaryKernelFunc$1(Maximum, maximumImpl);
const maximumConfig$1 = {
kernelName: Maximum,
backendName: 'cpu',
kernelFunc: maximum$1
};
const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue)));
const minimum$1 = binaryKernelFunc$1(Minimum, minimumImpl);
const minimumConfig$1 = {
kernelName: Minimum,
backendName: 'cpu',
kernelFunc: minimum$1
};
const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
const multiplyComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
return {
real: aReal * bReal - aImag * bImag,
imag: aReal * bImag + aImag * bReal
};
}));
const multiply$1 = binaryKernelFunc$1(Multiply, multiplyImpl, multiplyComplexImpl);
const multiplyConfig$1 = {
kernelName: Multiply,
backendName: 'cpu',
kernelFunc: multiply$1
};
function negImpl(xVals, xShape, xDtype) {
const minusOne = createScalarValue(-1, xDtype);
return multiplyImpl([], xShape, minusOne, xVals, xDtype);
}
function neg$1(args) {
const { inputs, backend } = args;
const { x } = inputs;
assertNotComplex(x, 'neg');
const xVals = backend.data.get(x.dataId).values;
const [res, newShape] = negImpl(xVals, x.shape, x.dtype);
return backend.makeTensorInfo(newShape, x.dtype, res);
}
const negConfig$1 = {
kernelName: Neg,
backendName: 'cpu',
kernelFunc: neg$1
};
const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
const notEqual$1 = binaryKernelFunc$1(NotEqual, notEqualImpl, null , 'bool');
const notEqualConfig$1 = {
kernelName: NotEqual,
backendName: 'cpu',
kernelFunc: notEqual$1
};
function transposeImpl$1(xVals, xShape, dtype, perm, newShape) {
const xRank = xShape.length;
const xSize = sizeFromShape(xShape);
const xStrides = computeStrides(xShape);
const newStrides = computeStrides(newShape);
const result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
for (let i = 0; i < xSize; ++i) {
const loc = indexToLoc(i, xRank, xStrides);
const newLoc = new Array(loc.length);
for (let i = 0; i < newLoc.length; i++) {
newLoc[i] = loc[perm[i]];
}
const newIndex = locToIndex(newLoc, xRank, newStrides);
result[newIndex] = xVals[i];
}
return result;
}
function transpose$1(args) {
const { inputs, attrs, backend } = args;
const { x } = inputs;
const { perm } = attrs;
assertNotComplex(x, 'transpose');
const xRank = x.shape.length;
const newShape = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
const values = backend.data.get(x.dataId).values;
const result = transposeImpl$1(values, x.shape, x.dtype, perm, newShape);
const dataId = backend.write(result, newShape, x.dtype);
return { dataId, shape: newShape, dtype: x.dtype };
}
const transposeConfig$1 = {
kernelName: Transpose,
backendName: 'cpu',
kernelFunc: transpose$1
};
function prodImpl(xShape, xDtype, xVals, reductionAxes) {
const [outShape, reduceShape] = computeOutAndReduceShapes(xShape, reductionAxes);
const outDtype = upcastType(xDtype, 'int32');
const outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype);
const reduceSize = sizeFromShape(reduceShape);
for (let i = 0; i < outVals.length; ++i) {
const offset = i * reduceSize;
let prod = 1;
for (let j = 0; j < reduceSize; ++j) {
prod *= xVals[offset + j];
}
outVals[i] = prod;
}
return { outVals, outShape, outDtype };
}
function prod$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
assertNotComplex(x, 'prod');
const xRank = x.shape.length;
const axes = parseAxisParam(axis, x.shape);
const permutation = getAxesPermutation(axes, xRank);
let reductionAxes = axes;
let permutedX = x;
const intermediateTensorInfos = [];
if (permutation != null) {
permutedX = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
intermediateTensorInfos.push(permutedX);
reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
}
const xVals = backend.data.get(permutedX.dataId).values;
const { outVals, outShape, outDtype } = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes);
let resultShape = outShape;
if (keepDims) {
resultShape = expandShapeToKeepDim(outShape, axes);
}
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return backend.makeTensorInfo(resultShape, outDtype, outVals);
}
const prodConfig$1 = {
kernelName: Prod,
backendName: 'cpu',
kernelFunc: prod$1
};
function validateIndices(indices, indicesShape, numParams) {
indices.forEach((index, i) => {
if (index < 0 || index >= numParams) {
const locString = indexToLoc(i, indicesShape.length, computeStrides(indicesShape))
.join(',');
throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`);
}
});
}
function validateSplits(paramsNestedSplits, numParamsDenseValues) {
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
const splits = paramsNestedSplits[dim];
const lastSplit = (dim === paramsNestedSplits.length - 1) ?
numParamsDenseValues :
paramsNestedSplits[dim + 1].length;
if (splits.length === 0) {
throw new Error('Ragged splits may not be empty');
}
if (splits[0] < 0) {
throw new Error('Ragged splits must be non-negative');
}
if (splits[splits.length - 1] > lastSplit) {
throw new Error('Ragged splits must not point past values');
}
for (let i = 1; i < splits.length; ++i) {
if (splits[i - 1] > splits[i]) {
throw new Error('Ragged splits must be sorted in ascending order');
}
}
}
}
function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) {
const valueSlices = [];
let numValues = 0;
const numSplits = indicesShape.length - 1 + paramsNestedSplits.length;
const outSplits = new Array(numSplits).fill(null).map(() => [0]);
validateSplits(paramsNestedSplits, numParamsDenseValues);
let nrows = 1;
for (let dim = 0; dim < indicesShape.length - 1; ++dim) {
nrows *= indicesShape[dim];
const rowLength = indicesShape[dim + 1];
for (let i = 1; i < nrows + 1; ++i) {
outSplits[dim].push(i * rowLength);
}
}
for (let i = 0; i < indices.length; ++i) {
let start = indices[i];
let limit = indices[i] + 1;
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
const splits = paramsNestedSplits[dim];
const outDim = dim + indicesShape.length - 1;
if (outDim >= 0) {
const outSplitsOutDim = outSplits[outDim];
const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];
for (let j = start; j < limit; ++j) {
outSplits[outDim].push(splits[j + 1] + delta);
}
}
start = splits[start];
limit = splits[limit];
}
if (limit !== start) {
valueSlices.push([start, limit]);
numValues += limit - start;
}
}
return { outSplits, valueSlices, numValues };
}
function getSplits(outSplits) {
const splitsOut = [];
for (let i = 0; i < outSplits.length; ++i) {
const numSplits = outSplits[i].length;
const splits = getArrayFromDType('int32', numSplits);
splitsOut.push(splits);
outSplits[i].forEach((value, j) => splits[j] = value);
}
return splitsOut;
}
function computeFlatOuterDims(orig, numOutDims) {
const outDims = orig.slice(0, numOutDims);
while (outDims.length < numOutDims) {
outDims.push(1);
}
for (let inDim = numOutDims; inDim < orig.length; inDim++) {
outDims[numOutDims - 1] *= orig[inDim];
}
return outDims;
}
function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) {
const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];
const valuesM = computeFlatOuterDims(valuesShape, 2)[1];
let outPos = 0;
for (const slice of valueSlices) {
for (let i = slice[0]; i < slice[1]; ++i) {
for (let j = 0; j < valueSize; ++j) {
values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];
}
++outPos;
}
}
}
function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) {
const valuesShape = paramsDenseValuesShape.slice();
valuesShape[0] = numValues;
const valuesOut = getArrayFromDType(paramsDenseValuesDType, sizeFromShape(valuesShape));
const numElements = paramsDenseValues.length;
const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]);
writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape);
return [valuesOut, valuesShape];
}
function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) {
if (paramsNestedSplits.length === 0) {
throw new Error('paramsNestedSplits must be non empty');
}
if (paramsNestedSplitsShapes[0].length === 0) {
throw new Error('Split tensors must not be scalars');
}
const numParams = paramsNestedSplitsShapes[0][0] - 1;
validateIndices(indices, indicesShape, numParams);
if (paramsDenseValuesShape.length === 0) {
throw new Error('params.rank must be nonzero');
}
const numParamsDenseValues = paramsDenseValuesShape[0];
const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues);
const outputNestedSplits = getSplits(outSplits);
const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues);
return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];
}
const INT32_MAX = 2147483647;
function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) {
if (startsShape.length > 1) {
throw new Error('starts must be a scalar or vector');
}
if (limitsShape.length > 1) {
throw new Error('limits must be a scalar or vector');
}
if (deltasShape.length > 1) {
throw new Error('deltas must be a scalar or vector');
}
const broadcastStarts = startsShape.length === 0;
const broadcastLimits = limitsShape.length === 0;
const broadcastDeltas = deltasShape.length === 0;
const inSizes = [];
if (!broadcastStarts) {
inSizes.push(startsShape[0]);
}
if (!broadcastLimits) {
inSizes.push(limitsShape[0]);
}
if (!broadcastDeltas) {
inSizes.push(deltasShape[0]);
}
for (let i = 1; i < inSizes.length; ++i) {
if (inSizes[i] !== inSizes[i - 1]) {
throw new Error('starts, limits, and deltas must have the same shape');
}
}
const nRows = inSizes.length === 0 ? 1 : inSizes[0];
const rtNestedSplits = getArrayFromDType('int32', nRows + 1);
rtNestedSplits[0] = 0;
for (let row = 0; row < nRows; ++row) {
const start = broadcastStarts ? starts[0] : starts[row];
const limit = broadcastLimits ? limits[0] : limits[row];
const delta = broadcastDeltas ? deltas[0] : deltas[row];
if (delta === 0) {
throw new Error('Requires delta != 0');
}
let size;
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
size = 0;
}
else {
size = Math.ceil(Math.abs((limit - start) / delta));
if (size > INT32_MAX) {
throw new Error(`Requires ((limit - start) / delta) <= ${INT32_MAX}`);
}
}
rtNestedSplits[row + 1] = rtNestedSplits[row] + size;
}
const nVals = rtNestedSplits[nRows];
const rtDenseValues = getArrayFromDType(startsDType, nVals);
let valueIndex = 0;
for (let row = 0; row < nRows; ++row) {
const rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row];
let value = broadcastStarts ? starts[0] : starts[row];
const delta = broadcastDeltas ? deltas[0] : deltas[row];
for (let i = 0; i < rowSize; ++i) {
rtDenseValues[valueIndex++] = value;
value += delta;
}
}
return [rtNestedSplits, rtDenseValues];
}
var RowPartitionType = RowPartitionType$1;
class RaggedTensorToTensorOp {
constructor(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) {
this.shape = shape;
this.shapeShape = shapeShape;
this.values = values;
this.valuesShape = valuesShape;
this.valuesDType = valuesDType;
this.defaultValue = defaultValue;
this.defaultValueShape = defaultValueShape;
this.rowPartitionValues = rowPartitionValues;
this.rowPartitionValuesShapes = rowPartitionValuesShapes;
this.rowPartitionTypes =
getRowPartitionTypesHelper(rowPartitionTypeStrings);
this.raggedRank = getRaggedRank(this.rowPartitionTypes);
}
getRowPartitionTypeByDimension(dimension) {
if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
return this.rowPartitionTypes[dimension + 1];
}
else {
return this.rowPartitionTypes[dimension];
}
}
getRowPartitionTensor(dimension) {
if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
return this.rowPartitionValues[dimension + 1];
}
else {
return this.rowPartitionValues[dimension];
}
}
getMaxWidth(dimension) {
const rowPartitionTensor = this.getRowPartitionTensor(dimension - 1);
switch (this.getRowPartitionTypeByDimension(dimension - 1)) {
case RowPartitionType.VALUE_ROWIDS:
return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor);
case RowPartitionType.ROW_SPLITS:
return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor);
default:
throw new Error(`Cannot handle partition type ${RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]}`);
}
}
static getMaxWidthRowSplit(rowSplit) {
const tensorLength = rowSplit.length;
if (tensorLength === 0 || tensorLength === 1) {
return 0;
}
let maxWidth = 0;
for (let i = 0; i < tensorLength - 1; ++i) {
const currentWidth = rowSplit[i + 1] - rowSplit[i];
if (currentWidth > maxWidth) {
maxWidth = currentWidth;
}
}
return maxWidth;
}
static getMaxWidthValueRowID(valueRowIds) {
const indexLength = valueRowIds.length;
if (indexLength === 0) {
return 0;
}
let firstEqualIndex = 0;
let firstEqualIndexValue = valueRowIds[0];
let maxWidth = 0;
for (let i = 1; i < indexLength; ++i) {
const value = valueRowIds[i];
if (value !== firstEqualIndexValue) {
firstEqualIndexValue = value;
maxWidth = Math.max(i - firstEqualIndex, maxWidth);
firstEqualIndex = i;
}
}
return Math.max(indexLength - firstEqualIndex, maxWidth);
}
tensorShapeFromTensor(t, tShape, isPartial = true) {
if (tShape.length === 0) {
if (t[0] === -1) {
return [];
}
throw new Error(`The only valid scalar shape tensor is the fully unknown shape specified as -1.`);
}
return makeShape(t, isPartial);
}
calculateOutputSize(firstDim) {
const valueShape = this.valuesShape;
const defaultValueShape = this.defaultValueShape;
validateDefaultValueShape(defaultValueShape, valueShape);
const shape = this.tensorShapeFromTensor(this.shape, this.shapeShape);
const outputShape = combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape);
const result = outputShape;
if (result[0] < 0) {
result[0] = firstDim;
}
for (let i = 1; i <= this.raggedRank; ++i) {
if (result[i] < 0) {
result[i] = this.getMaxWidth(i);
}
}
return result;
}
calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) {
const minDimension = Math.min(firstDimension, firstDimensionOutput);
const result = [];
let currentOutputIndex = 0;
for (let i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) {
result.push(currentOutputIndex);
}
for (let i = minDimension; i < firstDimension; ++i) {
result.push(-1);
}
assert$1(result.length === firstDimension, () => 'Final length of result must be equal to firstDimension.');
return result;
}
calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) {
const rowSplitSize = rowSplit.length;
const result = [];
for (let i = 0; i < rowSplitSize - 1; ++i) {
const rowLength = rowSplit[i + 1] - rowSplit[i];
let realLength = Math.min(outputSize, rowLength);
let parentOutputIndexCurrent = parentOutputIndex[i];
if (parentOutputIndexCurrent === -1) {
realLength = 0;
}
for (let j = 0; j < realLength; ++j) {
result.push(parentOutputIndexCurrent);
parentOutputIndexCurrent += outputIndexMultiplier;
}
for (let j = 0; j < rowLength - realLength; ++j) {
result.push(-1);
}
}
if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) {
throw new Error('Invalid row split size.');
}
return result;
}
calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) {
const indexSize = valueRowIds.length;
const result = [];
if (indexSize === 0) {
return [];
}
let currentOutputColumn = 0;
let currentValueRowId = valueRowIds[0];
if (currentValueRowId >= parentOutputIndex.length) {
throw new Error(`Got currentValueRowId=${currentValueRowId}, which is not less than ${parentOutputIndex.length}`);
}
let currentOutputIndex = parentOutputIndex[currentValueRowId];
result.push(currentOutputIndex);
for (let i = 1; i < indexSize; ++i) {
const nextValueRowId = valueRowIds[i];
if (nextValueRowId === currentValueRowId) {
if (currentOutputIndex >= 0) {
++currentOutputColumn;
if (currentOutputColumn < outputSize) {
currentOutputIndex += outputIndexMultiplier;
}
else {
currentOutputIndex = -1;
}
}
}
else {
currentOutputColumn = 0;
currentValueRowId = nextValueRowId;
if (nextValueRowId >= parentOutputIndex.length) {
throw new Error(`Got nextValueRowId=${nextValueRowId} which is not less than ${parentOutputIndex.length}`);
}
currentOutputIndex = parentOutputIndex[nextValueRowId];
}
result.push(currentOutputIndex);
}
if (result.length !== valueRowIds.length) {
throw new Error('Invalid row ids.');
}
return result;
}
calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) {
const rowPartitionTensor = this.getRowPartitionTensor(dimension);
const partitionType = this.getRowPartitionTypeByDimension(dimension);
switch (partitionType) {
case RowPartitionType.VALUE_ROWIDS:
return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
case RowPartitionType.ROW_SPLITS:
if (rowPartitionTensor.length - 1 > parentOutputIndex.length) {
throw new Error(`Row partition size is greater than output size: ${rowPartitionTensor.length - 1} > ${parentOutputIndex.length}`);
}
return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
default:
throw new Error(`Unsupported partition type: ${RowPartitionType[partitionType]}`);
}
}
getFirstDimensionSize() {
const firstPartitionTensor = this.rowPartitionValues[0];
if (this.rowPartitionTypes.length === 0) {
throw new Error('No row_partition_types given.');
}
const firstPartitionType = this.rowPartitionTypes[0];
switch (firstPartitionType) {
case RowPartitionType.FIRST_DIM_SIZE:
return firstPartitionTensor[0];
case RowPartitionType.VALUE_ROWIDS:
throw new Error('Cannot handle VALUE_ROWIDS in first dimension.');
case RowPartitionType.ROW_SPLITS:
return this.rowPartitionValuesShapes[0][0] - 1;
default:
throw new Error(`Cannot handle type ${RowPartitionType[firstPartitionType]}`);
}
}
compute() {
const firstPartitionTensor = this.rowPartitionValues[0];
if (firstPartitionTensor.length <= 0) {
throw new Error('Invalid first partition input. ' +
'Tensor requires at least one element.');
}
const firstDimension = this.getFirstDimensionSize();
const outputSize = this.calculateOutputSize(firstDimension);
const multiplier = new Array(this.raggedRank + 1);
multiplier[multiplier.length - 1] = 1;
for (let i = multiplier.length - 2; i >= 0; --i) {
multiplier[i] = multiplier[i + 1] * outputSize[i + 1];
}
const outputShape = makeShape(outputSize, false);
const outputTensor = getArrayFromDType(this.valuesDType, sizeFromShape(outputShape));
const fullSize = multiplier[0] * outputSize[0];
if (fullSize > 0) {
let outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]);
for (let i = 1; i <= this.raggedRank; ++i) {
const newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]);
outputIndex = newOutputIndex;
}
this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape);
}
return [outputShape, outputTensor];
}
setOutput(raggedRank, outputIndex, outputTensor, outputShape) {
if (outputTensor.length === 0) {
return;
}
const valuesBase = this.values;
const outputBase = outputTensor;
let elementShape = outputShape.slice();
elementShape = elementShape.slice(raggedRank + 1);
const valueElementSize = sizeFromShape(elementShape);
const outputIndexSize = outputIndex.length;
let defaultValue = this.defaultValue;
if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) {
const srcShape = this.defaultValueShape;
tidy(() => {
const defaultValueTensor = reshape$2(defaultValue, srcShape);
const bCastDefault = broadcastTo(defaultValueTensor, elementShape);
defaultValue = bCastDefault.dataSync();
});
}
let srcStart = 0;
let dstStart = 0;
let dstEnd = 0;
for (let srcI = 0; srcI <= outputIndexSize; ++srcI) {
let dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1;
if (dstI === dstEnd) {
++dstEnd;
continue;
}
if (dstStart < dstEnd) {
const src = valuesBase.subarray(srcStart * valueElementSize);
const dst = outputBase.subarray(dstStart * valueElementSize);
const nVals = (dstEnd - dstStart) * valueElementSize;
copyArray(dst, src, nVals);
}
if (srcI >= outputIndexSize) {
const outputSize = outputTensor.length;
dstI = Math.floor(outputSize / valueElementSize);
}
if (dstI > dstEnd) {
if (this.defaultValue.length === 1) {
outputBase
.subarray(dstEnd * valueElementSize, dstI * valueElementSize)
.fill(this.defaultValue[0]);
dstEnd = dstI;
}
else {
while (dstI > dstEnd) {
const dst = outputBase.slice(dstEnd * valueElementSize);
copyArray(dst, defaultValue, valueElementSize);
++dstEnd;
}
}
}
if (dstI < 0) {
srcStart = srcI + 1;
dstStart = dstEnd;
}
else {
srcStart = srcI;
dstStart = dstEnd;
dstEnd = dstStart + 1;
}
}
}
}
function copyArray(dst, src, size) {
for (let i = 0; i < size; i++) {
dst[i] = src[i];
}
}
function makeShape(shape, isPartial) {
const out = [];
for (let dim of shape) {
if (dim < 0) {
if (!isPartial) {
throw new Error(`Dimension ${dim} must be >= 0`);
}
if (dim < -1) {
throw new Error(`Dimension ${dim} must be >= -1`);
}
dim = -1;
}
out.push(dim);
}
return out;
}
function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) {
return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes)
.compute();
}
function rangeImpl(start, stop, step, dtype) {
const sameStartStop = start === stop;
const increasingRangeNegativeStep = start < stop && step < 0;
const decreasingRangePositiveStep = stop < start && step > 1;
if (sameStartStop || increasingRangeNegativeStep ||
decreasingRangePositiveStep) {
return makeZerosTypedArray(0, dtype);
}
const numElements = Math.abs(Math.ceil((stop - start) / step));
const values = makeZerosTypedArray(numElements, dtype);
if (stop < start && step === 1) {
step = -1;
}
values[0] = start;
for (let i = 1; i < values.length; i++) {
values[i] = values[i - 1] + step;
}
return values;
}
const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
const rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
const rsqrtConfig$1 = {
kernelName: Rsqrt,
backendName: 'cpu',
kernelFunc: rsqrt$1,
};
function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
const flattenShape = [outputSize / sliceSize, sliceSize];
const indicesData = indices.values;
const updatesData = updates.values;
if (outputSize === 0) {
return buffer(shape, updates.dtype);
}
const outBuf = (defaultValue instanceof TensorBuffer) ?
defaultValue :
buffer(flattenShape, updates.dtype);
if (typeof defaultValue === 'string') {
outBuf.values.fill(defaultValue);
}
else if (typeof defaultValue === 'number') {
outBuf.values.fill(defaultValue);
}
else if (typeof defaultValue === 'boolean') {
outBuf.values.fill(+defaultValue);
}
for (let i = 0; i < numUpdates; i++) {
const index = [];
let flattenIndex = 0;
for (let j = 0; j < sliceRank; j++) {
const dim = indicesData[i * sliceRank + j];
index.push(dim);
flattenIndex += dim * strides[j];
}
if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
}
for (let k = 0; k < sliceSize; k++) {
if (sumDupeIndices) {
outBuf.values[flattenIndex * sliceSize + k] +=
updatesData[i * sliceSize + k];
}
else {
outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
updatesData[0] :
updatesData[i * sliceSize + k];
}
}
}
return outBuf;
}
const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi)));
const sigmoid$1 = unaryKernelFunc$1(Sigmoid$1, (xi) => 1 / (1 + Math.exp(-xi)));
const sigmoidConfig$1 = {
kernelName: Sigmoid$1,
backendName: 'cpu',
kernelFunc: sigmoid$1,
};
function sliceImpl(vals, begin, size, shape, dtype) {
const isContinous = isSliceContinous(shape, begin, size);
const length = sizeFromShape(size);
const xStrides = computeStrides(shape);
if (isContinous) {
const flatOffset = computeFlatOffset(begin, xStrides);
if (dtype === 'string') {
return vals.slice(flatOffset, flatOffset + length);
}
return vals.subarray(flatOffset, flatOffset + length);
}
const decodedData = dtype === 'string' ?
fromUint8ToStringArray(vals) :
vals;
const inBuf = buffer(shape, dtype, decodedData);
const outBuf = buffer(size, dtype);
for (let i = 0; i < outBuf.size; ++i) {
const outLoc = outBuf.indexToLoc(i);
const inLoc = outLoc.map((idx, j) => idx + begin[j]);
outBuf.set(inBuf.get(...inLoc), ...outLoc);
}
if (dtype === 'string') {
return fromStringArrayToUint8(outBuf.values);
}
return outBuf.values;
}
function slice$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { begin, size } = attrs;
assertNotComplex(x, 'slice');
const [$begin, $size] = parseSliceParams(x, begin, size);
assertParamsValid(x, $begin, $size);
const vals = backend.data.get(x.dataId).values;
const outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
return backend.makeTensorInfo($size, x.dtype, outVals);
}
const sliceConfig$1 = {
kernelName: Slice,
backendName: 'cpu',
kernelFunc: slice$1
};
function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
const indicesCount = indicesShape[0];
const denseRows = denseShape[0];
const emptyRowIndicator = new Array(denseRows);
const reverseIndexMap = new Array(indicesCount);
const rank = indicesShape[1];
if (denseRows === 0) {
if (indicesCount !== 0) {
throw new Error(getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
}
const outputIndices = getArrayFromDType(indicesDType, 0);
const outputValues = getArrayFromDType(valuesDType, 0);
return [
outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
];
}
let rowsAreOrdered = true;
let lastIndicesRow = 0;
const csrOffset = new Array(denseRows).fill(0);
for (let i = 0; i < indicesCount; ++i) {
const row = indices[i * rank];
if (row < 0) {
throw new Error(getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
}
if (row >= denseRows) {
throw new Error(getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
}
++csrOffset[row];
rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
lastIndicesRow = row;
}
let allRowsFull = true;
for (let row = 0; row < denseRows; ++row) {
const rowEmpty = (csrOffset[row] === 0);
emptyRowIndicator[row] = rowEmpty;
allRowsFull = allRowsFull && !rowEmpty;
csrOffset[row] = Math.max(csrOffset[row], 1);
if (row > 0) {
csrOffset[row] += csrOffset[row - 1];
}
}
if (allRowsFull && rowsAreOrdered) {
const outputIndices = indices;
const outputValues = values;
for (let i = 0; i < indicesCount; ++i) {
reverseIndexMap[i] = i;
}
return [
outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
reverseIndexMap
];
}
else {
const fullIndicesCount = csrOffset[denseRows - 1];
const outputIndices = getArrayFromDType(indicesDType, fullIndicesCount * rank);
const outputValues = getArrayFromDType(valuesDType, fullIndicesCount);
const filledCount = new Array(denseRows).fill(0);
for (let i = 0; i < indicesCount; ++i) {
const row = indices[i * rank];
const offset = filledCount[row];
const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
filledCount[row]++;
for (let j = 0; j < rank; ++j) {
outputIndices[outputI * rank + j] = indices[i * rank + j];
}
outputValues[outputI] = values[i];
reverseIndexMap[i] = outputI;
}
for (let row = 0; row < denseRows; ++row) {
const rowCount = filledCount[row];
if (rowCount === 0) {
const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
outputIndices[startingIndex * rank + 0] = row;
for (let col = 1; col < rank; ++col) {
outputIndices[startingIndex * rank + col] = 0;
}
outputValues[startingIndex] = defaultValue;
}
}
return [
outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
reverseIndexMap
];
}
}
function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
const denseSize = sizeFromShape(inputShape);
const nnz = inputIndicesShape[0];
const outputRank = targetShape.length;
const outputShape = [];
let product = 1;
let unknownIndex = -1;
for (let d = 0; d < outputRank; ++d) {
const size = targetShape[d];
if (size === -1) {
if (unknownIndex !== -1) {
throw new Error(getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
}
unknownIndex = d;
outputShape.push(1);
}
else {
if (size < 0) {
throw new Error(getSparseReshapeNegativeOutputDimErrorMessage(d, size));
}
product *= size;
outputShape.push(size);
}
}
if (unknownIndex !== -1) {
if (product <= 0) {
throw new Error(getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
}
const missing = Math.trunc(denseSize / product);
if (product * missing !== denseSize) {
throw new Error(getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
}
outputShape[unknownIndex] = missing;
}
const outputSize = sizeFromShape(outputShape);
if (outputSize !== denseSize) {
throw new Error(getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
}
const inputRank = inputShape.length;
const inputStrides = [];
if (inputRank > 0) {
inputStrides[inputRank - 1] = 1;
for (let d = inputRank - 2; d >= 0; --d) {
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
}
}
const outputStrides = [];
if (outputRank > 0) {
outputStrides[outputRank - 1] = 1;
for (let d = outputRank - 2; d >= 0; --d) {
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
}
}
const newIndices = getArrayFromDType(inputDType, nnz * outputRank);
for (let i = 0; i < nnz; ++i) {
let id = 0;
for (let j = 0; j < inputRank; ++j) {
id += inputIndices[i * inputRank + j] * inputStrides[j];
}
for (let j = 0; j < outputRank; ++j) {
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
id %= outputStrides[j];
}
}
return [newIndices, [nnz, outputRank], outputShape];
}
function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
const numIndices = indices.length;
const inputFlat = [inputShape[0], input.length / inputShape[0]];
const numCol = inputFlat[1];
const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
const outputRows = lastSegmentIdPlusOne;
if (outputRows < 0) {
throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
}
const outputShape = inputShape.slice();
outputShape[0] = outputRows;
const outputLength = outputShape.reduce((product, value) => product * value, 1);
const output = getArrayFromDType(inputDType, outputLength);
if (numIndices === 0) {
if (outputRows > 0) {
output.fill(defaultValue);
}
return [output, outputShape];
}
if (outputRows <= 0) {
throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
}
let start = 0, end = 1;
let uninitializedIndex = 0;
let outIndex = segmentIds[start];
while (true) {
let nextIndex = 0;
if (end < numIndices) {
nextIndex = segmentIds[end];
if (outIndex === nextIndex) {
++end;
continue;
}
if (outIndex >= nextIndex) {
throw new Error(getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
}
}
if (outIndex < 0 || outIndex >= outputRows) {
throw new Error(getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
}
if (outIndex > uninitializedIndex) {
output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
}
for (let i = start; i < end; ++i) {
const index = indices[i];
if (index < 0 || index >= inputFlat[0]) {
throw new Error(getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
}
for (let j = 0; j < numCol; j++) {
output[outIndex * numCol + j] += input[index * numCol + j];
}
}
if (isMean) {
for (let j = 0; j < numCol; j++) {
output[outIndex * numCol + j] /= end - start;
}
}
start = end;
++end;
uninitializedIndex = outIndex + 1;
outIndex = nextIndex;
if (end > numIndices) {
break;
}
}
if (uninitializedIndex < outputRows) {
output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
}
return [output, outputShape];
}
const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi));
const sqrt$1 = unaryKernelFunc$1(Sqrt, (xi) => Math.sqrt(xi));
const sqrtConfig$1 = {
kernelName: Sqrt,
backendName: 'cpu',
kernelFunc: sqrt$1,
};
const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => {
const diff = a - b;
return diff * diff;
}));
const squaredDifference$1 = binaryKernelFunc$1(SquaredDifference, squaredDifferenceImpl);
const squaredDifferenceConfig$1 = {
kernelName: SquaredDifference,
backendName: 'cpu',
kernelFunc: squaredDifference$1
};
const staticRegexReplaceImpl = createSimpleUnaryImpl((x, attrs) => {
const { pattern, replaceGlobal, rewrite } = attrs;
return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
});
const staticRegexReplace$1 = unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
const staticRegexReplaceConfig$1 = {
kernelName: StaticRegexReplace,
backendName: 'cpu',
kernelFunc: staticRegexReplace$1,
};
function stridedSliceImpl(outShape, xBuf, strides, begin) {
const outBuf = buffer(outShape, xBuf.dtype);
for (let i = 0; i < outBuf.size; i++) {
const loc = outBuf.indexToLoc(i);
const newLoc = new Array(loc.length);
for (let j = 0; j < newLoc.length; j++) {
newLoc[j] = loc[j] * strides[j] + begin[j];
}
outBuf.set(xBuf.get(...newLoc), ...loc);
}
return outBuf;
}
class StringNGramsOp {
constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
this.separator = encodeString(separator);
this.nGramWidths = nGramWidths;
this.leftPad = encodeString(leftPad);
this.rightPad = encodeString(rightPad);
this.padWidth = padWidth;
this.preserveShort = preserveShortSequences;
}
getPadWidth(nGramWidth) {
return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
}
getNumNGrams(length, nGramWidth) {
const padWidth = this.getPadWidth(nGramWidth);
return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1);
}
createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
const padWidth = this.getPadWidth(nGramWidth);
const leftPadding = Math.max(0, padWidth - nGramIndex);
const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
const numTokens = nGramWidth - (leftPadding + rightPadding);
const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
let nGramSize = 0;
nGramSize += leftPadding * this.leftPad.length;
for (let n = 0; n < numTokens; ++n) {
nGramSize += data[dataStartIndex + n].length;
}
nGramSize += rightPadding * this.rightPad.length;
const numSeparators = leftPadding + rightPadding + numTokens - 1;
nGramSize += numSeparators * this.separator.length;
output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
const nGram = output[outputStartIndex + nGramIndex];
let nextNGramIndex = 0;
const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value);
for (let n = 0; n < leftPadding; ++n) {
appendToNGram(this.leftPad);
appendToNGram(this.separator);
}
for (let n = 0; n < numTokens - 1; ++n) {
appendToNGram(data[dataStartIndex + n]);
appendToNGram(this.separator);
}
if (numTokens > 0) {
appendToNGram(data[dataStartIndex + numTokens - 1]);
for (let n = 0; n < rightPadding; ++n) {
appendToNGram(this.separator);
appendToNGram(this.rightPad);
}
}
else {
for (let n = 0; n < rightPadding - 1; ++n) {
appendToNGram(this.rightPad);
appendToNGram(this.separator);
}
appendToNGram(this.rightPad);
}
}
}
compute(data, splits) {
const inputDataSize = data.length;
const splitsSize = splits.length;
if (splitsSize > 0) {
let prevSplit = splits[0];
if (prevSplit !== 0) {
throw new Error(`First split value must be 0, got ${prevSplit}`);
}
for (let i = 1; i < splitsSize; ++i) {
let validSplits = splits[i] >= prevSplit;
validSplits = validSplits && (splits[i] <= inputDataSize);
if (!validSplits) {
throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`);
}
prevSplit = splits[i];
}
if (prevSplit !== inputDataSize) {
throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`);
}
}
const numBatchItems = splitsSize - 1;
const nGramsSplits = getArrayFromDType('int32', splitsSize);
if (inputDataSize === 0 || splitsSize === 0) {
const empty = new Array(inputDataSize);
for (let i = 0; i <= numBatchItems; ++i) {
nGramsSplits[i] = 0;
}
return [empty, nGramsSplits];
}
nGramsSplits[0] = 0;
for (let i = 1; i <= numBatchItems; ++i) {
const length = splits[i] - splits[i - 1];
let numNGrams = 0;
this.nGramWidths.forEach((nGramWidth) => {
numNGrams += this.getNumNGrams(length, nGramWidth);
});
if (this.preserveShort && length > 0 && numNGrams === 0) {
numNGrams = 1;
}
nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams;
}
const nGrams = new Array(nGramsSplits[numBatchItems]);
for (let i = 0; i < numBatchItems; ++i) {
const splitIndex = splits[i];
let outputStartIdx = nGramsSplits[i];
this.nGramWidths.forEach((nGramWidth) => {
const length = splits[i + 1] - splits[i];
const numNGrams = this.getNumNGrams(length, nGramWidth);
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
outputStartIdx += numNGrams;
});
if (this.preserveShort && outputStartIdx === nGramsSplits[i]) {
const dataLength = splits[i + 1] - splits[i];
if (dataLength === 0) {
continue;
}
const nGramWidth = dataLength + 2 * this.padWidth;
const numNGrams = 1;
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
}
}
return [nGrams, nGramsSplits];
}
}
function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences)
.compute(data, dataSplits);
}
function split(str, delimiters, skipEmpty, result) {
if (!str.length) {
return;
}
if (delimiters.length === 0) {
for (let i = 0; i < str.length; ++i) {
result.push(str.subarray(i, i + 1));
}
return;
}
if (delimiters.length === 1) {
const delimiter = delimiters[0];
let f = str.indexOf(delimiter);
while (f !== -1) {
const token = str.subarray(0, f);
if (!skipEmpty || token.length !== 0) {
result.push(token);
}
str = str.subarray(f + 1);
f = str.indexOf(delimiter);
}
if (!skipEmpty || str.length !== 0) {
result.push(str);
}
return;
}
let tokenStart = 0;
for (let i = 0; i < str.length + 1; i++) {
if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) {
const token = str.subarray(tokenStart, i);
if (!skipEmpty || token.length !== 0) {
result.push(token);
}
tokenStart = i + 1;
}
}
}
function stringSplitImpl(input, delimiter, skipEmpty) {
const batchSize = input.length;
const tokens = [];
let outputSize = 0;
let maxNumEntries = 0;
const numIndices = new Array(batchSize);
for (let i = 0; i < batchSize; ++i) {
const prevTokensLength = tokens.length;
split(input[i], delimiter, skipEmpty, tokens);
const nEntries = tokens.length - prevTokensLength;
numIndices[i] = nEntries;
outputSize += nEntries;
maxNumEntries = Math.max(maxNumEntries, nEntries);
}
const indices = getArrayFromDType('int32', outputSize * 2);
const values = new Array(outputSize);
const shape = [batchSize, maxNumEntries];
let c = 0;
for (let i = 0; i < batchSize; ++i) {
for (let j = 0; j < numIndices[i]; ++j) {
indices[c * 2] = i;
indices[c * 2 + 1] = j;
values[c] = tokens[c];
++c;
}
}
return [indices, values, shape];
}
function stringToHashBucketFastImpl(input, numBuckets) {
const output = getArrayFromDType('int32', input.length);
for (let i = 0; i < input.length; ++i) {
output[i] =
fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
}
return output;
}
const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
const subComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
return { real: aReal - bReal, imag: aImag - bImag };
}));
const sub$1 = binaryKernelFunc$1(Sub, subImpl, subComplexImpl);
const subConfig$1 = {
kernelName: Sub,
backendName: 'cpu',
kernelFunc: sub$1
};
function tileImpl(xBuf, reps) {
const newShape = new Array(xBuf.rank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xBuf.shape[i] * reps[i];
}
const result = buffer(newShape, xBuf.dtype);
for (let i = 0; i < result.values.length; ++i) {
const newLoc = result.indexToLoc(i);
const originalLoc = new Array(xBuf.rank);
for (let j = 0; j < originalLoc.length; j++) {
originalLoc[j] = newLoc[j] % xBuf.shape[j];
}
const originalIndex = xBuf.locToIndex(originalLoc);
result.values[i] = xBuf.values[originalIndex];
}
return result;
}
const comparePair = (a, b) => {
const valueDiff = b.value - a.value;
return valueDiff === 0 ? a.index - b.index : valueDiff;
};
function select$2(array, k, left = 0, right = array.length - 1) {
while (right > left) {
if (right - left > 600) {
const n = right - left + 1;
const i = k - left + 1;
const z = Math.log(n);
const s = 0.5 * Math.exp(2 * z / 3);
const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2);
const newLeft = Math.max(left, Math.floor(k - i * s / n + sd));
const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd));
select$2(array, k, newLeft, newRight);
}
const t = array[k];
let i = left;
let j = right;
swap(array, left, k);
if (comparePair(array[right], t) > 0) {
swap(array, left, right);
}
while (i < j) {
swap(array, i, j);
i++;
j--;
while (comparePair(array[i], t) < 0) {
i = i + 1;
}
while (comparePair(array[j], t) > 0) {
j = j - 1;
}
}
if (comparePair(array[left], t) === 0) {
swap(array, left, j);
}
else {
j = j + 1;
swap(array, j, right);
}
if (j <= k) {
left = j + 1;
}
if (k <= j) {
right = j - 1;
}
}
}
function topKImpl(x, xShape, xDtype, k, sorted) {
const lastDim = xShape[xShape.length - 1];
const [batch, size] = [x.length / lastDim, lastDim];
const allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
const allTopKIndices = getTypedArrayFromDType('int32', batch * k);
for (let b = 0; b < batch; b++) {
const offset = b * size;
const vals = x.subarray(offset, offset + size);
let valAndInd = new Array(vals.length);
vals.forEach((value, index) => valAndInd[index] = { value, index });
if (k < valAndInd.length) {
select$2(valAndInd, k);
valAndInd = valAndInd.slice(0, k);
}
if (sorted) {
valAndInd.sort(comparePair);
}
const outOffset = b * k;
const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
for (let i = 0; i < k; i++) {
topKVals[i] = valAndInd[i].value;
topKIndices[i] = valAndInd[i].index;
}
}
const outputShape = xShape.slice();
outputShape[outputShape.length - 1] = k;
return [
buffer(outputShape, xDtype, allTopKVals),
buffer(outputShape, 'int32', allTopKIndices)
];
}
function uniqueImpl(values, axis, shape, dtype) {
const $axis = parseAxisParam(axis, shape)[0];
const newShape = [1, shape[0], 1];
for (let i = 0; i < $axis; i++) {
newShape[0] *= shape[i];
}
newShape[1] = shape[$axis];
for (let i = $axis + 1; i < shape.length; i++) {
newShape[2] *= shape[i];
}
const uniqueElements = new Map();
const indices = new Int32Array(shape[$axis]);
const inputBuffer = new TensorBuffer(newShape, dtype, values);
const uniqueIndices = [];
const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
for (let i = 0; i < shape[$axis]; i++) {
let element;
if (is1DTensor) {
element = values[i].toString();
}
else {
const axisValues = [];
for (let m = 0; m < newShape[0]; m++) {
for (let n = 0; n < newShape[2]; n++) {
axisValues.push(inputBuffer.get(m, i, n));
}
}
element = axisValues.join(',');
}
const existingIndex = uniqueElements.get(element);
if (existingIndex != null) {
indices[i] = existingIndex;
}
else {
const uniqueIndex = uniqueElements.size;
uniqueElements.set(element, uniqueIndex);
indices[i] = uniqueIndex;
uniqueIndices.push(i);
}
}
const outputTmpShape = newShape.slice();
outputTmpShape[1] = uniqueElements.size;
const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
uniqueIndices.forEach((uniqueElementIndex, i) => {
for (let m = 0; m < newShape[0]; m++) {
for (let n = 0; n < newShape[2]; n++) {
outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
}
}
});
const outputShape = shape.slice();
outputShape[$axis] = outputTmpShape[1];
return {
outputValues: outputBuffer.values,
outputShape,
indices,
};
}
var shared = Object.freeze({
__proto__: null,
addImpl: addImpl,
bincountImpl: bincountImpl,
bincountReduceImpl: bincountReduceImpl,
bitwiseAndImpl: bitwiseAndImpl,
castImpl: castImpl,
ceilImpl: ceilImpl,
concatImpl: concatImpl$1,
equalImpl: equalImpl,
expImpl: expImpl,
expm1Impl: expm1Impl,
floorDivImpl: floorDivImpl,
floorImpl: floorImpl,
gatherNdImpl: gatherNdImpl,
gatherV2Impl: gatherV2Impl,
greaterEqualImpl: greaterEqualImpl,
greaterImpl: greaterImpl,
lessEqualImpl: lessEqualImpl,
lessImpl: lessImpl,
linSpaceImpl: linSpaceImpl,
logImpl: logImpl,
maxImpl: maxImpl$1,
maximumImpl: maximumImpl,
minimumImpl: minimumImpl,
multiplyImpl: multiplyImpl,
negImpl: negImpl,
notEqualImpl: notEqualImpl,
prodImpl: prodImpl,
raggedGatherImpl: raggedGatherImpl,
raggedRangeImpl: raggedRangeImpl,
raggedTensorToTensorImpl: raggedTensorToTensorImpl,
rangeImpl: rangeImpl,
rsqrtImpl: rsqrtImpl,
scatterImpl: scatterImpl,
sigmoidImpl: sigmoidImpl,
simpleAbsImpl: simpleAbsImpl,
sliceImpl: sliceImpl,
sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl,
sparseReshapeImpl: sparseReshapeImpl,
sparseSegmentReductionImpl: sparseSegmentReductionImpl,
sqrtImpl: sqrtImpl,
squaredDifferenceImpl: squaredDifferenceImpl,
staticRegexReplaceImpl: staticRegexReplaceImpl,
stridedSliceImpl: stridedSliceImpl,
stringNGramsImpl: stringNGramsImpl,
stringSplitImpl: stringSplitImpl,
stringToHashBucketFastImpl: stringToHashBucketFastImpl,
subImpl: subImpl,
tileImpl: tileImpl,
topKImpl: topKImpl,
transposeImpl: transposeImpl$1,
uniqueImpl: uniqueImpl
});
const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, bitwiseAndImpl: bitwiseAndImplCPU, castImpl: castImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, equalImpl: equalImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterImpl: greaterImplCPU, greaterEqualImpl: greaterEqualImplCPU, lessImpl: lessImplCPU, lessEqualImpl: lessEqualImplCPU, linSpaceImpl: linSpaceImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, maximumImpl: maximumImplCPU, minimumImpl: minimumImplCPU, multiplyImpl: multiplyImplCPU, negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, raggedGatherImpl: raggedGatherImplCPU, raggedRangeImpl: raggedRangeImplCPU, raggedTensorToTensorImpl: raggedTensorToTensorImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, scatterImpl: scatterImplCPU, sigmoidImpl: sigmoidImplCPU, simpleAbsImpl: simpleAbsImplCPU, sliceImpl: sliceImplCPU, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImplCPU, sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, staticRegexReplaceImpl: staticRegexReplaceImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, stringToHashBucketFastImpl: stringToHashBucketFastImplCPU, subImpl: subImplCPU, tileImpl: tileImplCPU, topKImpl: topKImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared;
function getVecChannels(name, rank) {
return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
}
function getChannels(name, rank) {
if (rank === 1) {
return [name];
}
return getVecChannels(name, rank);
}
function getSourceCoords$2(rank, dims) {
if (rank === 1) {
return 'rc';
}
let coords = '';
for (let i = 0; i < rank; i++) {
coords += dims[i];
if (i < rank - 1) {
coords += ',';
}
}
return coords;
}
class PackProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
this.outputShape = outputShape;
this.rank = outputShape.length;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
if (this.rank === 0) {
this.userCode = `
void main() {
setOutput(vec4(getA(), 0., 0., 0.));
}
`;
}
else {
const channels = getChannels('rc', this.rank);
const dtype = getCoordsDataType(this.rank);
const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
const setup = this.getSetup(channels);
const output = this.getOutput(channels);
this.userCode = `
void main() {
${dtype} rc = getOutputCoords();
if(${outOfBoundsCondition}) {
setOutput(vec4(0));
} else {
${setup}
setOutput(vec4(${output}));
}
}
`;
}
}
getSourceCoordsArr(dims) {
const coords = [];
for (let row = 0; row <= 1; row++) {
for (let col = 0; col <= 1; col++) {
let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
for (let d = 2; d < this.rank; d++) {
coord = `${dims[dims.length - 1 - d]},` + coord;
}
coords.push(coord);
}
}
return coords;
}
getOutOfBoundsCondition(dims) {
if (this.rank === 1) {
return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;
}
let cond = '';
for (let i = this.rank - 2; i < this.rank; i++) {
cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;
if (i < this.rank - 1) {
cond += '||';
}
}
return cond;
}
getSetup(dims) {
if (this.rank === 1) {
return '';
}
const innerDims = dims.slice(-2);
const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :
this.outputShape[this.rank - 1];
const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :
this.outputShape[this.rank - 2];
return `
int r = ${innerDims[0]};
int c = ${innerDims[1]};
int rp1 = r + 1;
int cp1 = c + 1;
bool cEdge = cp1 >= ${col};
bool rEdge = rp1 >= ${row};
`;
}
getOutput(dims) {
const sourceCoords = this.getSourceCoordsArr(dims);
if (this.rank === 1) {
const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;
}
return `getA(${sourceCoords[0]}),
cEdge ? 0. : getA(${sourceCoords[1]}),
rEdge ? 0. : getA(${sourceCoords[2]}),
rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
}
}
class ReshapePackedProgram {
constructor(outputShape, inputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }];
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
let mainLoop = ``;
for (let i = 0; i < 4; i++) {
let thisRC = `thisRC = rc;`;
if (i % 2 === 1) {
thisRC += `thisRC.z += 1;`;
}
if (i > 1) {
thisRC += `thisRC.y += 1;`;
}
mainLoop += `
${thisRC}
${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
int flatIndex = getFlatIndex(thisRC);
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
result[${i}] =
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
${i > 0 ? '}' : ''}
`;
}
this.userCode = `
${getReshapedInputCoords(inputShape, this.enableShapeUniforms)}
${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
getFlatIndexFrom3D(outputShape)}
void main() {
ivec3 rc = getOutputCoords();
vec4 result = vec4(0.);
ivec3 thisRC;
int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]};
int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]};
${mainLoop}
setOutput(result);
}
`;
}
}
function getReshapedInputCoords(shape, enableShapeUniforms) {
const coordsFromIndexSnippet = enableShapeUniforms ?
getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') :
getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
return `
ivec3 inputCoordsFromReshapedOutCoords(int index) {
${coordsFromIndexSnippet}
return ivec3(r, c, d);
}
`;
}
class TextureManager {
constructor(gpgpu) {
this.gpgpu = gpgpu;
this.numUsedTextures = 0;
this.numFreeTextures = 0;
this._numBytesAllocated = 0;
this._numBytesFree = 0;
this.freeTextures = {};
this.usedTextures = {};
this.logEnabled = false;
}
acquireTexture(shapeRC, usage, isPacked) {
const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
if (!(shapeKey in this.freeTextures)) {
this.freeTextures[shapeKey] = [];
}
if (!(shapeKey in this.usedTextures)) {
this.usedTextures[shapeKey] = [];
}
const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
if (this.freeTextures[shapeKey].length > 0) {
this.numFreeTextures--;
this.numUsedTextures++;
this._numBytesFree -= texBytes;
this.log();
const newTexture = this.freeTextures[shapeKey].pop();
this.usedTextures[shapeKey].push(newTexture);
return newTexture;
}
let newTexture;
if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
}
else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
newTexture =
this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
}
else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
newTexture =
this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
}
else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
newTexture =
this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
}
else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
newTexture =
this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
}
this.usedTextures[shapeKey].push(newTexture);
this.numUsedTextures++;
this._numBytesAllocated += texBytes;
this.log();
return newTexture;
}
releaseTexture(texture, shape, logicalTexType, isPacked) {
if (this.freeTextures == null) {
return;
}
const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
if (!(shapeKey in this.freeTextures)) {
this.freeTextures[shapeKey] = [];
}
const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
const deleteTexThreshold = env()
.getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD');
if (deleteTexThreshold !== -1 &&
this._numBytesAllocated > deleteTexThreshold) {
this.gpgpu.deleteMatrixTexture(texture.texture);
this._numBytesAllocated -= texBytes;
}
else {
this.freeTextures[shapeKey].push(texture);
this.numFreeTextures++;
this._numBytesFree += texBytes;
}
this.numUsedTextures--;
const texList = this.usedTextures[shapeKey];
const texIndex = texList && texList.indexOf(texture);
if (texIndex == null || texIndex < 0) {
throw new Error('Cannot release a texture that was never provided by this ' +
'texture manager');
}
texList[texIndex] = texList[texList.length - 1];
texList.pop();
this.log();
}
log() {
if (!this.logEnabled) {
return;
}
const total = this.numFreeTextures + this.numUsedTextures;
console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`);
const freeRatio = this._numBytesFree / this._numBytesAllocated;
console.log(`Bytes allocated: ${this._numBytesAllocated}`);
console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`);
}
get numBytesAllocated() {
return this._numBytesAllocated;
}
get numBytesFree() {
return this._numBytesFree;
}
getNumUsedTextures() {
return this.numUsedTextures;
}
getNumFreeTextures() {
return this.numFreeTextures;
}
dispose() {
if (this.freeTextures == null) {
return;
}
for (const texShape in this.freeTextures) {
this.freeTextures[texShape].forEach(tex => {
this.gpgpu.deleteMatrixTexture(tex.texture);
});
}
for (const texShape in this.usedTextures) {
this.usedTextures[texShape].forEach(tex => {
this.gpgpu.deleteMatrixTexture(tex.texture);
});
}
this.freeTextures = null;
this.usedTextures = null;
this.numUsedTextures = 0;
this.numFreeTextures = 0;
this._numBytesAllocated = 0;
this._numBytesFree = 0;
}
}
function numBytesForInternalFormat(gl, internalFormat) {
const glany = gl;
if (internalFormat === glany.R32F) {
return 4;
}
else if (internalFormat === glany.R16F) {
return 2;
}
else if (internalFormat === glany.RGBA32F) {
return 16;
}
else if (internalFormat === gl.RGBA) {
return 16;
}
else if (internalFormat === glany.RGBA16F) {
return 8;
}
else if (internalFormat === glany.RGBA8) {
return 4;
}
throw new Error(`Unknown internal format ${internalFormat}`);
}
function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
let numElements;
if (isPacked) {
const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
numElements = packedWidth * packedHeight;
}
else {
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
numElements = width * height;
}
const bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
return numElements * bytesPerElement;
}
function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
switch (physicalTexType) {
case PhysicalTextureType.PACKED_2X2_FLOAT32:
return getInternalFormatForPackedMatrixTexture(textureConfig);
case PhysicalTextureType.PACKED_2X2_FLOAT16:
return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
case PhysicalTextureType.UNPACKED_FLOAT32:
return getInternalFormatForFloat32MatrixTexture(textureConfig);
case PhysicalTextureType.UNPACKED_FLOAT16:
return getInternalFormatForFloat16MatrixTexture(textureConfig);
case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
default:
throw new Error(`Unknown physical texture type ${physicalTexType}`);
}
}
function getPhysicalTextureForRendering(isPacked) {
if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
if (isPacked) {
return PhysicalTextureType.PACKED_2X2_FLOAT32;
}
return PhysicalTextureType.UNPACKED_FLOAT32;
}
if (isPacked) {
return PhysicalTextureType.PACKED_2X2_FLOAT16;
}
return PhysicalTextureType.UNPACKED_FLOAT16;
}
function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
if (logicalTexType === TextureUsage.UPLOAD) {
return PhysicalTextureType.PACKED_2X2_FLOAT32;
}
else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
return getPhysicalTextureForRendering(isPacked);
}
else if (logicalTexType === TextureUsage.DOWNLOAD ||
logicalTexType === TextureUsage.PIXELS) {
return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
}
throw new Error(`Unknown logical texture type ${logicalTexType}`);
}
function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`;
}
class UnaryOpProgram {
constructor(aShape, opSnippet) {
this.variableNames = ['A'];
this.outputShape = aShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = `
float unaryOperation(float x) {
${opSnippet}
}
void main() {
float x = getAAtOutCoords();
float y = unaryOperation(x);
setOutput(y);
}
`;
}
}
const CHECK_NAN_SNIPPET$1 = `if (isnan(x)) return x;`;
const LINEAR$1 = `return x;`;
const ABS$1 = `return abs(x);`;
const ELU$2 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
const RELU$2 = CHECK_NAN_SNIPPET$1 + `
return (x < 0.0) ? 0.0 : x;
`;
const RELU6$2 = CHECK_NAN_SNIPPET$1 + `
return (x < 0.0) ? 0.0 : min(6.0, x);
`;
const CLONE = 'return x;';
const SIGMOID$2 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
const LINEAR = `return x;`;
const ELU$1 = `
vec4 result;
result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
return result;
`;
const RELU$1 = `
vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
const RELU6$1 = `
vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
const SIGMOID$1 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
class UnaryOpPackedProgram {
constructor(aShape, opSnippet) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = aShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = `
vec4 unaryOperation(vec4 x) {
${opSnippet}
}
void main() {
vec4 x = getAAtOutCoords();
vec4 y = unaryOperation(x);
setOutput(y);
}
`;
}
}
class UnpackProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = false;
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const rank = outputShape.length;
const channels = getChannels('rc', rank);
const dtype = getCoordsDataType(rank);
const sourceCoords = getSourceCoords$2(rank, channels);
const innerDims = channels.slice(-2);
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
this.userCode = `
void main() {
${dtype} rc = getOutputCoords();
vec4 packedInput = getA(${sourceCoords});
setOutput(getChannel(packedInput, ${coords}));
}
`;
}
}
const whereImpl$1 = whereImpl$2;
const EPSILON_FLOAT32 = 1e-7;
const EPSILON_FLOAT16 = 1e-4;
const binaryCaches = {};
function getBinaryCache(webGLVersion) {
if (webGLVersion in binaryCaches) {
return binaryCaches[webGLVersion];
}
binaryCaches[webGLVersion] = {};
return binaryCaches[webGLVersion];
}
const CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
const BEFORE_PAGING_CONSTANT = 600;
function numMBBeforeWarning() {
if (env().global.screen == null) {
return 1024;
}
return (env().global.screen.height * env().global.screen.width *
window.devicePixelRatio) *
BEFORE_PAGING_CONSTANT / 1024 / 1024;
}
class MathBackendWebGL extends KernelBackend {
nextDataId() {
return MathBackendWebGL.nextDataId++;
}
constructor(gpuResource) {
super();
this.pendingRead = new WeakMap();
this.pendingDisposal = new WeakSet();
this.dataRefCount = new WeakMap();
this.numBytesInGPU = 0;
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
this.lastGlFlushTime = 0;
this.warnedAboutMemory = false;
this.pendingDeletes = 0;
this.disposed = false;
if (!env().getBool('HAS_WEBGL')) {
throw new Error('WebGL is not supported on this device');
}
let newGPGPU;
if (gpuResource != null) {
if (gpuResource instanceof GPGPUContext) {
newGPGPU = gpuResource;
}
else {
const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'), gpuResource);
newGPGPU = new GPGPUContext(gl);
}
this.binaryCache = {};
this.gpgpuCreatedLocally = false;
}
else {
const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
newGPGPU = new GPGPUContext(gl);
this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
this.gpgpuCreatedLocally = true;
}
this.gpgpu = newGPGPU;
this.canvas = this.gpgpu.gl.canvas;
this.textureManager = new TextureManager(this.gpgpu);
this.numMBBeforeWarning = numMBBeforeWarning();
this.texData = new DataStorage(this, engine());
}
numDataIds() {
return this.texData.numDataIds() - this.pendingDeletes;
}
writeTexture(texture, shape, dtype, texHeight, texWidth, channels) {
const input = this.makeTensorInfo(shape, dtype);
const inData = this.texData.get(input.dataId);
inData.isPacked = false;
inData.texture = { texture, texShape: [texHeight, texWidth] };
inData.texShape = [texHeight, texWidth];
const shapeAs3D = getShapeAs3D(shape);
const program = new EncodeMatrixProgram(shapeAs3D, false , channels);
const output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]);
output.shape = shape;
inData.texture = null;
this.disposeIntermediateTensorInfo(input);
return output.dataId;
}
write(values, shape, dtype) {
if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') ||
env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64' && values != null) {
throw new Error(`Cannot write to a complex64 dtype. ` +
`Please use tf.complex(real, imag).`);
}
const dataId = { id: this.nextDataId() };
this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1 });
return dataId;
}
refCount(dataId) {
if (this.texData.has(dataId)) {
const tensorData = this.texData.get(dataId);
return tensorData.refCount;
}
return 0;
}
incRef(dataId) {
const texData = this.texData.get(dataId);
texData.refCount++;
}
decRef(dataId) {
if (this.texData.has(dataId)) {
const texData = this.texData.get(dataId);
texData.refCount--;
}
}
move(dataId, values, shape, dtype, refCount) {
if (env().getBool('DEBUG')) {
this.checkNumericalProblems(values);
}
if (dtype === 'complex64') {
throw new Error(`Cannot write to a complex64 dtype. ` +
`Please use tf.complex(real, imag).`);
}
this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount });
}
disposeIntermediateTensorInfo(tensorInfo) {
this.disposeData(tensorInfo.dataId);
}
readSync(dataId) {
const texData = this.texData.get(dataId);
const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData;
if (slice != null) {
let program;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
}
else {
program = new UnaryOpProgram(shape, CLONE);
}
const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
const data = this.readSync(res.dataId);
this.disposeIntermediateTensorInfo(res);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (dtype === 'string') {
return values;
}
const shouldTimeProgram = this.activeTimers != null;
let start;
if (shouldTimeProgram) {
start = now();
}
let result;
if (dtype === 'complex64') {
const realValues = this.readSync(complexTensorInfos.real.dataId);
const imagValues = this.readSync(complexTensorInfos.imag.dataId);
result = mergeRealAndImagArrays(realValues, imagValues);
}
else {
result = this.getValuesFromTexture(dataId);
}
if (shouldTimeProgram) {
this.downloadWaitMs += now() - start;
}
return this.convertAndCacheOnCPU(dataId, result);
}
async read(dataId) {
if (this.pendingRead.has(dataId)) {
const subscribers = this.pendingRead.get(dataId);
return new Promise(resolve => subscribers.push(resolve));
}
const texData = this.texData.get(dataId);
const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData;
if (slice != null) {
let program;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
}
else {
program = new UnaryOpProgram(shape, CLONE);
}
const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
const data = this.read(res.dataId);
this.disposeIntermediateTensorInfo(res);
return data;
}
if (values != null) {
return this.convertAndCacheOnCPU(dataId);
}
if (env().getBool('DEBUG')) {
if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
env().getNumber('WEBGL_VERSION') === 2) {
throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +
`WEBGL_VERSION=2 not yet supported.`);
}
}
let buffer = null;
let tmpDownloadTarget;
if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
tmpDownloadTarget = this.decode(dataId);
const tmpData = this.texData.get(tmpDownloadTarget.dataId);
buffer = this.gpgpu.createBufferFromTexture(tmpData.texture.texture, ...getDenseTexShape(shape));
}
this.pendingRead.set(dataId, []);
if (dtype !== 'complex64') {
await this.gpgpu.createAndWaitForFence();
}
let vals;
if (dtype === 'complex64') {
const ps = await Promise.all([
this.read(complexTensorInfos.real.dataId),
this.read(complexTensorInfos.imag.dataId)
]);
const realValues = ps[0];
const imagValues = ps[1];
vals = mergeRealAndImagArrays(realValues, imagValues);
}
else if (buffer == null) {
vals = this.getValuesFromTexture(dataId);
}
else {
const size = sizeFromShape(shape);
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
}
if (tmpDownloadTarget != null) {
this.disposeIntermediateTensorInfo(tmpDownloadTarget);
}
if (buffer != null) {
const gl = this.gpgpu.gl;
callAndCheck(gl, () => gl.deleteBuffer(buffer));
}
const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
const subscribers = this.pendingRead.get(dataId);
this.pendingRead.delete(dataId);
subscribers.forEach(resolve => resolve(dTypeVals));
if (this.pendingDisposal.has(dataId)) {
this.pendingDisposal.delete(dataId);
if (this.disposeData(dataId)) {
engine().removeDataId(dataId, this);
}
this.pendingDeletes--;
}
return dTypeVals;
}
readToGPU(dataId, options = {}) {
const texData = this.texData.get(dataId);
const { values, shape, slice, dtype, isPacked, texture } = texData;
if (dtype === 'complex64') {
throw new Error('Does not support reading texture for complex64 dtype.');
}
if (slice != null) {
let program;
if (isPacked) {
program = new UnaryOpPackedProgram(shape, CLONE);
}
else {
program = new UnaryOpProgram(shape, CLONE);
}
const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
const gpuResouorce = this.readToGPU(res, options);
this.disposeIntermediateTensorInfo(res);
return gpuResouorce;
}
if (texture == null) {
if (values != null) {
throw new Error('Data is not on GPU but on CPU.');
}
else {
throw new Error('There is no data on GPU or CPU.');
}
}
const tmpTarget = this.decode(dataId, options.customTexShape);
const tensorRef = engine().makeTensorFromTensorInfo(tmpTarget);
const tmpData = this.texData.get(tmpTarget.dataId);
return Object.assign({ tensorRef }, tmpData.texture);
}
bufferSync(t) {
const data = this.readSync(t.dataId);
if (t.dtype === 'string') {
try {
const strings = data.map(d => decodeString(d));
return buffer(t.shape, t.dtype, strings);
}
catch (_a) {
throw new Error('Failed to decode encoded string bytes into utf-8');
}
}
return buffer(t.shape, t.dtype, data);
}
checkNumericalProblems(values) {
if (values == null) {
return;
}
for (let i = 0; i < values.length; i++) {
const num = values[i];
if (!canBeRepresented(num)) {
if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
throw Error(`The value ${num} cannot be represented with your ` +
`current settings. Consider enabling float32 rendering: ` +
`'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);
}
throw Error(`The value ${num} cannot be represented on this device.`);
}
}
}
getValuesFromTexture(dataId) {
const { shape, dtype, isPacked } = this.texData.get(dataId);
const size = sizeFromShape(shape);
if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
const tmpTarget = this.decode(dataId);
const tmpData = this.texData.get(tmpTarget.dataId);
const vals = this.gpgpu
.downloadMatrixFromPackedTexture(tmpData.texture.texture, ...getDenseTexShape(shape))
.subarray(0, size);
this.disposeIntermediateTensorInfo(tmpTarget);
return vals;
}
const shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
const program = shouldUsePackedProgram ?
new EncodeFloatPackedProgram(outputShape) :
new EncodeFloatProgram(outputShape);
const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32');
const tmpData = this.texData.get(output.dataId);
const vals = this.gpgpu
.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1])
.subarray(0, size);
this.disposeIntermediateTensorInfo(output);
return vals;
}
timerAvailable() {
return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
}
time(f) {
const oldActiveTimers = this.activeTimers;
const newActiveTimers = [];
let outerMostTime = false;
if (this.programTimersStack == null) {
this.programTimersStack = newActiveTimers;
outerMostTime = true;
}
else {
this.activeTimers.push(newActiveTimers);
}
this.activeTimers = newActiveTimers;
f();
const flattenedActiveTimerQueries = flatten$1(this.activeTimers.map((d) => d.query))
.filter(d => d != null);
const flattenedActiveTimerNames = flatten$1(this.activeTimers.map((d) => d.name))
.filter(d => d != null);
this.activeTimers = oldActiveTimers;
if (outerMostTime) {
this.programTimersStack = null;
}
const res = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: null,
wallMs: null
};
return (async () => {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') >
0) {
const kernelMs = await Promise.all(flattenedActiveTimerQueries);
res['kernelMs'] = sum$3(kernelMs);
res['getExtraProfileInfo'] = () => kernelMs
.map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
.map(d => `${d.name}: ${d.ms}`)
.join(', ');
}
else {
res['kernelMs'] = {
error: 'WebGL query timers are not supported in this environment.'
};
}
this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return res;
})();
}
memory() {
return {
unreliable: false,
numBytesInGPU: this.numBytesInGPU,
numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
numBytesInGPUFree: this.textureManager.numBytesFree
};
}
startTimer() {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.beginQuery();
}
return { startMs: now(), endMs: null };
}
endTimer(query) {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
this.gpgpu.endQuery();
return query;
}
query.endMs = now();
return query;
}
async getQueryTime(query) {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.waitForQueryAndGetTime(query);
}
const timerQuery = query;
return timerQuery.endMs - timerQuery.startMs;
}
disposeData(dataId, force = false) {
if (this.pendingDisposal.has(dataId)) {
return false;
}
if (!this.texData.has(dataId)) {
return true;
}
if (force) {
this.texData.get(dataId).refCount = 0;
}
else {
this.texData.get(dataId).refCount--;
}
if (!force && this.texData.get(dataId).refCount > 0) {
return false;
}
if (this.pendingRead.has(dataId)) {
this.pendingDisposal.add(dataId);
this.pendingDeletes++;
return false;
}
this.releaseGPUData(dataId);
const { complexTensorInfos } = this.texData.get(dataId);
if (complexTensorInfos != null) {
this.disposeData(complexTensorInfos.real.dataId, force);
this.disposeData(complexTensorInfos.imag.dataId, force);
}
this.texData.delete(dataId);
return true;
}
releaseGPUData(dataId) {
const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId);
const key = slice && slice.origDataId || dataId;
const refCount = this.dataRefCount.get(key);
if (refCount > 1) {
this.dataRefCount.set(key, refCount - 1);
}
else {
this.dataRefCount.delete(key);
if (texture != null) {
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
}
}
const texData = this.texData.get(dataId);
texData.texture = null;
texData.texShape = null;
texData.isPacked = false;
texData.slice = null;
}
getTexture(dataId) {
this.uploadToGPU(dataId);
return this.texData.get(dataId).texture.texture;
}
getDataInfo(dataId) {
return this.texData.get(dataId);
}
shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) {
return env().getBool('WEBGL_CPU_FORWARD') &&
inputs.every(input => this.texData.get(input.dataId).texture == null &&
sizeFromShape(input.shape) < sizeThreshold);
}
getGPGPUContext() {
return this.gpgpu;
}
where(condition) {
warn('tf.where() in webgl locks the UI thread. ' +
'Call tf.whereAsync() instead');
const condVals = condition.dataSync();
return whereImpl$1(condition.shape, condVals);
}
packedUnaryOp(x, op, dtype) {
const program = new UnaryOpPackedProgram(x.shape, op);
const outInfo = this.compileAndRun(program, [x], dtype);
return engine().makeTensorFromTensorInfo(outInfo);
}
abs(x) {
if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
return this.makeOutput(x.shape, x.dtype, outValues);
}
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
return this.packedUnaryOp(x, ABS$1, x.dtype);
}
const program = new UnaryOpProgram(x.shape, ABS$1);
const outInfo = this.compileAndRun(program, [x]);
return engine().makeTensorFromTensorInfo(outInfo);
}
makeTensorInfo(shape, dtype, values) {
let dataId;
if (dtype === 'string' && values != null && values.length > 0 &&
isString(values[0])) {
const encodedValues = values.map(d => encodeString(d));
dataId = this.write(encodedValues, shape, dtype);
}
else {
dataId = this.write(values, shape, dtype);
}
this.texData.get(dataId).usage = null;
return { dataId, shape, dtype };
}
makeOutput(shape, dtype, values) {
return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
}
unpackTensor(input) {
const program = new UnpackProgram(input.shape);
return this.runWebGLProgram(program, [input], input.dtype);
}
packTensor(input) {
const program = new PackProgram(input.shape);
const preventEagerUnpackingOutput = true;
return this.runWebGLProgram(program, [input], input.dtype, null , preventEagerUnpackingOutput);
}
packedReshape(input, afterShape) {
const input3DShape = [
getBatchDim(input.shape),
...getRowsCols(input.shape)
];
const input3D = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
const afterShapeAs3D = [
getBatchDim(afterShape), ...getRowsCols(afterShape)
];
const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
const preventEagerUnpackingOfOutput = true;
const customValues = [input3DShape];
const output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
}
decode(dataId, customTexShape) {
const texData = this.texData.get(dataId);
const { isPacked, shape, dtype } = texData;
if (customTexShape != null) {
const size = sizeFromShape(shape);
const texSize = customTexShape[0] * customTexShape[1] * 4;
assert$1(size <= texSize, () => 'customTexShape is too small. ' +
'Row * Column * 4 should be equal or larger than the ' +
'size of the tensor data.');
}
const shapeAs3D = getShapeAs3D(shape);
let program;
if (isPacked) {
program = new DecodeMatrixPackedProgram(shapeAs3D);
}
else {
program = new DecodeMatrixProgram(shapeAs3D);
}
const preventEagerUnpackingOfOutput = true;
const customValues = [customTexShape != null ? customTexShape :
getDenseTexShape(shapeAs3D)];
const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape);
return { dtype, shape, dataId: out.dataId };
}
runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false, customTexShape) {
const output = this.makeTensorInfo(program.outputShape, outputDtype);
const outData = this.texData.get(output.dataId);
if (program.packedOutput) {
outData.isPacked = true;
}
if (program.outPackingScheme === PackingScheme.DENSE) {
const texelShape = customTexShape != null ?
customTexShape :
getDenseTexShape(program.outputShape);
outData.texShape = texelShape.map(d => d * 2);
}
if (program.outTexUsage != null) {
outData.usage = program.outTexUsage;
}
if (sizeFromShape(output.shape) === 0) {
outData.values =
getTypedArrayFromDType(output.dtype, 0);
return output;
}
const dataToDispose = [];
const inputsData = inputs.map(input => {
if (input.dtype === 'complex64') {
throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` +
`dtypes, please separate the program into real and imaginary ` +
`parts.`);
}
let texData = this.texData.get(input.dataId);
if (texData.texture == null) {
if (!program.packedInputs &&
sizeFromShape(input.shape) <=
env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
return {
shape: input.shape,
texData: null,
isUniform: true,
uniformValues: texData.values
};
}
if (program.packedInputs) {
texData.isPacked = true;
texData.shape = input.shape;
}
}
this.uploadToGPU(input.dataId);
if (!!texData.isPacked !== !!program.packedInputs) {
input = texData.isPacked ? this.unpackTensor(input) :
this.packTensor(input);
dataToDispose.push(input);
texData = this.texData.get(input.dataId);
}
else if (texData.isPacked &&
!isReshapeFree(texData.shape, input.shape)) {
const savedInput = input;
const targetShape = input.shape;
input.shape = texData.shape;
input = this.packedReshape(input, targetShape);
dataToDispose.push(input);
texData = this.texData.get(input.dataId);
savedInput.shape = targetShape;
}
return { shape: input.shape, texData, isUniform: false };
});
this.uploadToGPU(output.dataId);
const outputData = { shape: output.shape, texData: outData, isUniform: false };
const key = makeShaderKey(program, inputsData, outputData);
const binary = this.getAndSaveBinary(key, () => {
return compileProgram(this.gpgpu, program, inputsData, outputData);
});
const shouldTimeProgram = this.activeTimers != null;
let query;
if (shouldTimeProgram) {
query = this.startTimer();
}
if (!env().get('ENGINE_COMPILE_ONLY')) {
runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues);
}
dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
if (shouldTimeProgram) {
query = this.endTimer(query);
this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
}
const glFlushThreshold = env().getNumber('WEBGL_FLUSH_THRESHOLD');
if (glFlushThreshold > 0) {
const time = now();
if ((time - this.lastGlFlushTime) > glFlushThreshold) {
this.gpgpu.gl.flush();
this.lastGlFlushTime = time;
}
}
if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
preventEagerUnpackingOfOutput === false) {
const unpacked = this.unpackTensor(output);
this.disposeIntermediateTensorInfo(output);
return unpacked;
}
return output;
}
compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false) {
outputDtype = outputDtype || inputs[0].dtype;
const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput);
return outInfo;
}
getAndSaveBinary(key, getBinary) {
if (!(key in this.binaryCache)) {
this.binaryCache[key] = getBinary();
}
return this.binaryCache[key];
}
getTextureManager() {
return this.textureManager;
}
dispose() {
if (this.disposed) {
return;
}
if (!env().getBool('IS_TEST')) {
const allKeys = Object.keys(this.binaryCache);
allKeys.forEach(key => {
this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);
delete this.binaryCache[key];
});
}
this.textureManager.dispose();
if (this.canvas != null &&
(typeof (HTMLCanvasElement) !== 'undefined' &&
this.canvas instanceof HTMLCanvasElement)) {
this.canvas.remove();
}
else {
this.canvas = null;
}
if (this.gpgpuCreatedLocally) {
this.gpgpu.program = null;
this.gpgpu.dispose();
}
this.disposed = true;
}
floatPrecision() {
if (this.floatPrecisionValue == null) {
this.floatPrecisionValue = tidy(() => {
if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
const debugFlag = env().getBool('DEBUG');
env().set('DEBUG', false);
const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];
env().set('DEBUG', debugFlag);
if (underflowCheckValue > 0) {
return 32;
}
}
return 16;
});
}
return this.floatPrecisionValue;
}
epsilon() {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
}
uploadToGPU(dataId) {
const texData = this.texData.get(dataId);
const { shape, dtype, values, texture, usage, isPacked } = texData;
if (texture != null) {
return;
}
const shouldTimeProgram = this.activeTimers != null;
let start;
if (shouldTimeProgram) {
start = now();
}
let texShape = texData.texShape;
if (texShape == null) {
texShape = getTextureShapeFromLogicalShape(shape, isPacked);
texData.texShape = texShape;
}
if (values != null) {
const shapeAs3D = getShapeAs3D(shape);
let program;
let width = texShape[1], height = texShape[0];
const isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray;
if (isPacked || !isByteArray) {
[width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
}
if (isPacked) {
program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray);
}
else {
program = new EncodeMatrixProgram(shapeAs3D, isByteArray);
}
const tempDenseInputTexShape = isByteArray ? [height, width] : texShape;
const tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype);
const tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId);
if (isByteArray) {
tempDenseInputTexData.usage = TextureUsage.PIXELS;
}
else {
tempDenseInputTexData.usage = TextureUsage.UPLOAD;
}
tempDenseInputTexData.texShape = tempDenseInputTexShape;
this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
const customValues = [[height, width]];
const preventEagerUnpacking = true;
const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking);
const outputTexData = this.texData.get(encodedOutputTarget.dataId);
texData.texShape = outputTexData.texShape;
texData.isPacked = outputTexData.isPacked;
texData.usage = outputTexData.usage;
if (!env().get('ENGINE_COMPILE_ONLY')) {
texData.texture = outputTexData.texture;
texData.values = null;
this.texData.delete(encodedOutputTarget.dataId);
}
else {
this.disposeData(encodedOutputTarget.dataId);
}
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
if (shouldTimeProgram) {
this.uploadWaitMs += now() - start;
}
}
else {
const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
texData.texture = newTexture;
}
}
convertAndCacheOnCPU(dataId, float32Values) {
const texData = this.texData.get(dataId);
const { dtype } = texData;
if (float32Values != null) {
texData.values = float32ToTypedArray(float32Values, dtype);
}
return texData.values;
}
acquireTexture(texShape, texType, dtype, isPacked) {
this.numBytesInGPU += this.computeBytes(texShape, dtype);
if (!this.warnedAboutMemory &&
this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
this.warnedAboutMemory = true;
console.warn(`High memory usage in GPU: ${mb} MB, ` +
`most likely due to a memory leak`);
}
return this.textureManager.acquireTexture(texShape, texType, isPacked);
}
computeBytes(shape, dtype) {
return shape[0] * shape[1] * bytesPerElement(dtype);
}
checkCompileCompletion() {
for (const [, binary] of Object.entries(this.binaryCache)) {
this.checkCompletion_(binary);
}
}
async checkCompileCompletionAsync() {
const ps = [];
if (this.gpgpu.parallelCompilationExtension) {
for (const [, binary] of Object.entries(this.binaryCache)) {
ps.push(this.checkCompletionAsync_(binary));
}
return Promise.all(ps);
}
else {
for (const [, binary] of Object.entries(this.binaryCache)) {
const p = new Promise((resolve) => {
try {
this.checkCompletion_(binary);
resolve(true);
}
catch (error) {
throw error;
}
});
ps.push(p);
}
return Promise.all(ps);
}
}
async checkCompletionAsync_(binary) {
if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) {
return this.checkCompletion_(binary);
}
else {
await nextFrame();
return this.checkCompletionAsync_(binary);
}
}
checkCompletion_(binary) {
if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) {
console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram));
if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) {
logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader));
throw new Error('Failed to compile fragment shader.');
}
throw new Error('Failed to link vertex and fragment shaders.');
}
return true;
}
getUniformLocations() {
for (const binary of Object.values(this.binaryCache)) {
this.gpgpu.buildVao(binary.webGLProgram);
const { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation } = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram);
binary.variablesLocations = variablesLocations;
binary.customUniformLocations = customUniformLocations;
binary.infLoc = infLoc;
binary.nanLoc = nanLoc;
binary.outShapeLocation = outShapeLocation;
binary.outShapeStridesLocation = outShapeStridesLocation;
binary.outTexShapeLocation = outTexShapeLocation;
}
}
createTensorFromGPUData(values, shape, dtype) {
values.channels = values.channels || 'RGBA';
const { texture, height, width, channels } = values;
const backend = engine().backend;
if (!backend.gpgpu.gl.isTexture(texture)) {
throw new Error(`The texture is invalid. Also, please make sure the texture and ` +
`the TFJS WebGL backend are using the same canvas. If you want to ` +
`use your own custom canvas, you have to create and use the custom ` +
`TFJS WebGL backend created from the canvas through ` +
`'new tf.MathBackendWebGL(customCanvas)'.`);
}
const dataId = backend.writeTexture(texture, shape, dtype, height, width, channels);
return engine().makeTensorFromDataId(dataId, shape, dtype, backend);
}
}
MathBackendWebGL.nextDataId = 0;
function float32ToTypedArray(a, dtype) {
if (dtype === 'float32' || dtype === 'complex64') {
return a;
}
else if (dtype === 'int32' || dtype === 'bool') {
const result = (dtype === 'int32') ? new Int32Array(a.length) :
new Uint8Array(a.length);
for (let i = 0; i < result.length; ++i) {
result[i] = Math.round(a[i]);
}
return result;
}
else {
throw new Error(`Unknown dtype ${dtype}`);
}
}
if (isBrowser()) {
registerBackend('webgl', () => new MathBackendWebGL(), 2 );
}
const CHECK_NAN_SNIPPET = `
if (isnan(a)) return a;
if (isnan(b)) return b;
`;
class BinaryOpProgram {
constructor(op, aShape, bShape) {
this.variableNames = ['A', 'B'];
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
this.userCode = `
float binaryOperation(float a, float b) {
${op}
}
void main() {
float a = getAAtOutCoords();
float b = getBAtOutCoords();
setOutput(binaryOperation(a, b));
}
`;
}
}
const CHECK_NAN_SNIPPET_PACKED = `
result.r = isNaN.r ? NAN : result.r;
result.g = isNaN.g ? NAN : result.g;
result.b = isNaN.b ? NAN : result.b;
result.a = isNaN.a ? NAN : result.a;
`;
class BinaryOpPackedProgram {
constructor(op, aShape, bShape, checkOutOfBounds = false) {
this.variableNames = ['A', 'B'];
this.supportsBroadcasting = true;
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
const rank = this.outputShape.length;
this.enableShapeUniforms = useShapeUniforms(rank);
let checkOutOfBoundsString = '';
if (checkOutOfBounds) {
if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
checkOutOfBoundsString = `
result.y = 0.;
result.z = 0.;
result.w = 0.;
`;
}
else {
const dtype = getCoordsDataType(rank);
checkOutOfBoundsString = `
${dtype} coords = getOutputCoords();
`;
if (rank === 1) {
if (this.enableShapeUniforms) {
checkOutOfBoundsString += `
result.y = (coords + 1) >= outShape ? 0. : result.y;
result.z = 0.;
result.w = 0.;
`;
}
else {
checkOutOfBoundsString += `
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
result.z = 0.;
result.w = 0.;
`;
}
}
else {
const channels = getChannels('coords', rank);
if (this.enableShapeUniforms) {
checkOutOfBoundsString += `
bool nextRowOutOfBounds =
(${channels[rank - 2]} + 1) >= outShape[${rank} - 2];
bool nextColOutOfBounds =
(${channels[rank - 1]} + 1) >= outShape[${rank} - 1];
result.y = nextColOutOfBounds ? 0. : result.y;
result.z = nextRowOutOfBounds ? 0. : result.z;
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
`;
}
else {
checkOutOfBoundsString += `
bool nextRowOutOfBounds =
(${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
bool nextColOutOfBounds =
(${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};
result.y = nextColOutOfBounds ? 0. : result.y;
result.z = nextRowOutOfBounds ? 0. : result.z;
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
`;
}
}
}
}
this.userCode = `
vec4 binaryOperation(vec4 a, vec4 b) {
${op}
}
void main() {
vec4 a = getAAtOutCoords();
vec4 b = getBAtOutCoords();
vec4 result = binaryOperation(a, b);
${checkOutOfBoundsString}
setOutput(result);
}
`;
}
}
function identity(args) {
const { inputs, backend } = args;
const { x } = inputs;
backend.incRef(x.dataId);
return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
}
const identityConfig = {
kernelName: Identity$1,
backendName: 'webgl',
kernelFunc: identity
};
function complex(args) {
const { inputs, backend } = args;
const { real, imag } = inputs;
const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
const complex = backend.texData.get(complexInfo.dataId);
const realTensorInfo = identity({ inputs: { x: real }, backend });
const imagTensorInfo = identity({ inputs: { x: imag }, backend });
complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo };
return complexInfo;
}
const complexConfig = {
kernelName: Complex,
backendName: 'webgl',
kernelFunc: complex
};
const LEAKYRELU = `return (a < 0.) ? b * a : a;`;
const LEAKYRELU_PACKED = `
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
`;
function leakyRelu$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { alpha } = attrs;
const $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32'));
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) :
new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
const result = backend.runWebGLProgram(program, [x, $alpha], 'float32');
backend.disposeIntermediateTensorInfo($alpha);
return result;
}
const leakyReluConfig$1 = {
kernelName: LeakyRelu,
backendName: 'webgl',
kernelFunc: leakyRelu$1
};
const PRELU = `return (a < 0.) ? b * a : a;`;
const PRELU_PACKED = `
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
`;
function prelu$1(args) {
const { inputs, backend } = args;
const { x, alpha } = inputs;
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) :
new BinaryOpProgram(PRELU, x.shape, alpha.shape);
return backend.runWebGLProgram(program, [x, alpha], 'float32');
}
const preluConfig$1 = {
kernelName: Prelu,
backendName: 'webgl',
kernelFunc: prelu$1
};
const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`;
function unaryKernelFunc({ opSnippet, packedOpSnippet, cpuKernelImpl, dtype }) {
return ({ inputs, backend }) => {
const { x } = inputs;
const webglBackend = backend;
const $dtype = dtype || x.dtype;
if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
const xData = webglBackend.texData.get(x.dataId);
const outValues = cpuKernelImpl(xData.values, $dtype);
return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
}
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
let program;
if (shouldUsePackedProgram) {
program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
}
else {
program = new UnaryOpProgram(x.shape, opSnippet);
}
return webglBackend.runWebGLProgram(program, [x], $dtype);
};
}
function binaryKernelFunc({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) {
return ({ inputs, backend }) => {
const { a, b } = inputs;
const webglBackend = backend;
if (supportsComplex && a.dtype === 'complex64') {
const aData = webglBackend.texData.get(a.dataId);
const bData = webglBackend.texData.get(b.dataId);
const [real, imag] = [
[aData.complexTensorInfos.real, bData.complexTensorInfos.real],
[aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
].map(complexParts => {
const [aPart, bPart] = complexParts;
const aHandle = {
dataId: aPart.dataId,
dtype: aPart.dtype,
shape: a.shape
};
const bHandle = {
dataId: bPart.dataId,
dtype: bPart.dtype,
shape: b.shape
};
const program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
});
const complexOutput = complex({ inputs: { real, imag }, backend: webglBackend });
webglBackend.disposeIntermediateTensorInfo(real);
webglBackend.disposeIntermediateTensorInfo(imag);
return complexOutput;
}
const $dtype = dtype || upcastType(a.dtype, b.dtype);
if ((a.dtype === 'string' || b.dtype === 'string' ||
webglBackend.shouldExecuteOnCPU([a, b])) &&
cpuKernelImpl != null) {
const aVals = webglBackend.texData.get(a.dataId).values;
const bVals = webglBackend.texData.get(b.dataId).values;
const decodedAVals = a.dtype === 'string' ?
fromUint8ToStringArray(aVals) :
aVals;
const decodedBVals = a.dtype === 'string' ?
fromUint8ToStringArray(bVals) :
bVals;
const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
const out = webglBackend.makeTensorInfo(outShape, $dtype);
const outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') &&
packedOpSnippet != null;
let program;
if (shouldUsePackedProgram) {
program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
}
else {
program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
}
return webglBackend.runWebGLProgram(program, [a, b], $dtype);
};
}
function mapActivationToShaderProgram(activation, packed = false) {
if (activation === 'linear') {
if (packed) {
return LINEAR;
}
return LINEAR$1;
}
else if (activation === 'relu') {
if (packed) {
return RELU$1;
}
return RELU$2;
}
else if (activation === 'elu') {
if (packed) {
return ELU$1;
}
return ELU$2;
}
else if (activation === 'relu6') {
if (packed) {
return RELU6$1;
}
return RELU6$2;
}
else if (activation === 'prelu') {
if (packed) {
return PRELU_PACKED;
}
return PRELU;
}
else if (activation === 'leakyrelu') {
if (packed) {
return LEAKYRELU_PACKED;
}
return LEAKYRELU;
}
else if (activation === 'sigmoid') {
if (packed) {
return SIGMOID$1;
}
return SIGMOID$2;
}
throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`);
}
class MatMulPackedProgram {
constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) {
this.variableNames = ['matrixA', 'matrixB'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const sharedDim = transposeA ? aShape[1] : aShape[2];
const sharedDimensionPacked = Math.ceil(sharedDim / 2);
const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
}
else if (hasLeakyreluActivation) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getLeakyreluAlphaAtOutCoords();
${activation}
}`;
}
else {
activationSnippet = `vec4 activation(vec4 x) {
${activation}
}`;
}
applyActivationSnippet = `result = activation(result);`;
}
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyreluActivation) {
this.variableNames.push('leakyreluAlpha');
}
let batchASnippet = 'rc.x';
let batchBSnippet = 'rc.x';
if (aShape[0] < bShape[0]) {
batchASnippet = `imod(rc.x, ${aShape[0]})`;
}
else if (bShape[0] < aShape[0]) {
batchBSnippet = `imod(rc.x, ${bShape[0]})`;
}
this.userCode = `
${activationSnippet}
const float sharedDimension = ${sharedDimensionPacked}.0;
vec4 dot2x2ARowBCol(ivec3 rc) {
vec4 result = vec4(0);
int batchA = ${batchASnippet};
int batchB = ${batchBSnippet};
for (int i = 0; i < ${sharedDimensionPacked}; i++) {
vec4 a = getMatrixA(batchA, ${aSample});
vec4 b = getMatrixB(batchB, ${bSample});
result += (${aSwizzle[0]} * ${bSwizzle[0]});
result += (${aSwizzle[1]} * ${bSwizzle[1]});
}
return result;
}
void main() {
ivec3 rc = getOutputCoords();
vec4 result = dot2x2ARowBCol(rc);
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
}
const COMPLEX_MULTIPLY = {
REAL: 'return areal * breal - aimag * bimag;',
IMAG: 'return areal * bimag + aimag * breal;'
};
class BinaryOpComplexProgram {
constructor(op, aShape, bShape) {
this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
this.userCode = `
float binaryOpComplex(
float areal, float aimag, float breal, float bimag) {
${op}
}
void main() {
float areal = getARealAtOutCoords();
float aimag = getAImagAtOutCoords();
float breal = getBRealAtOutCoords();
float bimag = getBImagAtOutCoords();
setOutput(binaryOpComplex(areal, aimag, breal, bimag));
}
`;
}
}
const MUL = 'return a * b;';
function multiply(args) {
const { inputs, backend } = args;
const { a, b } = inputs;
const dtype = upcastType(a.dtype, b.dtype);
if (a.dtype === 'complex64') {
const aData = backend.texData.get(a.dataId);
const bData = backend.texData.get(b.dataId);
const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
const inputs = [
{
dataId: aData.complexTensorInfos.real.dataId,
dtype: aData.complexTensorInfos.real.dtype,
shape: a.shape
},
{
dataId: aData.complexTensorInfos.imag.dataId,
dtype: aData.complexTensorInfos.imag.dtype,
shape: a.shape
},
{
dataId: bData.complexTensorInfos.real.dataId,
dtype: bData.complexTensorInfos.real.dtype,
shape: b.shape
},
{
dataId: bData.complexTensorInfos.imag.dataId,
dtype: bData.complexTensorInfos.imag.dtype,
shape: b.shape
}
];
const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(imagPart);
return complexOutput;
}
if (backend.shouldExecuteOnCPU([a, b])) {
const aData = backend.texData.get(a.dataId);
const bData = backend.texData.get(b.dataId);
const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype);
const out = backend.makeTensorInfo(outShape, dtype);
const outData = backend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
let program;
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
}
else {
program = new BinaryOpProgram(MUL, a.shape, b.shape);
}
return backend.runWebGLProgram(program, [a, b], dtype);
}
const multiplyConfig = {
kernelName: Multiply,
backendName: 'webgl',
kernelFunc: multiply
};
function packedReshape(input, afterShape, backend) {
const input3DShape = [getBatchDim(input.shape),
...getRowsCols(input.shape)];
const input3D = {
dtype: input.dtype,
shape: input3DShape,
dataId: input.dataId
};
const afterShapeAs3D = [getBatchDim(afterShape),
...getRowsCols(afterShape)];
const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
const preventEagerUnpackingOfOutput = true;
const customValues = [input3DShape];
const output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
}
function reshape$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { shape } = attrs;
const webglBackend = backend;
const xSize = sizeFromShape(x.shape);
const $shape = inferFromImplicitShape(shape, xSize);
const $xSize = sizeFromShape($shape);
assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
`shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
`shape must have the same number of elements.`);
const xTexData = webglBackend.texData.get(x.dataId);
if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
!(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
return packedReshape(x, $shape, webglBackend);
}
webglBackend.incRef(x.dataId);
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
}
const reshapeConfig$1 = {
kernelName: Reshape$1,
backendName: 'webgl',
kernelFunc: reshape$1
};
class MeanProgram {
constructor(reduceInfo, divisor) {
this.variableNames = ['x'];
const { windowSize, batchSize, inSize, outSize } = reduceInfo;
this.outputShape = [batchSize, outSize];
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
const windowSizeVec4Remainder = windowSize % 4;
let updateSnippet = `sumValue += dot(values, ones);`;
if (divisor != null) {
const denominator = 1 / divisor;
updateSnippet = `sumValue += dot(values * ${isInt(denominator) ? denominator.toPrecision(2) :
denominator}, ones);`;
}
let checkOutOfBounds = '';
if (inSize % windowSize > 0) {
checkOutOfBounds = `
if (inIdx < 0 || inIdx >= ${inSize}) {
return 0.0;
}
`;
}
this.userCode = `
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float getValue(int batch, int inIdx) {
${checkOutOfBounds}
return getX(batch, inIdx);
}
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = outIdx * ${windowSize};
float sumValue = 0.0;
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
${updateSnippet}
}
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder === 1}) {
vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 2}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1), 0.0, 0.0);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 3}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2), 0.0);
${updateSnippet}
}
setOutput(sumValue);
}
`;
}
}
class ReduceProgram {
constructor(reduceInfo, reduceType) {
this.variableNames = ['x'];
const { windowSize, batchSize, inSize, outSize } = reduceInfo;
this.outputShape = [batchSize, outSize];
let initializationValue = '0.0';
let compareOp = ``;
if (reduceType === 'prod') {
initializationValue = '1.0';
}
else if (reduceType === 'min') {
initializationValue = '1.0 / 1e-20';
compareOp = `min`;
}
else if (reduceType === 'max') {
initializationValue = '-1.0 / 1e-20';
compareOp = `max`;
}
let returnValue = `${reduceType}(${reduceType}(${reduceType}(` +
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (reduceType === 'sum') {
returnValue = `sumValue`;
}
else if (reduceType === 'prod') {
returnValue = `prodValue`;
}
else if (reduceType === 'all') {
returnValue = `allValue`;
}
else if (reduceType === 'any') {
returnValue = `anyValue`;
}
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
const windowSizeVec4Remainder = windowSize % 4;
let updateSnippet = `
if (${reduceType === 'sum'}) {
sumValue += dot(values, ones);
} else if (${reduceType === 'prod'}) {
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
prodValue *= tmp[0] * tmp[1];
} else {
minMaxValue = ${compareOp}(values, minMaxValue);
if (${reduceType === 'min'} || ${reduceType === 'max'}) {
minMaxValue = ${compareOp}(values, minMaxValue);
bvec4 isNaN = isnan(values);
if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
minMaxValue = vec4(NAN);
}
}
}
`;
let vecType = `vec4`;
if (reduceType === 'all') {
initializationValue = '1.0';
updateSnippet = `
bool reducedAllValue = all(values);
float floatedReducedAllValue = float(reducedAllValue);
allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
`;
vecType = `bvec4`;
}
else if (reduceType === 'any') {
initializationValue = '0.0';
updateSnippet = `
bool reducedAnyValue = any(values);
float floatedReducedAnyValue = float(reducedAnyValue);
anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
`;
vecType = `bvec4`;
}
let checkOutOfBounds = '';
if (inSize % windowSize > 0) {
checkOutOfBounds = `
if (inIdx < 0 || inIdx >= ${inSize}) {
return initializationValue;
}
`;
}
this.userCode = `
const float initializationValue = ${initializationValue};
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float getValue(int batch, int inIdx) {
${checkOutOfBounds}
return getX(batch, inIdx);
}
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = outIdx * ${windowSize};
vec4 minMaxValue = vec4(${initializationValue});
float prodValue = 1.0;
float sumValue = 0.0;
float allValue = 1.0;
float anyValue = 0.0;
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
${vecType} values = ${vecType}(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
${updateSnippet}
}
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder === 1}) {
${vecType} values = ${vecType}(
getValue(batch, inIdx),
initializationValue,
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 2}) {
${vecType} values = ${vecType}(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 3}) {
${vecType} values = ${vecType}(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
initializationValue
);
${updateSnippet}
}
setOutput(${returnValue});
}
`;
}
}
function getReductionStages(inShape) {
const stages = [];
while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
const windowSize = computeOptimalWindowSize(outSize);
stages.push({
inSize: outSize,
windowSize,
outSize: Math.ceil(outSize / windowSize)
});
}
return stages;
}
function reduce(x, dtype, reductionType, backend) {
const reductionStages = getReductionStages(x.shape);
let result = x;
for (let i = 0; i < reductionStages.length; i++) {
const { inSize, windowSize, outSize } = reductionStages[i];
let program;
let previousResult;
if (reductionType === 'mean') {
program = i === 0 ?
new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) :
new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize });
}
else {
program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType);
}
previousResult = result;
result = backend.runWebGLProgram(program, [result], dtype);
if (previousResult.dataId !== x.dataId) {
backend.disposeIntermediateTensorInfo(previousResult);
}
}
return result;
}
class TransposeProgram {
constructor(aShape, newDim) {
this.variableNames = ['A'];
const outputShape = new Array(aShape.length);
for (let i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
const dtype = getCoordsDataType(this.rank);
const switched = getSwitchedCoords(newDim);
this.userCode = `
void main() {
${dtype} resRC = getOutputCoords();
setOutput(getA(${switched}));
}
`;
}
}
function getSwitchedCoords(newDim) {
const rank = newDim.length;
if (rank > 6) {
throw Error(`Transpose for rank ${rank} is not yet supported`);
}
const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
const switchedCoords = new Array(rank);
for (let i = 0; i < newDim.length; i++) {
switchedCoords[newDim[i]] = originalOrder[i];
}
return switchedCoords.join();
}
class TransposePackedProgram {
constructor(aShape, newDim) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
const outputShape = new Array(aShape.length);
for (let i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[newDim[i]];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
if (this.rank > 6) {
throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
}
const dtype = getCoordsDataType(this.rank);
const outputOrder = getVecChannels('rc', this.rank);
const switchedOrder = new Array(this.rank);
for (let i = 0; i < newDim.length; i++) {
switchedOrder[newDim[i]] = outputOrder[i];
}
const innerDims = `vec2(${switchedOrder.slice(-2).join()})`;
const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`;
const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`;
this.userCode = `
void main() {
${dtype} rc = getOutputCoords();
vec4 result = vec4(0.);
result[0] = ${getc};
if(${nextColumn}) {
result[1] = ${getc};
}
--${outputOrder[this.rank - 1]};
if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) {
result[2] = ${getc};
if(${nextColumn}) {
result[3] = ${getc};
}
}
setOutput(result);
}
`;
}
}
function transposeImpl(x, perm, backend) {
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new TransposePackedProgram(x.shape, perm) :
new TransposeProgram(x.shape, perm);
return backend.runWebGLProgram(program, [x], x.dtype);
}
function sumImpl(x, axis, keepDims, backend) {
const reductionIndices = axis;
const xRank = x.shape.length;
const origAxes = parseAxisParam(reductionIndices, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
const sumInputIsTransposed = permutedAxes != null;
let sumInput = x;
if (sumInputIsTransposed) {
sumInput = transposeImpl(x, permutedAxes, backend);
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('sum', axes, xRank);
const [sumOutShape, reduceShape] = computeOutAndReduceShapes(sumInput.shape, axes);
let outShape = sumOutShape;
if (keepDims) {
outShape = expandShapeToKeepDim(sumOutShape, origAxes);
}
const inSize = sizeFromShape(reduceShape);
const xSize = sizeFromShape(x.shape);
const batchSize = xSize / inSize;
const reshapedInput = reshape$1({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend });
const outType = sumOutType(x.dtype);
const reduced = reduce(reshapedInput, outType, 'sum', backend);
const out = reshape$1({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
if (sumInputIsTransposed) {
backend.disposeIntermediateTensorInfo(sumInput);
}
return out;
}
function sum$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
return sumImpl(x, axis, keepDims, backend);
}
const sumConfig$1 = {
kernelName: Sum,
backendName: 'webgl',
kernelFunc: sum$1
};
function transpose(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { perm } = attrs;
const webglBackend = backend;
const xRank = x.shape.length;
const newShape = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[perm[i]];
}
let out;
if (webglBackend.shouldExecuteOnCPU([x])) {
const xTexData = webglBackend.texData.get(x.dataId);
const values = xTexData.values;
const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
out = webglBackend.makeTensorInfo(newShape, x.dtype);
const outData = webglBackend.texData.get(out.dataId);
outData.values = outValues;
}
else {
out = transposeImpl(x, perm, webglBackend);
}
return out;
}
const transposeConfig = {
kernelName: Transpose,
backendName: 'webgl',
kernelFunc: transpose
};
const MATMUL_SHARED_DIM_THRESHOLD = 1000;
function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
const aRank = a.shape.length;
const bRank = b.shape.length;
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
const outerDimsA = a.shape.slice(0, -2);
const outerDimsB = b.shape.slice(0, -2);
const batchDimA = sizeFromShape(outerDimsA);
const batchDimB = sizeFromShape(outerDimsB);
const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
`${b.shape} and transposeA=${transposeA}` +
` and transposeB=${transposeB} must match.`);
const a3dShape = transposeA ?
[batchDimA, innerShapeA, outerShapeA] :
[batchDimA, outerShapeA, innerShapeA];
const b3dShape = transposeB ?
[batchDimB, outerShapeB, innerShapeB] :
[batchDimB, innerShapeB, outerShapeB];
const a3d = reshape$1({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
const b3d = reshape$1({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
const intermediates = [a3d, b3d];
const batchDim = Math.max(batchDimA, batchDimB);
const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const hasLeakyreluAlpha = activation === 'leakyrelu';
const fusedActivation = activation != null ?
mapActivationToShaderProgram(activation, true) :
null;
const containsFusedOps = hasBias || hasPreluActivationWeights ||
hasLeakyreluAlpha || fusedActivation != null;
let out;
if ((outerShapeA === 1 || outerShapeB === 1) &&
sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
let aVec = a3d;
let bVec = b3d;
if (transposeA) {
aVec = transpose({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } });
intermediates.push(aVec);
}
if (transposeB) {
bVec = transpose({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } });
intermediates.push(bVec);
}
const shouldReshapeA = outerShapeB !== 1;
const shouldReshapeB = outerShapeB === 1;
let aVec3d = aVec;
if (shouldReshapeA) {
aVec3d = reshape$1({
inputs: { x: aVec },
backend,
attrs: { shape: [batchDim, sharedDim, 1] }
});
intermediates.push(aVec3d);
}
const axis = outerShapeB === 1 ? 2 : 1;
let bVec3d = bVec;
if (shouldReshapeB) {
bVec3d = reshape$1({
inputs: { x: bVec },
backend,
attrs: { shape: [batchDim, 1, sharedDim] }
});
intermediates.push(bVec3d);
}
const product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend });
out = sum$1({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } });
intermediates.push(product);
}
else {
const dtype = upcastType(a.dtype, b.dtype);
const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
const inputs = [a3d, b3d];
if (bias != null) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
inputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
out = backend.runWebGLProgram(program, inputs, dtype);
}
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: outShape } });
intermediates.push(out);
for (const i of intermediates) {
backend.disposeIntermediateTensorInfo(i);
}
return outReshaped;
}
function _fusedMatMul$1(args) {
const { inputs, backend, attrs } = args;
const { a, b, bias, preluActivationWeights } = inputs;
const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
return batchMatMulImpl({
a,
b,
transposeA,
transposeB,
backend,
bias,
preluActivationWeights,
leakyreluAlpha,
activation
});
}
const _fusedMatMulConfig$1 = {
kernelName: _FusedMatMul,
backendName: 'webgl',
kernelFunc: _fusedMatMul$1,
};
const ABS = `return abs(x);`;
function abs(args) {
const { inputs, backend } = args;
const { x } = inputs;
if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
const xData = backend.texData.get(x.dataId);
const outValues = simpleAbsImplCPU(xData.values);
return backend.makeTensorInfo(x.shape, x.dtype, outValues);
}
let program;
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
program = new UnaryOpPackedProgram(x.shape, ABS);
}
else {
program = new UnaryOpProgram(x.shape, ABS);
}
return backend.runWebGLProgram(program, [x], x.dtype);
}
const absConfig = {
kernelName: Abs,
backendName: 'webgl',
kernelFunc: abs
};
const ACOS = CHECK_NAN_SNIPPET$1 + `
if (abs(x) > 1.) {
return NAN;
}
return acos(x);
`;
const acos$1 = unaryKernelFunc({ opSnippet: ACOS });
const acosConfig$1 = {
kernelName: Acos,
backendName: 'webgl',
kernelFunc: acos$1,
};
const ACOSH = CHECK_NAN_SNIPPET$1 + `
if (x < 1.0) return NAN;
return log(x + sqrt(x * x - 1.0));`;
const acosh$1 = unaryKernelFunc({ opSnippet: ACOSH });
const acoshConfig$1 = {
kernelName: Acosh,
backendName: 'webgl',
kernelFunc: acosh$1,
};
const ADD = 'return a + b;';
const addKernelFunc = binaryKernelFunc({
opSnippet: ADD,
packedOpSnippet: ADD,
supportsComplex: true,
cpuKernelImpl: addImplCPU
});
const addConfig = {
kernelName: Add,
backendName: 'webgl',
kernelFunc: addKernelFunc
};
class AddNProgram {
constructor(outputShape, shapes) {
this.outputShape = [];
this.outputShape = outputShape;
this.variableNames = shapes.map((_, i) => `T${i}`);
const snippets = [];
this.variableNames.forEach(variable => {
snippets.push(`float v${variable} = get${variable}AtOutCoords();`);
});
const operation = this.variableNames
.map(variable => {
return `v${variable}`;
})
.join(' + ');
this.userCode = `
void main() {
${snippets.join('\n ')}
float result = ${operation};
setOutput(result);
}
`;
}
}
class AddNPackedProgram {
constructor(outputShape, shapes) {
this.outputShape = [];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = outputShape;
this.variableNames = shapes.map((_, i) => `T${i}`);
const snippets = [];
this.variableNames.forEach(variable => {
snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);
});
const operation = this.variableNames
.map(variable => {
return `v${variable}`;
})
.join(' + ');
this.userCode = `
void main() {
${snippets.join('\n ')}
vec4 result = ${operation};
setOutput(result);
}
`;
}
}
function addN$1(args) {
const { inputs, backend } = args;
const tensors = inputs;
if (tensors.length === 1) {
return identity({ inputs: { x: tensors[0] }, backend });
}
if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
const midIndex = Math.floor(tensors.length / 2);
const leftSide = addN$1({ inputs: tensors.slice(0, midIndex), backend });
const rightSide = addN$1({ inputs: tensors.slice(midIndex), backend });
return addN$1({ inputs: [leftSide, rightSide], backend });
}
const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));
const shapes = tensors.map(t => t.shape);
const usePackedOp = env().getBool('WEBGL_PACK');
const program = usePackedOp ?
new AddNPackedProgram(tensors[0].shape, shapes) :
new AddNProgram(tensors[0].shape, shapes);
return backend.runWebGLProgram(program, tensors, dtype);
}
const addNConfig$1 = {
kernelName: AddN,
backendName: 'webgl',
kernelFunc: addN$1
};
function all$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
const xRank = x.shape.length;
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
let permutedX = x;
if (permutedAxes != null) {
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('all', axes, xRank);
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
const inSize = sizeFromShape(reduceShape);
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
const reduced = reduce(a2D, a2D.dtype, 'all', backend);
let res;
if (keepDims) {
const newShape = expandShapeToKeepDim(outShape, origAxes);
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
}
else {
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
}
backend.disposeIntermediateTensorInfo(a2D);
backend.disposeIntermediateTensorInfo(reduced);
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return res;
}
const allConfig$1 = {
kernelName: All,
backendName: 'webgl',
kernelFunc: all$1
};
function any$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
const xRank = x.shape.length;
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
let permutedX = x;
if (permutedAxes != null) {
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('any', axes, xRank);
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
const inSize = sizeFromShape(reduceShape);
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
const reduced = reduce(a2D, a2D.dtype, 'any', backend);
let res;
if (keepDims) {
const newShape = expandShapeToKeepDim(outShape, origAxes);
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
}
else {
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
}
backend.disposeIntermediateTensorInfo(a2D);
backend.disposeIntermediateTensorInfo(reduced);
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return res;
}
const anyConfig$1 = {
kernelName: Any,
backendName: 'webgl',
kernelFunc: any$1
};
class ArgMinMaxProgram {
constructor(reduceInfo, op, firstPass) {
this.variableNames = ['A'];
const { windowSize, batchSize, outSize } = reduceInfo;
if (!firstPass) {
this.variableNames.push('bestIndicesA');
}
this.outputShape = [batchSize, outSize];
const compOp = (op === 'max') ? '>' : '<';
const indexSnippet = firstPass ?
'inOffset + i;' :
'round(getBestIndicesA(batch, inOffset + i));';
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = outIdx * ${windowSize};
int bestIndex = inOffset;
float bestValue = getA(batch, bestIndex);
for (int i = 0; i < ${windowSize}; i++) {
int inIdx = ${indexSnippet};
float candidate = getA(batch, inIdx);
if (candidate ${compOp} bestValue) {
bestValue = candidate;
bestIndex = inIdx;
}
}
setOutput(float(bestIndex));
}
`;
}
}
class ArgMinMaxPackedProgram {
constructor(shape, windowSize, op, firstPass) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
assert$1(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
op.slice(1)} supports only inputs with rank above 2.`);
const inSize = shape[shape.length - 1];
const outSize = Math.ceil(inSize / windowSize);
this.outputShape = shape.slice(0, -1);
if (outSize > 1) {
this.outputShape.push(outSize);
}
if (!firstPass) {
this.variableNames.push('bestIndicesA');
}
const outShape = this.outputShape;
const rank = outShape.length;
const dtype = getCoordsDataType(rank);
const coords = getChannels('coords', rank);
let sourceLocSetup;
let sourceRank;
if (outSize === 1) {
sourceRank = rank + 1;
const sourceLocDType = getCoordsDataType(sourceRank);
sourceLocSetup = `
${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
++${coords[rank - 1]};
${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
++${coords[rank - 2]};
${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
--${coords[rank - 1]};
${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
--${coords[rank - 2]};`;
}
else {
sourceRank = rank;
sourceLocSetup = `
${dtype} sourceLocR = coords;
++${coords[rank - 1]};
${dtype} sourceLocG = coords;
++${coords[rank - 2]};
${dtype} sourceLocA = coords;
--${coords[rank - 1]};
${dtype} sourceLocB = coords;
--${coords[rank - 2]};`;
}
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
const inChannel = '.' + channels[sourceRank - 1];
const intChannels = channels.map(x => 'int ' + x);
const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
const fetchCandidateIdx = firstPass ? '' : `
inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
getBestIndicesAChannel(${srcGCoords.join()}),
getBestIndicesAChannel(${srcBCoords.join()}),
getBestIndicesAChannel(${srcACoords.join()})));`;
const fetchValue = `vec4(
getAChannel(${srcRCoords.join()}),
hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
const getBestIndicesAChannelSnippet = firstPass ? '' : `
float getBestIndicesAChannel(${intChannels.join()}) {
return getChannel(getBestIndicesA(${channels.join()}),
vec2(${channels.slice(-2).join()}));
}`;
this.userCode = `
float getAChannel(${intChannels.join()}) {
return getChannel(getA(${channels.join()}),
vec2(${channels.slice(-2).join()}));
}
${getBestIndicesAChannelSnippet}
void main() {
${dtype} coords = getOutputCoords();
bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
${sourceLocSetup}
ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
ivec4 inIdx = srcIdx;
vec4 bestIndex = vec4(inIdx);
vec4 bestValue = ${fetchValue};
for (int i = 0; i < ${windowSize}; i++) {
inIdx = srcIdx;
${fetchCandidateIdx}
vec4 candidate = ${fetchValue};
bvec4 nan = isnan(candidate);
bvec4 replace = bvec4(
vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
bestValue = vec4(replace.x ? candidate.x : bestValue.x,
replace.y ? candidate.y : bestValue.y,
replace.z ? candidate.z : bestValue.z,
replace.w ? candidate.w : bestValue.w);
bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
srcIdx++;
}
setOutput(bestIndex);
}
`;
}
}
function argReduce(backend, x, reduceType, bestIndicesA = null) {
let batchSize = x.shape[0];
let inSize = x.shape[1];
if (bestIndicesA != null) {
batchSize = bestIndicesA.shape[0];
inSize = bestIndicesA.shape[1];
}
const windowSize = computeOptimalWindowSize(inSize);
const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) };
const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
const inputs = [x];
if (bestIndicesA != null) {
inputs.push(bestIndicesA);
}
const output = backend.runWebGLProgram(program, inputs, 'int32');
if (output.shape[1] === 1) {
return output;
}
const result = argReduce(backend, x, reduceType, output);
backend.disposeIntermediateTensorInfo(output);
return result;
}
function argReducePacked(backend, x, reduceType, bestIndicesA = null) {
const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
const inSize = inShape[inShape.length - 1];
const windowSize = computeOptimalWindowSize(inSize);
const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
const output = backend.runWebGLProgram(program, inputs, 'int32');
if (output.shape.length === x.shape.length) {
const result = argReducePacked(backend, x, reduceType, output);
backend.disposeIntermediateTensorInfo(output);
return result;
}
return output;
}
function argMinMaxReduce(backend, x, axis, reduceType) {
const axes = [axis];
assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
const intermediateTensorInfos = [];
const xtexData = backend.texData.get(x.dataId);
const xIsPacked = xtexData !== null && xtexData.isPacked;
let xUnPacked = x;
if (xIsPacked) {
xUnPacked = backend.unpackTensor(x);
intermediateTensorInfos.push(xUnPacked);
}
const [outShape, reduceShape] = computeOutAndReduceShapes(xUnPacked.shape, axes);
const inSize = sizeFromShape(reduceShape);
const a2D = reshape$1({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } });
intermediateTensorInfos.push(a2D);
const reduced = argReduce(backend, a2D, reduceType);
intermediateTensorInfos.push(reduced);
const reshaped = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return reshaped;
}
return argReducePacked(backend, x, reduceType);
}
function argMax$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis } = attrs;
let axes = parseAxisParam(axis, x.shape);
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
const intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
const out = argMinMaxReduce(backend, $x, axes[0], 'max');
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return out;
}
const argMaxConfig$1 = {
kernelName: ArgMax,
backendName: 'webgl',
kernelFunc: argMax$1
};
function argMin$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis } = attrs;
let axes = parseAxisParam(axis, x.shape);
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
const intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
const out = argMinMaxReduce(backend, $x, axes[0], 'min');
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return out;
}
const argMinConfig$1 = {
kernelName: ArgMin,
backendName: 'webgl',
kernelFunc: argMin$1
};
const ASIN = CHECK_NAN_SNIPPET$1 + `
if (abs(x) > 1.) {
return NAN;
}
return asin(x);
`;
const asin$1 = unaryKernelFunc({ opSnippet: ASIN });
const asinConfig$1 = {
kernelName: Asin,
backendName: 'webgl',
kernelFunc: asin$1,
};
const ASINH = CHECK_NAN_SNIPPET$1 + `return log(x + sqrt(x * x + 1.0));`;
const asinh$1 = unaryKernelFunc({ opSnippet: ASINH });
const asinhConfig$1 = {
kernelName: Asinh,
backendName: 'webgl',
kernelFunc: asinh$1,
};
const ATAN = CHECK_NAN_SNIPPET$1 + `
return atan(x);
`;
const atan$1 = unaryKernelFunc({ opSnippet: ATAN });
const atanConfig$1 = {
kernelName: Atan,
backendName: 'webgl',
kernelFunc: atan$1,
};
const ATAN2 = CHECK_NAN_SNIPPET + `
return atan(a, b);
`;
const ATAN2_PACKED = `
vec4 result = atan(a, b);
bvec4 isNaNA = isnan(a);
bvec4 isNaNB = isnan(b);
bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
` +
CHECK_NAN_SNIPPET_PACKED + `
return result;
`;
const atan2$1 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED });
const atan2Config$1 = {
kernelName: Atan2,
backendName: 'webgl',
kernelFunc: atan2$1,
};
const ATANH = CHECK_NAN_SNIPPET$1 + `
if ((x < -1.0) || (x > 1.0)) return NAN;
return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
const atanh$1 = unaryKernelFunc({ opSnippet: ATANH });
const atanhConfig$1 = {
kernelName: Atanh,
backendName: 'webgl',
kernelFunc: atanh$1,
};
class Pool2DProgram {
constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
this.variableNames = ['x'];
if (poolType === 'avg' && computePositions) {
throw new Error('Cannot compute positions for average pool.');
}
const filterWidth = convInfo.filterWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
this.outputShape = convInfo.outShape;
const isAvgPool = poolType === 'avg';
const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
let initializationValue = '0.0';
if (!isAvgPool) {
initializationValue = '-1.0 / 1e-20';
}
if (computePositions) {
const compareOp = '>=';
this.userCode = `
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d = coords[3];
ivec2 xRCCorner = coords.yz * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
float minMaxValue = 0.0;
float minMaxValueFound = 0.0;
int minMaxPosition = 0;
float avgValue = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
int xC = xCCorner + wC;
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float value = getX(batch, xR, xC, d);
float currMinMaxValue = mix(
value, minMaxValue, minMaxValueFound);
if (value ${compareOp} currMinMaxValue) {
minMaxValue = value;
minMaxValueFound = 1.0;
minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
flattenPositionStr) :
`wR * ${effectiveFilterWidth} + wC`};
}
}
}
setOutput(float(minMaxPosition));
}
`;
return;
}
const compareOp = 'max';
let returnValue = `${poolType}(${poolType}(${poolType}(` +
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (poolType === 'avg') {
returnValue = `avgValue / max(count, 1.0)`;
}
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
const filterWidthVec4Remainder = filterWidth % 4;
const updateSnippet = `
if (${isAvgPool}) {
avgValue += dot(values, ones);
} else {
minMaxValue = ${compareOp}(values, minMaxValue);
}
`;
this.userCode = `
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
const float initializationValue = ${initializationValue};
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float count = 0.0;
float getValue(int batch, int xR, int xC, int d) {
if (xC < 0 || xC >= ${convInfo.inWidth}) {
return initializationValue;
}
count += 1.0;
return getX(batch, xR, xC, d);
}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d = coords[3];
ivec2 xRCCorner = coords.yz * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
vec4 minMaxValue = vec4(${initializationValue});
float avgValue = 0.0;
count = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
int xC = xCCorner + wC * ${dilationWidth};
vec4 values = vec4(
getValue(batch, xR, xC, d),
getValue(batch, xR, xC + ${dilationWidth}, d),
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
);
${updateSnippet}
}
int xC = xCCorner + ${filterWidthNearestVec4};
if (${filterWidthVec4Remainder === 1}) {
vec4 values = vec4(
getValue(batch, xR, xC, d),
initializationValue,
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder === 2}) {
vec4 values = vec4(
getValue(batch, xR, xC, d),
getValue(batch, xR, xC + ${dilationWidth}, d),
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder === 3}) {
vec4 values = vec4(
getValue(batch, xR, xC, d),
getValue(batch, xR, xC + ${dilationWidth}, d),
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
initializationValue
);
${updateSnippet}
}
}
setOutput(${returnValue});
}
`;
}
}
class Pool3DProgram {
constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
this.variableNames = ['x'];
if (poolType === 'avg' && computePositions) {
throw new Error('Cannot compute positions for average pool.');
}
const filterWidth = convInfo.filterWidth;
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
this.outputShape = convInfo.outShape;
const isAvgPool = poolType === 'avg';
let initializationValue = '0.0';
if (!isAvgPool) {
initializationValue = '-1.0 / 1e-20';
}
if (computePositions) {
const compareOp = '>=';
this.userCode = `
const ivec3 strides =
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
int xDCorner = xCorner.x;
int xRCorner = xCorner.y;
int xCCorner = xCorner.z;
float minMaxValue = 0.0;
float minMaxValueFound = 0.0;
int minMaxPosition = 0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
int xD = xDCorner + wD;
if (xD < 0 || xD >= ${convInfo.inDepth}) {
continue;
}
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
int xC = xCCorner + wC;
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float value = getX(batch, xD, xR, xC, ch);
float currMinMaxValue = mix(
value, minMaxValue, minMaxValueFound);
if (value ${compareOp} currMinMaxValue) {
minMaxValue = value;
minMaxValueFound = 1.0;
minMaxPosition = ${flattenPositions ?
(includeBatchInIndex ?
`(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` :
`((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :
`wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
wR * ${effectiveFilterWidth} + wC`};
}
}
}
}
setOutput(float(minMaxPosition));
}
`;
return;
}
const compareOp = 'max';
let returnValue = `${poolType}(${poolType}(${poolType}(` +
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
if (poolType === 'avg') {
returnValue = `avgValue / max(count, 1.0)`;
}
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
const filterWidthVec4Remainder = filterWidth % 4;
const updateSnippet = `
if (${isAvgPool}) {
avgValue += dot(values, ones);
} else {
minMaxValue = ${compareOp}(values, minMaxValue);
}
`;
this.userCode = `
const ivec3 strides =
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
const float initializationValue = ${initializationValue};
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
float count = 0.0;
float getValue(int batch, int xD, int xR, int xC, int ch) {
if (xC < 0 || xC >= ${convInfo.inWidth}) {
return initializationValue;
}
count += 1.0;
return getX(batch, xD, xR, xC, ch);
}
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
int xDCorner = xCorner.x;
int xRCorner = xCorner.y;
int xCCorner = xCorner.z;
vec4 minMaxValue = vec4(${initializationValue});
float avgValue = 0.0;
count = 0.0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
int xD = xDCorner + wD;
if (xD < 0 || xD >= ${convInfo.inDepth}) {
continue;
}
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
int xR = xRCorner + wR;
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
int xC = xCCorner + wC * ${dilationWidth};
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
);
${updateSnippet}
}
int xC = xCCorner + ${filterWidthNearestVec4};
if (${filterWidthVec4Remainder === 1}) {
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
initializationValue,
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder === 2}) {
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
initializationValue,
initializationValue
);
${updateSnippet}
} else if (${filterWidthVec4Remainder === 3}) {
vec4 values = vec4(
getValue(batch, xD, xR, xC, ch),
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
initializationValue
);
${updateSnippet}
}
}
}
setOutput(${returnValue});
}
`;
}
}
function avgPool$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
assertNotComplex$1(x, 'avgPool');
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const dilations = 1;
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
arraysEqual(convInfo.inShape, convInfo.outShape)) {
return identity({ inputs: { x }, backend });
}
const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
}
const avgPoolConfig$1 = {
kernelName: AvgPool,
backendName: 'webgl',
kernelFunc: avgPool$1
};
function avgPool3D$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
const dilations = [1, 1, 1];
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
const avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
}
const avgPool3DConfig$1 = {
kernelName: AvgPool3D,
backendName: 'webgl',
kernelFunc: avgPool3D$1
};
class AvgPool2DBackpropProgram {
constructor(convInfo) {
this.variableNames = ['dy'];
this.outputShape = convInfo.inShape;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const avgMultiplier = 1 / (filterHeight * filterWidth);
this.userCode = `
const ivec2 pads = ivec2(${padTop}, ${padLeft});
const float avgMultiplier = float(${avgMultiplier});
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 dyRCCorner = coords.yz - pads;
int dyRCorner = dyRCCorner.x;
int dyCCorner = dyRCCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth};
wC+= ${dilationWidth}) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(b, idyR, idyC, d);
dotProd += dyValue * avgMultiplier;
}
}
setOutput(dotProd);
}
`;
}
}
class AvgPool3DBackpropProgram {
constructor(convInfo) {
this.variableNames = ['dy'];
this.outputShape = convInfo.inShape;
const filterDepth = convInfo.filterDepth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
this.userCode = `
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
const float avgMultiplier = float(${avgMultiplier});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
int dyDCorner = dyCorner.x;
int dyRCorner = dyCorner.y;
int dyCCorner = dyCorner.z;
float dotProd = 0.0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
continue;
}
int idyD = int(dyD);
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(batch, idyD, idyR, idyC, ch);
dotProd += dyValue * avgMultiplier;
}
}
}
setOutput(dotProd);
}
`;
}
}
function avgPool3DGrad$1(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const x = input;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const dilations = [1, 1, 1];
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
const avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
}
const avgPool3DGradConfig$2 = {
kernelName: AvgPool3DGrad,
backendName: 'webgl',
kernelFunc: avgPool3DGrad$1
};
function avgPoolGrad$2(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const x = input;
assertNotComplex$1([dy, input], 'avgPoolGrad');
const { filterSize, strides, pad } = attrs;
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad);
const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
}
const avgPoolGradConfig$2 = {
kernelName: AvgPoolGrad,
backendName: 'webgl',
kernelFunc: avgPoolGrad$2
};
function batchMatMul$1(args) {
const { inputs, backend, attrs } = args;
const { a, b } = inputs;
const { transposeA, transposeB } = attrs;
return batchMatMulImpl({ a, b, transposeA, transposeB, backend });
}
const batchMatMulConfig$1 = {
kernelName: BatchMatMul,
backendName: 'webgl',
kernelFunc: batchMatMul$1,
};
class BatchNormProgram {
constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
this.outputShape = [];
this.variableNames = ['x', 'mean', 'variance'];
assertAndGetBroadcastShape(xShape, meanShape);
assertAndGetBroadcastShape(xShape, varianceShape);
let offsetSnippet = '0.0';
if (offsetShape != null) {
assertAndGetBroadcastShape(xShape, offsetShape);
this.variableNames.push('offset');
offsetSnippet = 'getOffsetAtOutCoords()';
}
let scaleSnippet = '1.0';
if (scaleShape != null) {
assertAndGetBroadcastShape(xShape, scaleShape);
this.variableNames.push('scale');
scaleSnippet = 'getScaleAtOutCoords()';
}
this.outputShape = xShape;
this.userCode = `
void main() {
float x = getXAtOutCoords();
float mean = getMeanAtOutCoords();
float variance = getVarianceAtOutCoords();
float offset = ${offsetSnippet};
float scale = ${scaleSnippet};
float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));
setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));
}
`;
}
}
class BatchNormPackedProgram {
constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
this.packedInputs = true;
this.packedOutput = true;
this.variableNames = ['x', 'mean', 'variance'];
assertAndGetBroadcastShape(xShape, meanShape);
assertAndGetBroadcastShape(xShape, varianceShape);
let offsetSnippet = 'vec4(0.0)';
if (offsetShape != null) {
assertAndGetBroadcastShape(xShape, offsetShape);
this.variableNames.push('offset');
offsetSnippet = 'getOffsetAtOutCoords()';
}
let scaleSnippet = 'vec4(1.0)';
if (scaleShape != null) {
assertAndGetBroadcastShape(xShape, scaleShape);
this.variableNames.push('scale');
scaleSnippet = 'getScaleAtOutCoords()';
}
this.outputShape = xShape;
this.userCode = `
void main() {
vec4 offset = ${offsetSnippet};
vec4 scale = ${scaleSnippet};
vec4 x = getXAtOutCoords();
vec4 mean = getMeanAtOutCoords();
vec4 variance = getVarianceAtOutCoords();
vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
setOutput((x - mean) * inv + offset);
}
`;
}
}
const batchNorm$1 = ({ inputs, backend, attrs }) => {
const { x, mean, variance, offset, scale } = inputs;
assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
'equal ranks.');
assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
'equal ranks.');
assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
'equal ranks.');
let { varianceEpsilon } = attrs;
if (varianceEpsilon == null) {
varianceEpsilon = 0.001;
}
const finalInputs = [x, mean, variance];
let offsetShape = null;
if (offset != null) {
offsetShape = offset.shape;
finalInputs.push(offset);
}
let scaleShape = null;
if (scale != null) {
scaleShape = scale.shape;
finalInputs.push(scale);
}
const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) :
new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
return output;
};
const batchNormConfig$1 = {
kernelName: FusedBatchNorm,
backendName: 'webgl',
kernelFunc: batchNorm$1,
};
class SliceProgram {
constructor(destSize) {
this.variableNames = ['source'];
this.outputShape = destSize;
this.rank = destSize.length;
const dtype = getCoordsDataType(this.rank);
this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
const sourceCoords = getCoords$1(this.rank);
let body;
const coordSum = destSize.map((_, i) => {
return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`;
});
body = `
${dtype} sourceLoc;
${dtype} coords = getOutputCoords();
${coordSum.join('\n')}
`;
this.userCode = `
void main() {
${body}
setOutput(getSource(${sourceCoords}));
}
`;
}
}
const coords = ['x', 'y', 'z', 'w', 'u', 'v'];
function getCoords$1(rank) {
if (rank === 1) {
return 'sourceLoc';
}
else if (rank <= 6) {
return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(',');
}
else {
throw Error(`Slicing for rank ${rank} is not yet supported`);
}
}
class SlicePackedProgram {
constructor(destSize) {
this.variableNames = ['source'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = destSize;
this.rank = destSize.length;
this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
const dtype = getCoordsDataType(this.rank);
const coords = getChannels('coords', this.rank);
const sourceLoc = getChannels('sourceLoc', this.rank);
const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`;
const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;
const upperRow = `
result.x = ${getChannel};
if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
++${sourceLoc[this.rank - 1]};
result.y = ${getChannel};
--${sourceLoc[this.rank - 1]};
}
`;
const lowerRow = this.rank === 1 ? '' : `
--${coords[this.rank - 1]};
if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) {
++${sourceLoc[this.rank - 2]};
result.z = ${getChannel};
if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
++${sourceLoc[this.rank - 1]};
result.w = ${getChannel};
}
}
`;
const sourceLocSetup = this.rank <= 4 ?
`sourceLoc = coords +
${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` :
destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`)
.join('\n');
this.userCode = `
void main() {
${dtype} coords = getOutputCoords();
${dtype} sourceLoc;
${sourceLocSetup}
vec4 result = vec4(0.);
${upperRow}
${lowerRow}
setOutput(result);
}
`;
}
}
function shallowSlice(x, begin, size, backend) {
const xTexData = backend.texData.get(x.dataId);
const t = backend.makeTensorInfo(size, x.dtype);
const newTexData = backend.texData.get(t.dataId);
Object.assign(newTexData, xTexData);
newTexData.refCount = 1;
newTexData.shape = size;
newTexData.dtype = x.dtype;
let flatOffset = computeFlatOffset(begin, computeStrides(x.shape));
if (xTexData.slice) {
flatOffset += xTexData.slice.flatOffset;
}
newTexData.slice = {
flatOffset,
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
};
const refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
return t;
}
function slice(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { begin, size } = attrs;
const [$begin, $size] = parseSliceParams(x, begin, size);
assertParamsValid(x, $begin, $size);
if (sizeFromShape($size) === 0) {
return backend.makeTensorInfo($size, x.dtype, []);
}
if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
const xTexData = backend.texData.get(x.dataId);
const outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
return backend.makeTensorInfo($size, x.dtype, outValues);
}
const { isPacked } = backend.texData.get(x.dataId);
const isContinous = isSliceContinous(x.shape, $begin, $size);
if (isPacked || !isContinous) {
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new SlicePackedProgram($size) :
new SliceProgram($size);
const customValues = [$begin];
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
}
backend.uploadToGPU(x.dataId);
return shallowSlice(x, $begin, $size, backend);
}
const sliceConfig = {
kernelName: Slice,
backendName: 'webgl',
kernelFunc: slice
};
const batchToSpaceND$1 = (args) => {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { blockShape, crops } = attrs;
assert$1(x.shape.length <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
'implemented yet');
const prod = blockShape.reduce((a, b) => a * b);
const reshaped = getReshaped(x.shape, blockShape, prod);
const permuted = getPermuted(reshaped.length, blockShape.length);
const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
const toDispose = [];
const reshapedIntermediate = reshape$1({ inputs: { x }, backend, attrs: { shape: reshaped } });
const transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend, attrs: { perm: permuted } });
const reshapedIntermediate2 = reshape$1({
inputs: { x: transposedIntermediate },
backend,
attrs: { shape: reshapedPermuted }
});
const sliced = slice({
inputs: { x: reshapedIntermediate2 },
backend,
attrs: { begin: sliceBeginCoords, size: sliceSize }
});
toDispose.push(reshapedIntermediate);
toDispose.push(transposedIntermediate);
toDispose.push(reshapedIntermediate2);
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return sliced;
};
const batchToSpaceNDConfig$1 = {
kernelName: BatchToSpaceND,
backendName: 'webgl',
kernelFunc: batchToSpaceND$1
};
function bincount$1(args) {
const { inputs, backend, attrs } = args;
const { x, weights } = inputs;
const { size } = attrs;
const xVals = backend.readSync(x.dataId);
const weightsVals = backend.readSync(weights.dataId);
const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
}
const bincountConfig$1 = {
kernelName: Bincount,
backendName: 'webgl',
kernelFunc: bincount$1
};
const BITWISEAND = `
int r = int(a.r) & int(b.r);
int g = int(a.g) & int(b.g);
int rb = int(a.b) & int(b.b);
int ra = int(a.a) & int(b.a);
return vec4(r, g, rb, ra);
`;
const BITWISEAND_UNPACKED = `
return float(int(a.r) & int(b.r));
`;
function bitwiseAnd(args) {
const { inputs, backend } = args;
const { a, b } = inputs;
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');
const versionNumber = env().getNumber('WEBGL_VERSION');
if ((backend.shouldExecuteOnCPU([a, b])) || versionNumber === 1) {
const aVals = backend.texData.get(a.dataId).values;
const bVals = backend.texData.get(b.dataId).values;
const [outValues, outShape] = bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype);
const out = backend.makeTensorInfo(outShape, a.dtype);
const outData = backend.texData.get(out.dataId);
outData.values = outValues;
return out;
}
let program;
if (shouldUsePackedProgram) {
program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false);
}
else {
program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape);
}
return backend.runWebGLProgram(program, [a, b], a.dtype);
}
const bitwiseAndConfig = {
kernelName: BitwiseAnd,
backendName: 'webgl',
kernelFunc: bitwiseAnd
};
function broadcastArgs$1(args) {
const { inputs, backend } = args;
const { s0, s1 } = inputs;
const s0Vals = backend.readSync(s0.dataId);
const s1Vals = backend.readSync(s1.dataId);
const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
}
const broadcastArgsConfig$1 = {
kernelName: BroadcastArgs,
backendName: 'webgl',
kernelFunc: broadcastArgs$1
};
const NOT_EQUAL = `return float(a != b);`;
const notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' });
const notEqualConfig = {
kernelName: NotEqual,
backendName: 'webgl',
kernelFunc: notEqual,
};
function real(args) {
const { inputs, backend } = args;
const { input } = inputs;
const inputData = backend.texData.get(input.dataId);
return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend });
}
const realConfig = {
kernelName: Real,
backendName: 'webgl',
kernelFunc: real
};
const TO_INT = `return float(int(x));`;
function int(input, backend) {
const program = new UnaryOpProgram(input.shape, TO_INT);
const output = backend.runWebGLProgram(program, [input], 'int32');
return { dataId: output.dataId, shape: output.shape, dtype: output.dtype };
}
function cast$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { dtype } = attrs;
if (dtype === 'complex64') {
if (x.dtype === 'complex64') {
return identity({ inputs: { x }, backend });
}
const zerosTensor = zeros$1(x.shape);
const floatX = cast$1({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend });
zerosTensor.dispose();
backend.disposeIntermediateTensorInfo(floatX);
return result;
}
if (x.dtype === 'complex64') {
const realPart = real({ inputs: { input: x }, backend });
const result = cast$1({ inputs: { x: realPart }, backend, attrs: { dtype } });
backend.disposeIntermediateTensorInfo(realPart);
return result;
}
if (!hasEncodingLoss(x.dtype, dtype)) {
const result = identity({ inputs: { x }, backend });
return { dataId: result.dataId, shape: result.shape, dtype };
}
if (backend.shouldExecuteOnCPU([x])) {
const values = backend.texData.get(x.dataId).values;
const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype);
return backend.makeTensorInfo(resultShape, resultType, resultData);
}
if (dtype === 'int32') {
return int(x, backend);
}
if (dtype === 'bool') {
const zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1));
const binaryInputs = { a: x, b: zerosTensorInfo };
const result = notEqual({ inputs: binaryInputs, backend });
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
return result;
}
throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
}
const castConfig = {
kernelName: Cast,
backendName: 'webgl',
kernelFunc: cast$1
};
const CEIL = `return ceil(x);`;
const ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU });
const ceilConfig = {
kernelName: Ceil,
backendName: 'webgl',
kernelFunc: ceil
};
class ClipProgram {
constructor(aShape) {
this.variableNames = ['A'];
this.customUniforms = [
{ name: 'minVal', type: 'float' },
{ name: 'maxVal', type: 'float' }
];
this.outputShape = aShape;
this.userCode = `
void main() {
float value = getAAtOutCoords();
if (isnan(value)) {
setOutput(value);
return;
}
setOutput(clamp(value, minVal, maxVal));
}
`;
}
}
class ClipPackedProgram {
constructor(aShape) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [
{ name: 'minVal', type: 'float' },
{ name: 'maxVal', type: 'float' }
];
this.outputShape = aShape;
this.userCode = `
void main() {
vec4 value = getAAtOutCoords();
if (any(isnan(value))) {
setOutput(value);
return;
}
setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
}
`;
}
}
function clipByValue$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { clipValueMin, clipValueMax } = attrs;
let program;
if (env().getBool('WEBGL_PACK_CLIP')) {
program = new ClipPackedProgram(x.shape);
}
else {
program = new ClipProgram(x.shape);
}
const customValues = [[clipValueMin], [clipValueMax]];
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
}
const clipByValueConfig$1 = {
kernelName: ClipByValue,
backendName: 'webgl',
kernelFunc: clipByValue$1
};
class ComplexAbsProgram {
constructor(shape) {
this.variableNames = ['real', 'imag'];
this.outputShape = shape;
this.userCode = `
void main() {
float re = abs(getRealAtOutCoords());
float im = abs(getImagAtOutCoords());
float mx = max(re, im);
setOutput(
mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
);
}
`;
}
}
function makeComplexComponentTensorInfo(complexTensor, complexPart) {
return {
dataId: complexPart.dataId,
dtype: complexPart.dtype,
shape: complexTensor.shape
};
}
function complexAbs$1(args) {
const { inputs, backend } = args;
const { x } = inputs;
const xData = backend.texData.get(x.dataId);
const program = new ComplexAbsProgram(x.shape);
const programInputs = [
makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag),
];
return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
}
const complexAbsConfig$1 = {
kernelName: ComplexAbs,
backendName: 'webgl',
kernelFunc: complexAbs$1
};
class ConcatProgram {
constructor(shapes) {
this.outputShape = [];
this.outputShape = computeOutShape$1(shapes, 1 );
this.variableNames = shapes.map((_, i) => `T${i}`);
const offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][1];
for (let i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][1];
}
const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
for (let i = 1; i < offsets.length; i++) {
const shift = offsets[i - 1];
snippets.push(`else if (yC < ${offsets[i]}) ` +
`setOutput(getT${i}(yR, yC-${shift}));`);
}
const lastIndex = offsets.length;
const lastShift = offsets[offsets.length - 1];
snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int yR = coords.x;
int yC = coords.y;
${snippets.join('\n ')}
}
`;
}
}
class ConcatPackedProgram {
constructor(shapes, axis) {
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
this.outputShape = computeOutShape$1(shapes, axis);
const shape = this.outputShape;
const rank = shape.length;
const dtype = getCoordsDataType(rank);
const coords = getChannels('coords', rank);
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
this.variableNames = shapes.map((_, i) => `T${i}`);
const offsets = new Array(shapes.length - 1);
offsets[0] = shapes[0][axis];
for (let i = 1; i < offsets.length; i++) {
offsets[i] = offsets[i - 1] + shapes[i][axis];
}
const channel = channels[axis];
const lastChannels = channels.slice(-2);
const allChannels = channels.join();
let getValueSnippet = `if (${channel} < ${offsets[0]}) {
return getChannel(
getT0(${allChannels}), vec2(${lastChannels.join()}));
}`;
for (let i = 1; i < offsets.length; i++) {
const shift = offsets[i - 1];
getValueSnippet += `
if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {
return getChannel(
getT${i}(${shiftedChannels(channels, channel, shift)}),
vec2(${shiftedChannels(lastChannels, channel, shift)}));
}`;
}
const lastIndex = offsets.length;
const shift = offsets[offsets.length - 1];
getValueSnippet += `
return getChannel(
getT${lastIndex}(${shiftedChannels(channels, channel, shift)}),
vec2(${shiftedChannels(lastChannels, channel, shift)}));`;
this.userCode = `
float getValue(${channels.map(x => 'int ' + x)}) {
${getValueSnippet}
}
void main() {
${dtype} coords = getOutputCoords();
vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
${coords[rank - 1]} = ${coords[rank - 1]} + 1;
if (${coords[rank - 1]} < ${shape[rank - 1]}) {
result.g = getValue(${coords});
}
${coords[rank - 2]} = ${coords[rank - 2]} + 1;
if (${coords[rank - 2]} < ${shape[rank - 2]}) {
result.a = getValue(${coords});
}
${coords[rank - 1]} = ${coords[rank - 1]} - 1;
if (${coords[rank - 2]} < ${shape[rank - 2]} &&
${coords[rank - 1]} < ${shape[rank - 1]}) {
result.b = getValue(${coords});
}
setOutput(result);
}
`;
}
}
function shiftedChannels(channels, channel, shift) {
const channelIdx = channels.indexOf(channel);
const res = channels.map((c, idx) => {
if (idx === channelIdx) {
return `${c} - ${shift}`;
}
else {
return c;
}
});
return res.join();
}
function imag$1(args) {
const { inputs, backend } = args;
const { input } = inputs;
const inputData = backend.texData.get(input.dataId);
return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend });
}
const imagConfig$1 = {
kernelName: Imag,
backendName: 'webgl',
kernelFunc: imag$1
};
function concatImpl(inputs, axis, backend) {
const dtype = inputs[0].dtype;
if (dtype === 'complex64') {
const reals = inputs.map((t) => real({ inputs: { input: t }, backend }));
const imags = inputs.map((t) => imag$1({ inputs: { input: t }, backend }));
const realConcated = concatImpl(reals, axis, backend);
const imagConcated = concatImpl(imags, axis, backend);
const result = complex({ inputs: { real: realConcated, imag: imagConcated }, backend });
reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
backend.disposeIntermediateTensorInfo(realConcated);
backend.disposeIntermediateTensorInfo(imagConcated);
return result;
}
let runOnCpu = backend.shouldExecuteOnCPU(inputs);
if (dtype === 'string') {
runOnCpu = true;
}
if (runOnCpu) {
const tensors2D = inputs.map(t => {
const innerSize = sizeFromShape(t.shape.slice(axis));
const shape = [-1, innerSize];
return reshape$1({ inputs: { x: t }, backend, attrs: { shape } });
});
const inputsValShapes = tensors2D.map(t => {
return { vals: backend.readSync(t.dataId), shape: t.shape };
});
const outShape = computeOutShape$1(tensors2D.map(t => t.shape), 1 );
const simplyConcat = tensors2D[0].shape[0] === 1;
const outVals = concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat);
const finalOutShape = computeOutShape$1(inputs.map(t => t.shape), axis);
const outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
tensors2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
return outInfo;
}
const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
const shouldPack = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') &&
$inputs[0].shape.length > 1;
if ($inputs.length === 1) {
const program = shouldPack ?
new UnaryOpProgram(inputs[0].shape, CLONE) :
new UnaryOpPackedProgram(inputs[0].shape, CLONE);
return backend.runWebGLProgram(program, inputs, dtype);
}
const maxTexturesInShader = env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER');
if ($inputs.length > maxTexturesInShader) {
const reducedInputs = [];
for (let i = 0; i < $inputs.length; i += maxTexturesInShader) {
const subArray = $inputs.slice(i, i + maxTexturesInShader);
reducedInputs.push(concatImpl(subArray, axis, backend));
}
const result = concatImpl(reducedInputs, axis, backend);
for (const i of reducedInputs) {
backend.disposeIntermediateTensorInfo(i);
}
return result;
}
if (shouldPack) {
const program = new ConcatPackedProgram($inputs.map(t => t.shape), axis);
return backend.runWebGLProgram(program, $inputs, dtype);
}
const { tensors2D, outShape } = computeTensors2D($inputs, axis, backend);
const program = new ConcatProgram(tensors2D.map(t => t.shape));
const result = backend.runWebGLProgram(program, tensors2D, dtype);
tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r));
const reshapedResult = reshape$1({ inputs: { x: result }, attrs: { shape: outShape }, backend });
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
function computeTensors2D(inputs, axis, backend) {
const outShape = computeOutShape$1(inputs.map(t => t.shape), axis);
const tensors2D = inputs.map(x => reshape$1({
inputs: { x },
attrs: { shape: [-1, sizeFromShape(x.shape.slice(axis))] },
backend
}));
return { tensors2D, outShape };
}
function concat$1(args) {
const { inputs, backend, attrs } = args;
const { axis } = attrs;
const $axis = parseAxisParam(axis, inputs[0].shape)[0];
const shapes = inputs.map(t => t.shape);
assertParamsConsistent(shapes, $axis);
const outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
if (sizeFromShape(outShape) === 0) {
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
}
const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
if ($inputs.length === 1) {
return identity({ inputs: { x: $inputs[0] }, backend });
}
return concatImpl($inputs, $axis, backend);
}
const concatConfig$1 = {
kernelName: Concat,
backendName: 'webgl',
kernelFunc: concat$1
};
class Conv2DProgram {
constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false, hasLeakyreluAlpha = false) {
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
const inputDepthVec4Remainder = convInfo.inChannels % 4;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;
let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivationWeights) {
activationSnippet = `float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
}
else if (hasLeakyreluAlpha) {
activationSnippet = `float activation(float a) {
float b = getLeakyreluAlphaAtOutCoords();
${activation}
}`;
}
else {
activationSnippet = `
float activation(float x) {
${activation}
}
`;
}
applyActivationSnippet = `result = activation(result);`;
}
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyreluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = `
${activationSnippet}
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d2 = coords[${channelDim}];
ivec2 xRCCorner =
ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
vec4 wValues = vec4(
getW(wR, wC, d1, d2),
getW(wR, wC, d1 + 1, d2),
getW(wR, wC, d1 + 2, d2),
getW(wR, wC, d1 + 3, d2)
);
if (${isChannelsLast}) {
vec4 xValues = vec4(
getX(batch, xR, xC, d1),
getX(batch, xR, xC, d1 + 1),
getX(batch, xR, xC, d1 + 2),
getX(batch, xR, xC, d1 + 3)
);
dotProd += dot(xValues, wValues);
} else {
vec4 xValues = vec4(
getX(batch, d1, xR, xC),
getX(batch, d1 + 1, xR, xC),
getX(batch, d1 + 2, xR, xC),
getX(batch, d1 + 3, xR, xC)
);
dotProd += dot(xValues, wValues);
}
}
if (${inputDepthVec4Remainder === 1}) {
if (${isChannelsLast}) {
dotProd +=
getX(batch, xR, xC, ${inputDepthNearestVec4}) *
getW(wR, wC, ${inputDepthNearestVec4}, d2);
} else {
dotProd +=
getX(batch, ${inputDepthNearestVec4}, xR, xC) *
getW(wR, wC, ${inputDepthNearestVec4}, d2);
}
} else if (${inputDepthVec4Remainder === 2}) {
vec2 wValues = vec2(
getW(wR, wC, ${inputDepthNearestVec4}, d2),
getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)
);
if (${isChannelsLast}) {
vec2 xValues = vec2(
getX(batch, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)
);
dotProd += dot(xValues, wValues);
} else {
vec2 xValues = vec2(
getX(batch, ${inputDepthNearestVec4}, xR, xC),
getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)
);
dotProd += dot(xValues, wValues);
}
} else if (${inputDepthVec4Remainder === 3}) {
vec3 wValues = vec3(
getW(wR, wC, ${inputDepthNearestVec4}, d2),
getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),
getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)
);
if (${isChannelsLast}) {
vec3 xValues = vec3(
getX(batch, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),
getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)
);
dotProd += dot(xValues, wValues);
} else {
vec3 xValues = vec3(
getX(batch, ${inputDepthNearestVec4}, xR, xC),
getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),
getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)
);
dotProd += dot(xValues, wValues);
}
}
}
}
float result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
}
class Conv3DProgram {
constructor(convInfo) {
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const filterDepth = convInfo.filterDepth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
const inputDepthVec4Remainder = convInfo.inChannels % 4;
this.userCode = `
const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int d2 = coords.u;
ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
int xFCorner = xFRCCorner.x;
int xRCorner = xFRCCorner.y;
int xCCorner = xFRCCorner.z;
float dotProd = 0.0;
for (int wF = 0; wF < ${filterDepth}; wF++) {
int xF = xFCorner + wF * ${dilationDepth};
if (xF < 0 || xF >= ${convInfo.inDepth}) {
continue;
}
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * ${dilationHeight};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * ${dilationWidth};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
vec4 xValues = vec4(
getX(batch, xF, xR, xC, d1),
getX(batch, xF, xR, xC, d1 + 1),
getX(batch, xF, xR, xC, d1 + 2),
getX(batch, xF, xR, xC, d1 + 3)
);
vec4 wValues = vec4(
getW(wF, wR, wC, d1, d2),
getW(wF, wR, wC, d1 + 1, d2),
getW(wF, wR, wC, d1 + 2, d2),
getW(wF, wR, wC, d1 + 3, d2)
);
dotProd += dot(xValues, wValues);
}
if (${inputDepthVec4Remainder === 1}) {
dotProd +=
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);
} else if (${inputDepthVec4Remainder === 2}) {
vec2 xValues = vec2(
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)
);
vec2 wValues = vec2(
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)
);
dotProd += dot(xValues, wValues);
} else if (${inputDepthVec4Remainder === 3}) {
vec3 xValues = vec3(
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)
);
vec3 wValues = vec3(
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),
getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)
);
dotProd += dot(xValues, wValues);
}
}
}
}
setOutput(dotProd);
}
`;
}
}
class Conv2DPackedProgram {
constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
this.variableNames = ['x', 'W'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [
{ name: 'pads', type: 'ivec2' },
{ name: 'strides', type: 'ivec2' },
{ name: 'dilations', type: 'ivec2' },
{ name: 'inDims', type: 'ivec2' },
];
this.outputShape = convInfo.outShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const padLeft = convInfo.padInfo.left;
const strideWidth = convInfo.strideWidth;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const texelsAcross = filterWidth;
let mainLoop = `
int xR; int xC; int xCOffset;
vec4 wTexel; vec4 previous; vec4 final;`;
for (let c = 0; c < filterWidth; c++) {
mainLoop += `
vec4 xTexelC${c * 2};
int xTexelC${c * 2}Ready;
vec4 xTexelC${c * 2 + 1};
int xTexelC${c * 2 + 1}Ready;
vec4 xC${c};`;
}
mainLoop += `
for (int r = 0; r < ${filterHeight}; r++) {
for (int d1 = 0; d1 < ${convInfo.inChannels}; d1 += 2) {
`;
for (let c = 0; c < filterWidth; c++) {
mainLoop += `
xTexelC${c * 2} = vec4(0.0);
xTexelC${c * 2}Ready = 0;
xTexelC${c * 2 + 1} = vec4(0.0);
xTexelC${c * 2 + 1}Ready = 0;
xC${c} = vec4(0.0);`;
}
mainLoop += `
xR = xRCorner + r * dilations[0];
if (xR >=0 && xR < inDims[0]) {
`;
for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
const colIndex = texelC * 2;
mainLoop += `
xC = xCCorner + ${colIndex * dilationWidth};
`;
if (strideWidth === 1) {
if (colIndex < filterWidth) {
if (padLeft % 2 === 1) {
mainLoop += `
xCOffset = xC + 1;
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
`;
if (dilationWidth === 1 && colIndex > 0) {
mainLoop += `
xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
`;
}
else {
mainLoop += `
xCOffset = xC + 1 - 2;
if (xCOffset >= 0 && xCOffset < inDims[1]) {
previous = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
previous.zw = vec2(0.0);
}
xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
} else {
xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
}
`;
}
}
else {
mainLoop += `
if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xC, d1);
if (xC + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
xC${colIndex} = xTexelC${colIndex};
`;
}
if (colIndex + 1 < filterWidth) {
const nextTexelOffset = padLeft % 2 === 0 ?
nearestLargerEven(dilationWidth) :
dilationWidth;
if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
(dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
mainLoop += `
xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.0);
}
xTexelC${colIndex + 1}Ready = 1;
}
`;
if (dilationWidth > 1) {
mainLoop += `
xCOffset -= 2;
if (xCOffset >= 0 && xCOffset < inDims[1]) {
previous = getX(batch, xR, xCOffset, d1);
xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy);
} else {
xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy);
}
`;
}
else {
mainLoop += `
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
`;
}
}
else {
if (nextTexelOffset === 1) {
mainLoop += `
xC${colIndex + 1} = xTexelC${colIndex};
`;
}
else {
mainLoop += `
xCOffset = xC + ${nextTexelOffset};
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.0);
}
xTexelC${colIndex + 1}Ready = 1;
}
xC${colIndex + 1} = xTexelC${colIndex + 1};
`;
}
}
}
}
}
else {
if (colIndex < filterWidth) {
if (padLeft % 2 === 1) {
mainLoop += `
xCOffset = xC + 1 - strides[1];
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
if (xC + 2 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.0);
}
xTexelC${colIndex + 1}Ready = 1;
}
xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
`;
if (colIndex + 1 < filterWidth) {
mainLoop += `
final = vec4(0.0);
xCOffset = xC + 1 + strides[1];
if(xCOffset >= 0 && xCOffset < inDims[1]) {
final = getX(batch, xR, xCOffset, d1);
}
xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
`;
}
}
else {
mainLoop += `
if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xC, d1);
if (xC + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
xCOffset = xC + strides[1];
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.);
}
xTexelC${colIndex + 1}Ready = 1;
}
xC${colIndex} = vec4(
xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
`;
if (colIndex + 1 < filterWidth) {
mainLoop += `
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
`;
}
}
}
}
if (colIndex < filterWidth) {
mainLoop += `
wTexel = getW(r, ${colIndex}, d1, d2);
dotProd += xC${colIndex}.xxzz * vec4(wTexel.xy, wTexel.xy);
if(d1 + 1 < ${convInfo.inChannels}) {
dotProd += xC${colIndex}.yyww * vec4(wTexel.zw, wTexel.zw);
}
`;
if (colIndex + 1 < filterWidth) {
mainLoop += `
wTexel = getW(r, ${colIndex + 1}, d1, d2);
dotProd += xC${colIndex + 1}.xxzz * vec4(wTexel.xy, wTexel.xy);
if(d1 + 1 < ${convInfo.inChannels}) {
dotProd += xC${colIndex + 1}.yyww * vec4(wTexel.zw, wTexel.zw);
}
`;
}
}
}
mainLoop += `
}
`;
mainLoop += `
}
`;
mainLoop += `
}
`;
let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
}
else if (hasLeakyReluAlpha) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getLeakyreluAlphaAtOutCoords();
${activation}
}`;
}
else {
activationSnippet = `vec4 activation(vec4 x) {
${activation}
}`;
}
applyActivationSnippet = `result = activation(result);`;
}
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyReluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = `
${activationSnippet}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
ivec2 xRCCorner = coords.yz * strides - pads;
int d2 = coords.w;
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
vec4 dotProd = vec4(0.000000000000001);
${mainLoop}
vec4 result = dotProd - vec4(0.000000000000001);
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
}
class Im2ColPackedProgram {
constructor(outputShape, convInfo) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [
{ name: 'inputShape', type: 'ivec4' },
{ name: 'pad', type: 'ivec2' },
{ name: 'stride', type: 'ivec2' },
{ name: 'dilation', type: 'ivec2' },
{ name: 'inChannels', type: 'int' },
{ name: 'itemsPerBlockRow', type: 'int' },
{ name: 'outWidth', type: 'int' },
];
this.outputShape = outputShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const { dataFormat } = convInfo;
const glsl = getGlslDifferences();
const isChannelsLast = dataFormat === 'channelsLast';
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const boundsCheckingSnippet = this.enableShapeUniforms ?
'if(blockIndex < outShape[2] && pos < outShape[1]) {' :
`if(blockIndex < ${outputShape[2]} && pos < ${outputShape[1]}) {`;
let unrolled = ``;
for (let row = 0; row <= 1; row++) {
for (let col = 0; col <= 1; col++) {
unrolled += `
blockIndex = rc.z + ${col};
pos = rc.y + ${row};
${boundsCheckingSnippet}
offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];
d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);
if(d0 < inputShape[${rowDim}] && d0 >= 0) {
offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];
d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /
inChannels);
if(d1 < inputShape[${colDim}] && d1 >= 0) {
ch = imod(pos, inChannels);
if (${isChannelsLast}) {
innerDims = vec2(d1, ch);
result[${row * 2 + col}] = getChannel(
getA(rc.x, d0, int(innerDims.x),
int(innerDims.y)), innerDims);
} else {
innerDims = vec2(d0, d1);
result[${row * 2 + col}] = getChannel(
getA(rc.x, ch, int(innerDims.x),
int(innerDims.y)), innerDims);
}
}
}
}
`;
}
}
this.userCode = `
void main() {
ivec3 rc = getOutputCoords();
vec4 result = vec4(0);
int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
vec2 innerDims;
${unrolled}
${glsl.output} = result;
}
`;
}
}
function getShapeForBatchMatMul(shape, isChannelsLast) {
const length = shape.length;
if (length >= 3) {
return isChannelsLast ?
[
...shape.slice(0, -3) ,
shape[length - 3] * shape[length - 2] ,
shape[length - 1]
] :
[
...shape.slice(0, -3) , shape[length - 3] ,
shape[length - 2] * shape[length - 1]
];
}
else if (!isChannelsLast && length === 1 && shape[0] > 1) {
return [shape[0], 1];
}
else {
return null;
}
}
function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
const xShape = x.shape;
const xTexData = backend.texData.get(x.dataId);
const sharedMatMulDim = convInfo.inChannels;
const outerShapeX = xShape[0] * xShape[1] * xShape[2];
const outerShapeFilter = convInfo.outChannels;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
const transposeA = false;
const transposeB = false;
let out;
const intermediates = [];
if (preluActivationWeights != null) {
const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
if (targetShape != null) {
preluActivationWeights = reshape$1({
inputs: { x: preluActivationWeights },
backend,
attrs: { shape: targetShape }
});
intermediates.push(preluActivationWeights);
}
}
if (bias != null) {
const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
if (targetShape != null) {
bias = reshape$1({ inputs: { x: bias }, backend, attrs: { shape: targetShape } });
intermediates.push(bias);
}
}
const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) &&
sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked &&
isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 &&
arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));
if (canOptimize) {
const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1);
const xReshaped = {
dataId: x.dataId,
shape: [1, targetShape, convInfo.inChannels],
dtype: x.dtype
};
const originalXTexDataShape = xTexData.shape;
xTexData.shape = xTexData.shape.slice();
xTexData.shape[xTexData.shape.length - 2]++;
assert$1(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`);
const filterReshaped = reshape$1({
inputs: { x: filter },
backend,
attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
});
intermediates.push(filterReshaped);
const pointwiseConv = batchMatMulImpl({
a: xReshaped,
b: filterReshaped,
backend,
transposeA,
transposeB,
bias,
activation,
preluActivationWeights,
leakyreluAlpha
});
const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
assert$1(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed');
xTexData.shape = originalXTexDataShape;
pointwiseConvTexData.shape = convInfo.outShape;
out = identity({ inputs: { x: pointwiseConv }, backend });
out.shape = convInfo.outShape;
intermediates.push(pointwiseConv);
}
else {
const numCols = convInfo.outHeight * convInfo.outWidth;
const xReshaped = reshape$1({
inputs: { x },
backend,
attrs: {
shape: isChannelsLast ?
[convInfo.batchSize, numCols, convInfo.inChannels] :
[convInfo.batchSize, convInfo.inChannels, numCols]
}
});
const filterReshaped = reshape$1({
inputs: { x: filter },
backend,
attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
});
const result = batchMatMulImpl({
a: isChannelsLast ? xReshaped : filterReshaped,
b: isChannelsLast ? filterReshaped : xReshaped,
transposeA: !isChannelsLast,
transposeB,
backend,
bias,
activation,
preluActivationWeights,
leakyreluAlpha
});
out = reshape$1({ inputs: { x: result }, backend, attrs: { shape: convInfo.outShape } });
intermediates.push(xReshaped);
intermediates.push(filterReshaped);
intermediates.push(result);
}
for (const i of intermediates) {
backend.disposeIntermediateTensorInfo(i);
}
return out;
}
function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo;
const isChannelsLast = dataFormat === 'channelsLast';
const sharedDim = filterWidth * filterHeight * inChannels;
const numCols = outHeight * outWidth;
const x2ColShape = [convInfo.batchSize, sharedDim, numCols];
const transposeA = true;
const transposeB = false;
const intermediates = [];
if (preluActivationWeights != null) {
const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
if (targetShape != null) {
preluActivationWeights = reshape$1({
inputs: { x: preluActivationWeights },
backend,
attrs: { shape: targetShape }
});
intermediates.push(preluActivationWeights);
}
}
if (bias != null) {
const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
if (targetShape != null) {
bias = reshape$1({ inputs: { x: bias }, backend, attrs: { shape: targetShape } });
intermediates.push(bias);
}
}
const w2Row = reshape$1({
inputs: { x: filter },
backend,
attrs: { shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim] }
});
intermediates.push(w2Row);
const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);
const customValues = [
x.shape, [convInfo.padInfo.top, convInfo.padInfo.left],
[convInfo.strideHeight, convInfo.strideWidth],
[convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels],
[convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]
];
const im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues);
const im2ColReshaped = reshape$1({ inputs: { x: im2Col }, backend, attrs: { shape: x2ColShape } });
intermediates.push(im2Col);
intermediates.push(im2ColReshaped);
const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const hasLeakyreluAlpha = activation === 'leakyrelu';
const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
const matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape :
w2Row.shape, isChannelsLast ? w2Row.shape :
im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] :
[convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
const inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped];
if (bias) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
inputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
const out = reshape$1({ inputs: { x: product }, backend, attrs: { shape: convInfo.outShape } });
intermediates.push(product);
for (const i of intermediates) {
backend.disposeIntermediateTensorInfo(i);
}
return out;
}
function conv2d(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
const $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat);
let out;
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
out = conv2dByMatMul({ x, filter, convInfo, backend });
}
else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'
&& env().getBool('WEBGL_EXP_CONV')) {
const program = new Conv2DPackedProgram(convInfo);
const customValues = [
[convInfo.padInfo.top, convInfo.padInfo.left],
[convInfo.strideHeight, convInfo.strideWidth],
[convInfo.dilationHeight, convInfo.dilationWidth],
[convInfo.inHeight, convInfo.inWidth]
];
out =
backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
}
else if (env().getBool('WEBGL_CONV_IM2COL')) {
out = conv2dWithIm2Row({ x, filter, convInfo, backend });
}
else {
const program = new Conv2DProgram(convInfo);
out = backend.runWebGLProgram(program, [x, filter], 'float32');
}
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
backend.disposeIntermediateTensorInfo(out);
return outReshaped;
}
const conv2DConfig$1 = {
kernelName: Conv2D,
backendName: 'webgl',
kernelFunc: conv2d,
};
class Conv2DDerFilterProgram {
constructor(convInfo) {
this.variableNames = ['x', 'dy'];
this.outputShape = convInfo.filterShape;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int wR = coords.x;
int wC = coords.y;
int d1 = coords.z;
int d2 = coords.w;
float dotProd = 0.0;
for (int b = 0; b < ${convInfo.batchSize}; b++) {
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
int xR = wR + yR * ${strideHeight} - ${padTop};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
int xC = wC + yC * ${strideWidth} - ${padLeft};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
${isChannelsLast ?
`float dyValue = getDy(b, yR, yC, d2);
float xValue = getX(b, xR, xC, d1);
dotProd += (xValue * dyValue);` :
`float dyValue = getDy(b, d2, yR, yC);
float xValue = getX(b, d1, xR, xC);
dotProd += (xValue * dyValue);`}
}
}
}
setOutput(dotProd);
}
`;
}
}
class Conv2DDerInputProgram {
constructor(convInfo) {
this.variableNames = ['dy', 'W'];
this.outputShape = convInfo.inShape;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
const padTop = filterHeight - 1 - convInfo.padInfo.top;
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;
this.userCode = `
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d1 = coords[${channelDim}];
ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
int dyRCorner = dyCorner.x;
int dyCCorner = dyCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
int wCPerm = ${filterWidth} - 1 - wC;
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
if (${isChannelsLast}) {
float xValue = getDy(batch, idyR, idyC, d2);
float wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd += xValue * wValue;
} else {
float xValue = getDy(batch, d2, idyR, idyC);
float wValue = getW(wRPerm, wCPerm, d1, d2);
dotProd += xValue * wValue;
}
}
}
}
setOutput(dotProd);
}
`;
}
}
class Conv3DDerFilterProgram {
constructor(convInfo) {
this.variableNames = ['x', 'dy'];
this.outputShape = convInfo.filterShape;
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
this.userCode = `
void main() {
ivec5 coords = getOutputCoords();
int wF = coords.x;
int wR = coords.y;
int wC = coords.z;
int d1 = coords.w;
int d2 = coords.u;
float dotProd = 0.0;
for (int b = 0; b < ${convInfo.batchSize}; b++) {
for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {
int xF = wF + yF * ${strideDepth} - ${padFront};
if (xF < 0 || xF >= ${convInfo.inDepth}) {
continue;
}
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
int xR = wR + yR * ${strideHeight} - ${padTop};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
int xC = wC + yC * ${strideWidth} - ${padLeft};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float dyValue = getDy(b, yF, yR, yC, d2);
float xValue = getX(b, xF, xR, xC, d1);
dotProd += (xValue * dyValue);
}
}
}
}
setOutput(dotProd);
}
`;
}
}
class Conv3DDerInputProgram {
constructor(convInfo) {
this.variableNames = ['dy', 'W'];
this.outputShape = convInfo.inShape;
const filterDepth = convInfo.filterDepth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const padFront = filterDepth - 1 - convInfo.padInfo.front;
const padTop = filterHeight - 1 - convInfo.padInfo.top;
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
this.userCode = `
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int d1 = coords.u;
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
int dyFCorner = dyCorner.x;
int dyRCorner = dyCorner.y;
int dyCCorner = dyCorner.z;
float dotProd = 0.0;
for (int wF = 0; wF < ${filterDepth}; wF++) {
float dyF = float(dyFCorner + wF) / ${strideDepth}.0;
if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {
continue;
}
int idyF = int(dyF);
int wFPerm = ${filterDepth} - 1 - wF;
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
int wCPerm = ${filterWidth} - 1 - wC;
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
float xValue = getDy(batch, idyF, idyR, idyC, d2);
float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);
dotProd += xValue * wValue;
}
}
}
}
setOutput(dotProd);
}
`;
}
}
function conv2DBackpropFilter$1(args) {
const { inputs, backend, attrs } = args;
const { x, dy } = inputs;
const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
const $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 , pad, dimRoundingMode, false , $dataFormat);
const program = new Conv2DDerFilterProgram(convInfo);
return backend.runWebGLProgram(program, [x, dy], 'float32');
}
const conv2DBackpropFilterConfig$1 = {
kernelName: Conv2DBackpropFilter,
backendName: 'webgl',
kernelFunc: conv2DBackpropFilter$1,
};
class Conv2DDerInputPackedProgram {
constructor(convInfo) {
this.variableNames = ['dy', 'W'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [
{ name: 'strides', type: 'vec2' },
];
this.outputShape = convInfo.inShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const padTop = filterHeight - 1 - convInfo.padInfo.top;
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
this.userCode = `
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d1 = coords[3];
ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads;
int dyRCorner = dyCorner.x;
int dyCCorner = dyCorner.y;
vec4 result = vec4(0.);
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / strides[0];
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
int wCPerm = ${filterWidth} - 1 - wC;
float dyC = float(dyCCorner + wC) / strides[1];
bool idyCVal = (dyC >= 0.0) && (dyC < ${convInfo.outWidth}.0)
&& (fract(dyC) == 0.0);
int idyC = int(dyC);
float dyC2 = float(dyCCorner + wC + 1) / strides[1];
bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ${convInfo.outWidth}.0)
&& (fract(dyC2) == 0.0);
int idyC2 = int(dyC2);
if (idyCVal && idyCVal2) {
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
vec4 dySample = getDy(batch, idyR, idyC, d2);
vec4 dySample2 = (idyC / 2 == idyC2 / 2) ?
dySample : getDy(batch, idyR, idyC2, d2);
vec2 dyValue = mod(float(idyC), 2.) == 0. ?
dySample.xy : dySample.zw;
result.xy += vec2(dot(dyValue, wValue.xy),
dot(dyValue, wValue.zw));
dyValue = mod(float(idyC2), 2.) == 0. ?
dySample2.xy : dySample2.zw;
result.zw += vec2(dot(dyValue, wValue.xy),
dot(dyValue, wValue.zw));
}
} else if (idyCVal) {
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
vec4 dySample = getDy(batch, idyR, idyC, d2);
vec2 dyValue = mod(float(idyC), 2.) == 0. ?
dySample.xy : dySample.zw;
result.xy += vec2(dot(dyValue, wValue.xy),
dot(dyValue, wValue.zw));
}
} else if (idyCVal2) {
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
vec4 dySample = getDy(batch, idyR, idyC2, d2);
vec2 dyValue = mod(float(idyC2), 2.) == 0. ?
dySample.xy : dySample.zw;
result.zw += vec2(dot(dyValue, wValue.xy),
dot(dyValue, wValue.zw));
}
}
}
}
setOutput(result);
}
`;
}
}
function conv2DBackpropInput$1(args) {
const { inputs, backend, attrs } = args;
const { dy, filter } = inputs;
const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
const $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 , pad, dimRoundingMode, false, $dataFormat);
if (env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') &&
$dataFormat === 'channelsLast') {
const customValues = [
[convInfo.strideHeight, convInfo.strideWidth],
];
const program = new Conv2DDerInputPackedProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues);
}
else {
const program = new Conv2DDerInputProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32');
}
}
const conv2DBackpropInputConfig$1 = {
kernelName: Conv2DBackpropInput,
backendName: 'webgl',
kernelFunc: conv2DBackpropInput$1,
};
function conv3D$1(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dilations } = attrs;
const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
const program = new Conv3DProgram(convInfo);
return backend.runWebGLProgram(program, [x, filter], 'float32');
}
const conv3DConfig$1 = {
kernelName: Conv3D,
backendName: 'webgl',
kernelFunc: conv3D$1,
};
function conv3DBackpropFilterV2$1(args) {
const { inputs, backend, attrs } = args;
const { x, dy } = inputs;
const { strides, pad, filterShape } = attrs;
const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 , pad);
const program = new Conv3DDerFilterProgram(convInfo);
return backend.runWebGLProgram(program, [x, dy], 'float32');
}
const conv3DBackpropFilterV2Config$1 = {
kernelName: Conv3DBackpropFilterV2,
backendName: 'webgl',
kernelFunc: conv3DBackpropFilterV2$1
};
function conv3DBackpropInput(args) {
const { inputs, backend, attrs } = args;
const { dy, filter } = inputs;
const { pad, strides, inputShape } = attrs;
const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 , pad);
const program = new Conv3DDerInputProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32');
}
const conv3DBackpropInputConfig = {
kernelName: Conv3DBackpropInputV2,
backendName: 'webgl',
kernelFunc: conv3DBackpropInput,
};
const COS = CHECK_NAN_SNIPPET_UNARY + `
return cos(x);
`;
const COS_PACKED = `
vec4 result = cos(x);
bvec4 isNaN = isnan(x);
${CHECK_NAN_SNIPPET_PACKED}
return result;
`;
const cos$1 = unaryKernelFunc({ opSnippet: COS, packedOpSnippet: COS_PACKED });
const cosConfig$1 = {
kernelName: Cos,
backendName: 'webgl',
kernelFunc: cos$1,
};
const COSH = `
float e2x = exp(-x);
return (e2x + 1.0 / e2x) / 2.0;
`;
const cosh$1 = unaryKernelFunc({ opSnippet: COSH });
const coshConfig$1 = {
kernelName: Cosh,
backendName: 'webgl',
kernelFunc: cosh$1,
};
class CropAndResizeProgram {
constructor(imageShape, boxShape, cropSize, method, extrapolationValue) {
this.variableNames = ['Image', 'Boxes', 'BoxInd'];
this.outputShape = [];
const [batch, imageHeight, imageWidth, depth] = imageShape;
const [numBoxes,] = boxShape;
const [cropHeight, cropWidth] = cropSize;
this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
const methodId = method === 'bilinear' ? 1 : 0;
const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`];
const [heightRatio, heightScale, inY] = cropHeight > 1 ?
[
`${(imageHeight - 1) / (cropHeight - 1)}`,
'(y2-y1) * height_ratio',
`y1*${inputHeightFloat} + float(y)*(height_scale)`,
] :
[
'0.0',
'0.0',
`0.5 * (y1+y2) * ${inputHeightFloat}`,
];
const [widthRatio, widthScale, inX] = cropWidth > 1 ?
[
`${(imageWidth - 1) / (cropWidth - 1)}`,
'(x2-x1) * width_ratio',
`x1*${inputWidthFloat} + float(x)*(width_scale)`,
] :
[
'0.0',
'0.0',
`0.5 * (x1+x2) * ${inputWidthFloat}`,
];
this.userCode = `
const float height_ratio = float(${heightRatio});
const float width_ratio = float(${widthRatio});
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int y = coords[1];
int x = coords[2];
int d = coords[3];
float y1 = getBoxes(b,0);
float x1 = getBoxes(b,1);
float y2 = getBoxes(b,2);
float x2 = getBoxes(b,3);
int bInd = round(getBoxInd(b));
if(bInd < 0 || bInd >= ${batch}) {
return;
}
float height_scale = ${heightScale};
float width_scale = ${widthScale};
float in_y = ${inY};
if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {
setOutput(float(${extrapolationValue}));
return;
}
float in_x = ${inX};
if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {
setOutput(float(${extrapolationValue}));
return;
}
vec2 sourceFracIndexCR = vec2(in_x,in_y);
if(${methodId} == 1) {
ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);
ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));
float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);
float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);
float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);
float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);
vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);
float top = topLeft + (topRight - topLeft) * fracCR.x;
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;
float newValue = top + (bottom - top) * fracCR.y;
setOutput(newValue);
} else {
ivec2 sourceNearestCR = ivec2(floor(
sourceFracIndexCR + vec2(0.5,0.5)));
float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);
setOutput(newValue);
}
}
`;
}
}
const cropAndResize$1 = (args) => {
const { inputs, backend, attrs } = args;
const { image, boxes, boxInd } = inputs;
const { cropSize, method, extrapolationValue } = attrs;
const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
};
const cropAndResizeConfig$1 = {
kernelName: CropAndResize,
backendName: 'webgl',
kernelFunc: cropAndResize$1
};
var CumOpType;
(function (CumOpType) {
CumOpType["Prod"] = "*";
CumOpType["Sum"] = "+";
})(CumOpType || (CumOpType = {}));
class CumProgram {
constructor(op, outputShape, exclusive, reverse) {
this.op = op;
this.outputShape = outputShape;
this.variableNames = ['x'];
this.customUniforms = [{ name: 'index', type: 'float' }];
const rank = this.outputShape.length;
const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
const val = exclusive ? initVal : `getX(${getCoords(rank, 'coords', this.op)})`;
const length = this.outputShape[this.outputShape.length - 1];
let condition = '';
let idxString = '';
if (exclusive) {
condition = reverse ? `end != ${length - 1}` : 'end != 0';
idxString = reverse ? 'end + 1' : 'end - 1';
}
else {
condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2';
idxString = (reverse ? 'end + pow2' : 'end - pow2');
}
this.userCode = `
void main() {
${getCoordsDataType(rank)} coords = getOutputCoords();
int end = ${getFinalCoord(rank, 'coords', this.op)};
float val = ${val};
int pow2 = int(pow(2.0, index));
if (${condition}) {
int idx = ${idxString};
${getFinalCoord(rank, 'coords', this.op)} = idx;
val ${this.op}= getX(${getCoords(rank, 'coords', this.op)});
}
setOutput(val);
}
`;
}
}
function getCoords(rank, name, op) {
if (rank === 1) {
return `${name}`;
}
else if (rank === 2) {
return `${name}.x, ${name}.y`;
}
else if (rank === 3) {
return `${name}.x, ${name}.y, ${name}.z`;
}
else if (rank === 4) {
return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
}
else {
throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
}
}
function getFinalCoord(rank, name, op) {
if (rank === 1) {
return `${name}`;
}
else if (rank === 2) {
return `${name}.y`;
}
else if (rank === 3) {
return `${name}.z`;
}
else if (rank === 4) {
return `${name}.w`;
}
else {
throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
}
}
function cumImpl(op, x, backend, axis, exclusive, reverse) {
const xRank = x.shape.length;
const permutation = getAxesPermutation([axis], xRank);
let permutedX = x;
if (permutation != null) {
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
}
const permutedAxis = getInnerMostAxes(1, xRank)[0];
if (permutedAxis !== xRank - 1) {
throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` +
`but got axis=${axis}`);
}
const size = permutedX.shape[permutedAxis];
let result = identity({ inputs: { x: permutedX }, backend });
for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
const program = new CumProgram(op, permutedX.shape, false, reverse);
const customValues = [[i]];
const prevResult = result;
result =
backend.runWebGLProgram(program, [result], result.dtype, customValues);
backend.disposeIntermediateTensorInfo(prevResult);
}
if (exclusive) {
const program = new CumProgram(op, permutedX.shape, exclusive, reverse);
const prevResult = result;
result = backend.runWebGLProgram(program, [result], result.dtype);
backend.disposeIntermediateTensorInfo(prevResult);
}
if (permutation != null) {
const reversePermutation = getUndoAxesPermutation(permutation);
const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
backend.disposeIntermediateTensorInfo(result);
backend.disposeIntermediateTensorInfo(permutedX);
return reverseTransposedResult;
}
return result;
}
function cumprod$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, exclusive, reverse } = attrs;
return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse);
}
const cumprodConfig$1 = {
kernelName: Cumprod,
backendName: 'webgl',
kernelFunc: cumprod$1
};
function cumsum$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, exclusive, reverse } = attrs;
return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse);
}
const cumsumConfig$1 = {
kernelName: Cumsum,
backendName: 'webgl',
kernelFunc: cumsum$1
};
function denseBincount$1(args) {
const { inputs, backend, attrs } = args;
const { x, weights } = inputs;
const { size, binaryOutput } = attrs;
if (x.shape.length === 1) {
const xVals = backend.readSync(x.dataId);
const weightsVals = backend.readSync(weights.dataId);
const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
}
else if (x.shape.length === 2) {
const xBuf = backend.bufferSync(x);
const weightsBuf = backend.bufferSync(weights);
const outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
}
throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
`${x.shape.length}.`);
}
const denseBincountConfig$1 = {
kernelName: DenseBincount,
backendName: 'webgl',
kernelFunc: denseBincount$1
};
class DepthToSpaceProgram {
constructor(outputShape, blockSize, dataFormat) {
this.variableNames = ['x'];
this.outputShape = [];
this.outputShape = outputShape;
this.blockSize = blockSize;
this.dataFormat = dataFormat;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int h = ${this.getHeightCoordString()};
int w = ${this.getWidthCoordString()};
int d = ${this.getDepthCoordString()};
int in_h = h / ${blockSize};
int offset_h = imod(h, ${blockSize});
int in_w = w / ${blockSize};
int offset_w = imod(w, ${blockSize});
int offset_d = (offset_h * ${blockSize} + offset_w) *
${this.getOutputDepthSize()};
int in_d = d + offset_d;
float result = ${this.getInputSamplingString()};
setOutput(result);
}
`;
}
getHeightCoordString() {
if (this.dataFormat === 'NHWC') {
return `coords[1]`;
}
else {
return `coords[2]`;
}
}
getWidthCoordString() {
if (this.dataFormat === 'NHWC') {
return `coords[2]`;
}
else {
return `coords[3]`;
}
}
getDepthCoordString() {
if (this.dataFormat === 'NHWC') {
return `coords[3]`;
}
else {
return `coords[1]`;
}
}
getOutputDepthSize() {
if (this.dataFormat === 'NHWC') {
return this.outputShape[3];
}
else {
return this.outputShape[1];
}
}
getInputSamplingString() {
if (this.dataFormat === 'NHWC') {
return `getX(b, in_h, in_w, in_d)`;
}
else {
return `getX(b, in_d, in_h, in_w)`;
}
}
}
function depthToSpace$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { blockSize, dataFormat } = attrs;
const batchSize = x.shape[0];
const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
const outputHeight = inputHeight * blockSize;
const outputWidth = inputWidth * blockSize;
const outputDepth = inputDepth / (blockSize * blockSize);
const outputShape = (dataFormat === 'NHWC') ?
[batchSize, outputHeight, outputWidth, outputDepth] :
[batchSize, outputDepth, outputHeight, outputWidth];
const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
return backend.runWebGLProgram(program, [x], x.dtype);
}
const depthToSpaceConfig$1 = {
kernelName: DepthToSpace,
backendName: 'webgl',
kernelFunc: depthToSpace$1
};
class DepthwiseConv2DProgram {
constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
this.variableNames = ['x', 'W'];
this.customUniforms = [
{ name: 'pads', type: 'ivec2' },
{ name: 'strides', type: 'ivec2' },
{ name: 'dilations', type: 'ivec2' },
{ name: 'inDims', type: 'ivec2' },
];
this.outputShape = convInfo.outShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const channelMul = convInfo.outChannels / convInfo.inChannels;
let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = `float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
}
else if (hasLeakyReluAlpha) {
activationSnippet = `float activation(float a) {
float b = getLeakyreluAlphaAtOutCoords();
${activation}
}`;
}
else {
activationSnippet = `
float activation(float x) {
${activation}
}
`;
}
applyActivationSnippet = `result = activation(result);`;
}
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyReluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = `
${activationSnippet}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
ivec2 xRCCorner = coords.yz * strides - pads;
int d2 = coords.w;
int d1 = d2 / ${channelMul};
int q = d2 - d1 * ${channelMul};
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
int xR = xRCorner + wR * dilations[0];
if (xR < 0 || xR >= inDims[0]) {
continue;
}
for (int wC = 0; wC < ${filterWidth}; wC++) {
int xC = xCCorner + wC * dilations[1];
if (xC < 0 || xC >= inDims[1]) {
continue;
}
float xVal = getX(batch, xR, xC, d1);
float wVal = getW(wR, wC, d1, q);
dotProd += xVal * wVal;
}
}
float result = dotProd;
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
}
class DepthwiseConvPacked2DProgram {
constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
this.variableNames = ['x', 'W'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [
{ name: 'pads', type: 'ivec2' },
{ name: 'strides', type: 'ivec2' },
{ name: 'dilations', type: 'ivec2' },
{ name: 'inDims', type: 'ivec2' },
];
this.outputShape = convInfo.outShape;
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
const channelMul = convInfo.outChannels / convInfo.inChannels;
const padLeft = convInfo.padInfo.left;
const strideWidth = convInfo.strideWidth;
const dilationWidth = convInfo.dilationWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const texelsAcross = filterWidth;
let mainLoop = `
int xR; int xC; int xCOffset;
vec4 wTexel; vec4 previous; vec4 final;`;
for (let c = 0; c < filterWidth; c++) {
mainLoop += `
vec4 xTexelC${c * 2};
int xTexelC${c * 2}Ready;
vec4 xTexelC${c * 2 + 1};
int xTexelC${c * 2 + 1}Ready;
vec4 xC${c};`;
}
mainLoop += `
for (int r = 0; r < ${filterHeight}; r++) {
`;
for (let c = 0; c < filterWidth; c++) {
mainLoop += `
xTexelC${c * 2} = vec4(0.0);
xTexelC${c * 2}Ready = 0;
xTexelC${c * 2 + 1} = vec4(0.0);
xTexelC${c * 2 + 1}Ready = 0;
xC${c} = vec4(0.0);`;
}
mainLoop += `
xR = xRCorner + r * dilations[0];
if (xR >=0 && xR < inDims[0]) {
`;
for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
const colIndex = texelC * 2;
mainLoop += `
xC = xCCorner + ${colIndex * dilationWidth};
`;
if (strideWidth === 1) {
if (colIndex < filterWidth) {
if (padLeft % 2 === 1) {
mainLoop += `
xCOffset = xC + 1;
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
`;
if (dilationWidth === 1 && colIndex > 0) {
mainLoop += `
xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
`;
}
else {
mainLoop += `
xCOffset = xC + 1 - 2;
if (xCOffset >= 0 && xCOffset < inDims[1]) {
previous = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
previous.zw = vec2(0.0);
}
xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
} else {
xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
}
`;
}
}
else {
mainLoop += `
if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xC, d1);
if (xC + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
xC${colIndex} = xTexelC${colIndex};
`;
}
if (colIndex + 1 < filterWidth) {
const nextTexelOffset = padLeft % 2 === 0 ?
nearestLargerEven(dilationWidth) :
dilationWidth;
if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
(dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
mainLoop += `
xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.0);
}
xTexelC${colIndex + 1}Ready = 1;
}
`;
if (dilationWidth > 1) {
mainLoop += `
xCOffset -= 2;
if (xCOffset >= 0 && xCOffset < inDims[1]) {
previous = getX(batch, xR, xCOffset, d1);
xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy);
} else {
xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy);
}
`;
}
else {
mainLoop += `
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
`;
}
}
else {
if (nextTexelOffset === 1) {
mainLoop += `
xC${colIndex + 1} = xTexelC${colIndex};
`;
}
else {
mainLoop += `
xCOffset = xC + ${nextTexelOffset};
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.0);
}
xTexelC${colIndex + 1}Ready = 1;
}
xC${colIndex + 1} = xTexelC${colIndex + 1};
`;
}
}
}
}
}
else {
if (colIndex < filterWidth) {
if (padLeft % 2 === 1) {
mainLoop += `
xCOffset = xC + 1 - strides[1];
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
if (xC + 2 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.0);
}
xTexelC${colIndex + 1}Ready = 1;
}
xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
`;
if (colIndex + 1 < filterWidth) {
mainLoop += `
final = vec4(0.0);
xCOffset = xC + 1 + strides[1];
if(xCOffset >= 0 && xCOffset < inDims[1]) {
final = getX(batch, xR, xCOffset, d1);
}
xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
`;
}
}
else {
mainLoop += `
if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
xTexelC${colIndex} = getX(batch, xR, xC, d1);
if (xC + 1 >= inDims[1]) {
xTexelC${colIndex}.zw = vec2(0.0);
}
xTexelC${colIndex}Ready = 1;
}
xCOffset = xC + strides[1];
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
if (xCOffset + 1 >= inDims[1]) {
xTexelC${colIndex + 1}.zw = vec2(0.);
}
xTexelC${colIndex + 1}Ready = 1;
}
xC${colIndex} = vec4(
xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
`;
if (colIndex + 1 < filterWidth) {
mainLoop += `
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
`;
}
}
}
}
if (colIndex < filterWidth) {
mainLoop += `
wTexel = getW(r, ${colIndex}, d1, q);
dotProd += xC${colIndex} * vec4(wTexel.xz, wTexel.xz);
`;
if (colIndex + 1 < filterWidth) {
mainLoop += `
wTexel = getW(r, ${colIndex + 1}, d1, q);
dotProd += xC${colIndex + 1} * vec4(wTexel.xz, wTexel.xz);
`;
}
}
}
mainLoop += `
}
`;
mainLoop += `
}
`;
let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
if (hasPreluActivation) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
}
else if (hasLeakyReluAlpha) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getLeakyreluAlphaAtOutCoords();
${activation}
}`;
}
else {
activationSnippet = `vec4 activation(vec4 x) {
${activation}
}`;
}
applyActivationSnippet = `result = activation(result);`;
}
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
if (addBias) {
this.variableNames.push('bias');
}
if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}
if (hasLeakyReluAlpha) {
this.variableNames.push('leakyreluAlpha');
}
this.userCode = `
${activationSnippet}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
ivec2 xRCCorner = coords.yz * strides - pads;
int d2 = coords.w;
int d1 = d2 / ${channelMul};
int q = d2 - d1 * ${channelMul};
int xRCorner = xRCCorner.x;
int xCCorner = xRCCorner.y;
vec4 dotProd = vec4(0.000000000000001);
${mainLoop}
vec4 result = dotProd - vec4(0.000000000000001);
${addBiasSnippet}
${applyActivationSnippet}
setOutput(result);
}
`;
}
}
function depthwiseConv2dNative$1(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dilations, dimRoundingMode } = attrs;
let $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
`1. Got strides ${strides} and dilations '${$dilations}'`);
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true );
let program;
if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
convInfo.outChannels / convInfo.inChannels === 1) {
program = new DepthwiseConvPacked2DProgram(convInfo);
}
else {
program = new DepthwiseConv2DProgram(convInfo);
}
const customValues = [
[convInfo.padInfo.top, convInfo.padInfo.left],
[convInfo.strideHeight, convInfo.strideWidth],
[convInfo.dilationHeight, convInfo.dilationWidth],
[convInfo.inHeight, convInfo.inWidth]
];
return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
}
const depthwiseConv2dNativeConfig$1 = {
kernelName: DepthwiseConv2dNative,
backendName: 'webgl',
kernelFunc: depthwiseConv2dNative$1,
};
class DepthwiseConv2DDerFilterProgram {
constructor(convInfo) {
this.variableNames = ['x', 'dy'];
this.outputShape = convInfo.filterShape;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int wR = coords.x;
int wC = coords.y;
int d1 = coords.z;
int dm = coords.w;
int d2 = d1 * ${channelMul} + dm;
float dotProd = 0.0;
for (int b = 0; b < ${convInfo.batchSize}; b++) {
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
int xR = wR + yR * ${strideHeight} - ${padTop};
if (xR < 0 || xR >= ${convInfo.inHeight}) {
continue;
}
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
int xC = wC + yC * ${strideWidth} - ${padLeft};
if (xC < 0 || xC >= ${convInfo.inWidth}) {
continue;
}
float dyValue = getDy(b, yR, yC, d2);
float xValue = getX(b, xR, xC, d1);
dotProd += (xValue * dyValue);
}
}
}
setOutput(dotProd);
}
`;
}
}
class DepthwiseConv2DDerInputProgram {
constructor(convInfo) {
this.variableNames = ['dy', 'W'];
this.outputShape = convInfo.inShape;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const padTop = filterHeight - 1 - convInfo.padInfo.top;
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
const channelMul = convInfo.outChannels / convInfo.inChannels;
this.userCode = `
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int batch = coords[0];
int d1 = coords[3];
ivec2 dyCorner = coords.yz - pads;
int dyRCorner = dyCorner.x;
int dyCCorner = dyCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${filterHeight}; wR++) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
int wRPerm = ${filterHeight} - 1 - wR;
for (int wC = 0; wC < ${filterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
int wCPerm = ${filterWidth} - 1 - wC;
for (int dm = 0; dm < ${channelMul}; dm++) {
int d2 = d1 * ${channelMul} + dm;
float xValue = getDy(batch, idyR, idyC, d2);
float wValue = getW(wRPerm, wCPerm, d1, dm);
dotProd += xValue * wValue;
}
}
}
setOutput(dotProd);
}
`;
}
}
function depthwiseConv2dNativeBackpropFilter$1(args) {
const { inputs, backend, attrs } = args;
const { x, dy } = inputs;
const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true );
const program = new DepthwiseConv2DDerFilterProgram(convInfo);
return backend.runWebGLProgram(program, [x, dy], 'float32');
}
const depthwiseConv2dNativeBackpropFilterConfig$1 = {
kernelName: DepthwiseConv2dNativeBackpropFilter,
backendName: 'webgl',
kernelFunc: depthwiseConv2dNativeBackpropFilter$1
};
function depthwiseConv2dNativeBackpropInput$1(args) {
const { inputs, backend, attrs } = args;
const { dy, filter } = inputs;
const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true );
const program = new DepthwiseConv2DDerInputProgram(convInfo);
return backend.runWebGLProgram(program, [dy, filter], 'float32');
}
const depthwiseConv2dNativeBackpropInputConfig$1 = {
kernelName: DepthwiseConv2dNativeBackpropInput,
backendName: 'webgl',
kernelFunc: depthwiseConv2dNativeBackpropInput$1
};
class DiagProgram {
constructor(size) {
this.variableNames = ['X'];
this.outputShape = [size, size];
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
setOutput(val);
}
`;
}
}
function diag$1(args) {
const { inputs, backend } = args;
const { x } = inputs;
const outShape = [...x.shape, ...x.shape];
const xSize = sizeFromShape(x.shape);
const flat = reshape$1({ inputs: { x }, backend, attrs: { shape: [xSize] } });
const program = new DiagProgram(xSize);
const res = backend.runWebGLProgram(program, [flat], flat.dtype);
const out = reshape$1({ inputs: { x: res }, backend, attrs: { shape: outShape } });
backend.disposeIntermediateTensorInfo(flat);
backend.disposeIntermediateTensorInfo(res);
return out;
}
const diagConfig$1 = {
kernelName: Diag,
backendName: 'webgl',
kernelFunc: diag$1
};
class Dilation2DProgram {
constructor(convInfo) {
this.variableNames = ['x', 'W'];
this.outputShape = convInfo.outShape;
const { inHeight, inWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth } = convInfo;
const { top: padTop, left: padLeft } = padInfo;
this.userCode = `
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
const ivec2 pads = ivec2(${padTop}, ${padLeft});
const float neg_infinity = -3.4e38;
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
int d1 = coords.w;
ivec2 outTopLeftCorner =
coords.yz * strides - pads;
int hBeg = outTopLeftCorner.x;
int wBeg = outTopLeftCorner.y;
float curVal = neg_infinity;
for (int h = 0; h < ${filterHeight}; h++) {
int hIn = hBeg + h * ${dilationHeight};
if (hIn >= 0 && hIn < ${inHeight}) {
for (int w = 0; w < ${filterWidth}; w++) {
int wIn = wBeg + w * ${dilationWidth};
if (wIn >= 0 && wIn < ${inWidth}) {
float xVal = getX(batch, hIn, wIn, d1);
float wVal = getW(h, w, d1);
float val = xVal + wVal;
if (val > curVal) {
curVal = val;
}
}
}
}
}
float result = curVal;
setOutput(result);
}
`;
}
}
function dilation2D(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dilations } = attrs;
const convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
let out;
const program = new Dilation2DProgram(convInfo);
out = backend.runWebGLProgram(program, [x, filter], 'float32');
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
backend.disposeIntermediateTensorInfo(out);
return outReshaped;
}
const dilation2DConfig$1 = {
kernelName: Dilation2D,
backendName: 'webgl',
kernelFunc: dilation2D,
};
function einsum$1(args) {
const { inputs, backend, attrs } = args;
const { equation } = attrs;
const tensors = inputs;
const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
checkEinsumDimSizes(allDims.length, idDims, tensors);
const { path, steps } = getEinsumComputePath(summedDims, idDims);
const nSteps = steps.length;
let out = null;
let numDimsRemaining = allDims.length;
const tensorsToDispose = [];
for (let i = 0; i < nSteps; ++i) {
for (const idTerm of steps[i]) {
const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
let x;
if (isIdentityPermutation(perm)) {
x = tensors[idTerm];
}
else {
x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
tensorsToDispose.push(x);
}
const targetShape = x.shape.slice();
for (let k = 0; k < dimsToExpand.length; ++k) {
targetShape.splice(dimsToExpand[k], 0, 1);
}
if (!arraysEqual(x.shape, targetShape)) {
x = reshape$1({ inputs: { x }, backend, attrs: { shape: targetShape } });
tensorsToDispose.push(x);
}
if (out === null) {
out = x;
}
else {
out = multiply({ inputs: { a: x, b: out }, backend });
tensorsToDispose.push(out);
}
}
if (i < nSteps - 1) {
if (path[i] >= 0) {
out = sum$1({
inputs: { x: out },
backend,
attrs: {
axis: path[i] - (allDims.length - numDimsRemaining),
keepDims: false
}
});
tensorsToDispose.push(out);
}
numDimsRemaining--;
}
}
for (const tensorInfo of tensorsToDispose) {
if (tensorInfo === out) {
continue;
}
backend.disposeIntermediateTensorInfo(tensorInfo);
}
return out;
}
const einsumConfig$1 = {
kernelName: Einsum,
backendName: 'webgl',
kernelFunc: einsum$1
};
const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
const ELU_PACKED = `
vec4 result;
result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
return result;
`;
const elu$2 = unaryKernelFunc({ opSnippet: ELU, packedOpSnippet: ELU_PACKED });
const eluConfig$1 = {
kernelName: Elu$1,
backendName: 'webgl',
kernelFunc: elu$2
};
const ELU_DER = `return (b >= 0.0) ? a : a * (b + 1.0);`;
const ELU_DER_PACKED = `
vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
`;
const eluGrad$1 = (args) => {
const { inputs, backend } = args;
const { dy, y } = inputs;
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) :
new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
return backend.runWebGLProgram(program, [dy, y], dy.dtype);
};
const eluGradConfig$2 = {
kernelName: EluGrad,
backendName: 'webgl',
kernelFunc: eluGrad$1
};
const PACKED_EQUAL = `
return vec4(equal(a, b));
`;
const EQUAL = `return float(a == b);`;
const equal = binaryKernelFunc({
opSnippet: EQUAL,
packedOpSnippet: PACKED_EQUAL,
dtype: 'bool',
cpuKernelImpl: equalImplCPU,
});
const equalConfig = {
kernelName: Equal,
backendName: 'webgl',
kernelFunc: equal
};
const ERF = `
float p = ${ERF_P};
float a1 = ${ERF_A1};
float a2 = ${ERF_A2};
float a3 = ${ERF_A3};
float a4 = ${ERF_A4};
float a5 = ${ERF_A5};
float sign = sign(x);
x = abs(x);
float t = 1.0 / (1.0 + p * x);
return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
`;
const erf$1 = unaryKernelFunc({ opSnippet: ERF });
const erfConfig$1 = {
kernelName: Erf,
backendName: 'webgl',
kernelFunc: erf$1,
};
const EXP = CHECK_NAN_SNIPPET_UNARY + `
return exp(x);
`;
const EXP_PACKED = `
vec4 result = exp(x);
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
const exp = unaryKernelFunc({
opSnippet: EXP,
packedOpSnippet: EXP_PACKED,
cpuKernelImpl: expImplCPU,
dtype: 'float32',
});
const expConfig = {
kernelName: Exp,
backendName: 'webgl',
kernelFunc: exp
};
function expandDims$2(args) {
const { inputs, attrs, backend } = args;
const { dim } = attrs;
const { input } = inputs;
const inputRank = input.shape.length;
const newShape = input.shape.slice();
let $dim = dim;
if (dim < 0) {
assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
$dim = inputRank + dim + 1;
}
newShape.splice($dim, 0, 1);
return reshape$1({ inputs: { x: input }, backend, attrs: { shape: newShape } });
}
const expandDimsConfig$1 = {
kernelName: ExpandDims,
backendName: 'webgl',
kernelFunc: expandDims$2,
};
const EXPM1 = `return exp(x) - 1.0;`;
const expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU });
const expm1Config = {
kernelName: Expm1,
backendName: 'webgl',
kernelFunc: expm1
};
class FFTProgram {
constructor(component, inputShape, inverse) {
this.variableNames = ['real', 'imag'];
const innerDim = inputShape[1];
this.outputShape = inputShape;
const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`;
const resultDenominator = inverse ? `${innerDim}.0` : '1.0';
let opString;
if (component === 'real') {
opString = 'return real * expR - imag * expI;';
}
else if (component === 'imag') {
opString = 'return real * expI + imag * expR;';
}
else {
throw new Error(`FFT component must be either "real" or "imag", got ${component}.`);
}
this.userCode = `
const float exponentMultiplier = ${exponentMultiplierSnippet};
float unaryOpComplex(float real, float expR, float imag, float expI) {
${opString}
}
float mulMatDFT(int batch, int index) {
float indexRatio = float(index) / float(${innerDim});
float exponentMultiplierTimesIndexRatio =
exponentMultiplier * indexRatio;
float result = 0.0;
for (int i = 0; i < ${innerDim}; i++) {
float x = exponentMultiplierTimesIndexRatio * float(i);
float expR = cos(x);
float expI = sin(x);
float real = getReal(batch, i);
float imag = getImag(batch, i);
result +=
unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};
}
return result;
}
void main() {
ivec2 coords = getOutputCoords();
setOutput(mulMatDFT(coords[0], coords[1]));
}
`;
}
}
function fftImpl$1(x, inverse, backend) {
const xData = backend.texData.get(x.dataId);
const inputSize = sizeFromShape(x.shape);
const innerDimensionSize = x.shape[x.shape.length - 1];
const batch = inputSize / innerDimensionSize;
const input2D = reshape$1({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } });
const xShape = input2D.shape;
const realProgram = new FFTProgram('real', xShape, inverse);
const imagProgram = new FFTProgram('imag', xShape, inverse);
const inputs = [
{
dataId: xData.complexTensorInfos.real.dataId,
dtype: xData.complexTensorInfos.real.dtype,
shape: xShape
},
{
dataId: xData.complexTensorInfos.imag.dataId,
dtype: xData.complexTensorInfos.imag.dtype,
shape: xShape
}
];
const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(imagPart);
const complexOutputReshaped = reshape$1({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } });
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(complexOutput);
return complexOutputReshaped;
}
function fft$1(args) {
const { inputs, backend } = args;
const { input } = inputs;
return fftImpl$1(input, false , backend);
}
const fftConfig$1 = {
kernelName: FFT,
backendName: 'webgl',
kernelFunc: fft$1
};
class FillProgram {
constructor(shape, value) {
this.outputShape = [];
this.customUniforms = [{ name: 'value', type: 'float' }];
this.variableNames = ['x'];
this.outputShape = shape;
this.userCode = `
void main() {
setOutput(value);
}
`;
}
}
function fill$1(args) {
const { backend, attrs } = args;
const { shape, value } = attrs;
let { dtype } = attrs;
dtype = dtype || inferDtype(value);
if (dtype === 'string') {
const values = getArrayFromDType(dtype, sizeFromShape(shape));
values.fill(value);
return backend.makeTensorInfo(shape, dtype, values);
}
else {
const program = new FillProgram(shape, value);
const customValues = [[value]];
return backend.runWebGLProgram(program, [], dtype, customValues);
}
}
const fillConfig$1 = {
kernelName: Fill,
backendName: 'webgl',
kernelFunc: fill$1
};
class FlipLeftRightProgram {
constructor(imageShape) {
this.variableNames = ['Image'];
this.outputShape = [];
const imageWidth = imageShape[2];
this.outputShape = imageShape;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int x = coords[2];
int coordX = ${imageWidth} - x - 1;
float outputValue;
if(coordX >= 0 && coordX < ${imageWidth}) {
outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
} else {
outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
}
setOutput(outputValue);
}
`;
}
}
const flipLeftRightConfig$1 = {
kernelName: FlipLeftRight,
backendName: 'webgl',
kernelFunc: ({ inputs, backend }) => {
const { image } = inputs;
const webglBackend = backend;
const program = new FlipLeftRightProgram(image.shape);
const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
return output;
}
};
const FLOOR = `return floor(x);`;
const floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU });
const floorConfig = {
kernelName: Floor,
backendName: 'webgl',
kernelFunc: floor,
};
const INT_DIV = `
float s = sign(a) * sign(b);
int ia = round(a);
int ib = round(b);
if (ib != 0) {
return float(idiv(ia, ib, s));
} else {
return NAN;
}
`;
const INT_DIV_PACKED = `
ivec4 ia = round(a);
ivec4 ib = round(b);
bvec4 cond = notEqual(ib, ivec4(0));
ivec4 result = ivec4(0);
vec4 s = sign(a) * sign(b);
if (cond[0]) {
result[0] = idiv(ia[0], ib[0], s[0]);
}
if (cond[1]) {
result[1] = idiv(ia[1], ib[1], s[1]);
}
if (cond[2]) {
result[2] = idiv(ia[2], ib[2], s[2]);
}
if (cond[3]) {
result[3] = idiv(ia[3], ib[3], s[3]);
}
return vec4(result);
`;
const floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' });
const floorDivConfig = {
kernelName: FloorDiv,
backendName: 'webgl',
kernelFunc: floorDiv
};
class FromPixelsProgram {
constructor(outputShape) {
this.variableNames = ['A'];
const glsl = getGlslDifferences();
const [height, width,] = outputShape;
this.outputShape = outputShape;
this.userCode = `
void main() {
ivec3 coords = getOutputCoords();
int texR = coords[0];
int texC = coords[1];
int depth = coords[2];
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
vec4 values = ${glsl.texture2D}(A, uv);
float value;
if (depth == 0) {
value = values.r;
} else if (depth == 1) {
value = values.g;
} else if (depth == 2) {
value = values.b;
} else if (depth == 3) {
value = values.a;
}
setOutput(floor(value * 255.0 + 0.5));
}
`;
}
}
class FromPixelsPackedProgram {
constructor(outputShape) {
this.variableNames = ['A'];
this.packedInputs = false;
this.packedOutput = true;
const glsl = getGlslDifferences();
const [height, width,] = outputShape;
this.outputShape = outputShape;
this.userCode = `
void main() {
ivec3 coords = getOutputCoords();
int texR = coords[0];
int texC = coords[1];
int depth = coords[2];
vec4 result = vec4(0.);
for(int row=0; row<=1; row++) {
for(int col=0; col<=1; col++) {
texC = coords[1] + row;
depth = coords[2] + col;
vec2 uv = (vec2(texC, texR) + halfCR) /
vec2(${width}.0, ${height}.0);
vec4 values = ${glsl.texture2D}(A, uv);
float value;
if (depth == 0) {
value = values.r;
} else if (depth == 1) {
value = values.g;
} else if (depth == 2) {
value = values.b;
} else if (depth == 3) {
value = values.a;
}
result[row * 2 + col] = floor(value * 255.0 + 0.5);
}
}
${glsl.output} = result;
}
`;
}
}
const fromPixelsConfig = {
kernelName: FromPixels,
backendName: 'webgl',
kernelFunc: fromPixels,
};
let fromPixels2DContext;
let willReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
function fromPixels(args) {
const { inputs, backend, attrs } = args;
let { pixels } = inputs;
const { numChannels } = attrs;
const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
pixels instanceof HTMLVideoElement;
const isImage = typeof (HTMLImageElement) !== 'undefined' &&
pixels instanceof HTMLImageElement;
const [width, height] = isVideo ?
[
pixels.videoWidth,
pixels.videoHeight
] :
[pixels.width, pixels.height];
const texShape = [height, width];
const outShape = [height, width, numChannels];
if (isImage || isVideo) {
const newWillReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
if (fromPixels2DContext == null ||
newWillReadFrequently !== willReadFrequently) {
willReadFrequently = newWillReadFrequently;
fromPixels2DContext =
document.createElement('canvas').getContext('2d', { willReadFrequently });
}
fromPixels2DContext.canvas.width = width;
fromPixels2DContext.canvas.height = height;
fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
pixels = fromPixels2DContext.canvas;
}
const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
const program = env().getBool('WEBGL_PACK') ?
new FromPixelsPackedProgram(outShape) :
new FromPixelsProgram(outShape);
const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
backend.disposeData(tempPixelHandle.dataId);
return res;
}
function fusedConv2d(args) {
const { inputs, backend, attrs } = args;
const { x, filter, bias, preluActivationWeights } = inputs;
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
const $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat);
let out;
const intermediates = [];
const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const hasLeakyreluAlpha = activation === 'leakyrelu';
const prepareInputs = () => {
const inputs = [x, filter];
const alignInputWithDataFormat = (input, dataFormat) => {
if (dataFormat === 'NCHW' && input.shape.length === 1 &&
input.shape[0] !== 1) {
const alignedInput = reshape$1({
inputs: { x: input },
backend,
attrs: { shape: [input.shape[0], 1, 1] }
});
intermediates.push(alignedInput);
return alignedInput;
}
return input;
};
if (hasBias) {
inputs.push(alignInputWithDataFormat(bias, dataFormat));
}
if (hasPreluActivationWeights) {
inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat));
}
if (hasLeakyreluAlpha) {
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
inputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
return inputs;
};
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
out = conv2dByMatMul({
x,
filter,
convInfo,
backend,
bias,
activation,
preluActivationWeights,
leakyreluAlpha
});
}
else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'
&& env().getBool('WEBGL_EXP_CONV')) {
const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
const program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
const customValues = [
[convInfo.padInfo.top, convInfo.padInfo.left],
[convInfo.strideHeight, convInfo.strideWidth],
[convInfo.dilationHeight, convInfo.dilationWidth],
[convInfo.inHeight, convInfo.inWidth]
];
const inputs = prepareInputs();
out = backend.runWebGLProgram(program, inputs, 'float32', customValues);
}
else if (env().getBool('WEBGL_CONV_IM2COL')) {
out = conv2dWithIm2Row({
x,
filter,
convInfo,
backend,
bias,
activation,
preluActivationWeights,
leakyreluAlpha
});
}
else {
const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
const inputs = prepareInputs();
out = backend.runWebGLProgram(program, inputs, 'float32');
}
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
intermediates.push(out);
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
return outReshaped;
}
const fusedConv2DConfig$1 = {
kernelName: FusedConv2D,
backendName: 'webgl',
kernelFunc: fusedConv2d,
};
function fusedDepthwiseConv2D$1(args) {
const { inputs, backend, attrs } = args;
const { x, filter, bias, preluActivationWeights } = inputs;
const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
const intermediates = [];
let $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
`1. Got strides ${strides} and dilations '${$dilations}'`);
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true );
const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
convInfo.strideWidth <= 2 &&
convInfo.outChannels / convInfo.inChannels === 1;
const fusedActivation = activation ?
mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
null;
const programInputs = [x, filter];
const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const hasLeakyreluAlpha = activation === 'leakyrelu';
if (hasBias) {
programInputs.push(bias);
}
if (hasPreluActivationWeights) {
programInputs.push(preluActivationWeights);
}
if (hasLeakyreluAlpha) {
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
programInputs.push($leakyreluAlpha);
intermediates.push($leakyreluAlpha);
}
let program;
if (shouldPackDepthwiseConv) {
program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
}
else {
program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
}
const customValues = [
[convInfo.padInfo.top, convInfo.padInfo.left],
[convInfo.strideHeight, convInfo.strideWidth],
[convInfo.dilationHeight, convInfo.dilationWidth],
[convInfo.inHeight, convInfo.inWidth]
];
const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues);
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
}
const fusedDepthwiseConv2DConfig$1 = {
kernelName: FusedDepthwiseConv2D,
backendName: 'webgl',
kernelFunc: fusedDepthwiseConv2D$1,
};
class GatherNDProgram {
constructor(sliceDim, strides, shape, paramsShape) {
this.sliceDim = sliceDim;
this.strides = strides;
this.paramsShape = paramsShape;
this.variableNames = ['x', 'indices'];
this.outputShape = shape;
const dtype = getCoordsDataType(shape.length);
let mainLoop = `
int index;`;
for (let j = 0; j < this.sliceDim; j++) {
mainLoop += `
index = round(getIndices(coords[0], ${j}));
out_of_bounds = out_of_bounds || index < 0;
out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]};
flattenIndex += index * ${this.strides[j]};`;
}
this.userCode = `
void main() {
${dtype} coords = getOutputCoords();
int flattenIndex = 0;
bool out_of_bounds = false;
${mainLoop}
setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));
}
`;
}
}
function gatherNd$1(args) {
const { inputs, backend } = args;
const { params, indices } = inputs;
const indicesShape = indices.shape;
const sliceRank = indicesShape[indicesShape.length - 1];
const paramsSize = sizeFromShape(params.shape);
const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
const flattenIndices = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [numSlices, sliceRank] } });
const flattenX = reshape$1({
inputs: { x: params },
backend,
attrs: { shape: [(sizeFromShape(params.shape) / sliceSize), sliceSize] }
});
if (backend.shouldExecuteOnCPU([params, indices]) ||
params.dtype === 'string') {
const indicesData = backend.readSync(indices.dataId);
const paramsBuf = backend.bufferSync(params);
const outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
}
const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape);
const res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: resultShape } });
backend.disposeIntermediateTensorInfo(flattenIndices);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(res);
return reshaped;
}
const gatherNdConfig$1 = {
kernelName: GatherNd,
backendName: 'webgl',
kernelFunc: gatherNd$1
};
class GatherProgram {
constructor(aShape, outputShape) {
this.variableNames = ['A', 'indices'];
this.outputShape = outputShape;
this.rank = outputShape.length;
const dtype = getCoordsDataType(this.rank);
const sourceCoords = getSourceCoords$1(aShape);
this.userCode = `
void main() {
${dtype} resRC = getOutputCoords();
int index = int(getIndices(resRC.x, resRC.z));
float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0;
setOutput(inBounds * getA(${sourceCoords}));
}
`;
}
}
function getSourceCoords$1(aShape, axis) {
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
const sourceCoords = [];
for (let i = 0; i < aShape.length; i++) {
if (i === 2) {
sourceCoords.push('index');
}
else {
sourceCoords.push(`${currentCoords[i]}`);
}
}
return sourceCoords.join();
}
function gatherV2$1(args) {
const { inputs, backend, attrs } = args;
const { x, indices } = inputs;
const { axis, batchDims } = attrs;
const parsedAxis = parseAxisParam(axis, x.shape)[0];
if (env().get('DEBUG')) {
const indicesVals = backend.readSync(indices.dataId);
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}
}
const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
const indicesSize = sizeFromShape(indices.shape);
const toDispose = [];
const flattenX = reshape$1({
inputs: { x },
backend,
attrs: {
shape: [
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
shapeInfo.sliceSize
]
}
});
const flattenIndex = reshape$1({
inputs: { x: indices },
backend,
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
});
toDispose.push(flattenX);
toDispose.push(flattenIndex);
const flattenOutputShape = [
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
shapeInfo.sliceSize
];
if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
const indicesBuf = backend.bufferSync(flattenIndex);
const xBuf = backend.bufferSync(flattenX);
const outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
}
const program = new GatherProgram(flattenX.shape, flattenOutputShape);
const res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
toDispose.push(res);
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: shapeInfo.outputShape } });
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return reshaped;
}
const gatherV2Config$1 = {
kernelName: GatherV2,
backendName: 'webgl',
kernelFunc: gatherV2$1
};
const GREATER = `return float(a > b);`;
const GREATER_PACKED = `
return vec4(greaterThan(a, b));
`;
const greater = binaryKernelFunc({
opSnippet: GREATER,
packedOpSnippet: GREATER_PACKED,
cpuKernelImpl: greaterImplCPU,
dtype: 'bool'
});
const greaterConfig = {
kernelName: Greater,
backendName: 'webgl',
kernelFunc: greater
};
const GREATER_EQUAL = `return float(a >= b);`;
const GREATER_EQUAL_PACKED = `
return vec4(greaterThanEqual(a, b));
`;
const greaterEqual = binaryKernelFunc({
opSnippet: GREATER_EQUAL,
packedOpSnippet: GREATER_EQUAL_PACKED,
dtype: 'bool',
cpuKernelImpl: greaterEqualImplCPU
});
const greaterEqualConfig = {
kernelName: GreaterEqual,
backendName: 'webgl',
kernelFunc: greaterEqual
};
function ifft$1(args) {
const { inputs, backend } = args;
const { input } = inputs;
return fftImpl$1(input, true , backend);
}
const ifftConfig$1 = {
kernelName: IFFT,
backendName: 'webgl',
kernelFunc: ifft$1
};
const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;
const isFinite$2 = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' });
const isFiniteConfig$1 = {
kernelName: IsFinite,
backendName: 'webgl',
kernelFunc: isFinite$2,
};
const IS_INF = `return float(isinf(x));`;
const isInf$1 = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' });
const isInfConfig$1 = {
kernelName: IsInf,
backendName: 'webgl',
kernelFunc: isInf$1,
};
const IS_NAN = `return float(isnan(x));`;
const isNaN$2 = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' });
const isNaNConfig$1 = {
kernelName: IsNan,
backendName: 'webgl',
kernelFunc: isNaN$2,
};
const LESS = `return float(a < b);`;
const LESS_PACKED = `
return vec4(lessThan(a, b));
`;
const less = binaryKernelFunc({
opSnippet: LESS,
packedOpSnippet: LESS_PACKED,
cpuKernelImpl: lessImplCPU,
dtype: 'bool'
});
const lessConfig = {
kernelName: Less,
backendName: 'webgl',
kernelFunc: less
};
const LESS_EQUAL = `return float(a <= b);`;
const LESS_EQUAL_PACKED = `
return vec4(lessThanEqual(a, b));
`;
const lessEqual = binaryKernelFunc({
opSnippet: LESS_EQUAL,
packedOpSnippet: LESS_EQUAL_PACKED,
cpuKernelImpl: lessEqualImplCPU,
dtype: 'bool'
});
const lessEqualConfig = {
kernelName: LessEqual,
backendName: 'webgl',
kernelFunc: lessEqual
};
function linSpace$1(args) {
const { backend, attrs } = args;
const { start, stop, num } = attrs;
const outVals = linSpaceImplCPU(start, stop, num);
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
}
const linSpaceConfig$1 = {
kernelName: LinSpace,
backendName: 'webgl',
kernelFunc: linSpace$1
};
const LOG = CHECK_NAN_SNIPPET_UNARY + `
return x < 0.0 ? 0./0. : log(x);
`;
const LOG_PACKED = `
vec4 result = log(x);
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r);
result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g);
result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b);
result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a);
return result;
`;
const log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU });
const logConfig = {
kernelName: Log,
backendName: 'webgl',
kernelFunc: log
};
const LOG1P = CHECK_NAN_SNIPPET_UNARY + `
return log(1.0 + x);
`;
const log1p$1 = unaryKernelFunc({ opSnippet: LOG1P });
const log1pConfig$1 = {
kernelName: Log1p,
backendName: 'webgl',
kernelFunc: log1p$1,
};
const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`;
const LOGICAL_AND_PACKED = `
return vec4(
vec4(greaterThanEqual(a, vec4(1.0))) *
vec4(greaterThanEqual(b, vec4(1.0))));
`;
const logicalAnd$1 = binaryKernelFunc({
opSnippet: LOGICAL_AND,
packedOpSnippet: LOGICAL_AND_PACKED,
dtype: 'bool'
});
const logicalAndConfig$1 = {
kernelName: LogicalAnd,
backendName: 'webgl',
kernelFunc: logicalAnd$1
};
const LOGICAL_NOT = `return float(!(x >= 1.0));`;
const logicalNot$1 = unaryKernelFunc({ opSnippet: LOGICAL_NOT });
const logicalNotConfig$1 = {
kernelName: LogicalNot,
backendName: 'webgl',
kernelFunc: logicalNot$1,
};
const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`;
const LOGICAL_OR_PACKED = `
return min(
vec4(greaterThanEqual(a, vec4(1.0))) +
vec4(greaterThanEqual(b, vec4(1.0))),
vec4(1.0));
`;
const logicalOr$1 = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' });
const logicalOrConfig$1 = {
kernelName: LogicalOr,
backendName: 'webgl',
kernelFunc: logicalOr$1
};
class LRNProgram {
constructor(xShape, radius, bias, alpha, beta) {
this.variableNames = ['x'];
this.outputShape = [];
const rad = radius;
const maxD = xShape[3] - 1;
this.outputShape = xShape;
let powOperator;
const basis = `float(${bias}) + float(${alpha}) * sum`;
if (beta === 0.5) {
powOperator = `inversesqrt(${basis})`;
}
else if (beta === 1.0) {
powOperator = `1.0/(${basis})`;
}
else {
powOperator = `exp(log(${basis}) * float(-${beta}));`;
}
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int r = coords[1];
int c = coords[2];
int d = coords[3];
float x = getX(b, r, c, d);
float sum = 0.0;
for (int j = -${rad}; j <= ${rad}; j++) {
int idx = d + j;
if (idx >= 0 && idx <= ${maxD}) {
float z = getX(b, r, c, idx);
sum += z * z;
}
}
float val = x * ${powOperator};
setOutput(val);
}
`;
}
}
class LRNPackedProgram {
constructor(xShape, radius, bias, alpha, beta) {
this.variableNames = ['x'];
this.outputShape = [];
this.packedInputs = true;
this.packedOutput = true;
const rad = radius;
const maxD = xShape[3] - 1;
this.outputShape = xShape;
let powOperator;
const basis = `float(${bias}) + float(${alpha}) * sum`;
if (beta === 0.5) {
powOperator = `inversesqrt(${basis})`;
}
else if (beta === 1.0) {
powOperator = `1.0/(${basis})`;
}
else {
powOperator = `exp(log(${basis}) * float(-${beta}));`;
}
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords.x;
int r = coords.y;
int c = coords.z;
int d = coords.w;
bool hasNextCol = d < ${this.outputShape[3]};
bool hasNextRow = c < ${this.outputShape[2]};
vec4 sum = vec4(0.);
vec4 xFragAtOutputCoords = getX(b, r, c, d);
vec4 xAtOutputCoords = vec4(
getChannel(xFragAtOutputCoords, vec2(c, d)),
hasNextCol ?
getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
hasNextRow ?
getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
(hasNextRow && hasNextCol) ?
getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
);
int firstChannel = d - ${rad};
vec2 cache = vec2(0.);
if(firstChannel >= 0){
vec4 firstChannelFrag = getX(b, r, c, firstChannel);
cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
if(hasNextRow){
cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
}
}
ivec2 depth = ivec2(d, d + 1);
for (int j = - ${rad}; j <= ${rad}; j++) {
ivec2 idx = depth + j;
bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
if(depthInRange || depthPlusOneInRange){
vec4 z = vec4(0.);
vec4 xFragAtCurrentDepth;
z.xz = cache.xy;
if(depthPlusOneInRange && hasNextCol){
xFragAtCurrentDepth = idx.y != d ?
getX(b, r, c, idx.y) : xFragAtOutputCoords;
z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
if(hasNextRow){
z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
}
}
cache.xy = z.yw;
sum += z * z;
}
}
vec4 result = xAtOutputCoords * ${powOperator};
setOutput(result);
}
`;
}
}
const lrn = (args) => {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { depthRadius, bias, alpha, beta } = attrs;
const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) :
new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
return backend.runWebGLProgram(program, [x], x.dtype);
};
const LRNConfig$1 = {
kernelName: LRN,
backendName: 'webgl',
kernelFunc: lrn
};
class LRNGradProgram {
constructor(inputShape, depthRadius, bias, alpha, beta) {
this.variableNames = ['inputImage', 'outputImage', 'dy'];
this.outputShape = [];
this.outputShape = inputShape;
this.depth = inputShape[3];
this.depthRadius = depthRadius;
this.bias = bias;
this.alpha = alpha;
this.beta = beta;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int r = coords[1];
int c = coords[2];
float result = 0.0;
for (int d = 0; d < ${this.depth}; ++d) {
int depthBegin = int(max(0.0, float(d - ${depthRadius})));
int depthEnd = int(min(float(${this.depth}),
float(d + ${depthRadius} + 1)));
const int MIN_DEPTH_BEGIN = 0;
const int MAX_DEPTH_END = ${this.depth};
float norm = 0.0;
for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {
if (k < depthBegin){
continue;
}
else if (k >= depthBegin && k < depthEnd) {
norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);
}
else {
break;
}
}
norm = float(${alpha}) * norm + float(${bias});
for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){
if (k < depthBegin){
continue;
}
else if (k >= depthBegin && k < depthEnd){
float dyi = -2.0 * float(${alpha})
* float(${beta})
* getInputImage(b, r, c, k) * getOutputImage(b, r, c, d)
/ norm;
if (k == d) {
dyi += pow(norm, -1.0 * ${beta});
}
if (k == coords[3]) {
dyi *= getDy(b, r, c, d);
result += dyi;
}
}
else {
break;
}
}
}
setOutput(result);
}
`;
}
}
const lrnGrad = (args) => {
const { inputs, backend, attrs } = args;
const { x, y, dy } = inputs;
const { depthRadius, bias, alpha, beta } = attrs;
const program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
};
const LRNGradConfig$1 = {
kernelName: LRNGrad,
backendName: 'webgl',
kernelFunc: lrnGrad
};
function maxImpl(x, reduceShape, outShape, backend) {
const inSize = sizeFromShape(reduceShape);
const xSize = sizeFromShape(x.shape);
const batchSize = xSize / inSize;
const reshapedInput = reshape$1({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
const reduced = reduce(reshapedInput, x.dtype, 'max', backend);
const reshapedOutput = reshape$1({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
return reshapedOutput;
}
function max$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { reductionIndices, keepDims } = attrs;
const xRank = x.shape.length;
const origAxes = parseAxisParam(reductionIndices, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
const maxInputIsTransposed = permutedAxes != null;
const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
let maxInput = x;
if (maxInputIsTransposed) {
if (shouldExecuteOnCPU) {
const xTexData = backend.texData.get(maxInput.dataId);
const values = xTexData.values;
const newShape = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[permutedAxes[i]];
}
const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
maxInput = backend.makeTensorInfo(newShape, x.dtype);
const maxInputData = backend.texData.get(maxInput.dataId);
maxInputData.values = maxInputValues;
}
else {
maxInput = transposeImpl(x, permutedAxes, backend);
}
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('max', axes, xRank);
const [maxOutShape, reduceShape] = computeOutAndReduceShapes(maxInput.shape, axes);
let outShape = maxOutShape;
if (keepDims) {
outShape = expandShapeToKeepDim(maxOutShape, origAxes);
}
let out;
if (shouldExecuteOnCPU) {
const xTexData = backend.texData.get(maxInput.dataId);
const values = xTexData.values;
const outValues = maxImplCPU(values, sizeFromShape(reduceShape), outShape, x.dtype);
out = backend.makeTensorInfo(outShape, x.dtype);
const outData = backend.texData.get(out.dataId);
outData.values = outValues;
}
else {
out = maxImpl(maxInput, reduceShape, outShape, backend);
}
if (maxInputIsTransposed) {
backend.disposeIntermediateTensorInfo(maxInput);
}
return out;
}
const maxConfig$1 = {
kernelName: Max,
backendName: 'webgl',
kernelFunc: max$1
};
const MAXIMUM = CHECK_NAN_SNIPPET + `
return max(a, b);
`;
const MAXIMUM_PACKED = `
vec4 result = vec4(max(a, b));
bvec4 isNaNA = isnan(a);
bvec4 isNaNB = isnan(b);
bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
` +
CHECK_NAN_SNIPPET_PACKED + `
return result;
`;
const maximum = binaryKernelFunc({
opSnippet: MAXIMUM,
packedOpSnippet: MAXIMUM_PACKED,
cpuKernelImpl: maximumImplCPU
});
const maximumConfig = {
kernelName: Maximum,
backendName: 'webgl',
kernelFunc: maximum
};
function maxPool$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
assertNotComplex$1(x, 'maxPool');
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const dilations = 1;
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
arraysEqual(convInfo.inShape, convInfo.outShape)) {
return identity({ inputs: { x }, backend });
}
const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
}
const maxPoolConfig$1 = {
kernelName: MaxPool,
backendName: 'webgl',
kernelFunc: maxPool$1
};
function maxPool3d(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, dataFormat, dimRoundingMode } = attrs;
const dilations = [1, 1, 1];
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
const maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
}
const maxPool3DConfig$1 = {
kernelName: MaxPool3D,
backendName: 'webgl',
kernelFunc: maxPool3d
};
class MaxPool2DBackpropProgram {
constructor(convInfo) {
this.variableNames = ['dy', 'maxPos'];
this.outputShape = convInfo.inShape;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
this.userCode = `
const ivec2 pads = ivec2(${padTop}, ${padLeft});
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 dyRCCorner = coords.yz - pads;
int dyRCorner = dyRCCorner.x;
int dyCCorner = dyRCCorner.y;
float dotProd = 0.0;
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(b, idyR, idyC, d);
int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));
int curPosValue = wR * ${effectiveFilterWidth} + wC;
float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
dotProd += dyValue * mask;
}
}
setOutput(dotProd);
}
`;
}
}
class MaxPool3DBackpropProgram {
constructor(convInfo) {
this.variableNames = ['dy', 'maxPos'];
this.outputShape = convInfo.inShape;
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
this.userCode = `
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
void main() {
ivec5 coords = getOutputCoords();
int batch = coords.x;
int ch = coords.u;
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
int dyDCorner = dyCorner.x;
int dyRCorner = dyCorner.y;
int dyCCorner = dyCorner.z;
float dotProd = 0.0;
for (int wD = 0; wD < ${effectiveFilterDepth};
wD += ${dilationDepth}) {
float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
continue;
}
int idyD = int(dyD);
for (int wR = 0; wR < ${effectiveFilterHeight};
wR += ${dilationHeight}) {
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
fract(dyR) > 0.0) {
continue;
}
int idyR = int(dyR);
for (int wC = 0; wC < ${effectiveFilterWidth};
wC += ${dilationWidth}) {
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
fract(dyC) > 0.0) {
continue;
}
int idyC = int(dyC);
float dyValue = getDy(batch, idyD, idyR, idyC, ch);
int maxPosValue = ${lastIndex} -
int(getMaxPos(batch, idyD, idyR, idyC, ch));
int curPosValue =
wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
wR * ${effectiveFilterWidth} + wC;
float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
dotProd += dyValue * mask;
}
}
}
setOutput(dotProd);
}
`;
}
}
function maxPool3DGrad$1(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const x = input;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const dilations = [1, 1, 1];
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true );
const maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
const maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
const result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
backend.disposeIntermediateTensorInfo(maxPool3dPositions);
return result;
}
const maxPool3DGradConfig$2 = {
kernelName: MaxPool3DGrad,
backendName: 'webgl',
kernelFunc: maxPool3DGrad$1
};
function maxPoolGrad$2(args) {
const { inputs, backend, attrs } = args;
const { dy, input, output } = inputs;
const x = input;
assertNotComplex$1([input, output], 'maxPoolGrad');
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode);
const getPositions = true;
const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
backend.disposeIntermediateTensorInfo(maxPoolPositions);
return result;
}
const maxPoolGradConfig$2 = {
kernelName: MaxPoolGrad,
backendName: 'webgl',
kernelFunc: maxPoolGrad$2
};
function maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, backend) {
let program = new Pool2DProgram(convInfo, 'max', false);
const poolOutput = backend.runWebGLProgram(program, [x], 'float32');
program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
const indexOutput = backend.runWebGLProgram(program, [x], 'float32');
return [poolOutput, indexOutput];
}
const maxPoolWithArgmaxConfig$1 = {
kernelName: MaxPoolWithArgmax,
backendName: 'webgl',
kernelFunc: ({ inputs, attrs, backend }) => {
const { x } = inputs;
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
const webglBackend = backend;
assert$1(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
const dilations = [1, 1];
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
const [result, indexes] = maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, webglBackend);
return [result, indexes];
}
};
function meanImpl(x, reduceShape, outShape, backend) {
const inSize = sizeFromShape(reduceShape);
const xSize = sizeFromShape(x.shape);
const batchSize = xSize / inSize;
const reshapedInput = reshape$1({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
const reduced = reduce(reshapedInput, 'float32', 'mean', backend);
const reshapedOutput = reshape$1({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
backend.disposeIntermediateTensorInfo(reshapedInput);
backend.disposeIntermediateTensorInfo(reduced);
return reshapedOutput;
}
const meanConfig$1 = {
kernelName: Mean,
backendName: 'webgl',
kernelFunc: ({ inputs, attrs, backend }) => {
const { x } = inputs;
const { keepDims, axis } = attrs;
const webglBackend = backend;
const xRank = x.shape.length;
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
const meanInputIsTransposed = permutedAxes != null;
const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
const intermediates = [];
let meanInput = x;
if (meanInputIsTransposed) {
if (shouldExecuteOnCPU) {
const xTexData = webglBackend.texData.get(meanInput.dataId);
const values = xTexData.values;
const newShape = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = x.shape[permutedAxes[i]];
}
const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
const meanInputData = webglBackend.texData.get(meanInput.dataId);
meanInputData.values = meanInputValues;
}
else {
meanInput = transposeImpl(x, permutedAxes, webglBackend);
}
intermediates.push(meanInput);
axes = getInnerMostAxes(axes.length, xRank);
}
assertAxesAreInnerMostDims('sum', axes, xRank);
const [meanOutShape, reduceShape] = computeOutAndReduceShapes(meanInput.shape, axes);
let outShape = meanOutShape;
if (keepDims) {
outShape = expandShapeToKeepDim(meanOutShape, origAxes);
}
const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
for (const i of intermediates) {
webglBackend.disposeIntermediateTensorInfo(i);
}
return out;
}
};
function min$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
const xRank = x.shape.length;
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
let permutedX = x;
if (permutedAxes != null) {
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('min', axes, xRank);
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
const inSize = sizeFromShape(reduceShape);
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
const reduced = reduce(a2D, a2D.dtype, 'min', backend);
let res;
if (keepDims) {
const newShape = expandShapeToKeepDim(outShape, origAxes);
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
}
else {
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
}
backend.disposeIntermediateTensorInfo(a2D);
backend.disposeIntermediateTensorInfo(reduced);
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return res;
}
const minConfig$1 = {
kernelName: Min,
backendName: 'webgl',
kernelFunc: min$1
};
const MINIMUM = CHECK_NAN_SNIPPET + `
return min(a, b);
`;
const MINIMUM_PACKED = `
vec4 result = vec4(min(a, b));
bvec4 isNaNA = isnan(a);
bvec4 isNaNB = isnan(b);
bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
` +
CHECK_NAN_SNIPPET_PACKED + `
return result;
`;
const minimum = binaryKernelFunc({
opSnippet: MINIMUM,
packedOpSnippet: MINIMUM_PACKED,
cpuKernelImpl: minimumImplCPU
});
const minimumConfig = {
kernelName: Minimum,
backendName: 'webgl',
kernelFunc: minimum
};
class MirrorPadProgram {
constructor(xShape, paddings, mode) {
this.variableNames = ['x'];
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
const rank = xShape.length;
const dtype = getCoordsDataType(rank);
const start = paddings.map(p => p[0]).join(',');
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
const offset = mode === 'reflect' ? 0 : 1;
if (rank === 1) {
this.userCode = `
int start = ${start};
int end = ${end};
void main() {
int outC = getOutputCoords();
if (outC < start) {
outC = start * 2 - outC - ${offset};
} else if(outC >= end) {
outC = (end - 1) * 2 - outC + ${offset};
}
setOutput(getX(outC - start));
}
`;
return;
}
this.userCode = `
${dtype} start = ${dtype}(${start});
${dtype} end = ${dtype}(${end});
void main() {
${dtype} outC = getOutputCoords();
for (int i = 0; i < ${rank}; i++) {
if (outC[i] < start[i]) {
outC[i] = start[i] * 2 - outC[i] - ${offset};
} else if(outC[i] >= end[i]) {
outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset};
}
}
${dtype} coords = outC - start;
setOutput(getX(${unpackedCoords}));
}
`;
}
}
class MirrorPadPackedProgram {
constructor(xShape, paddings, mode) {
this.variableNames = ['x'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
const rank = xShape.length;
const dtype = getCoordsDataType(rank);
const start = paddings.map(p => p[0]).join(',');
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
const coords = getChannels('rc', rank);
const source = getChannels('source', rank);
const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
const offset = mode === 'reflect' ? 0 : 1;
let mainLoop = '';
if (rank === 1) {
const padSetup = `
${dtype} source = rc;
if (source < start) {
source = start * 2 - source - ${offset};
} else if (source >= end) {
source = (end - 1) * 2 - source + ${offset};
}
source -= start;
`;
mainLoop = `
${dtype} rc = outputLoc;
${padSetup}
result[0] = getChannel(getX(${source.join()}), ${innerDims});
${coords[rank - 1]} += 1;
if(${cLimit}) {
${padSetup}
result[1] = getChannel(getX(${source.join()}), ${innerDims});
}
`;
}
else {
const padSetup = `
${dtype} source = rc;
${dtype} lt = ${dtype}(lessThan(source, start));
${dtype} gte = ${dtype}(greaterThanEqual(source, end));
${dtype} orig = 1 - (lt + gte);
source = orig * source +
lt * (start * 2 - source - ${offset}) +
gte * ((end - 1) * 2 - source + ${offset});
source -= start;
`;
mainLoop = `
${dtype} rc = outputLoc;
${padSetup}
result[0] = getChannel(getX(${source.join()}), ${innerDims});
${coords[rank - 1]} += 1;
if(${cLimit}) {
${padSetup}
result[1] = getChannel(getX(${source.join()}), ${innerDims});
}
rc = outputLoc;
${coords[rank - 2]} += 1;
if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {
${padSetup}
result[2] = getChannel(getX(${source.join()}), ${innerDims});
${coords[rank - 1]} += 1;
if(${cLimit}) {
${padSetup}
result[3] = getChannel(getX(${source.join()}), ${innerDims});
}
}
`;
}
this.userCode = `
const ${dtype} start = ${dtype}(${start});
const ${dtype} end = ${dtype}(${end});
void main() {
${dtype} outputLoc = getOutputCoords();
vec4 result = vec4(0.);
${mainLoop}
setOutput(result);
}
`;
}
}
const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => {
const { x } = inputs;
const { paddings, mode } = attrs;
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new MirrorPadPackedProgram(x.shape, paddings, mode) :
new MirrorPadProgram(x.shape, paddings, mode);
const output = backend.runWebGLProgram(program, [x], x.dtype);
return output;
};
const mirrorPadConfig$1 = {
kernelName: MirrorPad,
backendName: 'webgl',
kernelFunc: mirrorPadKernelFunc,
};
const MOD = `if (b == 0.0) return NAN;
return mod(a, b);`;
const MOD_PACKED = `
vec4 result = mod(a, b);
bvec4 isNaN = equal(b, vec4(0.0));
` +
CHECK_NAN_SNIPPET_PACKED + `
return result;
`;
const mod$1 = binaryKernelFunc({
opSnippet: MOD,
packedOpSnippet: MOD_PACKED,
});
const modConfig$1 = {
kernelName: Mod,
backendName: 'webgl',
kernelFunc: mod$1
};
class MultinomialProgram {
constructor(batchSize, numOutcomes, numSamples) {
this.variableNames = ['probs'];
this.customUniforms = [{ name: 'seed', type: 'float' }];
this.outputShape = [batchSize, numSamples];
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
float r = random(seed);
float cdf = 0.0;
for (int i = 0; i < ${numOutcomes - 1}; i++) {
cdf += getProbs(batch, i);
if (r < cdf) {
setOutput(float(i));
return;
}
}
setOutput(float(${numOutcomes - 1}));
}
`;
}
}
const DIV = `
if (a == b) {
return 1.0;
};
return a / b;`;
const DIV_PACKED = `
vec4 result = a / b;
if(a.x == b.x) {
result.x = 1.;
}
if(a.y == b.y) {
result.y = 1.;
}
if(a.z == b.z) {
result.z = 1.;
}
if(a.w == b.w) {
result.w = 1.;
}
return result;
`;
const realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true });
const realDivConfig$1 = {
kernelName: RealDiv,
backendName: 'webgl',
kernelFunc: realDiv,
};
const SUB = 'return a - b;';
const sub = binaryKernelFunc({
opSnippet: SUB,
packedOpSnippet: SUB,
supportsComplex: true,
cpuKernelImpl: subImplCPU
});
const subConfig = {
kernelName: Sub,
backendName: 'webgl',
kernelFunc: sub
};
function softmax$1(args) {
const { inputs, backend, attrs } = args;
const { logits } = inputs;
const { dim } = attrs;
const axes = parseAxisParam([dim], logits.shape);
const maxLogit = max$1({
inputs: { x: logits },
backend,
attrs: { reductionIndices: axes, keepDims: false }
});
const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
const maxLogitsReshaped = reshape$1({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend });
const b = exp({ inputs: { x: a }, backend });
const sumExp = sum$1({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
const sumExpReshaped = reshape$1({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend });
backend.disposeIntermediateTensorInfo(maxLogit);
backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
backend.disposeIntermediateTensorInfo(a);
backend.disposeIntermediateTensorInfo(b);
backend.disposeIntermediateTensorInfo(sumExp);
backend.disposeIntermediateTensorInfo(sumExpReshaped);
return res;
}
const softmaxConfig$1 = {
kernelName: Softmax$1,
backendName: 'webgl',
kernelFunc: softmax$1
};
function multinomial$1(args) {
const { inputs, backend, attrs } = args;
const { logits } = inputs;
const { numSamples, seed, normalized } = attrs;
const probs = normalized ?
logits :
softmax$1({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 } });
const batchSize = probs.shape[0];
const numOutcomes = probs.shape[1];
const program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
const customValues = [[seed]];
const res = backend.runWebGLProgram(program, [probs], 'int32', customValues);
if (!normalized) {
backend.disposeIntermediateTensorInfo(probs);
}
return res;
}
const multinomialConfig$1 = {
kernelName: Multinomial,
backendName: 'webgl',
kernelFunc: multinomial$1
};
const NEG = CHECK_NAN_SNIPPET$1 + `
return -x;
`;
const NEG_PACKED = `
vec4 result = -x;
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
function neg(args) {
const { inputs, backend } = args;
const { x } = inputs;
if (backend.shouldExecuteOnCPU([x])) {
const xData = backend.texData.get(x.dataId);
const [outValues, newShape] = negImplCPU(xData.values, x.shape, x.dtype);
return backend.makeTensorInfo(newShape, x.dtype, outValues);
}
let program;
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
program = new UnaryOpPackedProgram(x.shape, NEG_PACKED);
}
else {
program = new UnaryOpProgram(x.shape, NEG);
}
return backend.runWebGLProgram(program, [x], x.dtype);
}
const negConfig = {
kernelName: Neg,
backendName: 'webgl',
kernelFunc: neg
};
const nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl$2;
function nonMaxSuppressionV3$1(args) {
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
'Call tf.nonMaxSuppressionAsync() instead');
const { inputs, backend, attrs } = args;
const { boxes, scores } = inputs;
const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
const boxesVals = backend.readSync(boxes.dataId);
const scoresVals = backend.readSync(scores.dataId);
const { selectedIndices } = nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
}
const nonMaxSuppressionV3Config$1 = {
kernelName: NonMaxSuppressionV3,
backendName: 'webgl',
kernelFunc: nonMaxSuppressionV3$1
};
const nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl$2;
function nonMaxSuppressionV4$1(args) {
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
'Call tf.nonMaxSuppressionAsync() instead');
const { inputs, backend, attrs } = args;
const { boxes, scores } = inputs;
const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
const boxesVals = backend.readSync(boxes.dataId);
const scoresVals = backend.readSync(scores.dataId);
const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
return [
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
];
}
const nonMaxSuppressionV4Config$1 = {
kernelName: NonMaxSuppressionV4,
backendName: 'webgl',
kernelFunc: nonMaxSuppressionV4$1
};
const nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl$2;
function nonMaxSuppressionV5$1(args) {
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
'Call tf.nonMaxSuppressionAsync() instead');
const { inputs, backend, attrs } = args;
const { boxes, scores } = inputs;
const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
const boxesVals = backend.readSync(boxes.dataId);
const scoresVals = backend.readSync(scores.dataId);
const maxOutputSizeVal = maxOutputSize;
const iouThresholdVal = iouThreshold;
const scoreThresholdVal = scoreThreshold;
const softNmsSigmaVal = softNmsSigma;
const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
return [
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
];
}
const nonMaxSuppressionV5Config$1 = {
kernelName: NonMaxSuppressionV5,
backendName: 'webgl',
kernelFunc: nonMaxSuppressionV5$1
};
class OneHotProgram {
constructor(numIndices, depth, onValue, offValue) {
this.variableNames = ['indices'];
this.outputShape = [numIndices, depth];
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int index = round(getIndices(coords.x));
setOutput(mix(float(${offValue}), float(${onValue}),
float(index == coords.y)));
}
`;
}
}
const oneHot$1 = (args) => {
const { inputs, backend, attrs } = args;
const { indices } = inputs;
const { dtype, depth, onValue, offValue } = attrs;
const indicesSize = sizeFromShape(indices.shape);
const program = new OneHotProgram(indicesSize, depth, onValue, offValue);
const reshaped = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [indicesSize] } });
const result = backend.runWebGLProgram(program, [reshaped], dtype);
backend.disposeIntermediateTensorInfo(reshaped);
const outShape = [...indices.shape, depth];
const out = reshape$1({ inputs: { x: result }, backend, attrs: { shape: outShape } });
backend.disposeIntermediateTensorInfo(result);
return out;
};
const oneHotConfig$1 = {
kernelName: OneHot,
backendName: 'webgl',
kernelFunc: oneHot$1
};
function zerosLike$1(args) {
const { inputs, backend } = args;
const { x } = inputs;
if (x.dtype === 'complex64') {
const realPart = real({ inputs: { input: x }, backend });
const r = zerosLike$1({ inputs: { x: realPart }, backend });
const imagPart = imag$1({ inputs: { input: x }, backend });
const i = zerosLike$1({ inputs: { x: imagPart }, backend });
const result = complex({ inputs: { real: r, imag: i }, backend });
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
}
else {
return fill$1({
attrs: {
shape: x.shape,
dtype: x.dtype,
value: x.dtype === 'string' ? '' : 0
},
backend
});
}
}
const zerosLikeConfig$1 = {
kernelName: ZerosLike,
backendName: 'webgl',
kernelFunc: zerosLike$1
};
function onesLike$1(args) {
const { inputs, backend } = args;
const { x } = inputs;
if (x.dtype === 'string') {
throw new Error('onesLike is not supported under string dtype');
}
else if (x.dtype === 'complex64') {
const realPart = real({ inputs: { input: x }, backend });
const r = onesLike$1({ inputs: { x: realPart }, backend });
const imagPart = imag$1({ inputs: { input: x }, backend });
const i = zerosLike$1({ inputs: { x: imagPart }, backend });
const result = complex({ inputs: { real: r, imag: i }, backend });
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
}
else {
return fill$1({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend });
}
}
const onesLikeConfig$1 = {
kernelName: OnesLike,
backendName: 'webgl',
kernelFunc: onesLike$1
};
function pack$1(args) {
const { inputs, backend, attrs } = args;
const { axis } = attrs;
if (inputs.length === 1) {
return expandDims$2({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
}
const shape = inputs[0].shape;
const dtype = inputs[0].dtype;
inputs.forEach(t => {
assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
});
const intermediateTensorInfos = [];
const expandedTensors = inputs.map(t => {
const expandedT = expandDims$2({ inputs: { input: t }, backend, attrs: { dim: axis } });
intermediateTensorInfos.push(expandedT);
return expandedT;
});
const result = concat$1({ inputs: expandedTensors, backend, attrs: { axis } });
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
}
const packConfig$1 = {
kernelName: Pack,
backendName: 'webgl',
kernelFunc: pack$1
};
class PadProgram {
constructor(xShape, paddings, constantValue) {
this.variableNames = ['x'];
this.customUniforms = [{ name: 'value', type: 'float' }];
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
const rank = xShape.length;
const type = getCoordsDataType(rank);
const start = paddings.map(p => p[0]).join(',');
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
if (rank === 1) {
this.userCode = `
int start = ${start};
int end = ${end};
void main() {
int outC = getOutputCoords();
if (outC < start || outC >= end) {
setOutput(value);
} else {
setOutput(getX(outC - start));
}
}
`;
return;
}
this.userCode = `
${type} start = ${type}(${start});
${type} end = ${type}(${end});
void main() {
${type} outC = getOutputCoords();
if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {
setOutput(value);
} else {
${type} coords = outC - start;
setOutput(getX(${unpackedCoords}));
}
}
`;
}
}
class PadPackedProgram {
constructor(xShape, paddings, constantValue) {
this.variableNames = ['x'];
this.packedInputs = true;
this.packedOutput = true;
this.customUniforms = [{ name: 'value', type: 'float' }];
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
const rank = xShape.length;
const dtype = getCoordsDataType(rank);
const start = paddings.map(p => p[0]).join(',');
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
const coords = getChannels('rc', rank);
const source = getChannels('source', rank);
const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
const componentSetup = [
`${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;
if(${cLimit}) {
`,
rank === 1 ? '' : `}
rc = outputLoc;
${coords[rank - 2]} += 1;
if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,
rank === 1 ? '' : ` ${coords[rank - 1]} += 1;
if(${cLimit}) {`
];
const paddingArea = rank === 1 ?
'rc < start || rc >= end' :
'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
let mainLoop = '';
for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
mainLoop += `
${componentSetup[i]}
if (${paddingArea}) {
result[${i}] = float(value);
} else {
${dtype} source = rc - start;
result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
}
`;
}
mainLoop += (rank === 1 ? `} ` : `}}`);
this.userCode = `
const ${dtype} start = ${dtype}(${start});
const ${dtype} end = ${dtype}(${end});
void main() {
${dtype} outputLoc = getOutputCoords();
vec4 result = vec4(0.);
${mainLoop}
setOutput(result);
}
`;
}
}
const padV2$1 = (args) => {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { paddings, constantValue } = attrs;
if (sizeFromShape(x.shape) === 0) {
const outputShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] );
return fill$1({
backend,
attrs: { shape: outputShape, value: constantValue, dtype: x.dtype }
});
}
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new PadPackedProgram(x.shape, paddings, constantValue) :
new PadProgram(x.shape, paddings, constantValue);
const customValues = [[constantValue]];
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
};
const padV2Config$1 = {
kernelName: PadV2,
backendName: 'webgl',
kernelFunc: padV2$1
};
const POW = `
if(a < 0.0 && floor(b) < b){
return NAN;
}
if (b == 0.0) {
return 1.0;
}
return (round(mod(b, 2.0)) != 1) ?
pow(abs(a), b) : sign(a) * pow(abs(a), b);
`;
const POW_PACKED = `
vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));
vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
vec4 result = multiplier * pow(abs(a), b);
bvec4 isExpZero = equal(b, vec4(0.0));
result.r = isExpZero.r ? 1.0 : result.r;
result.g = isExpZero.g ? 1.0 : result.g;
result.b = isExpZero.b ? 1.0 : result.b;
result.a = isExpZero.a ? 1.0 : result.a;
bvec4 isNaN1 = lessThan(a, vec4(0.0));
bvec4 isNaN2 = lessThan(floor(b), b);
bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w);
` +
CHECK_NAN_SNIPPET_PACKED + `
return result;
`;
const pow$1 = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED });
const powConfig$1 = {
kernelName: Pow,
backendName: 'webgl',
kernelFunc: pow$1
};
function prod(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
const xRank = x.shape.length;
const toDispose = [];
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
let permutedX = x;
if (permutedAxes != null) {
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, xRank);
toDispose.push(permutedX);
}
assertAxesAreInnerMostDims('prod', axes, xRank);
let res;
if (backend.shouldExecuteOnCPU([permutedX])) {
const xVals = backend.texData.get(permutedX.dataId).values;
const { outVals, outShape, outDtype } = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes);
res = backend.makeTensorInfo(outShape, outDtype, outVals);
}
else {
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
const inSize = sizeFromShape(reduceShape);
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
const outputDType = sumOutType(x.dtype);
const reduced = reduce(a2D, outputDType, 'prod', backend);
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
toDispose.push(a2D);
toDispose.push(reduced);
}
if (keepDims) {
toDispose.push(res);
const newShape = expandShapeToKeepDim(res.shape, origAxes);
res = reshape$1({ inputs: { x: res }, backend, attrs: { shape: newShape } });
}
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return res;
}
const prodConfig = {
kernelName: Prod,
backendName: 'webgl',
kernelFunc: prod
};
function raggedGather$1(args) {
const { inputs, backend, attrs } = args;
const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
const { outputRaggedRank } = attrs;
const $paramsNestedSplits = paramsNestedSplits.map(t => backend.readSync(t.dataId));
const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
const $paramsDenseValues = backend.readSync(paramsDenseValues.dataId);
const $indices = backend.readSync(indices.dataId);
const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank);
const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
}
const raggedGatherConfig$1 = {
kernelName: RaggedGather,
backendName: 'webgl',
kernelFunc: raggedGather$1,
};
function raggedRange$1(args) {
const { inputs, backend } = args;
const { starts, limits, deltas } = inputs;
const $starts = backend.readSync(starts.dataId);
const $limits = backend.readSync(limits.dataId);
const $deltas = backend.readSync(deltas.dataId);
const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
return [rtNestedSplits, rtDenseValues];
}
const raggedRangeConfig$1 = {
kernelName: RaggedRange,
backendName: 'webgl',
kernelFunc: raggedRange$1,
};
function raggedTensorToTensor$1(args) {
const { inputs, backend, attrs } = args;
const { shape, values, defaultValue, rowPartitionTensors } = inputs;
const { rowPartitionTypes } = attrs;
const $shape = backend.readSync(shape.dataId);
const $values = backend.readSync(values.dataId);
const $defaultValue = backend.readSync(defaultValue.dataId);
const $rowPartitionValues = rowPartitionTensors.map(t => backend.readSync(t.dataId));
const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
const [outputShape, output] = raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
return backend.makeTensorInfo(outputShape, values.dtype, output);
}
const raggedTensorToTensorConfig$1 = {
kernelName: RaggedTensorToTensor,
backendName: 'webgl',
kernelFunc: raggedTensorToTensor$1,
};
const range$2 = (args) => {
const { backend, attrs } = args;
const { start, stop, step, dtype } = attrs;
const values = rangeImplCPU(start, stop, step, dtype);
return backend.makeTensorInfo([values.length], dtype, values);
};
const rangeConfig$1 = {
kernelName: Range,
backendName: 'webgl',
kernelFunc: range$2
};
const RECIPROCAL = `return 1.0 / x;`;
const reciprocal$1 = unaryKernelFunc({ opSnippet: RECIPROCAL });
const reciprocalConfig$1 = {
kernelName: Reciprocal,
backendName: 'webgl',
kernelFunc: reciprocal$1,
};
const RELU = CHECK_NAN_SNIPPET$1 + `
return (x < 0.0) ? 0.0 : x;
`;
const RELU_PACKED = `
vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
const relu$1 = unaryKernelFunc({ opSnippet: RELU, packedOpSnippet: RELU_PACKED });
const reluConfig$1 = {
kernelName: Relu$1,
backendName: 'webgl',
kernelFunc: relu$1
};
const RELU6 = CHECK_NAN_SNIPPET$1 + `
return (x < 0.0) ? 0.0 : min(6.0, x);
`;
const RELU6_PACKED = `
vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
const relu6$1 = unaryKernelFunc({ opSnippet: RELU6, packedOpSnippet: RELU6_PACKED });
const relu6Config$1 = {
kernelName: Relu6$1,
backendName: 'webgl',
kernelFunc: relu6$1
};
class ResizeBilinearProgram {
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.outputShape = [];
const [batch, oldHeight, oldWidth, depth] = inputShape;
this.outputShape = [batch, newHeight, newWidth, depth];
const effectiveInSize = [
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
];
const effectiveOutSize = [
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
];
let sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC =
`(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
` - vec2(0.5)`;
}
else {
sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
}
this.userCode = `
const vec2 effectiveInputOverOutputRatioRC = vec2(
${effectiveInSize[0] / effectiveOutSize[0]},
${effectiveInSize[1] / effectiveOutSize[1]});
const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 yRC = coords.yz;
vec2 sourceFracIndexRC = ${sourceFracIndexRC};
ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));
ivec2 sourceCeilRC = ivec2(
min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);
float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);
float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);
float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);
vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
float top = topLeft + (topRight - topLeft) * fracRC.y;
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
float newValue = top + (bottom - top) * fracRC.x;
setOutput(newValue);
}
`;
}
}
class ResizeBilinearPackedProgram {
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
const [batch, oldHeight, oldWidth, depth] = inputShape;
this.outputShape = [batch, newHeight, newWidth, depth];
const effectiveInSize = [
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
];
const effectiveOutSize = [
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
];
let sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC = `(vec3(yRC) + vec3(0.5)) * ` +
`effectiveInputOverOutputRatioRC - vec3(0.5)`;
}
else {
sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
}
this.userCode = `
const vec3 effectiveInputOverOutputRatioRC = vec3(
${effectiveInSize[0] / effectiveOutSize[0]},
${effectiveInSize[1] / effectiveOutSize[1]},
${effectiveInSize[1] / effectiveOutSize[1]});
const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
${oldWidth}.0);
float getAValue(int b, int r, int c, int d) {
return getChannel(getA(b, r, c, d), vec2(c, d));
}
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
vec3 sourceFracIndexRC = ${sourceFracIndexRC};
ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));
ivec3 sourceCeilRC = ivec3(
min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
bool hasNextCol = d < ${depth - 1};
bool hasNextRow = coords.z < ${newWidth - 1};
vec4 topLeft = vec4(
getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),
hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);
vec4 bottomLeft = vec4(
getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),
hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);
vec4 topRight = vec4(
getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),
hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);
vec4 bottomRight = vec4(
getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),
hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);
vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);
vec4 top = mix(topLeft, topRight, fracRC.yyzz);
vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);
vec4 newValue = mix(top, bottom, fracRC.x);
setOutput(newValue);
}
`;
}
}
function resizeBilinear$1(args) {
const { inputs, backend, attrs } = args;
const { images } = inputs;
const { alignCorners, halfPixelCenters, size } = attrs;
const [newHeight, newWidth] = size;
const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
return backend.runWebGLProgram(program, [images], 'float32');
}
const resizeBilinearConfig$1 = {
kernelName: ResizeBilinear,
backendName: 'webgl',
kernelFunc: resizeBilinear$1
};
class ResizeBilinearBackpropProgram {
constructor(dyShape, inputShape, alignCorners) {
this.variableNames = ['dy'];
this.outputShape = [];
this.outputShape = inputShape;
const [, xHeight, xWidth,] = inputShape;
const [, yHeight, yWidth] = dyShape;
const effectiveXSize = [
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
];
const effectiveYSize = [
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
];
const heightScale = effectiveXSize[0] / effectiveYSize[0];
const widthScale = effectiveXSize[1] / effectiveYSize[1];
const invHeightScale = 1 / heightScale;
const invWidthScale = 1 / widthScale;
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
int r = coords[1];
int c = coords[2];
float accumulator = 0.0;
const float heightScale = float(${heightScale});
const float widthScale = float(${widthScale});
const float invHeightScale = float(${invHeightScale});
const float invWidthScale = float(${invWidthScale});
const int winHeight = int(${winHeight});
const int winWidth = int(${winWidth});
float startRLerp = floor(float(r) * invHeightScale);
int startDyR = int(startRLerp - float(winHeight / 2));
float startCLerp = floor(float(c) * invWidthScale);
int startDyC = int(startCLerp - float(winWidth / 2));
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
int dyR = dyROffset + startDyR;
if (dyR < 0 || dyR >= ${yHeight}) {
continue;
}
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
int dyC = dyCOffset + startDyC;
if (dyC < 0 || dyC >= ${yWidth}) {
continue;
}
float dxR = float(dyR) * heightScale;
int topDxRIndex = int(floor(dxR));
int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));
float dxRLerp = dxR - float(topDxRIndex);
float inverseDxRLerp = 1.0 - dxRLerp;
float dxC = float(dyC) * widthScale;
int leftDxCIndex = int(floor(dxC));
int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));
float dxCLerp = dxC - float(leftDxCIndex);
float inverseDxCLerp = 1.0 - dxCLerp;
if (r == topDxRIndex && c == leftDxCIndex) {
accumulator +=
getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
}
if (r == topDxRIndex && c == rightDxCIndex) {
accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
}
if (r == bottomDxRIndex && c == leftDxCIndex) {
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
}
if (r == bottomDxRIndex && c == rightDxCIndex) {
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
}
}
}
setOutput(accumulator);
}
`;
}
}
function resizeBilinearGrad$1(args) {
const { inputs, backend, attrs } = args;
const { images, dy } = inputs;
const { alignCorners } = attrs;
const program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
return backend.runWebGLProgram(program, [dy], dy.dtype);
}
const resizeBilinearGradConfig$2 = {
kernelName: ResizeBilinearGrad,
backendName: 'webgl',
kernelFunc: resizeBilinearGrad$1
};
class ResizeNearestNeighborProgram {
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.outputShape = [];
const [batch, oldHeight, oldWidth, depth] = inputShape;
this.outputShape = [batch, newHeight, newWidth, depth];
const effectiveInSize = [
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
];
const effectiveOutSize = [
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
];
const roundBase = alignCorners ? '0.5' : '0.0';
let sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC =
`max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
`, vec2(0.0))`;
}
else {
sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
}
this.userCode = `
const vec2 effectiveInputOverOutputRatioRC = vec2(
${effectiveInSize[0] / effectiveOutSize[0]},
${effectiveInSize[1] / effectiveOutSize[1]});
const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec2 yRC = coords.yz;
vec2 sourceFracIndexRC = ${sourceFracIndexRC};
ivec2 sourceNearestRC = ivec2(
min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);
setOutput(newValue);
}
`;
}
}
class ResizeNearestNeighborPackedProgram {
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
this.variableNames = ['A'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = [];
const [batch, oldHeight, oldWidth, depth] = inputShape;
this.outputShape = [batch, newHeight, newWidth, depth];
const effectiveInSize = [
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
];
const effectiveOutSize = [
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
];
const roundBase = alignCorners ? '0.5' : '0.0';
let sourceFracIndexRC;
if (halfPixelCenters) {
sourceFracIndexRC = `max((vec3(yRC) + vec3(0.5)) * ` +
`effectiveInputOverOutputRatioRC, vec3(0.0))`;
}
else {
sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
}
this.userCode = `
const vec3 effectiveInputOverOutputRatioRC = vec3(
${effectiveInSize[0] / effectiveOutSize[0]},
${effectiveInSize[1] / effectiveOutSize[1]},
${effectiveInSize[1] / effectiveOutSize[1]});
const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
${oldWidth}.0);
float getAValue(int b, int r, int c, int d) {
return getChannel(getA(b, r, c, d), vec2(c, d));
}
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
vec3 sourceFracIndexRC = ${sourceFracIndexRC};
ivec3 sourceNearestRC = ivec3(
min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
bool hasNextCol = d < ${depth - 1};
bool hasNextRow = coords.z < ${newWidth - 1};
vec4 newValue = vec4(
getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),
hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)
: 0.0,
hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)
: 0.0,
(hasNextRow && hasNextCol) ?
getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);
setOutput(newValue);
}
`;
}
}
function resizeNearestNeighbor$1(args) {
const { inputs, backend, attrs } = args;
const { images } = inputs;
const { alignCorners, halfPixelCenters, size } = attrs;
const [newHeight, newWidth] = size;
const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
return backend.runWebGLProgram(program, [images], images.dtype);
}
const resizeNearestNeighborConfig$1 = {
kernelName: ResizeNearestNeighbor,
backendName: 'webgl',
kernelFunc: resizeNearestNeighbor$1
};
class ResizeNearestNeigborBackpropProgram {
constructor(dyShape, inputShape, alignCorners) {
this.variableNames = ['dy'];
this.outputShape = [];
this.outputShape = inputShape;
const [, xHeight, xWidth,] = inputShape;
const [, yHeight, yWidth] = dyShape;
const effectiveXSize = [
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
];
const effectiveYSize = [
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
];
const heightScale = effectiveXSize[0] / effectiveYSize[0];
const widthScale = effectiveXSize[1] / effectiveYSize[1];
const invHeightScale = 1 / heightScale;
const invWidthScale = 1 / widthScale;
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int b = coords[0];
int d = coords[3];
int r = coords[1];
int c = coords[2];
float accumulator = 0.0;
const float heightScale = float(${heightScale});
const float widthScale = float(${widthScale});
const float invHeightScale = float(${invHeightScale});
const float invWidthScale = float(${invWidthScale});
const int winHeight = int(${winHeight});
const int winWidth = int(${winWidth});
float startRLerp = floor(float(r) * invHeightScale);
int startDyR = int(floor(startRLerp - float(winHeight / 2)));
float startCLerp = floor(float(c) * invWidthScale);
int startDyC = int(floor(startCLerp - float(winWidth / 2)));
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
int dyR = dyROffset + startDyR;
if (dyR < 0 || dyR >= ${yHeight}) {
continue;
}
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
int dyC = dyCOffset + startDyC;
if (dyC < 0 || dyC >= ${yWidth}) {
continue;
}
float sourceFracRow =
float(${effectiveXSize[0]}) *
(float(dyR) / float(${effectiveYSize[0]}));
float sourceFracCol =
float(${effectiveXSize[1]}) *
(float(dyC) / float(${effectiveYSize[1]}));
int sourceNearestRow = int(min(
float(int(${xHeight}) - 1),
${alignCorners} ? float(round(sourceFracRow)) :
float(floor(sourceFracRow))));
int sourceNearestCol = int(min(
float(int(${xWidth}) - 1),
${alignCorners} ? float(round(sourceFracCol)) :
float(floor(sourceFracCol))));
if (r == sourceNearestRow && c == sourceNearestCol) {
accumulator += getDy(b, dyR, dyC, d);
}
}
}
setOutput(accumulator);
}
`;
}
}
function resizeNearestNeighborGrad$1(args) {
const { inputs, backend, attrs } = args;
const { images, dy } = inputs;
const { alignCorners } = attrs;
const program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
return backend.runWebGLProgram(program, [dy], dy.dtype);
}
const resizeNearestNeighborGradConfig$2 = {
kernelName: ResizeNearestNeighborGrad,
backendName: 'webgl',
kernelFunc: resizeNearestNeighborGrad$1
};
class ReverseProgram {
constructor(xShape, axis) {
this.variableNames = ['x'];
const rank = xShape.length;
if (rank > 4) {
throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
}
this.outputShape = xShape;
if (rank === 1) {
this.userCode = `
void main() {
int coord = getOutputCoords();
setOutput(getX(${xShape[0]} - coord - 1));
}
`;
return;
}
const getInCoord = (i) => {
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
return `${xShape[i]} - coords[${i}] - 1`;
}
return `coords[${i}]`;
};
const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');
const type = getCoordsDataType(rank);
this.userCode = `
void main() {
${type} coords = getOutputCoords();
setOutput(getX(${inCoords}));
}
`;
}
}
class ReversePackedProgram {
constructor(xShape, axis) {
this.variableNames = ['x'];
this.packedInputs = true;
this.packedOutput = true;
const rank = xShape.length;
if (rank > 4) {
throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
}
this.outputShape = xShape;
const channels = getChannels('rc', rank);
const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`;
const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`;
const type = getCoordsDataType(rank);
if (rank === 1) {
this.userCode = `
void main(){
int rc = getOutputCoords();
vec4 result = vec4(0.);
result.r = getChannel(getX(${xShape[0]} - rc - 1),
${xShape[0]} - rc - 1);
if(${nextColumn}){
result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),
${xShape[0]} - (rc + 1) - 1);
}
setOutput(result);
}
`;
}
else {
this.userCode = `
void main() {
${type} rc = getOutputCoords();
vec4 result = vec4(0.);
result.r = ${getR(channels.slice())};
if(${nextColumn}){
result.g = ${getG(channels.slice())};
}
if(${nextRow}) {
result.b = ${getB(channels.slice())};
if(${nextColumn}) {
result.a = ${getA(channels.slice())};
}
}
setOutput(result);
}
`;
}
function getR(channels) {
return getChannel(channels);
}
function getG(channels) {
channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
return getChannel(channels);
}
function getB(channels) {
channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
return getChannel(channels);
}
function getA(channels) {
channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
return getChannel(channels);
}
function getChannel(channels) {
const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels));
const inCoords = inCoordsArray.join(',');
const innerDims = inCoordsArray.slice(-2).join(',');
return `getChannel(getX(${inCoords}), vec2(${innerDims}))`;
}
function getInCoord(i, channels1) {
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
return `${xShape[i]} - ${channels1[i]} - 1`;
}
else {
return `${channels1[i]}`;
}
}
}
}
function reverse$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { dims } = attrs;
const xRank = x.shape.length;
const $dims = parseAxisParam(dims, x.shape);
if (xRank === 0) {
return identity({ inputs: { x }, backend });
}
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new ReversePackedProgram(x.shape, $dims) :
new ReverseProgram(x.shape, $dims);
return backend.runWebGLProgram(program, [x], x.dtype);
}
const reverseConfig$1 = {
kernelName: Reverse,
backendName: 'webgl',
kernelFunc: reverse$1
};
class RotateProgram {
constructor(imageShape, fillValue) {
this.variableNames = ['Image'];
this.outputShape = [];
this.customUniforms = [{ name: 'params', type: 'vec4' }];
const imageHeight = imageShape[1];
const imageWidth = imageShape[2];
this.outputShape = imageShape;
let fillSnippet = '';
if (typeof fillValue === 'number') {
fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`;
}
else {
fillSnippet = `
vec3 fill = vec3(${fillValue.join(',')});
float outputValue = fill[coords[3]];`;
}
this.userCode = `
void main() {
ivec4 coords = getOutputCoords();
int x = coords[2];
int y = coords[1];
float coordXFloat = (float(x) - params[0]) * params[3] -
(float(y) - params[1]) * params[2];
float coordYFloat = (float(x) - params[0]) * params[2] +
(float(y) - params[1]) * params[3];
int coordX = int(round(coordXFloat + params[0]));
int coordY = int(round(coordYFloat + params[1]));
${fillSnippet}
if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) {
outputValue = getImage(coords[0], coordY, coordX, coords[3]);
}
setOutput(outputValue);
}
`;
}
}
const rotateWithOffsetConfig$1 = {
kernelName: RotateWithOffset,
backendName: 'webgl',
kernelFunc: ({ inputs, attrs, backend }) => {
const { image } = inputs;
const { radians, fillValue, center } = attrs;
const webglBackend = backend;
const program = new RotateProgram(image.shape, fillValue);
const [centerX, centerY] = getImageCenter(center, image.shape[1], image.shape[2]);
const customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]];
const output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues);
return output;
}
};
const ROUND = `
float base = floor(x);
if ((x - base) < 0.5) {
return floor(x);
} else if ((x - base) > 0.5) {
return ceil(x);
} else {
if (mod(base, 2.0) == 0.0) {
return base;
} else {
return base + 1.0;
}
}
`;
const round$1 = unaryKernelFunc({ opSnippet: ROUND });
const roundConfig$1 = {
kernelName: Round,
backendName: 'webgl',
kernelFunc: round$1,
};
const RSQRT = `return inversesqrt(x);`;
const rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU });
const rsqrtConfig = {
kernelName: Rsqrt,
backendName: 'webgl',
kernelFunc: rsqrt
};
class ScatterProgram {
constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) {
this.variableNames = ['updates', 'indices', 'defaultValue'];
this.outputShape = shape;
const stridesType = getCoordsDataType(strides.length);
const dtype = getCoordsDataType(shape.length);
let indicesString = '';
if (indicesRank === 1) {
indicesString = 'i';
}
else if (indicesRank === 2) {
indicesString = 'i, j';
}
const indicesSnippet = `getIndices(${indicesString})`;
let updatesString = '';
if (updatesRank === 1) {
updatesString = 'i';
}
else if (updatesRank === 2) {
updatesString = 'i, coords[1]';
}
const updatesSnippet = `getUpdates(${updatesString})`;
let defaultValuesString = '';
if (defaultIsTensor) {
defaultValuesString = 'coords[0], coords[1]';
}
const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`;
const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
this.userCode = `
${stridesType} strides = ${stridesType}(${strides});
void main() {
${dtype} coords = getOutputCoords();
float sum = 0.0;
bool found = false;
for (int i = 0; i < ${updateSize}; i++) {
int flattenedIndex = 0;
for (int j = 0; j < ${sliceDim}; j++) {
int index = round(${indicesSnippet});
flattenedIndex += index * ${strideString};
}
if (flattenedIndex == coords[0]) {
sum += ${updatesSnippet};
found = true;
}
}
setOutput(mix(${defaultValueSnippet}, sum, float(found)));
}
`;
}
}
class ScatterPackedProgram {
constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) {
this.variableNames = ['updates', 'indices', 'defaultValue'];
this.packedInputs = true;
this.packedOutput = true;
this.outputShape = shape;
const stridesType = getCoordsDataType(strides.length);
const dtype = getCoordsDataType(shape.length);
let indicesString = '';
if (indicesRank === 1) {
indicesString = 'i';
}
else if (indicesRank === 2) {
indicesString = 'i, j';
}
const indicesSnippet = `getIndices(${indicesString})`;
let updatesString = '';
if (updatesRank === 1) {
updatesString = 'i';
}
else if (updatesRank === 2) {
updatesString = 'i, coords[1]';
}
const updatesSnippet = `getUpdates(${updatesString})`;
let defaultValuesString = '';
if (defaultIsTensor) {
defaultValuesString = 'coords[0], coords[1]';
}
const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`;
const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
const strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides';
this.userCode = `
${stridesType} strides = ${stridesType}(${strides});
void main() {
${dtype} coords = getOutputCoords();
vec4 sum = vec4(0.);
vec4 found = vec4(0.);
for (int i = 0; i < ${updateSize}; i+=2) {
ivec2 flattenedIndex = ivec2(0);
for (int j = 0; j < ${sliceDim}; j+=2) {
ivec4 index = round(${indicesSnippet});
flattenedIndex += index.xz * ${strideString};
if (j + 1 < ${sliceDim}) {
flattenedIndex += index.yw * ${strideString2};
}
}
if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] ||
flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) {
vec4 updVals = ${updatesSnippet};
if (flattenedIndex[0] == coords[0]) {
sum.xy += updVals.xy;
found.xy = vec2(1.);
} else if (flattenedIndex[0] == coords[0] + 1) {
sum.zw += updVals.xy;
found.zw = vec2(1.);
}
if (flattenedIndex[1] == coords[0]) {
sum.xy += updVals.zw;
found.xy = vec2(1.);
} else if (flattenedIndex[1] == coords[0] + 1) {
sum.zw += updVals.zw;
found.zw = vec2(1.);
}
}
}
setOutput(mix(${defaultValueSnippet}, sum, found));
}
`;
}
}
function scatterNd$1(args) {
const { inputs, backend, attrs } = args;
const { indices, updates } = inputs;
const { shape } = attrs;
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
const flattenShape = [outputSize / sliceSize, sliceSize];
if (outputSize === 0) {
return backend.makeTensorInfo(shape, indices.dtype);
}
const flattenIndices = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
const flattenX = reshape$1({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
const defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0]));
let program;
if (env().getBool('WEBGL_PACK')) {
program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
}
else {
program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
}
const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape } });
backend.disposeIntermediateTensorInfo(flattenIndices);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(res);
backend.disposeIntermediateTensorInfo(defaultValue);
return reshaped;
}
const scatterNdConfig$1 = {
kernelName: ScatterNd,
backendName: 'webgl',
kernelFunc: scatterNd$1
};
class SearchSortedProgram {
constructor(batchSize, numInputs, numValues, side) {
this.variableNames = ['sortedSequence', 'values'];
this.customUniforms = [{ name: 'numInputs', type: 'int' }];
this.outputShape = [batchSize, numValues];
const webGL2LoopHead = 'while (left < right) {';
const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`;
const loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead :
webGL1LoopHead;
const boundComparator = side === 'left' ? '<' : '<=';
this.userCode = `
int findBound(int batch, float value) {
int left = 0;
int right = numInputs;
int mid;
${loopHead}
mid = (left + right) / 2;
if (getSortedSequence(batch, mid) ${boundComparator} value) {
left = mid + 1;
} else {
right = mid;
}
}
return right;
}
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int valueIndex = coords[1];
float value = getValues(batch, valueIndex);
setOutput(float(findBound(batch, value)));
}
`;
}
}
function searchSorted$1(args) {
const { inputs, backend, attrs } = args;
const { sortedSequence, values } = inputs;
const { side } = attrs;
const program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
const customValues = [[sortedSequence.shape[1]]];
return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues);
}
const searchSortedConfig$1 = {
kernelName: SearchSorted,
backendName: 'webgl',
kernelFunc: searchSorted$1,
};
class SelectProgram {
constructor(cRank, shape, rank) {
this.variableNames = ['c', 'a', 'b'];
this.outputShape = shape;
let cCoords;
let abCoords;
if (rank > 4) {
throw Error(`Where for rank ${rank} is not yet supported`);
}
if (rank === 1) {
abCoords = `resRC`;
cCoords = `resRC`;
}
else {
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
const cCoordVars = [];
const abCoordVars = [];
for (let i = 0; i < shape.length; i++) {
abCoordVars.push(`${currentCoords[i]}`);
if (i < cRank) {
cCoordVars.push(`${currentCoords[i]}`);
}
}
cCoords = cCoordVars.join();
abCoords = abCoordVars.join();
}
const dtype = getCoordsDataType(rank);
this.userCode = `
void main() {
${dtype} resRC = getOutputCoords();
float cVal = getC(${cCoords});
if (cVal >= 1.0) {
setOutput(getA(${abCoords}));
} else {
setOutput(getB(${abCoords}));
}
}
`;
}
}
function select$1(args) {
const { inputs, backend } = args;
const { condition, t, e } = inputs;
const program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype));
}
const selectConfig$1 = {
kernelName: Select,
backendName: 'webgl',
kernelFunc: select$1
};
const SELU = `
float scaleAlpha = ${SELU_SCALEALPHA};
float scale = ${SELU_SCALE};
return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
`;
const selu$1 = unaryKernelFunc({ opSnippet: SELU });
const seluConfig$1 = {
kernelName: Selu$1,
backendName: 'webgl',
kernelFunc: selu$1,
};
const SIGMOID = CHECK_NAN_SNIPPET_UNARY + `
return 1.0 / (1.0 + exp(-1.0 * x));
`;
const SIGMOID_PACKED = `
vec4 result = 1.0 / (1.0 + exp(-1.0 * x));
bvec4 isNaN = isnan(x);
result.r = isNaN.r ? x.r : result.r;
result.g = isNaN.g ? x.g : result.g;
result.b = isNaN.b ? x.b : result.b;
result.a = isNaN.a ? x.a : result.a;
return result;
`;
const sigmoid = unaryKernelFunc({
opSnippet: SIGMOID,
packedOpSnippet: SIGMOID_PACKED,
cpuKernelImpl: sigmoidImplCPU
});
const sigmoidConfig = {
kernelName: Sigmoid$1,
backendName: 'webgl',
kernelFunc: sigmoid,
};
const SIGN = `
if (isnan(x)) { return 0.0; }
return sign(x);
`;
const sign$1 = unaryKernelFunc({ opSnippet: SIGN });
const signConfig$1 = {
kernelName: Sign,
backendName: 'webgl',
kernelFunc: sign$1,
};
const SIN = CHECK_NAN_SNIPPET_UNARY + `
return sin(x);
`;
const SIN_PACKED = `
vec4 result = sin(x);
bvec4 isNaN = isnan(x);
${CHECK_NAN_SNIPPET_PACKED}
return result;
`;
const sin$1 = unaryKernelFunc({ opSnippet: SIN, packedOpSnippet: SIN_PACKED });
const sinConfig$1 = {
kernelName: Sin,
backendName: 'webgl',
kernelFunc: sin$1,
};
const SINH = `
float e2x = exp(x);
return (e2x - 1.0 / e2x) / 2.0;
`;
const sinh$1 = unaryKernelFunc({ opSnippet: SINH });
const sinhConfig$1 = {
kernelName: Sinh,
backendName: 'webgl',
kernelFunc: sinh$1,
};
const SOFTPLUS = `
float epsilon = 1.1920928955078125e-7;
float threshold = log(epsilon) + 2.0;
bool too_large = x > -threshold;
bool too_small = x < threshold;
float result;
float exp_x = exp(x);
if (too_large){
result = x;
}
else if (too_small){
result = exp_x;
}
else{
result = log(exp_x + 1.0);
}
return result;
`;
const softplus$1 = unaryKernelFunc({ opSnippet: SOFTPLUS });
const softplusConfig$1 = {
kernelName: Softplus$1,
backendName: 'webgl',
kernelFunc: softplus$1,
};
const spaceToBatchND$1 = (args) => {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { blockShape, paddings } = attrs;
assert$1(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
'implemented yet');
const prod = blockShape.reduce((a, b) => a * b);
const completePaddings = [[0, 0]];
completePaddings.push(...paddings);
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
const toDispose = [];
const paddedX = padV2$1({
inputs: { x },
backend,
attrs: { paddings: completePaddings, constantValue: 0 }
});
const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
const reshapedPaddedX = reshape$1({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } });
const paddedXT = transpose({
inputs: { x: reshapedPaddedX },
backend,
attrs: { perm: permutedReshapedPaddedPermutation }
});
const result = reshape$1({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } });
toDispose.push(paddedX);
toDispose.push(reshapedPaddedX);
toDispose.push(paddedXT);
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
};
const spaceToBatchNDConfig$1 = {
kernelName: SpaceToBatchND,
backendName: 'webgl',
kernelFunc: spaceToBatchND$1
};
function sparseFillEmptyRows$1(args) {
const { inputs, backend } = args;
const { indices, values, denseShape, defaultValue } = inputs;
if (denseShape.shape.length !== 1) {
throw new Error(`Dense shape must be a vector, saw:
${denseShape.shape}`);
}
if (indices.shape.length !== 2) {
throw new Error(`Indices must be a matrix, saw:
${indices.shape}`);
}
if (values.shape.length !== 1) {
throw new Error(`Values must be a vector, saw:
${values.shape}`);
}
if (defaultValue.shape.length !== 0) {
throw new Error(`Default value must be a scalar, saw:
${defaultValue.shape}`);
}
const $indices = backend.readSync(indices.dataId);
const $values = backend.readSync(values.dataId);
const $denseShape = backend.readSync(denseShape.dataId);
const $defaultValue = backend.readSync(defaultValue.dataId)[0];
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
return [
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
];
}
const sparseFillEmptyRowsConfig$1 = {
kernelName: SparseFillEmptyRows,
backendName: 'webgl',
kernelFunc: sparseFillEmptyRows$1,
};
function sparseReshape$1(args) {
const { inputs, backend } = args;
const { inputIndices, inputShape, newShape } = inputs;
if (inputIndices.shape.length !== 2) {
throw new Error(`Input indices should be a matrix but received shape ${inputIndices.shape}`);
}
if (inputShape.shape.length !== 1) {
throw new Error(`Input shape should be a vector but received shape ${inputShape.shape}`);
}
if (newShape.shape.length !== 1) {
throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
}
const $inputShape = Array.from(backend.readSync(inputShape.dataId));
const $inputIndices = backend.readSync(inputIndices.dataId);
const targetShape = Array.from(backend.readSync(newShape.dataId));
const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
return [
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
];
}
const sparseReshapeConfig$1 = {
kernelName: SparseReshape,
backendName: 'webgl',
kernelFunc: sparseReshape$1,
};
function sparseSegmentMean$1(args) {
const { inputs, backend } = args;
const { data, indices, segmentIds } = inputs;
if (data.shape.length < 1) {
throw new Error(`Data should be at least 1 dimensional but received scalar`);
}
if (indices.shape.length !== 1) {
throw new Error(`Indices should be a vector but received shape
${indices.shape}`);
}
if (segmentIds.shape.length !== 1) {
throw new Error(`Segment ids should be a vector but received shape
${segmentIds.shape}`);
}
const $data = backend.readSync(data.dataId);
const $indices = backend.readSync(indices.dataId);
const $segmentIds = backend.readSync(segmentIds.dataId);
const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true);
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
const sparseSegmentMeanConfig$1 = {
kernelName: SparseSegmentMean,
backendName: 'webgl',
kernelFunc: sparseSegmentMean$1,
};
function sparseSegmentSum$1(args) {
const { inputs, backend } = args;
const { data, indices, segmentIds } = inputs;
if (data.shape.length < 1) {
throw new Error(`Data should be at least 1 dimensional but received scalar`);
}
if (indices.shape.length !== 1) {
throw new Error(`Indices should be a vector but received shape
${indices.shape}`);
}
if (segmentIds.shape.length !== 1) {
throw new Error(`Segment ids should be a vector but received shape
${segmentIds.shape}`);
}
const $data = backend.readSync(data.dataId);
const $indices = backend.readSync(indices.dataId);
const $segmentIds = backend.readSync(segmentIds.dataId);
const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds);
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
const sparseSegmentSumConfig$1 = {
kernelName: SparseSegmentSum,
backendName: 'webgl',
kernelFunc: sparseSegmentSum$1,
};
function sparseToDense$1(args) {
const { inputs, backend, attrs } = args;
const { sparseIndices, sparseValues, defaultValue } = inputs;
const { outputShape } = attrs;
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
const sumDupeIndices = false;
if (sparseValues.dtype === 'string') {
const indicesBuf = backend.bufferSync(sparseIndices);
const updatesBuf = backend.bufferSync(sparseValues);
const $defaultValue = decodeString(backend.readSync(defaultValue.dataId)[0]);
const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
}
const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: outputShape } });
backend.disposeIntermediateTensorInfo(res);
return reshaped;
}
const sparseToDenseConfig$1 = {
kernelName: SparseToDense,
backendName: 'webgl',
kernelFunc: sparseToDense$1
};
function splitV$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { numOrSizeSplits, axis } = attrs;
const $axis = parseAxisParam(axis, x.shape)[0];
const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
const xRank = x.shape.length;
const begin = new Array(xRank).fill(0);
const size = x.shape.slice();
return splitSizes.map(s => {
const sliceSize = [...size];
sliceSize[$axis] = s;
const sliceT = slice({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
begin[$axis] += s;
return sliceT;
});
}
const splitVConfig$1 = {
kernelName: SplitV,
backendName: 'webgl',
kernelFunc: splitV$1
};
const SQRT = `return sqrt(x);`;
const sqrt = unaryKernelFunc({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU });
const sqrtConfig = {
kernelName: Sqrt,
backendName: 'webgl',
kernelFunc: sqrt
};
const SQUARE = `return x * x;`;
const square$1 = unaryKernelFunc({ opSnippet: SQUARE });
const squareConfig$1 = {
kernelName: Square,
backendName: 'webgl',
kernelFunc: square$1,
};
const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
const squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE });
const squaredDifferenceConfig = {
kernelName: SquaredDifference,
backendName: 'webgl',
kernelFunc: squaredDifference,
};
function staticRegexReplace(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
if (x.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
const $x = backend.readSync(x.dataId);
const stringInput = fromUint8ToStringArray($x);
const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs);
return backend.makeTensorInfo(x.shape, 'string', output);
}
const staticRegexReplaceConfig = {
kernelName: StaticRegexReplace,
backendName: 'webgl',
kernelFunc: staticRegexReplace,
};
function step$1({ inputs, attrs, backend }) {
const { x } = inputs;
const opSnippet = CHECK_NAN_SNIPPET$1 + `
return x > 0.0 ? 1.0 : float(${attrs.alpha});
`;
const program = new UnaryOpProgram(x.shape, opSnippet);
return backend.runWebGLProgram(program, [x], x.dtype);
}
const stepConfig$1 = {
kernelName: Step,
backendName: 'webgl',
kernelFunc: step$1,
};
class StridedSliceProgram {
constructor(begin, strides, size) {
this.variableNames = ['x'];
this.outputShape = size;
const rank = size.length;
const inputDtype = getCoordsDataType(size.length);
const dtype = getCoordsDataType(size.length);
let newCoords = '';
if (rank === 1) {
newCoords = 'coords * strides + begin';
}
else {
let outputAxis = 0;
newCoords =
size.map((_, i) => {
outputAxis++;
return size.length === 1 ?
`coords * strides[${i}] + begin[${i}]` :
`coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
})
.join(',');
}
this.userCode = `
${inputDtype} begin = ${inputDtype}(${begin});
${inputDtype} strides = ${inputDtype}(${strides});
void main() {
${dtype} coords = getOutputCoords();
setOutput(getX(${newCoords}));
}
`;
}
}
function stridedSlice$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
let result;
if (isIdentity) {
result = reshape$1({ inputs: { x }, backend, attrs: { shape: finalShape } });
}
else if (sliceDim0 || isSimpleSlice) {
assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
const size = computeOutShape$2($begin, $end, $strides);
const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
result =
reshape$1({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
backend.disposeIntermediateTensorInfo(sliced);
}
else {
const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
if (shouldExecuteOnCPU) {
const values = backend.readSync(x.dataId);
const xBuf = buffer(x.shape, x.dtype, values);
const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
}
else {
const program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
result = backend.runWebGLProgram(program, [x], x.dtype);
}
}
const resultReshaped = reshape$1({ inputs: { x: result }, backend, attrs: { shape: finalShape } });
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
const stridedSliceConfig$1 = {
kernelName: StridedSlice,
backendName: 'webgl',
kernelFunc: stridedSlice$1
};
function stringNGrams$1(args) {
const { inputs, backend, attrs } = args;
const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
const { data, dataSplits } = inputs;
const $data = backend.readSync(data.dataId);
const $dataSplits = backend.readSync(dataSplits.dataId);
const [nGrams, nGramsSplits] = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
return [
backend.makeTensorInfo([nGrams.length], 'string', nGrams),
backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
];
}
const stringNGramsConfig$1 = {
kernelName: StringNGrams,
backendName: 'webgl',
kernelFunc: stringNGrams$1,
};
function stringSplit$1(args) {
const { inputs, backend, attrs } = args;
const { skipEmpty } = attrs;
const { input, delimiter } = inputs;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (input.shape.length !== 1) {
throw new Error(`Input must be a vector, got shape: ${input.shape}`);
}
if (delimiter.shape.length !== 0) {
throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
}
const $input = backend.readSync(input.dataId);
const $delimiter = backend.readSync(delimiter.dataId)[0];
const [indices, values, shape] = stringSplitImplCPU($input, $delimiter, skipEmpty);
const outputSize = values.length;
return [
backend.makeTensorInfo([outputSize, 2], 'int32', indices),
backend.makeTensorInfo([outputSize], 'string', values),
backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
];
}
const stringSplitConfig$1 = {
kernelName: StringSplit,
backendName: 'webgl',
kernelFunc: stringSplit$1,
};
function stringToHashBucketFast$1(args) {
const { inputs, backend, attrs } = args;
const { numBuckets } = attrs;
const { input } = inputs;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (numBuckets <= 0) {
throw new Error(`Number of buckets must be at least 1`);
}
const $input = backend.readSync(input.dataId);
const output = stringToHashBucketFastImplCPU($input, numBuckets);
return backend.makeTensorInfo(input.shape, 'int32', output);
}
const stringToHashBucketFastConfig$1 = {
kernelName: StringToHashBucketFast,
backendName: 'webgl',
kernelFunc: stringToHashBucketFast$1,
};
const TAN = `return tan(x);`;
const tan$1 = unaryKernelFunc({ opSnippet: TAN });
const tanConfig$1 = {
kernelName: Tan,
backendName: 'webgl',
kernelFunc: tan$1,
};
const TANH = `
float e2x = exp(-2.0 * abs(x));
return sign(x) * (1.0 - e2x) / (1.0 + e2x);
`;
const tanh$1 = unaryKernelFunc({ opSnippet: TANH });
const tanhConfig$1 = {
kernelName: Tanh$1,
backendName: 'webgl',
kernelFunc: tanh$1,
};
function tensorScatterUpdate$1(args) {
const { inputs, backend} = args;
const { tensor, indices, updates } = inputs;
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape);
const flattenShape = [outputSize / sliceSize, sliceSize];
if (outputSize === 0) {
return backend.makeTensorInfo(tensor.shape, indices.dtype);
}
const flattenIndices = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
const flattenX = reshape$1({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
const flattenTensor = reshape$1({ inputs: { x: tensor }, backend, attrs: { shape: flattenShape } });
const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true);
const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype);
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: tensor.shape } });
backend.disposeIntermediateTensorInfo(flattenIndices);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(flattenTensor);
backend.disposeIntermediateTensorInfo(res);
return reshaped;
}
const tensorScatterUpdateConfig$1 = {
kernelName: TensorScatterUpdate,
backendName: 'webgl',
kernelFunc: tensorScatterUpdate$1
};
class TileProgram {
constructor(aShape, reps) {
this.variableNames = ['A'];
const outputShape = new Array(aShape.length);
for (let i = 0; i < outputShape.length; i++) {
outputShape[i] = aShape[i] * reps[i];
}
this.outputShape = outputShape;
this.rank = outputShape.length;
const dtype = getCoordsDataType(this.rank);
const sourceCoords = getSourceCoords(aShape);
this.userCode = `
void main() {
${dtype} resRC = getOutputCoords();
setOutput(getA(${sourceCoords}));
}
`;
}
}
function getSourceCoords(aShape) {
const rank = aShape.length;
if (rank > 5) {
throw Error(`Tile for rank ${rank} is not yet supported`);
}
if (rank === 1) {
return `imod(resRC, ${aShape[0]})`;
}
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
const sourceCoords = [];
for (let i = 0; i < aShape.length; i++) {
sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);
}
return sourceCoords.join();
}
function tile$2(params) {
const { inputs, backend, attrs } = params;
const { x } = inputs;
const { reps } = attrs;
if (x.dtype === 'string' || x.shape.length > 5) {
const data = backend.readSync(x.dataId);
const value = x.dtype === 'string' ?
data.map(d => decodeString(d)) :
data;
const buf = buffer(x.shape, x.dtype, value);
const outBuf = tileImplCPU(buf, reps);
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
const program = new TileProgram(x.shape, reps);
const output = backend.runWebGLProgram(program, [x], x.dtype);
return output;
}
const tileConfig$1 = {
kernelName: Tile,
backendName: 'webgl',
kernelFunc: tile$2,
};
class SwapProgram {
constructor(shape) {
this.variableNames = ['x', 'indices'];
this.customUniforms = [
{ name: 'n', type: 'int' },
{ name: 'firstPass', type: 'int' },
{ name: 'negativeInf', type: 'float' },
{ name: 'dir', type: 'int' },
{ name: 'inc', type: 'int' }
];
this.outputShape = shape;
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int elemIdx = coords[1];
bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;
int i = isFirstInPair ? elemIdx : elemIdx - inc;
int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));
float x0 = i0 < n ? getX(batch, i0) : negativeInf;
float x1 = i1 < n ? getX(batch, i1) : negativeInf;
bool reverse = imod(elemIdx, 2 * dir) >= dir;
bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);
if (reverse == isGreater) {
int iTemp = i0;
i0 = i1;
i1 = iTemp;
}
if (isFirstInPair) {
setOutput(float(i0));
} else {
setOutput(float(i1));
}
}
`;
}
}
class MergeProgram {
constructor(shape) {
this.variableNames = ['x', 'indices'];
this.customUniforms = [
{ name: 'n', type: 'int' },
{ name: 'firstPass', type: 'int' },
{ name: 'k', type: 'int' }
];
this.outputShape = shape;
this.userCode = `
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int elemIdx = coords[1];
int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));
int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));
float x0 = getX(batch, i0);
float x1 = i1 < n ? getX(batch, i1) : x0;
setOutput(x0 >= x1 ? float(i0) : float(i1));
}
`;
}
}
function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
if (tensorInfo !== null) {
backend.disposeIntermediateTensorInfo(tensorInfo);
}
}
function roundUpToPow2(num) {
let pow2 = 1;
while (pow2 < num) {
pow2 *= 2;
}
return pow2;
}
function topK$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { k, sorted } = attrs;
const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
const xShape = x.shape;
const lastDim = xShape[xShape.length - 1];
if (backend.shouldExecuteOnCPU([x]) ||
lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||
k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
const xVals = backend.readSync(x.dataId);
const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted);
return [
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
];
}
if (k === 0) {
xShape[xShape.length - 1] = 0;
return [
backend.makeTensorInfo(xShape, x.dtype, []),
backend.makeTensorInfo(xShape, 'int32', [])
];
}
if (lastDim === 1 ) {
return [
x, fill$1({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend })
];
}
const xtexData = backend.texData.get(x.dataId);
const xIsPacked = xtexData !== null && xtexData.isPacked;
const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
const xSize = sizeFromShape(xShape);
const batch = xSize / lastDim;
const x2D = reshape$1({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend });
if (xIsPacked) {
disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
}
const kPow2 = roundUpToPow2(k);
const lastDimPow2 = roundUpToPow2(lastDim);
let indices = null;
const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];
const runSwap = (dir, inc, shape) => {
const inputs = getInputs();
const program = new SwapProgram(shape);
const fistPass = indices === null ? 1 : 0;
const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
const prevIndices = indices;
indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
};
for (let len = 1; len < kPow2; len *= 2) {
const dir = len * 2;
for (let inc = len; inc >= 1; inc /= 2) {
runSwap(dir, inc, [batch, lastDimPow2]);
}
}
for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
const inputs = getInputs();
const mergeProgram = new MergeProgram([batch, indicesSize / 2]);
const firstPass = indices === null ? 1 : 0;
const customValues = [[lastDim], [firstPass], [kPow2]];
const prevIndices = indices;
indices =
backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
const len = kPow2 / 2;
const dir = len * 2;
for (let inc = len; inc >= 1; inc /= 2) {
runSwap(dir, inc, indices.shape);
}
}
let prevIndices = indices;
indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } });
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
let values = gatherV2$1({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } });
disposeIntermediateTensorInfoOrNull(backend, x2D);
const newShape = xShape.slice(0, -1);
newShape.push(k);
prevIndices = indices;
indices = reshape$1({ inputs: { x: indices }, attrs: { shape: newShape }, backend });
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
const prevValues = values;
values = reshape$1({ inputs: { x: values }, attrs: { shape: newShape }, backend });
disposeIntermediateTensorInfoOrNull(backend, prevValues);
return [values, indices];
}
const topKConfig$1 = {
kernelName: TopK,
backendName: 'webgl',
kernelFunc: topK$1
};
class TransformProgram {
constructor(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
this.variableNames = ['Image', 'Transforms'];
this.outputShape = outShape;
const interpolationModeId = interpolation === 'nearest' ? 1 : 2;
let fillModeId;
switch (fillMode) {
case 'constant':
fillModeId = 1;
break;
case 'reflect':
fillModeId = 2;
break;
case 'wrap':
fillModeId = 3;
break;
case 'nearest':
fillModeId = 4;
break;
default:
fillModeId = 1;
break;
}
this.userCode = `
float mapCoord(float outCoord, float len) {
float inCoord = outCoord;
if(${fillModeId} == 2) {
if (inCoord < 0.0) {
if (len <= 1.0) {
inCoord = 0.0;
} else {
float sz2 = 2.0 * len;
if (inCoord < sz2) {
inCoord = sz2 * float(int(float(-inCoord / sz2))) +
inCoord;
}
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;
}
} else if (inCoord > len - 1.0) {
if (len <= 1.0) {
inCoord = 0.0;
} else {
float sz2 = 2.0 * len;
inCoord -= sz2 * float(int(float(inCoord / sz2)));
if (inCoord >= len) {
inCoord = sz2 - inCoord - 1.0;
}
}
}
return clamp(inCoord, 0.0, len - 1.0);
} else if (${fillModeId} == 3) {
if (inCoord < 0.0) {
if (len <= 1.0) {
inCoord = 0.0;
} else {
float sz = len - 1.0;
inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);
}
} else if (inCoord > len - 1.0) {
if (len <= 1.0) {
inCoord = 0.0;
} else {
float sz = len - 1.0;
inCoord -= len * float(int(float(inCoord / sz)));
}
}
return clamp(inCoord, 0.0, len - 1.0);
} else if (${fillModeId} == 4) {
return clamp(outCoord, 0.0, len - 1.0);
} else {
return outCoord;
}
}
float readWithFillValue(int batch, int coordY, int coordX,
int channel) {
float outputValue;
if (0 <= coordY && coordY < ${imageHeight} && 0 <= coordX && coordX < ${imageWidth}) {
outputValue = getImage(batch, coordY, coordX, channel);
} else {
outputValue = float(${fillValue});
}
return outputValue;
}
void main() {
ivec4 coords = getOutputCoords();
float outputValue;
int batch = coords[0];
int x = coords[2];
int y = coords[1];
int channel = coords[3];
float xf = float(x);
float yf = float(y);
float a1 = getTransforms(batch, 0);
float a2 = getTransforms(batch, 1);
float a3 = getTransforms(batch, 2);
float b1 = getTransforms(batch, 3);
float b2 = getTransforms(batch, 4);
float b3 = getTransforms(batch, 5);
float c1 = getTransforms(batch, 6);
float c2 = getTransforms(batch, 7);
float projection = c1 * xf + c2 * yf + 1.0;
if (projection == 0.0) {
outputValue = float(${fillValue});
} else {
float inX = (a1 * xf + a2 * yf + a3) / projection;
float inY = (b1 * xf + b2 * yf + b3) / projection;
float mapX = mapCoord(inX, float(${imageWidth}));
float mapY = mapCoord(inY, float(${imageHeight}));
if (${interpolationModeId} == 1) {
int coordY = int(round(mapY));
int coordX = int(round(mapX));
outputValue = readWithFillValue(batch, coordY, coordX,
channel);
} else {
float yFloor = floor(mapY);
float xFloor = floor(mapX);
float yCeil = yFloor + 1.0;
float xCeil = xFloor + 1.0;
float valueYFloor = (xCeil - mapX) *
readWithFillValue(batch, int(yFloor), int(xFloor), channel) +
(mapX - xFloor) *
readWithFillValue(batch, int(yFloor), int(xCeil), channel);
float valueYCeil = (xCeil - mapX) *
readWithFillValue(batch, int(yCeil), int(xFloor), channel) +
(mapX - xFloor) *
readWithFillValue(batch, int(yCeil), int(xCeil), channel);
outputValue = (yCeil - mapY) * valueYFloor +
(mapY - yFloor) * valueYCeil;
}
}
setOutput(outputValue);
}
`;
}
}
function transform$1(args) {
const { inputs, backend, attrs } = args;
const { image, transforms } = inputs;
const { interpolation, fillMode, fillValue, outputShape } = attrs;
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
const outShape = [batch, outHeight, outWidth,
numChannels];
const program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
return backend.runWebGLProgram(program, [image, transforms], 'float32');
}
const transformConfig$1 = {
kernelName: Transform,
backendName: 'webgl',
kernelFunc: transform$1
};
function unique$2(args) {
const { inputs, attrs, backend } = args;
const { axis } = attrs;
const { x } = inputs;
assertNotComplex$1(x, 'unique');
console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
const values = backend.readSync(x.dataId);
const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype);
return [
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
backend.makeTensorInfo([indices.length], 'int32', indices),
];
}
const uniqueConfig$1 = {
kernelName: Unique,
backendName: 'webgl',
kernelFunc: unique$2,
};
function unpack$1(args) {
const { inputs, backend, attrs } = args;
const { value } = inputs;
let { axis } = attrs;
if (axis < 0) {
axis += value.shape.length;
}
const x = value;
const xRank = x.shape.length;
const num = value.shape[axis];
const outShape = new Array(xRank - 1);
let outIndex = 0;
for (let i = 0; i < xRank; i++) {
if (i !== axis) {
outShape[outIndex++] = x.shape[i];
}
}
const toDispose = [];
const begin = new Array(xRank).fill(0);
const size = x.shape.slice();
size[axis] = 1;
const res = new Array(num);
for (let i = 0; i < res.length; i++) {
begin[axis] = i;
const sliced = slice({ inputs: { x }, backend, attrs: { begin, size } });
const reshaped = reshape$1({ inputs: { x: sliced }, backend, attrs: { shape: outShape } });
res[i] = reshaped;
toDispose.push(sliced);
}
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return res;
}
const unpackConfig$1 = {
kernelName: Unpack,
backendName: 'webgl',
kernelFunc: unpack$1
};
class SegmentOpProgram {
constructor(segOpInfo, segOpType) {
this.variableNames = ['x', 'segmentIds'];
const windowSize = segOpInfo.windowSize;
const batchSize = segOpInfo.batchSize;
const inSize = segOpInfo.inSize;
const numSegments = segOpInfo.numSegments;
const outSize = numSegments * Math.ceil(inSize / windowSize);
this.outputShape = [batchSize, outSize];
const initializationValue = '0.0';
const returnValue = `sumValue`;
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
const windowSizeVec4Remainder = windowSize % 4;
const updateSnippet = `
sumValue += dot(values, segFilter);
`;
let checkValueOutOfBounds = '';
if (inSize % windowSize > 0) {
checkValueOutOfBounds = `
if (inIdx < 0 || inIdx >= ${inSize}) {
return initializationValue;
}
`;
}
let checkSegmentIdOutOfBounds = '';
if (inSize % windowSize > 0) {
checkSegmentIdOutOfBounds = `
if (inIdx < 0 || inIdx >= ${inSize}) {
return -1.0;
}
`;
}
this.userCode = `
const float initializationValue = ${initializationValue};
float getValue(int batch, int inIdx) {
${checkValueOutOfBounds}
return getX(batch, inIdx);
}
float getSegmentIdAtIndex(int inIdx) {
${checkSegmentIdOutOfBounds}
return getSegmentIds(inIdx);
}
void main() {
ivec2 coords = getOutputCoords();
int batch = coords[0];
int outIdx = coords[1];
int inOffset = int(floor(float(outIdx) / float(
${numSegments})) * float(${windowSize}));
int currentSeg = int(mod(float(outIdx), float(${numSegments})));
float sumValue = 0.0;
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
int inIdx = inOffset + i;
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
getValue(batch, inIdx + 3)
);
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
);
${updateSnippet}
}
int inIdx = inOffset + ${windowSizeNearestVec4};
if (${windowSizeVec4Remainder === 1}) {
vec4 values = vec4(
getValue(batch, inIdx),
initializationValue,
initializationValue,
initializationValue
);
int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
0,
0,
0
);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 2}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
initializationValue,
initializationValue
);
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
0,
0
);
${updateSnippet}
} else if (${windowSizeVec4Remainder === 3}) {
vec4 values = vec4(
getValue(batch, inIdx),
getValue(batch, inIdx + 1),
getValue(batch, inIdx + 2),
initializationValue
);
vec4 segFilter = vec4(
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
0
);
${updateSnippet}
}
setOutput(${returnValue});
}
`;
}
}
function unsortedSegmentSum$1(args) {
const { inputs, backend, attrs } = args;
const { x, segmentIds } = inputs;
const { numSegments } = attrs;
const xRank = x.shape.length;
const toDispose = [];
let axis = 0;
const permutation = getAxesPermutation([axis], xRank);
let permutedX = x;
if (permutation != null) {
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
toDispose.push(permutedX);
axis = getInnerMostAxes(1, xRank)[0];
}
const outShape = computeOutShape(permutedX.shape, axis, numSegments);
const inSize = sizeFromShape([permutedX.shape[axis]]);
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
toDispose.push(a2D);
const outputDType = sumOutType(x.dtype);
const segOpCompute = (x, segOpType, segmentIds, dtype, numSegments) => {
const batchSize = x.shape[0];
const inSize = x.shape[1];
const windowSize = segOpComputeOptimalWindowSize(inSize, numSegments);
const segOpInfo = { windowSize, inSize, batchSize, numSegments };
const program = new SegmentOpProgram(segOpInfo, segOpType);
const output = backend.compileAndRun(program, [x, segmentIds], dtype);
toDispose.push(output);
if (output.shape[1] === numSegments) {
return output;
}
const rangeInfo = range$2({
backend,
attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' }
});
const tileInfo = tile$2({
inputs: { x: rangeInfo },
backend,
attrs: { reps: [inSize / windowSize] }
});
toDispose.push(rangeInfo);
toDispose.push(tileInfo);
const result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
return result;
};
const segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
const reshaped = reshape$1({ inputs: { x: segOpResult }, backend, attrs: { shape: outShape } });
let result = reshaped;
if (permutation != null) {
toDispose.push(reshaped);
const perm = getUndoAxesPermutation(permutation);
result = transpose({ inputs: { x: result }, backend, attrs: { perm } });
}
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
}
const unsortedSegmentSumConfig$1 = {
kernelName: UnsortedSegmentSum,
backendName: 'webgl',
kernelFunc: unsortedSegmentSum$1
};
const kernelConfigs$1 = [
_fusedMatMulConfig$1,
absConfig,
acosConfig$1,
acoshConfig$1,
addConfig,
addNConfig$1,
allConfig$1,
anyConfig$1,
argMaxConfig$1,
argMinConfig$1,
asinConfig$1,
asinhConfig$1,
atanConfig$1,
atan2Config$1,
atanhConfig$1,
avgPoolConfig$1,
avgPool3DConfig$1,
avgPool3DGradConfig$2,
avgPoolGradConfig$2,
batchMatMulConfig$1,
batchNormConfig$1,
batchToSpaceNDConfig$1,
bincountConfig$1,
bitwiseAndConfig,
broadcastArgsConfig$1,
castConfig,
ceilConfig,
clipByValueConfig$1,
complexConfig,
complexAbsConfig$1,
concatConfig$1,
conv2DConfig$1,
conv2DBackpropFilterConfig$1,
conv2DBackpropInputConfig$1,
conv3DConfig$1,
conv3DBackpropFilterV2Config$1,
conv3DBackpropInputConfig,
cosConfig$1,
coshConfig$1,
cropAndResizeConfig$1,
cumprodConfig$1,
cumsumConfig$1,
denseBincountConfig$1,
depthToSpaceConfig$1,
depthwiseConv2dNativeConfig$1,
depthwiseConv2dNativeBackpropFilterConfig$1,
depthwiseConv2dNativeBackpropInputConfig$1,
diagConfig$1,
dilation2DConfig$1,
einsumConfig$1,
eluConfig$1,
eluGradConfig$2,
equalConfig,
erfConfig$1,
expConfig,
expandDimsConfig$1,
expm1Config,
fftConfig$1,
fillConfig$1,
flipLeftRightConfig$1,
floorConfig,
floorDivConfig,
fromPixelsConfig,
fusedConv2DConfig$1,
fusedDepthwiseConv2DConfig$1,
gatherNdConfig$1,
gatherV2Config$1,
greaterConfig,
greaterEqualConfig,
identityConfig,
ifftConfig$1,
imagConfig$1,
isFiniteConfig$1,
isInfConfig$1,
isNaNConfig$1,
leakyReluConfig$1,
lessConfig,
lessEqualConfig,
linSpaceConfig$1,
logConfig,
log1pConfig$1,
logicalAndConfig$1,
logicalNotConfig$1,
logicalOrConfig$1,
LRNConfig$1,
LRNGradConfig$1,
maxConfig$1,
maximumConfig,
maxPoolConfig$1,
maxPool3DConfig$1,
maxPool3DGradConfig$2,
maxPoolGradConfig$2,
maxPoolWithArgmaxConfig$1,
meanConfig$1,
minConfig$1,
minimumConfig,
mirrorPadConfig$1,
modConfig$1,
multinomialConfig$1,
multiplyConfig,
negConfig,
nonMaxSuppressionV3Config$1,
nonMaxSuppressionV4Config$1,
nonMaxSuppressionV5Config$1,
notEqualConfig,
oneHotConfig$1,
onesLikeConfig$1,
packConfig$1,
padV2Config$1,
powConfig$1,
preluConfig$1,
prodConfig,
raggedGatherConfig$1,
raggedRangeConfig$1,
raggedTensorToTensorConfig$1,
rangeConfig$1,
realConfig,
realDivConfig$1,
reciprocalConfig$1,
reluConfig$1,
relu6Config$1,
reshapeConfig$1,
resizeBilinearConfig$1,
resizeBilinearGradConfig$2,
resizeNearestNeighborConfig$1,
resizeNearestNeighborGradConfig$2,
reverseConfig$1,
rotateWithOffsetConfig$1,
roundConfig$1,
rsqrtConfig,
scatterNdConfig$1,
searchSortedConfig$1,
selectConfig$1,
seluConfig$1,
sigmoidConfig,
signConfig$1,
sinConfig$1,
sinhConfig$1,
sliceConfig,
softmaxConfig$1,
softplusConfig$1,
spaceToBatchNDConfig$1,
sparseFillEmptyRowsConfig$1,
sparseReshapeConfig$1,
sparseSegmentMeanConfig$1,
sparseSegmentSumConfig$1,
sparseToDenseConfig$1,
splitVConfig$1,
sqrtConfig,
squareConfig$1,
squaredDifferenceConfig,
staticRegexReplaceConfig,
stepConfig$1,
stridedSliceConfig$1,
stringNGramsConfig$1,
stringSplitConfig$1,
stringToHashBucketFastConfig$1,
subConfig,
sumConfig$1,
tanConfig$1,
tanhConfig$1,
tensorScatterUpdateConfig$1,
tileConfig$1,
topKConfig$1,
transformConfig$1,
transposeConfig,
uniqueConfig$1,
unpackConfig$1,
unsortedSegmentSumConfig$1,
zerosLikeConfig$1
];
for (const kernelConfig of kernelConfigs$1) {
registerKernel(kernelConfig);
}
const whereImpl = whereImpl$2;
class MathBackendCPU extends KernelBackend {
nextDataId() {
return MathBackendCPU.nextDataId++;
}
constructor() {
super();
this.blockSize = 48;
this.firstUse = true;
this.data = new DataStorage(this, engine());
}
write(values, shape, dtype) {
if (this.firstUse) {
this.firstUse = false;
if (env().get('IS_NODE')) {
warn('\n============================\n' +
'Hi, looks like you are running TensorFlow.js in ' +
'Node.js. To speed things up dramatically, install our node ' +
'backend, visit https://github.com/tensorflow/tfjs-node for more details. ' +
'\n============================');
}
}
const dataId = { id: this.nextDataId() };
this.data.set(dataId, { values, dtype, refCount: 1 });
return dataId;
}
makeTensorInfo(shape, dtype, values) {
let outId;
if (dtype === 'string' && values != null && values.length > 0 &&
isString(values[0])) {
const encodedValues = values.map(d => encodeString(d));
outId = this.write(encodedValues, shape, dtype);
}
else {
outId = this.write(values, shape, dtype);
}
return { dataId: outId, shape, dtype };
}
refCount(dataId) {
if (this.data.has(dataId)) {
const tensorData = this.data.get(dataId);
return tensorData.refCount;
}
return 0;
}
incRef(dataId) {
const tensorData = this.data.get(dataId);
tensorData.refCount++;
}
decRef(dataId) {
if (this.data.has(dataId)) {
const tensorData = this.data.get(dataId);
tensorData.refCount--;
}
}
move(dataId, values, shape, dtype, refCount) {
this.data.set(dataId, { values, dtype, refCount });
}
numDataIds() {
return this.data.numDataIds();
}
async read(dataId) {
return this.readSync(dataId);
}
readSync(dataId) {
const { dtype, complexTensorInfos } = this.data.get(dataId);
if (dtype === 'complex64') {
const realValues = this.readSync(complexTensorInfos.real.dataId);
const imagValues = this.readSync(complexTensorInfos.imag.dataId);
return mergeRealAndImagArrays(realValues, imagValues);
}
return convertBackendValuesAndArrayBuffer(this.data.get(dataId).values, dtype);
}
bufferSync(t) {
const data = this.readSync(t.dataId);
if (t.dtype === 'string') {
try {
const strings = data.map(d => decodeString(d));
return buffer(t.shape, t.dtype, strings);
}
catch (_a) {
throw new Error('Failed to decode encoded string bytes into utf-8');
}
}
return buffer(t.shape, t.dtype, data);
}
makeOutput(values, shape, dtype) {
return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
}
disposeData(dataId, force = false) {
if (this.data.has(dataId)) {
this.data.get(dataId).refCount--;
if (!force && this.data.get(dataId).refCount > 0) {
return false;
}
const { complexTensorInfos } = this.data.get(dataId);
if (complexTensorInfos != null) {
this.disposeData(complexTensorInfos.real.dataId, true);
this.disposeData(complexTensorInfos.imag.dataId, true);
}
this.data.delete(dataId);
}
return true;
}
disposeIntermediateTensorInfo(tensorInfo) {
this.disposeData(tensorInfo.dataId);
}
async time(f) {
const start = now();
f();
const kernelMs = now() - start;
return { kernelMs };
}
memory() {
return {
unreliable: true,
reasons: ['The reported memory is an upper bound. Due to automatic garbage ' +
'collection, the true allocated memory may be less.']
};
}
where(condition) {
assertNotComplex([condition], 'where');
const condVals = this.readSync(condition.dataId);
return whereImpl(condition.shape, condVals);
}
dispose() { }
floatPrecision() {
return 32;
}
epsilon() {
return super.epsilon();
}
}
MathBackendCPU.nextDataId = 0;
registerBackend('cpu', () => new MathBackendCPU(), 1 );
const elu$1 = unaryKernelFunc$1(Elu$1, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1));
const eluConfig = {
kernelName: Elu$1,
backendName: 'cpu',
kernelFunc: elu$1,
};
function leakyRelu(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { alpha } = attrs;
assertNotComplex([x], 'leakyRelu');
const xSize = sizeFromShape(x.shape);
const xVals = backend.data.get(x.dataId).values;
const outVals = getTypedArrayFromDType('float32', xSize);
for (let i = 0; i < xVals.length; i++) {
outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
}
return backend.makeTensorInfo(x.shape, 'float32', outVals);
}
const leakyReluConfig = {
kernelName: LeakyRelu,
backendName: 'cpu',
kernelFunc: leakyRelu
};
const preluImpl = createSimpleBinaryKernelImpl((xValue, aValue) => xValue < 0 ? aValue * xValue : xValue);
function prelu(args) {
const { inputs, backend } = args;
const { x, alpha } = inputs;
assertNotComplex([x, alpha], 'prelu');
const aVals = backend.data.get(x.dataId).values;
const bVals = backend.data.get(alpha.dataId).values;
const [resultData, resultShape] = preluImpl(x.shape, alpha.shape, aVals, bVals, 'float32');
return backend.makeTensorInfo(resultShape, 'float32', resultData);
}
const preluConfig = {
kernelName: Prelu,
backendName: 'cpu',
kernelFunc: prelu,
};
const relu = unaryKernelFunc$1(Relu$1, (xi) => Math.max(0, xi));
const reluConfig = {
kernelName: Relu$1,
backendName: 'cpu',
kernelFunc: relu,
};
const relu6 = unaryKernelFunc$1(Relu6$1, (xi) => Math.min(Math.max(0, xi), 6));
const relu6Config = {
kernelName: Relu6$1,
backendName: 'cpu',
kernelFunc: relu6,
};
function applyActivation(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
if (activation === 'linear') {
return identity$1({ inputs: { x }, backend });
}
else if (activation === 'relu') {
return relu({ inputs: { x }, backend });
}
else if (activation === 'elu') {
return elu$1({ inputs: { x }, backend });
}
else if (activation === 'relu6') {
return relu6({ inputs: { x }, backend });
}
else if (activation === 'prelu') {
return prelu({ inputs: { x, alpha: preluActivationWeights }, backend });
}
else if (activation === 'leakyrelu') {
return leakyRelu({ inputs: { x }, backend, attrs: { alpha: leakyreluAlpha } });
}
else if (activation === 'sigmoid') {
return sigmoid$1({ inputs: { x }, backend });
}
throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
}
function reshape(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { shape } = attrs;
const xSize = sizeFromShape(x.shape);
const $shape = inferFromImplicitShape(shape, xSize);
const $xSize = sizeFromShape($shape);
assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
`shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
`shape must have the same number of elements.`);
backend.incRef(x.dataId);
const xData = backend.data.get(x.dataId);
if (xData.complexTensorInfos != null) {
const real = xData.complexTensorInfos.real;
const imag = xData.complexTensorInfos.imag;
real.shape = $shape;
imag.shape = $shape;
}
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
}
const reshapeConfig = {
kernelName: Reshape$1,
backendName: 'cpu',
kernelFunc: reshape
};
function batchMatMul(args) {
const { inputs, backend, attrs } = args;
const { a, b } = inputs;
const { transposeA, transposeB } = attrs;
assertNotComplex([a, b], 'matMul');
const aRank = a.shape.length;
const bRank = b.shape.length;
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
const outerDimsA = a.shape.slice(0, -2);
const outerDimsB = b.shape.slice(0, -2);
const batchDimA = sizeFromShape(outerDimsA);
const batchDimB = sizeFromShape(outerDimsB);
const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
`${b.shape} and transposeA=${transposeA}` +
` and transposeB=${transposeB} must match.`);
const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
[batchDimA, outerShapeA, innerShapeA];
const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
[batchDimB, innerShapeB, outerShapeB];
const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
const batchDim = Math.max(batchDimA, batchDimB);
const a3dValues = backend.data.get(a3d.dataId).values;
const b3dValues = backend.data.get(b3d.dataId).values;
const a3dStrides = computeStrides(a3d.shape);
const b3dStrides = computeStrides(b3d.shape);
const [aBatch, aOuterStep, aInnerStep] = transposeA ?
[a3dStrides[0], 1, a3dStrides[1]] :
[a3dStrides[0], a3dStrides[1], 1];
const [bInnerStep, bOuterStep, bBatch] = transposeB ?
[1, b3dStrides[1], b3dStrides[0]] :
[b3dStrides[1], 1, b3dStrides[0]];
const size = leftDim * rightDim;
const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
const resVals = result.values;
const blockSize = backend.blockSize;
for (let bi = 0; bi < batchDim; bi++) {
const batchIndexA = bi % batchDimA;
const batchIndexB = bi % batchDimB;
for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
const iBlock = Math.min(i0 + blockSize, leftDim);
for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
const jBlock = Math.min(j0 + blockSize, rightDim);
for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
const kBlock = Math.min(k0 + blockSize, sharedDim);
for (let i = i0; i < iBlock; i++) {
for (let j = j0; j < jBlock; j++) {
let sum = 0.0;
for (let k = k0; k < kBlock; k++) {
const aVal =
a3dValues[batchIndexA * aBatch + i * aOuterStep + k * aInnerStep];
const bVal =
b3dValues[k * bInnerStep + j * bOuterStep + batchIndexB * bBatch];
sum += aVal * bVal;
}
resVals[bi * size + (i * rightDim + j)] += sum;
}
}
}
}
}
}
backend.disposeIntermediateTensorInfo(a3d);
backend.disposeIntermediateTensorInfo(b3d);
return backend.makeTensorInfo(outShape, result.dtype, result.values);
}
const batchMatMulConfig = {
kernelName: BatchMatMul,
backendName: 'cpu',
kernelFunc: batchMatMul,
};
function _fusedMatMul(args) {
const { inputs, backend, attrs } = args;
const { a, b, bias, preluActivationWeights } = inputs;
const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
let current;
let addRes;
let activationRes;
const intermediates = [];
const matMulRes = batchMatMul({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
current = matMulRes;
if (bias) {
addRes = add({ inputs: { a: current, b: bias }, backend });
intermediates.push(current);
current = addRes;
}
if (activation) {
activationRes = applyActivation(backend, current, activation, preluActivationWeights, leakyreluAlpha);
intermediates.push(current);
current = activationRes;
}
for (const i of intermediates) {
backend.disposeIntermediateTensorInfo(i);
}
return current;
}
const _fusedMatMulConfig = {
kernelName: _FusedMatMul,
backendName: 'cpu',
kernelFunc: _fusedMatMul,
};
const acos = unaryKernelFunc$1(Acos, (xi) => Math.acos(xi));
const acosConfig = {
kernelName: Acos,
backendName: 'cpu',
kernelFunc: acos,
};
const acosh = unaryKernelFunc$1(Acosh, (xi) => Math.acosh(xi));
const acoshConfig = {
kernelName: Acosh,
backendName: 'cpu',
kernelFunc: acosh,
};
function addN(args) {
const { inputs, backend } = args;
const tensors = inputs;
assertNotComplex(inputs, 'addN');
const vals = tensors.map(t => backend.data.get(t.dataId).values);
const outBuf = buffer(tensors[0].shape, tensors[0].dtype);
const outVals = outBuf.values;
for (let i = 0; i < tensors.length; i++) {
const currVals = vals[i];
for (let j = 0; j < outVals.length; j++) {
outVals[j] += currVals[j];
}
}
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
const addNConfig = {
kernelName: AddN,
backendName: 'cpu',
kernelFunc: addN
};
function all(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
assertNotComplex(x, 'all');
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
if (permutedAxes != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('all', axes, $x.shape.length);
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
const reduceSize = sizeFromShape(reduceShape);
const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
const aVals = backend.data.get($x.dataId).values;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let all = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
all = all && value;
}
vals[i] = all;
}
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo($x);
}
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
if (keepDims) {
const expandedShape = expandShapeToKeepDim(outShape, origAxes);
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
return result;
}
const allConfig = {
kernelName: All,
backendName: 'cpu',
kernelFunc: all
};
function any(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
assertNotComplex(x, 'any');
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
if (permutedAxes != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('any', axes, $x.shape.length);
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
const reduceSize = sizeFromShape(reduceShape);
const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
const aVals = backend.data.get($x.dataId).values;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let anyVal = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
anyVal = anyVal || value;
}
vals[i] = anyVal;
}
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo($x);
}
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
if (keepDims) {
const expandedShape = expandShapeToKeepDim(outShape, origAxes);
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
return result;
}
const anyConfig = {
kernelName: Any,
backendName: 'cpu',
kernelFunc: any
};
function argMax(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis } = attrs;
assertNotComplex(x, 'argMax');
let axes = parseAxisParam(axis, x.shape);
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
const intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
axes = [axes[0]];
assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
const outSize = sizeFromShape(outShape);
const vals = makeZerosTypedArray(outSize, 'int32');
const reduceSize = sizeFromShape(reduceShape);
const aVals = backend.data.get($x.dataId).values;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let max = aVals[offset];
let maxIndex = 0;
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value > max) {
max = value;
maxIndex = j;
}
}
vals[i] = maxIndex;
}
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return backend.makeTensorInfo(outShape, 'int32', vals);
}
const argMaxConfig = {
kernelName: ArgMax,
backendName: 'cpu',
kernelFunc: argMax
};
function argMin(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis } = attrs;
assertNotComplex(x, 'argMin');
let axes = parseAxisParam(axis, x.shape);
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
const intermediateTensorInfos = [];
if (permutedAxes != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
intermediateTensorInfos.push($x);
axes = getInnerMostAxes(axes.length, $x.shape.length);
}
axes = [axes[0]];
assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
const outSize = sizeFromShape(outShape);
const vals = makeZerosTypedArray(outSize, 'int32');
const reduceSize = sizeFromShape(reduceShape);
const aVals = backend.data.get($x.dataId).values;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let min = aVals[offset];
let minIndex = 0;
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (value < min) {
min = value;
minIndex = j;
}
}
vals[i] = minIndex;
}
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return backend.makeTensorInfo(outShape, 'int32', vals);
}
const argMinConfig = {
kernelName: ArgMin,
backendName: 'cpu',
kernelFunc: argMin
};
const asin = unaryKernelFunc$1(Asin, (xi) => Math.asin(xi));
const asinConfig = {
kernelName: Asin,
backendName: 'cpu',
kernelFunc: asin,
};
const asinh = unaryKernelFunc$1(Asinh, (xi) => Math.asinh(xi));
const asinhConfig = {
kernelName: Asinh,
backendName: 'cpu',
kernelFunc: asinh,
};
const atan = unaryKernelFunc$1(Atan, (xi) => Math.atan(xi));
const atanConfig = {
kernelName: Atan,
backendName: 'cpu',
kernelFunc: atan,
};
const atan2Impl = createSimpleBinaryKernelImpl((aValue, bValue) => Math.atan2(aValue, bValue));
const atan2 = binaryKernelFunc$1(Atan2, atan2Impl);
const atan2Config = {
kernelName: Atan2,
backendName: 'cpu',
kernelFunc: atan2,
};
const atanh = unaryKernelFunc$1(Atanh, (xi) => Math.atanh(xi));
const atanhConfig = {
kernelName: Atanh,
backendName: 'cpu',
kernelFunc: atanh,
};
function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
Number.POSITIVE_INFINITY);
const output = buffer(convInfo.outShape, dtype);
const outputVals = output.values;
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
const outputColStrides = convInfo.outShape[3];
for (let b = 0; b < convInfo.batchSize; ++b) {
const outputBatchOffset = b * outputBatchStrides;
const inputBatchOffset = b * strides[0];
for (let d = 0; d < convInfo.inChannels; ++d) {
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const xRCorner = yR * strideHeight - padTop;
const xRMin = Math.max(0, xRCorner);
const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const xCCorner = yC * strideWidth - padLeft;
const xCMin = Math.max(0, xCCorner);
const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
let minMaxValue = initialValue;
let avgValue = 0;
let count = 0;
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
const xROffset = inputBatchOffset + xR * strides[1];
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
const xCOffset = xROffset + xC * strides[2];
const pixel = xValues[xCOffset + d];
if ((poolType === 'max' && pixel > minMaxValue)) {
minMaxValue = pixel;
}
else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
const outputOffset = outputRowOffset + yC * outputColStrides + d;
outputVals[outputOffset] =
poolType === 'avg' ? avgValue / count : minMaxValue;
}
}
}
}
return output;
}
function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions = false, includeBatchInIndex = false) {
const maxPositions = buffer(convInfo.outShape, 'int32');
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const xBuf = buffer(xShape, dtype, xValues);
for (let b = 0; b < convInfo.batchSize; ++b) {
for (let d = 0; d < convInfo.inChannels; ++d) {
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const xRCorner = yR * strideHeight - padTop;
let xRMin = xRCorner;
while (xRMin < 0) {
xRMin += dilationHeight;
}
const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const xCCorner = yC * strideWidth - padLeft;
let xCMin = xCCorner;
while (xCMin < 0) {
xCMin += dilationWidth;
}
const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
let maxValue = Number.NEGATIVE_INFINITY;
let maxPosition = -1;
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
const wR = xR - xRCorner;
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
const wC = xC - xCCorner;
const pixel = xBuf.get(b, xR, xC, d);
if (pixel > maxValue) {
maxValue = pixel;
if (flattenPositions) {
maxPosition = includeBatchInIndex ?
((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
convInfo.inChannels +
d :
(xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
}
else {
maxPosition = wR * effectiveFilterWidth + wC;
}
}
}
}
maxPositions.set(maxPosition, b, yR, yC, d);
}
}
}
}
return maxPositions;
}
function pool3d(xValues, xShape, dtype, strides, convInfo, poolType) {
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
Number.POSITIVE_INFINITY);
const output = buffer(convInfo.outShape, dtype);
const outputVals = output.values;
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
convInfo.outShape[3] * convInfo.outShape[4];
const outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
const outputColStrides = convInfo.outShape[4];
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
const outputBatchOffset = batch * outputBatchStrides;
const inputBatchOffset = batch * strides[0];
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
const xDepthCorner = yDepth * strideDepth - padFront;
let xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
const outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
const xRowCorner = yRow * strideHeight - padTop;
let xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
const xColCorner = yCol * strideWidth - padLeft;
let xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
const outputColOffset = outputRowOffset + yCol * outputColStrides;
let minMaxValue = initialValue;
let avgValue = 0;
let count = 0;
for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
const xDepthOffset = inputBatchOffset + xDepth * strides[1];
for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
const xRowOffset = xDepthOffset + xRow * strides[2];
for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
const xColOffset = xRowOffset + xCol * strides[3];
const pixel = xValues[xColOffset + channel];
if ((poolType === 'max' && pixel > minMaxValue)) {
minMaxValue = pixel;
}
else if (poolType === 'avg') {
avgValue += pixel;
count++;
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
if (isNaN(minMaxValue)) {
break;
}
}
const outputOffset = outputColOffset + channel;
outputVals[outputOffset] = poolType === 'avg' ?
avgValue / Math.max(count, 1) :
minMaxValue;
}
}
}
}
}
return output;
}
function maxPool3dPositions(xBuf, convInfo) {
const maxPositions = buffer(convInfo.outShape, 'int32');
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = convInfo.padInfo.front;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
const xDepthCorner = yDepth * strideDepth - padFront;
let xDepthMin = xDepthCorner;
while (xDepthMin < 0) {
xDepthMin += dilationDepth;
}
const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
const xRowCorner = yRow * strideHeight - padTop;
let xRowMin = xRowCorner;
while (xRowMin < 0) {
xRowMin += dilationHeight;
}
const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
const xColCorner = yCol * strideWidth - padLeft;
let xColMin = xColCorner;
while (xColMin < 0) {
xColMin += dilationWidth;
}
const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
let maxValue = Number.NEGATIVE_INFINITY;
let maxPosition = -1;
for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
const wDepth = xDepth - xDepthCorner;
for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
const wRow = xRow - xRowCorner;
for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
const wCol = xCol - xColCorner;
const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
if (pixel >= maxValue) {
maxValue = pixel;
maxPosition =
wDepth * effectiveFilterHeight * effectiveFilterWidth +
wRow * effectiveFilterHeight + wCol;
}
}
}
}
maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
}
}
}
}
}
return maxPositions;
}
function avgPool(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
assertNotComplex(x, 'avgPool');
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const dilations = 1;
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
let res;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
arraysEqual(convInfo.inShape, convInfo.outShape)) {
res = identity$1({ inputs: { x }, backend });
}
else {
const xValues = backend.data.get(x.dataId).values;
const strides = computeStrides(x.shape);
const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'avg');
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
}
return res;
}
const avgPoolConfig = {
kernelName: AvgPool,
backendName: 'cpu',
kernelFunc: avgPool
};
function avgPool3D(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
assertNotComplex(x, 'avgPool3d');
const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode, dataFormat);
const xValues = backend.data.get(x.dataId).values;
const outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'avg');
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
}
const avgPool3DConfig = {
kernelName: AvgPool3D,
backendName: 'cpu',
kernelFunc: avgPool3D
};
function avgPool3DGrad(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
assertNotComplex([dy, input], 'avgPool3DGrad');
const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 , pad, dimRoundingMode);
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const filterDepth = convInfo.filterDepth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const dx = buffer(input.shape, 'float32');
const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
const dyBuf = backend.bufferSync(dy);
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
const dyDepthCorner = dxDepth - padFront;
const dyRowCorner = dxRow - padTop;
const dyColCorner = dxCol - padLeft;
let dotProd = 0;
for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
const dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
Math.floor(dyRow) !== dyRow) {
continue;
}
for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
const dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
Math.floor(dyCol) !== dyCol) {
continue;
}
const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel;
}
}
}
dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const avgPool3DGradConfig$1 = {
kernelName: AvgPool3DGrad,
backendName: 'cpu',
kernelFunc: avgPool3DGrad
};
function avgPoolGrad$1(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const x = input;
assertNotComplex([dy, input], 'avgPoolGrad');
const { filterSize, strides, pad } = attrs;
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad);
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const dx = buffer(x.shape, 'float32');
const avgMultiplier = 1 / (filterHeight * filterWidth);
const dyData = backend.data.get(dy.dataId).values;
const dyBuf = buffer(dy.shape, 'float32', dyData);
for (let b = 0; b < convInfo.batchSize; ++b) {
for (let d = 0; d < convInfo.inChannels; ++d) {
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
const dyRCorner = dxR - padTop;
const dyCCorner = dxC - padLeft;
let dotProd = 0;
for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
const dyR = (dyRCorner + wR) / strideHeight;
if (dyR < 0 || dyR >= convInfo.outHeight ||
Math.floor(dyR) !== dyR) {
continue;
}
for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
const dyC = (dyCCorner + wC) / strideWidth;
if (dyC < 0 || dyC >= convInfo.outWidth ||
Math.floor(dyC) !== dyC) {
continue;
}
const pixel = dyBuf.get(b, dyR, dyC, d);
dotProd += pixel;
}
}
dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const avgPoolGradConfig$1 = {
kernelName: AvgPoolGrad,
backendName: 'cpu',
kernelFunc: avgPoolGrad$1
};
function batchNorm(args) {
const { inputs, backend, attrs } = args;
const { x, scale, offset, mean, variance } = inputs;
assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
'equal ranks.');
assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
'equal ranks.');
assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
'equal ranks.');
assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
let { varianceEpsilon } = attrs;
if (varianceEpsilon == null) {
varianceEpsilon = 0.001;
}
const xVals = backend.data.get(x.dataId).values;
const mVals = backend.data.get(mean.dataId).values;
const varVals = backend.data.get(variance.dataId).values;
const sVals = scale ? backend.data.get(scale.dataId).values :
new Float32Array([1]);
const offVals = offset ?
backend.data.get(offset.dataId).values :
new Float32Array([0]);
const outVals = new Float32Array(xVals.length);
const offValsLength = offVals.length;
const sValsLength = sVals.length;
const varValsLength = varVals.length;
const mValsLength = mVals.length;
let offi = 0;
let mi = 0;
let si = 0;
let vi = 0;
for (let i = 0; i < xVals.length; ++i) {
outVals[i] = offVals[offi++] +
(xVals[i] - mVals[mi++]) * sVals[si++] /
Math.sqrt(varVals[vi++] + varianceEpsilon);
if (offi >= offValsLength) {
offi = 0;
}
if (mi >= mValsLength) {
mi = 0;
}
if (si >= sValsLength) {
si = 0;
}
if (vi >= varValsLength) {
vi = 0;
}
}
return backend.makeTensorInfo(x.shape, x.dtype, outVals);
}
const batchNormConfig = {
kernelName: FusedBatchNorm,
backendName: 'cpu',
kernelFunc: batchNorm,
};
function batchToSpaceND(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { blockShape, crops } = attrs;
assertNotComplex([x], 'batchToSpaceND');
const prod = blockShape.reduce((a, b) => a * b);
const reshaped = getReshaped(x.shape, blockShape, prod);
const permuted = getPermuted(reshaped.length, blockShape.length);
const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
const xReshaped = reshape({ inputs: { x }, backend, attrs: { shape: reshaped } });
const xTransposed = transpose$1({ inputs: { x: xReshaped }, backend, attrs: { perm: permuted } });
const xTransposedReshaped = reshape({ inputs: { x: xTransposed }, backend, attrs: { shape: reshapedPermuted } });
const result = slice$1({
inputs: { x: xTransposedReshaped },
backend,
attrs: { begin: sliceBeginCoords, size: sliceSize }
});
backend.disposeIntermediateTensorInfo(xReshaped);
backend.disposeIntermediateTensorInfo(xTransposed);
backend.disposeIntermediateTensorInfo(xTransposedReshaped);
return result;
}
const batchToSpaceNDConfig = {
kernelName: BatchToSpaceND,
backendName: 'cpu',
kernelFunc: batchToSpaceND
};
function bincount(args) {
const { inputs, backend, attrs } = args;
const { x, weights } = inputs;
const { size } = attrs;
const xVals = backend.data.get(x.dataId).values;
const weightsVals = backend.data.get(weights.dataId).values;
const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
}
const bincountConfig = {
kernelName: Bincount,
backendName: 'cpu',
kernelFunc: bincount
};
function broadcastArgs(args) {
const { inputs, backend } = args;
const { s0, s1 } = inputs;
const s0Vals = backend.data.get(s0.dataId).values;
const s1Vals = backend.data.get(s1.dataId).values;
const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
}
const broadcastArgsConfig = {
kernelName: BroadcastArgs,
backendName: 'cpu',
kernelFunc: broadcastArgs
};
const clipByValue = unaryKernelFunc$1(ClipByValue, (xi, attrs) => {
const clipAttrs = attrs;
if (xi > clipAttrs.clipValueMax) {
return clipAttrs.clipValueMax;
}
return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
});
const clipByValueConfig = {
kernelName: ClipByValue,
backendName: 'cpu',
kernelFunc: clipByValue,
};
const complexAbs = (args) => {
const { x } = args.inputs;
const cpuBackend = args.backend;
const resultValues = new Float32Array(sizeFromShape(x.shape));
const complexVals = cpuBackend.data.get(x.dataId);
const real = complexVals.complexTensorInfos.real;
const imag = complexVals.complexTensorInfos.imag;
const realVals = cpuBackend.data.get(real.dataId).values;
const imagVals = cpuBackend.data.get(imag.dataId).values;
for (let i = 0; i < realVals.length; i++) {
const real = realVals[i];
const imag = imagVals[i];
resultValues[i] = Math.hypot(real, imag);
}
return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
};
const complexAbsConfig = {
kernelName: ComplexAbs,
backendName: 'cpu',
kernelFunc: complexAbs,
};
function imag(args) {
const { inputs, backend } = args;
const { input } = inputs;
const imag = backend.data.get(input.dataId).complexTensorInfos.imag;
const imagVal = backend.data.get(imag.dataId).values;
return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
}
const imagConfig = {
kernelName: Imag,
backendName: 'cpu',
kernelFunc: imag
};
function concat(args) {
const { inputs, backend, attrs } = args;
const { axis } = attrs;
const $axis = parseAxisParam(axis, inputs[0].shape)[0];
const shapes = inputs.map(t => t.shape);
assertParamsConsistent(shapes, $axis);
let outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
if (sizeFromShape(outShape) === 0) {
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
}
const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
if ($inputs.length === 1) {
return identity$1({ inputs: { x: $inputs[0] }, backend });
}
if ($inputs[0].dtype === 'complex64') {
const reals = $inputs.map((t) => real$1({ inputs: { input: t }, backend }));
const imags = $inputs.map((t) => imag({ inputs: { input: t }, backend }));
const realConcated = concat({ inputs: reals, backend, attrs: { axis: $axis } });
const imagConcated = concat({ inputs: imags, backend, attrs: { axis: $axis } });
const result = complex$1({ inputs: { real: realConcated, imag: imagConcated }, backend });
reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
backend.disposeIntermediateTensorInfo(realConcated);
backend.disposeIntermediateTensorInfo(imagConcated);
return result;
}
const inputs2D = $inputs.map(t => {
const innerSize = sizeFromShape(t.shape.slice($axis));
const shape = [-1, innerSize];
return reshape({ inputs: { x: t }, backend, attrs: { shape } });
});
const inputsValShapes = inputs2D.map(t => {
return { vals: backend.data.get(t.dataId).values, shape: t.shape };
});
outShape =
computeOutShape$1(inputs2D.map(t => t.shape), 1 );
const simplyConcat = inputs2D[0].shape[0] === 1;
const outVals = concatImpl$1(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
const finalOutShape = computeOutShape$1($inputs.map(t => t.shape), $axis);
const outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
inputs2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
return outInfo;
}
const concatConfig = {
kernelName: Concat,
backendName: 'cpu',
kernelFunc: concat
};
function conv2D(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
assertNotComplex([x, filter], 'conv2d');
const $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat);
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const padLeft = convInfo.padInfo.left;
const padTop = convInfo.padInfo.top;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
const y = new TensorBuffer(convInfo.outShape, x.dtype);
const xStrides = computeStrides(x.shape);
const filterStrides = computeStrides(filter.shape);
const xBatchStride = xStrides[0];
const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
const xColStride = isChannelsLast ? xStrides[2] : 1;
const xChannelStride = isChannelsLast ? 1 : xStrides[1];
const yBatchStride = y.strides[0];
const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
const yColStride = isChannelsLast ? y.strides[2] : 1;
const yChannelStride = isChannelsLast ? 1 : y.strides[1];
const xVals = backend.data.get(x.dataId).values;
const wVals = backend.data.get(filter.dataId).values;
const yVals = y.values;
for (let b = 0; b < convInfo.batchSize; ++b) {
const xOffset1 = b * xBatchStride;
const yOffset1 = b * yBatchStride;
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const yOffset2 = yOffset1 + yR * yRowStride;
const xRCorner = yR * convInfo.strideHeight - padTop;
for (let wR = 0; wR < filterHeight; ++wR) {
const xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
const wOffset1 = wR * filterStrides[0];
const xOffset2 = xOffset1 + xR * xRowStride;
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const yOffset3 = yOffset2 + yC * yColStride;
const xCCorner = yC * convInfo.strideWidth - padLeft;
for (let wC = 0; wC < filterWidth; ++wC) {
const xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
const wOffset2 = wOffset1 + wC * filterStrides[1];
const xOffset3 = xOffset2 + xC * xColStride;
let wOffset3 = wOffset2;
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
const xVal = xVals[xOffset3 + d1 * xChannelStride];
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
yVals[yOffset3 + d2 * yChannelStride] +=
xVal * wVals[wOffset3 + d2];
}
wOffset3 += convInfo.outChannels;
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, yVals);
}
const conv2DConfig = {
kernelName: Conv2D,
backendName: 'cpu',
kernelFunc: conv2D
};
function conv2DBackpropFilter(args) {
const { inputs, backend, attrs } = args;
const { x, dy } = inputs;
const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
assertNotComplex([x, dy], 'conv2dBackpropFilter');
const $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 , pad, dimRoundingMode, false , $dataFormat);
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
const leftPad = convInfo.padInfo.left;
const topPad = convInfo.padInfo.top;
const xVals = backend.data.get(x.dataId).values;
const dyVals = backend.data.get(dy.dataId).values;
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
for (let wR = 0; wR < filterHeight; ++wR) {
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
for (let wC = 0; wC < filterWidth; ++wC) {
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
let dotProd = 0;
for (let b = 0; b < convInfo.batchSize; ++b) {
for (let yR = yRMin; yR < yRMax; ++yR) {
const xR = wR + yR * strideHeight - topPad;
for (let yC = yCMin; yC < yCMax; ++yC) {
const xC = wC + yC * strideWidth - leftPad;
if (isChannelsLast) {
dotProd += xBuf.get(b, xR, xC, d1) *
dyBuf.get(b, yR, yC, d2);
}
else {
dotProd += xBuf.get(b, d1, xR, xC) *
dyBuf.get(b, d2, yR, yC);
}
}
}
}
dW.set(dotProd, wR, wC, d1, d2);
}
}
}
}
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
}
const conv2DBackpropFilterConfig = {
kernelName: Conv2DBackpropFilter,
backendName: 'cpu',
kernelFunc: conv2DBackpropFilter
};
function conv2DBackpropInput(args) {
const { inputs, backend, attrs } = args;
const { dy, filter } = inputs;
const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
assertNotComplex([dy, filter], 'conv2dBackpropInput');
const filterStrides = computeStrides(filter.shape);
const dyStrides = computeStrides(dy.shape);
let $dataFormat = convertConv2DDataFormat(dataFormat);
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 , pad, dimRoundingMode, false, $dataFormat);
const dx = new TensorBuffer(convInfo.inShape, 'float32');
const dxValues = dx.values;
const dyValues = backend.data.get(dy.dataId).values;
const fltValues = backend.data.get(filter.dataId).values;
const [fltS0, fltS1, fltS2] = filterStrides;
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
$dataFormat = convInfo.dataFormat;
const topPad = filterHeight - 1 - convInfo.padInfo.top;
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
const isChannelsLast = $dataFormat === 'channelsLast';
const xBatchStride = dx.strides[0];
const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
const xColStride = isChannelsLast ? dx.strides[2] : 1;
const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
const yBatchStride = dyStrides[0];
const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
const yColStride = isChannelsLast ? dyStrides[2] : 1;
const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
for (let b = 0; b < batchSize; ++b) {
for (let d1 = 0; d1 < inChannels; ++d1) {
for (let xR = 0; xR < inHeight; ++xR) {
const xRCorner = xR - topPad;
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (let xC = 0; xC < inWidth; ++xC) {
const xCCorner = xC - leftPad;
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
let dotProd = 0;
for (let yR = xRMin; yR < yRMax; ++yR) {
const wR = yR * strideHeight - xRCorner;
for (let yC = xCMin; yC < yCMax; ++yC) {
const wC = yC * strideWidth - xCCorner;
const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
for (let d2 = 0; d2 < outChannels; ++d2) {
const pixel = dyValues[dyOffset + yChannelStride * d2];
const weight = fltValues[fltOffset + d2];
dotProd += pixel * weight;
}
}
}
const dxOffset = xBatchStride * b + xRowStride * xR +
xColStride * xC + xChannelStride * d1;
dxValues[dxOffset] = dotProd;
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const conv2DBackpropInputConfig = {
kernelName: Conv2DBackpropInput,
backendName: 'cpu',
kernelFunc: conv2DBackpropInput
};
function conv3D(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dilations } = attrs;
assertNotComplex([x, filter], 'conv3d');
const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
const { filterDepth, filterHeight, filterWidth, dilationDepth, dilationHeight, dilationWidth, padInfo } = convInfo;
const padFront = padInfo.front;
const padLeft = padInfo.left;
const padTop = padInfo.top;
const y = new TensorBuffer(convInfo.outShape, x.dtype);
const xVals = backend.data.get(x.dataId).values;
const wVals = backend.data.get(filter.dataId).values;
const yVals = y.values;
const xStrides = computeStrides(x.shape);
const filterStrides = computeStrides(filter.shape);
for (let b = 0; b < convInfo.batchSize; ++b) {
const xOffset1 = b * xStrides[0];
const yOffset1 = b * y.strides[0];
for (let yF = 0; yF < convInfo.outDepth; ++yF) {
const yOffset2 = yOffset1 + yF * y.strides[1];
const xFCorner = yF * convInfo.strideDepth - padFront;
for (let wF = 0; wF < filterDepth; ++wF) {
const xF = xFCorner + wF * dilationDepth;
if (xF < 0 || xF >= convInfo.inDepth) {
continue;
}
const wOffset1 = wF * filterStrides[0];
const xOffset2 = xOffset1 + xF * xStrides[1];
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const yOffset3 = yOffset2 + yR * y.strides[2];
const xRCorner = yR * convInfo.strideHeight - padTop;
for (let wR = 0; wR < filterHeight; ++wR) {
const xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
const wOffset2 = wOffset1 + wR * filterStrides[1];
const xOffset3 = xOffset2 + xR * xStrides[2];
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const yOffset4 = yOffset3 + yC * convInfo.outChannels;
const xCCorner = yC * convInfo.strideWidth - padLeft;
for (let wC = 0; wC < filterWidth; ++wC) {
const xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
const wOffset3 = wOffset2 + wC * filterStrides[2];
const xOffset4 = xOffset3 + xC * convInfo.inChannels;
let wOffset4 = wOffset3;
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
const xVal = xVals[xOffset4 + d1];
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
}
wOffset4 += convInfo.outChannels;
}
}
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
}
const conv3DConfig = {
kernelName: Conv3D,
backendName: 'cpu',
kernelFunc: conv3D
};
function conv3DBackpropFilterV2(args) {
const { inputs, backend, attrs } = args;
const { x, dy } = inputs;
const { strides, pad, filterShape } = attrs;
assertNotComplex([x, dy], 'conv3dBackpropFilterV2');
const xStrides = computeStrides(x.shape);
const dyStrides = computeStrides(dy.shape);
const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 , pad);
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const filterDepth = convInfo.filterDepth;
const filterHeight = convInfo.filterHeight;
const filterWidth = convInfo.filterWidth;
const dw = new TensorBuffer(convInfo.filterShape, 'float32');
const dwValues = dw.values;
const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
const dyValues = backend.data.get(dy.dataId).values;
const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
const xValues = backend.data.get(x.dataId).values;
const [xS0, xS1, xS2, xS3] = xStrides;
const frontPad = convInfo.padInfo.front;
const leftPad = convInfo.padInfo.left;
const topPad = convInfo.padInfo.top;
for (let wF = 0; wF < filterDepth; ++wF) {
const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
const yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
const wOffset1 = wF * dwS0;
for (let wR = 0; wR < filterHeight; ++wR) {
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
const wOffset2 = wR * dwS1 + wOffset1;
for (let wC = 0; wC < filterWidth; ++wC) {
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
const wOffset3 = wC * dwS2 + wOffset2;
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
const wOffset4 = d1 * dwS3 + wOffset3;
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
let dotProd = 0;
for (let b = 0; b < convInfo.batchSize; ++b) {
const xOffset1 = b * xS0;
const yOffset1 = b * dyS0;
for (let yF = yFMin; yF < yFMax; ++yF) {
const xF = wF + yF * strideDepth - frontPad;
const xOffset2 = xF * xS1 + xOffset1;
const yOffset2 = yF * dyS1 + yOffset1;
for (let yR = yRMin; yR < yRMax; ++yR) {
const xR = wR + yR * strideHeight - topPad;
const xOffset3 = xR * xS2 + xOffset2;
const yOffset3 = yR * dyS2 + yOffset2;
for (let yC = yCMin; yC < yCMax; ++yC) {
const xC = wC + yC * strideWidth - leftPad;
const xOffset4 = xC * xS3 + xOffset3;
const yOffset4 = yC * dyS3 + yOffset3;
dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
}
}
}
}
dwValues[wOffset4 + d2] = dotProd;
}
}
}
}
}
return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
}
const conv3DBackpropFilterV2Config = {
kernelName: Conv3DBackpropFilterV2,
backendName: 'cpu',
kernelFunc: conv3DBackpropFilterV2
};
function conv3DBackpropInputV2(args) {
const { inputs, backend, attrs } = args;
const { dy, filter } = inputs;
const { pad, strides, inputShape } = attrs;
assertNotComplex([dy], 'conv3dBackpropInputV2');
const dyStrides = computeStrides(dy.shape);
const filterStrides = computeStrides(filter.shape);
const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 , pad);
const dx = new TensorBuffer(convInfo.inShape, 'float32');
const dxValues = dx.values;
const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
const dyValues = backend.data.get(dy.dataId).values;
const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
const fltValues = backend.data.get(filter.dataId).values;
const [fltS0, fltS1, fltS2, fltS3] = filterStrides;
const { batchSize, filterDepth, filterHeight, filterWidth, inChannels, inDepth, inHeight, inWidth, outChannels, outDepth, outHeight, outWidth, strideDepth, strideHeight, strideWidth } = convInfo;
const frontPad = filterDepth - 1 - convInfo.padInfo.front;
const topPad = filterHeight - 1 - convInfo.padInfo.top;
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
for (let b = 0; b < batchSize; ++b) {
for (let d1 = 0; d1 < inChannels; ++d1) {
for (let xF = 0; xF < inDepth; ++xF) {
const xFCorner = xF - frontPad;
const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
const yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
for (let xR = 0; xR < inHeight; ++xR) {
const xRCorner = xR - topPad;
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (let xC = 0; xC < inWidth; ++xC) {
const xCCorner = xC - leftPad;
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
let dotProd = 0;
for (let yF = xFMin; yF < yFMax; ++yF) {
const wF = yF * strideDepth - xFCorner;
for (let yR = xRMin; yR < yRMax; ++yR) {
const wR = yR * strideHeight - xRCorner;
for (let yC = xCMin; yC < yCMax; ++yC) {
const wC = yC * strideWidth - xCCorner;
const dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
const fltOffset = fltS0 * (filterDepth - 1 - wF) +
fltS1 * (filterHeight - 1 - wR) +
fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
for (let d2 = 0; d2 < outChannels; ++d2) {
const pixel = dyValues[dyOffset + d2];
const weight = fltValues[fltOffset + d2];
dotProd += pixel * weight;
}
}
}
}
dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
dotProd;
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const conv3DBackpropInputV2Config = {
kernelName: Conv3DBackpropInputV2,
backendName: 'cpu',
kernelFunc: conv3DBackpropInputV2
};
const cos = unaryKernelFunc$1(Cos, (xi) => Math.cos(xi));
const cosConfig = {
kernelName: Cos,
backendName: 'cpu',
kernelFunc: cos,
};
const cosh = unaryKernelFunc$1(Cosh, (xi) => Math.cosh(xi));
const coshConfig = {
kernelName: Cosh,
backendName: 'cpu',
kernelFunc: cosh,
};
function cropAndResize(args) {
const { inputs, backend, attrs } = args;
const { image, boxes, boxInd } = inputs;
const { cropSize, method, extrapolationValue } = attrs;
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
const numBoxes = boxes.shape[0];
const [cropHeight, cropWidth] = cropSize;
const output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
const boxVals = backend.data.get(boxes.dataId).values;
const boxIndVals = backend.data.get(boxInd.dataId).values;
const imageVals = backend.data.get(image.dataId).values;
const inStride = computeStrides(image.shape);
const outStride = computeStrides(output.shape);
for (let b = 0; b < numBoxes; b++) {
const startInd = b * 4;
const y1 = boxVals[startInd];
const x1 = boxVals[startInd + 1];
const y2 = boxVals[startInd + 2];
const x2 = boxVals[startInd + 3];
const bInd = boxIndVals[b];
if (bInd >= batch) {
continue;
}
const heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
const widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
for (let y = 0; y < cropHeight; y++) {
const yInd = (cropHeight > 1) ?
y1 * (imageHeight - 1) + y * (heightScale) :
0.5 * (y1 + y2) * (imageHeight - 1);
if (yInd < 0 || yInd > imageHeight - 1) {
for (let x = 0; x < cropWidth; x++) {
for (let c = 0; c < numChannels; c++) {
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
}
continue;
}
if (method === 'bilinear') {
const topInd = Math.floor(yInd);
const bottomInd = Math.ceil(yInd);
const yLerp = yInd - topInd;
for (let x = 0; x < cropWidth; x++) {
const xInd = (cropWidth > 1) ?
x1 * (imageWidth - 1) + x * widthScale :
0.5 * (x1 + x2) * (imageWidth - 1);
if (xInd < 0 || xInd > imageWidth - 1) {
for (let c = 0; c < numChannels; c++) {
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
continue;
}
const leftInd = Math.floor(xInd);
const rightInd = Math.ceil(xInd);
const xLerp = xInd - leftInd;
for (let c = 0; c < numChannels; c++) {
let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
bInd * inStride[0];
const topLeft = imageVals[ind];
ind = c + rightInd * inStride[2] + topInd * inStride[1] +
bInd * inStride[0];
const topRight = imageVals[ind];
ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
bInd * inStride[0];
const bottomLeft = imageVals[ind];
ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
bInd * inStride[0];
const bottomRight = imageVals[ind];
const top = topLeft + (topRight - topLeft) * xLerp;
const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = top + ((bottom - top) * yLerp);
}
}
}
else {
for (let x = 0; x < cropWidth; ++x) {
const xInd = (cropWidth > 1) ?
x1 * (imageWidth - 1) + x * widthScale :
0.5 * (x1 + x2) * (imageWidth - 1);
if (xInd < 0 || xInd > imageWidth - 1) {
for (let c = 0; c < numChannels; c++) {
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[ind] = extrapolationValue;
}
continue;
}
const closestX = Math.round(xInd);
const closestY = Math.round(yInd);
for (let c = 0; c < numChannels; c++) {
const inInd = c + closestX * inStride[2] + closestY * inStride[1] +
bInd * inStride[0];
const outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
output.values[outInd] = imageVals[inInd];
}
}
}
}
}
return backend.makeTensorInfo(output.shape, output.dtype, output.values);
}
const cropAndResizeConfig = {
kernelName: CropAndResize,
backendName: 'cpu',
kernelFunc: cropAndResize
};
function cumprod(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, exclusive, reverse } = attrs;
assertNotComplex(x, 'cumprod');
const permutation = getAxesPermutation([axis], x.shape.length);
let $x = x;
if (permutation != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
}
const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
if (permutedAxis !== $x.shape.length - 1) {
throw new Error(`backend.cumprod in CPU expects an inner-most ` +
`axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
}
const resultDtype = upcastType($x.dtype, 'int32');
const vals = makeOnesTypedArray(sizeFromShape($x.shape), resultDtype);
const aVals = backend.data.get($x.dataId).values;
const finalDim = $x.shape[$x.shape.length - 1];
const indexAdjuster = reverse ?
(i, j) => i + finalDim - j - 1 :
(i, j) => i + j;
for (let i = 0; i < aVals.length; i += finalDim) {
for (let j = 0; j < finalDim; j++) {
const idx = indexAdjuster(i, j);
if (j === 0) {
vals[idx] = exclusive ? 1 : aVals[idx];
}
else {
const prevIdx = indexAdjuster(i, j - 1);
vals[idx] = exclusive ? aVals[prevIdx] * vals[prevIdx] :
aVals[idx] * vals[prevIdx];
}
}
}
const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
if (permutation != null) {
const reversePermutation = getUndoAxesPermutation(permutation);
const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
backend.disposeIntermediateTensorInfo(result);
backend.disposeIntermediateTensorInfo($x);
return reverseTransposedResult;
}
return result;
}
const cumprodConfig = {
kernelName: Cumprod,
backendName: 'cpu',
kernelFunc: cumprod
};
function cumsum(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, exclusive, reverse } = attrs;
assertNotComplex(x, 'cumsum');
const permutation = getAxesPermutation([axis], x.shape.length);
let $x = x;
if (permutation != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
}
const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
if (permutedAxis !== $x.shape.length - 1) {
throw new Error(`backend.cumsum in CPU expects an inner-most ` +
`axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
}
const resultDtype = upcastType($x.dtype, 'int32');
const vals = makeZerosTypedArray(sizeFromShape($x.shape), resultDtype);
const aVals = backend.data.get($x.dataId).values;
const finalDim = $x.shape[$x.shape.length - 1];
const indexAdjuster = reverse ?
(i, j) => i + finalDim - j - 1 :
(i, j) => i + j;
for (let i = 0; i < aVals.length; i += finalDim) {
for (let j = 0; j < finalDim; j++) {
const idx = indexAdjuster(i, j);
if (j === 0) {
vals[idx] = exclusive ? 0 : aVals[idx];
}
else {
const prevIdx = indexAdjuster(i, j - 1);
vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
aVals[idx] + vals[prevIdx];
}
}
}
const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
if (permutation != null) {
const reversePermutation = getUndoAxesPermutation(permutation);
const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
backend.disposeIntermediateTensorInfo(result);
backend.disposeIntermediateTensorInfo($x);
return reverseTransposedResult;
}
return result;
}
const cumsumConfig = {
kernelName: Cumsum,
backendName: 'cpu',
kernelFunc: cumsum
};
function denseBincount(args) {
const { inputs, backend, attrs } = args;
const { x, weights } = inputs;
const { size, binaryOutput } = attrs;
if (x.shape.length === 1) {
const xVals = backend.data.get(x.dataId).values;
const weightsVals = backend.data.get(weights.dataId).values;
const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
return backend.makeTensorInfo([size], weights.dtype, outVals);
}
else if (x.shape.length === 2) {
const xBuf = backend.bufferSync(x);
const weightsBuf = backend.bufferSync(weights);
const outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
}
throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
`${x.shape.length}.`);
}
const denseBincountConfig = {
kernelName: DenseBincount,
backendName: 'cpu',
kernelFunc: denseBincount
};
function depthToSpace(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { blockSize, dataFormat } = attrs;
assert$1(dataFormat === 'NHWC', () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${dataFormat}`);
const batchSize = x.shape[0];
const inputHeight = x.shape[1];
const inputWidth = x.shape[2];
const inputDepth = x.shape[3];
const outputHeight = inputHeight * blockSize;
const outputWidth = inputWidth * blockSize;
const outputDepth = inputDepth / (blockSize * blockSize);
const xValues = backend.data.get(x.dataId).values;
const result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
let outputIdx = 0;
for (let b = 0; b < batchSize; ++b) {
for (let h = 0; h < outputHeight; ++h) {
const inH = Math.floor(h / blockSize);
const offsetH = (h % blockSize);
for (let w = 0; w < outputWidth; ++w) {
const inW = Math.floor(w / blockSize);
const offsetW = (w % blockSize);
const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
for (let d = 0; d < outputDepth; ++d) {
const inD = d + offsetD;
const inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
result[outputIdx++] = xValues[inputIdx];
}
}
}
}
return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
}
const depthToSpaceConfig = {
kernelName: DepthToSpace,
backendName: 'cpu',
kernelFunc: depthToSpace
};
function depthwiseConv2dNative(args) {
const { inputs, backend, attrs } = args;
const { x, filter } = inputs;
const { strides, pad, dilations, dimRoundingMode } = attrs;
assertNotComplex([x, filter], 'depthwiseConv2DNative');
const xStrides = computeStrides(x.shape);
const filterStrides = computeStrides(filter.shape);
let $dilations = dilations;
if ($dilations == null) {
$dilations = [1, 1];
}
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
`1. Got strides ${strides} and dilations '${$dilations}'`);
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true );
const { filterHeight, filterWidth, dilationHeight, dilationWidth, padInfo } = convInfo;
const padLeft = padInfo.left;
const padTop = padInfo.top;
const chMul = convInfo.outChannels / convInfo.inChannels;
const y = new TensorBuffer(convInfo.outShape, x.dtype);
const xVals = backend.data.get(x.dataId).values;
const wVals = backend.data.get(filter.dataId).values;
const yVals = y.values;
for (let b = 0; b < convInfo.batchSize; ++b) {
const xOffset1 = b * xStrides[0];
const yOffset1 = b * y.strides[0];
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
const yOffset2 = yOffset1 + yR * y.strides[1];
const xRCorner = yR * convInfo.strideHeight - padTop;
for (let wR = 0; wR < filterHeight; ++wR) {
const xR = xRCorner + wR * dilationHeight;
if (xR < 0 || xR >= convInfo.inHeight) {
continue;
}
const wOffset1 = wR * filterStrides[0];
const xOffset2 = xOffset1 + xR * xStrides[1];
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
const yOffset3 = yOffset2 + yC * y.strides[2];
const xCCorner = yC * convInfo.strideWidth - padLeft;
for (let wC = 0; wC < filterWidth; ++wC) {
const xC = xCCorner + wC * dilationWidth;
if (xC < 0 || xC >= convInfo.inWidth) {
continue;
}
const wOffset2 = wOffset1 + wC * filterStrides[1];
const xOffset3 = xOffset2 + xC * convInfo.inChannels;
let yOffset4 = yOffset3;
let wOffset3 = wOffset2;
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
const xVal = xVals[xOffset3 + d1];
for (let q = 0; q < chMul; ++q) {
yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
}
yOffset4 += chMul;
wOffset3 += chMul;
}
}
}
}
}
}
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
}
const depthwiseConv2dNativeConfig = {
kernelName: DepthwiseConv2dNative,
backendName: 'cpu',
kernelFunc: depthwiseConv2dNative
};
function depthwiseConv2dNativeBackpropFilter(args) {
const { inputs, backend, attrs } = args;
const { x, dy } = inputs;
const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
assertNotComplex([x, dy], 'depthwiseConv2dNativeBackpropFilter');
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true );
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
const leftPad = convInfo.padInfo.left;
const topPad = convInfo.padInfo.top;
const chMul = convInfo.outChannels / convInfo.inChannels;
const xVals = backend.data.get(x.dataId).values;
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
const dyVals = backend.data.get(dy.dataId).values;
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
for (let wR = 0; wR < filterHeight; ++wR) {
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
for (let wC = 0; wC < filterWidth; ++wC) {
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
const d1 = Math.trunc(d2 / chMul);
const dm = d2 % chMul;
let dotProd = 0;
for (let b = 0; b < convInfo.batchSize; ++b) {
for (let yR = yRMin; yR < yRMax; ++yR) {
const xR = wR + yR * strideHeight - topPad;
for (let yC = yCMin; yC < yCMax; ++yC) {
const xC = wC + yC * strideWidth - leftPad;
dotProd += xBuf.get(b, xR, xC, d1) *
dyBuf.get(b, yR, yC, d2);
}
}
}
dW.set(dotProd, wR, wC, d1, dm);
}
}
}
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
}
const depthwiseConv2dNativeBackpropFilterConfig = {
kernelName: DepthwiseConv2dNativeBackpropFilter,
backendName: 'cpu',
kernelFunc: depthwiseConv2dNativeBackpropFilter
};
function depthwiseConv2dNativeBackpropInput(args) {
const { inputs, backend, attrs } = args;
const { dy, filter } = inputs;
const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
assertNotComplex([dy, filter], 'depthwiseConv2DNativeBackpropInput');
const dyStrides = computeStrides(dy.shape);
const filterStrides = computeStrides(filter.shape);
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true );
const dx = new TensorBuffer(convInfo.inShape, 'float32');
const dxValues = dx.values;
const [dxS0, dxS1, dxS2] = dx.strides;
const dyValues = backend.data.get(dy.dataId).values;
const [dyS0, dyS1, dyS2] = dyStrides;
const fltValues = backend.data.get(filter.dataId).values;
const [fltS0, fltS1, fltS2] = filterStrides;
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
const topPad = filterHeight - 1 - convInfo.padInfo.top;
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
const chMul = outChannels / inChannels;
for (let b = 0; b < batchSize; ++b) {
for (let d1 = 0; d1 < inChannels; ++d1) {
for (let xR = 0; xR < inHeight; ++xR) {
const xRCorner = xR - topPad;
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
for (let xC = 0; xC < inWidth; ++xC) {
const xCCorner = xC - leftPad;
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
let dotProd = 0;
for (let yR = xRMin; yR < yRMax; ++yR) {
const wR = yR * strideHeight - xRCorner;
for (let yC = xCMin; yC < yCMax; ++yC) {
const wC = yC * strideWidth - xCCorner;
const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
for (let dm = 0; dm < chMul; ++dm) {
const d2 = d1 * chMul + dm;
const pixel = dyValues[dyOffset + d2];
const weight = fltValues[fltOffset + dm];
dotProd += pixel * weight;
}
}
}
dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const depthwiseConv2dNativeBackpropInputConfig = {
kernelName: DepthwiseConv2dNativeBackpropInput,
backendName: 'cpu',
kernelFunc: depthwiseConv2dNativeBackpropInput
};
function diag(args) {
const { inputs, backend } = args;
const { x } = inputs;
const xSize = sizeFromShape(x.shape);
const xVals = backend.data.get(x.dataId).values;
const outBuf = buffer([xSize, xSize], x.dtype);
const vals = outBuf.values;
for (let i = 0; i < xVals.length; i++) {
vals[i * xSize + i] = xVals[i];
}
const outShape = [...x.shape, ...x.shape];
return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
}
const diagConfig = {
kernelName: Diag,
backendName: 'cpu',
kernelFunc: diag
};
const dilation2DConfig = {
kernelName: Dilation2D,
backendName: 'cpu',
kernelFunc: ({ inputs, backend, attrs }) => {
const { x, filter } = inputs;
const { strides, pad, dilations } = attrs;
const cpuBackend = backend;
const xVals = cpuBackend.data.get(x.dataId).values;
const xRank = x.shape.length;
const filterVals = cpuBackend.data.get(filter.dataId).values;
const filterRank = filter.shape.length;
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
const outSize = sizeFromShape(outShape);
const outRank = outShape.length;
const outputVals = getArrayFromDType(x.dtype, outSize);
for (let b = 0; b < batchSize; ++b) {
for (let hOut = 0; hOut < outHeight; ++hOut) {
const hBeg = hOut * strideHeight - padInfo.top;
for (let wOut = 0; wOut < outWidth; ++wOut) {
const wBeg = wOut * strideWidth - padInfo.left;
for (let d = 0; d < inChannels; ++d) {
let curVal = Number.MIN_SAFE_INTEGER;
for (let h = 0; h < filterHeight; ++h) {
const hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (let w = 0; w < filterWidth; ++w) {
const wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
const xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
const filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
const val = xVals[xIndex] + filterVals[filterIndex];
if (val > curVal) {
curVal = val;
}
}
}
}
}
const outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
outputVals[outputIndex] = curVal;
}
}
}
}
const dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
return { dataId, shape: outShape, dtype: x.dtype };
}
};
const dilation2DBackpropFilterConfig = {
kernelName: Dilation2DBackpropFilter,
backendName: 'cpu',
kernelFunc: ({ inputs, backend, attrs }) => {
const { x, filter, dy } = inputs;
const { strides, pad, dilations } = attrs;
const cpuBackend = backend;
const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
assert$1(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropFilter}, dy ` +
`must have the same rank as output ${outShape.length}, but got ` +
`${dy.rank}`);
const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
const gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype);
for (let b = 0; b < batchSize; ++b) {
for (let hOut = 0; hOut < outHeight; ++hOut) {
const hBeg = hOut * strideHeight - padInfo.top;
for (let wOut = 0; wOut < outWidth; ++wOut) {
const wBeg = wOut * strideWidth - padInfo.left;
for (let d = 0; d < inChannels; ++d) {
let curVal = Number.MIN_SAFE_INTEGER;
let hMax = 0;
let wMax = 0;
for (let h = 0; h < filterHeight; ++h) {
const hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (let w = 0; w < filterWidth; ++w) {
const wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
if (val > curVal) {
curVal = val;
hMax = h;
wMax = w;
}
}
}
}
}
gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
}
}
}
}
const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
return { dataId, shape: filter.shape, dtype: filter.dtype };
}
};
const dilation2DBackpropInputConfig = {
kernelName: Dilation2DBackpropInput,
backendName: 'cpu',
kernelFunc: ({ inputs, backend, attrs }) => {
const { x, filter, dy } = inputs;
const { strides, pad, dilations } = attrs;
const cpuBackend = backend;
const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
assert$1(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropInput}, dy ` +
`must have the same rank as output ${outShape.length}, but got ` +
`${dy.rank}`);
const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
const gradients = makeZerosNestedTypedArray(x.shape, x.dtype);
for (let b = 0; b < batchSize; ++b) {
for (let hOut = 0; hOut < outHeight; ++hOut) {
const hBeg = hOut * strideHeight - padInfo.top;
for (let wOut = 0; wOut < outWidth; ++wOut) {
const wBeg = wOut * strideWidth - padInfo.left;
for (let d = 0; d < inChannels; ++d) {
let curVal = Number.MIN_SAFE_INTEGER;
let hInMax = (hBeg < 0) ? 0 : hBeg;
let wInMax = (wBeg < 0) ? 0 : wBeg;
for (let h = 0; h < filterHeight; ++h) {
const hIn = hBeg + h * dilationHeight;
if (hIn >= 0 && hIn < inHeight) {
for (let w = 0; w < filterWidth; ++w) {
const wIn = wBeg + w * dilationWidth;
if (wIn >= 0 && wIn < inWidth) {
const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
if (val > curVal) {
curVal = val;
hInMax = hIn;
wInMax = wIn;
}
}
}
}
}
gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
}
}
}
}
const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
return { dataId, shape: x.shape, dtype: x.dtype };
}
};
function draw(args) {
const { inputs, backend, attrs } = args;
const { image } = inputs;
const { canvas, options } = attrs;
const { contextOptions, imageOptions } = options || {};
const alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
const contextType = (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextType) || '2d';
if (contextType !== '2d') {
throw new Error(`Context type ${contextOptions.contextType} is not supported by the CPU backend.`);
}
const ctx = canvas.getContext(contextType, (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextAttributes) || {});
if (ctx == null) {
throw new Error(`Could not get the context with ${contextType} type.`);
}
const [height, width] = image.shape.slice(0, 2);
const depth = image.shape.length === 2 ? 1 : image.shape[2];
const data = backend.data.get(image.dataId).values;
const multiplier = image.dtype === 'float32' ? 255 : 1;
const bytes = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < height * width; ++i) {
const rgba = [0, 0, 0, 255 * alpha];
for (let d = 0; d < depth; d++) {
const value = data[i * depth + d];
if (image.dtype === 'float32') {
if (value < 0 || value > 1) {
throw new Error(`Tensor values for a float32 Tensor must be in the ` +
`range [0 - 1] but encountered ${value}.`);
}
}
else if (image.dtype === 'int32') {
if (value < 0 || value > 255) {
throw new Error(`Tensor values for a int32 Tensor must be in the ` +
`range [0 - 255] but encountered ${value}.`);
}
}
if (depth === 1) {
rgba[0] = value * multiplier;
rgba[1] = value * multiplier;
rgba[2] = value * multiplier;
}
else {
rgba[d] = value * multiplier;
}
}
const j = i * 4;
bytes[j + 0] = Math.round(rgba[0]);
bytes[j + 1] = Math.round(rgba[1]);
bytes[j + 2] = Math.round(rgba[2]);
bytes[j + 3] = Math.round(rgba[3]);
}
canvas.width = width;
canvas.height = height;
const imageData = new ImageData(bytes, width, height);
ctx.putImageData(imageData, 0, 0);
return image;
}
const drawConfig = {
kernelName: Draw,
backendName: 'cpu',
kernelFunc: draw
};
function sum(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
assertNotComplex(x, 'sum');
let $x;
if (x.dtype === 'bool') {
$x = cast$2({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
}
else {
$x = identity$1({ inputs: { x }, backend });
}
const xRank = $x.shape.length;
const axes = parseAxisParam(axis, $x.shape);
const permutation = getAxesPermutation(axes, xRank);
let reductionAxes = axes;
let permutedX = $x;
if (permutation != null) {
permutedX =
transpose$1({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
}
assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, reductionAxes);
const resultDtype = upcastType(permutedX.dtype, 'int32');
let result = zeros(backend, outShape, resultDtype);
const reduceSize = sizeFromShape(reduceShape);
const vals = backend.data.get(result.dataId).values;
const aVals = backend.data.get(permutedX.dataId).values;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let sum = 0;
for (let j = 0; j < reduceSize; ++j) {
sum += aVals[offset + j];
}
vals[i] = sum;
}
if (keepDims) {
const newShape = expandShapeToKeepDim(result.shape, axes);
const oldResult = result;
result = reshape({ inputs: { x: result }, backend, attrs: { shape: newShape } });
backend.disposeIntermediateTensorInfo(oldResult);
}
backend.disposeIntermediateTensorInfo($x);
if (permutation != null) {
backend.disposeIntermediateTensorInfo(permutedX);
}
return result;
}
const sumConfig = {
kernelName: Sum,
backendName: 'cpu',
kernelFunc: sum
};
function einsum(args) {
const { inputs, backend, attrs } = args;
const { equation } = attrs;
const tensors = inputs;
const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
checkEinsumDimSizes(allDims.length, idDims, tensors);
const { path, steps } = getEinsumComputePath(summedDims, idDims);
const nSteps = steps.length;
let out = null;
let numDimsRemaining = allDims.length;
const tensorsToDispose = [];
for (let i = 0; i < nSteps; ++i) {
for (const idTerm of steps[i]) {
const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
let x;
if (isIdentityPermutation(perm)) {
x = tensors[idTerm];
}
else {
x = transpose$1({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
tensorsToDispose.push(x);
}
const targetShape = x.shape.slice();
for (let k = 0; k < dimsToExpand.length; ++k) {
targetShape.splice(dimsToExpand[k], 0, 1);
}
if (!arraysEqual(x.shape, targetShape)) {
x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } });
tensorsToDispose.push(x);
}
if (out === null) {
out = x;
}
else {
out = multiply$1({ inputs: { a: x, b: out }, backend });
tensorsToDispose.push(out);
}
}
if (i < nSteps - 1) {
if (path[i] >= 0) {
out = sum({
inputs: { x: out },
backend,
attrs: {
axis: path[i] - (allDims.length - numDimsRemaining),
keepDims: false
}
});
tensorsToDispose.push(out);
}
numDimsRemaining--;
}
}
for (const tensorInfo of tensorsToDispose) {
if (tensorInfo === out) {
continue;
}
backend.disposeIntermediateTensorInfo(tensorInfo);
}
return out;
}
const einsumConfig = {
kernelName: Einsum,
backendName: 'cpu',
kernelFunc: einsum
};
function eluGrad(args) {
const { inputs, backend } = args;
const { dy, y } = inputs;
assertNotComplex([dy, y], 'eluGrad');
const resultValues = new Float32Array(sizeFromShape(y.shape));
const values = backend.data.get(y.dataId).values;
const dyValues = backend.data.get(dy.dataId).values;
for (let i = 0; i < values.length; ++i) {
const v = values[i];
if (v >= 0) {
resultValues[i] = dyValues[i];
}
else {
resultValues[i] = dyValues[i] * (v + 1);
}
}
return backend.makeTensorInfo(y.shape, 'float32', resultValues);
}
const eluGradConfig$1 = {
kernelName: EluGrad,
backendName: 'cpu',
kernelFunc: eluGrad
};
const p = ERF_P;
const a1 = ERF_A1;
const a2 = ERF_A2;
const a3 = ERF_A3;
const a4 = ERF_A4;
const a5 = ERF_A5;
const erf = unaryKernelFunc$1(Erf, (xi) => {
const sign = Math.sign(xi);
const v = Math.abs(xi);
const t = 1.0 / (1.0 + p * v);
return sign *
(1.0 -
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
Math.exp(-v * v));
});
const erfConfig = {
kernelName: Erf,
backendName: 'cpu',
kernelFunc: erf,
};
function expandDims$1(args) {
const { inputs, backend, attrs } = args;
const { input } = inputs;
const { dim } = attrs;
const inputRank = input.shape.length;
const newShape = input.shape.slice();
let $dim = dim;
if (dim < 0) {
assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
$dim = inputRank + dim + 1;
}
newShape.splice($dim, 0, 1);
return reshape({ inputs: { x: input }, backend, attrs: { shape: newShape } });
}
const expandDimsConfig = {
kernelName: ExpandDims,
backendName: 'cpu',
kernelFunc: expandDims$1
};
const realDivImpl = createSimpleBinaryKernelImpl((a, b) => a / b);
const div = binaryKernelFunc$1(RealDiv, realDivImpl);
const realDivConfig = {
kernelName: RealDiv,
backendName: 'cpu',
kernelFunc: div
};
function fftBatch(input, inverse, cpuBackend) {
const inputShape = input.shape;
const batch = inputShape[0];
const innerDim = inputShape[1];
const inputVals = cpuBackend.data.get(input.dataId);
const real2D = inputVals.complexTensorInfos.real;
const imag2D = inputVals.complexTensorInfos.imag;
const resultShape = [batch, innerDim];
const resultSize = sizeFromShape(resultShape);
const resultReal = getTypedArrayFromDType('float32', resultSize);
const resultImag = getTypedArrayFromDType('float32', resultSize);
for (let b = 0; b < batch; b++) {
const r = slice$1({
inputs: { x: real2D },
backend: cpuBackend,
attrs: { begin: [b, 0], size: [1, innerDim] }
});
const i = slice$1({
inputs: { x: imag2D },
backend: cpuBackend,
attrs: { begin: [b, 0], size: [1, innerDim] }
});
const input = complex$1({ inputs: { real: r, imag: i }, backend: cpuBackend });
const { real, imag } = fftImpl(input, inverse, cpuBackend);
const res = mergeRealAndImagArrays(real, imag);
for (let d = 0; d < innerDim; d++) {
const c = getComplexWithIndex(res, d);
resultReal[b * innerDim + d] = c.real;
resultImag[b * innerDim + d] = c.imag;
}
cpuBackend.disposeIntermediateTensorInfo(r);
cpuBackend.disposeIntermediateTensorInfo(i);
cpuBackend.disposeIntermediateTensorInfo(input);
}
const $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
const $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
const result = complex$1({ inputs: { real: $realInfo, imag: $imagInfo }, backend: cpuBackend });
cpuBackend.disposeIntermediateTensorInfo($realInfo);
cpuBackend.disposeIntermediateTensorInfo($imagInfo);
return result;
}
function fftImpl(input, inverse, cpuBackend) {
const inputSize = sizeFromShape(input.shape);
const inputVals = cpuBackend.data.get(input.dataId);
const realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
const imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
if (isExponentOf2(inputSize)) {
const result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
const resultShape = [input.shape[0], input.shape[1]];
if (inverse) {
const realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
const imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
const sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
const sizeInfoCopy = identity$1({ inputs: { x: sizeInfo }, backend: cpuBackend });
const divRealInfo = realDivConfig.kernelFunc({ inputs: { a: realInfo, b: sizeInfo }, backend: cpuBackend });
const divImagInfo = realDivConfig.kernelFunc({ inputs: { a: imagInfo, b: sizeInfoCopy }, backend: cpuBackend });
const divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
const divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
cpuBackend.disposeIntermediateTensorInfo(realInfo);
cpuBackend.disposeIntermediateTensorInfo(imagInfo);
cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
return { real: divRealVals, imag: divImagVals };
}
return result;
}
else {
const data = mergeRealAndImagArrays(realVals, imagVals);
const rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
return splitRealAndImagArrays(rawOutput);
}
}
function isExponentOf2(size) {
return (size & size - 1) === 0;
}
function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
if (size === 1) {
return { real: realVals, imag: imagVals };
}
const data = mergeRealAndImagArrays(realVals, imagVals);
const half = size / 2;
const evenComplex = complexWithEvenIndex(data);
const evenRealVals = evenComplex.real;
const evenImagVals = evenComplex.imag;
const evenShape = [evenRealVals.length];
const evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
const evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
const evenTensorInfo = complex$1({ inputs: { real: evenRealInfo, imag: evenImagInfo }, backend: cpuBackend });
const oddComplex = complexWithOddIndex(data);
const oddRealVals = oddComplex.real;
const oddImagVals = oddComplex.imag;
const oddShape = [oddRealVals.length];
const oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
const oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
const oddTensorInfo = complex$1({ inputs: { real: oddRealInfo, imag: oddImagInfo }, backend: cpuBackend });
const $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
const $evenRealVals = $evenComplex.real;
const $evenImagVals = $evenComplex.imag;
const $evenShape = [$evenRealVals.length];
const $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
const $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
const $evenTensorInfo = complex$1({
inputs: { real: $evenRealInfo, imag: $evenImagInfo },
backend: cpuBackend
});
const $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
const $oddRealVals = $oddComplex.real;
const $oddImagVals = $oddComplex.imag;
const $oddShape = [$oddRealVals.length];
const $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
const $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
const $oddTensorInfo = complex$1({ inputs: { real: $oddRealInfo, imag: $oddImagInfo }, backend: cpuBackend });
const e = exponents(size, inverse);
const eShape = [e.real.length];
const eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
const eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
const complexInfo = complex$1({ inputs: { real: eRealInfo, imag: eImagInfo }, backend: cpuBackend });
const exponentInfo = multiply$1({ inputs: { a: complexInfo, b: $oddTensorInfo }, backend: cpuBackend });
const addPart = add({
inputs: { a: $evenTensorInfo, b: exponentInfo },
backend: cpuBackend
});
const subPart = sub$1({
inputs: { a: $evenTensorInfo, b: exponentInfo },
backend: cpuBackend
});
const addPartReal = real$1({ inputs: { input: addPart }, backend: cpuBackend });
const subPartReal = real$1({ inputs: { input: subPart }, backend: cpuBackend });
const addPartImag = imag({ inputs: { input: addPart }, backend: cpuBackend });
const subPartImag = imag({ inputs: { input: subPart }, backend: cpuBackend });
const $real = concat({
inputs: [addPartReal, subPartReal],
backend: cpuBackend,
attrs: { axis: 0 }
});
const $imag = concat({
inputs: [addPartImag, subPartImag],
backend: cpuBackend,
attrs: { axis: 0 }
});
const $realVals = cpuBackend.data.get($real.dataId).values;
const $imagVals = cpuBackend.data.get($imag.dataId).values;
cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
cpuBackend.disposeIntermediateTensorInfo(complexInfo);
cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
cpuBackend.disposeIntermediateTensorInfo(addPart);
cpuBackend.disposeIntermediateTensorInfo(subPart);
cpuBackend.disposeIntermediateTensorInfo(addPartReal);
cpuBackend.disposeIntermediateTensorInfo(addPartImag);
cpuBackend.disposeIntermediateTensorInfo(subPartReal);
cpuBackend.disposeIntermediateTensorInfo(subPartImag);
cpuBackend.disposeIntermediateTensorInfo($real);
cpuBackend.disposeIntermediateTensorInfo($imag);
return { real: $realVals, imag: $imagVals };
}
function fourierTransformByMatmul(data, size, inverse) {
const ret = new Float32Array(size * 2);
for (let r = 0; r < size; r++) {
let real = 0.0;
let imag = 0.0;
for (let c = 0; c < size; c++) {
const e = exponent(r * c, size, inverse);
const term = getComplexWithIndex(data, c);
real += term.real * e.real - term.imag * e.imag;
imag += term.real * e.imag + term.imag * e.real;
}
if (inverse) {
real /= size;
imag /= size;
}
assignToTypedArray(ret, real, imag, r);
}
return ret;
}
function fft(args) {
const { inputs, backend } = args;
const { input } = inputs;
const inputSize = sizeFromShape(input.shape);
const innerDimensionSize = input.shape[input.shape.length - 1];
const batch = inputSize / innerDimensionSize;
const input2D = reshape({
inputs: { x: input },
backend,
attrs: { shape: [batch, innerDimensionSize] }
});
const result = fftBatch(input2D, false, backend);
const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
const fftConfig = {
kernelName: FFT,
backendName: 'cpu',
kernelFunc: fft
};
function fill(args) {
const { backend, attrs } = args;
const { shape, value, dtype } = attrs;
const $dtype = dtype || inferDtype(value);
const values = getArrayFromDType($dtype, sizeFromShape(shape));
fillValues(values, value, $dtype);
return backend.makeTensorInfo(shape, $dtype, values);
}
const fillConfig = {
kernelName: Fill,
backendName: 'cpu',
kernelFunc: fill
};
function fillValues(values, value, dtype) {
if (dtype === 'string') {
values.fill(value);
}
else {
values.fill(value);
}
}
const flipLeftRightConfig = {
kernelName: FlipLeftRight,
backendName: 'cpu',
kernelFunc: ({ inputs, attrs, backend }) => {
const { image } = inputs;
const cpuBackend = backend;
const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
const imageVals = cpuBackend.data.get(image.dataId).values;
for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
for (let row = 0; row < imageHeight; row++) {
const rowOffset = row * (imageWidth * numChannels);
for (let col = 0; col < imageWidth; col++) {
const colOffset = col * numChannels;
for (let channel = 0; channel < numChannels; channel++) {
const coordX = Math.round(imageWidth - col - 1);
const outIdx = batchOffset + rowOffset + colOffset + channel;
let outputValue = imageVals[outIdx];
if (coordX >= 0 && coordX < imageWidth) {
const rotatedColOffset = coordX * numChannels;
const imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
output[outIdx] = outputValue;
}
}
}
}
const dataId = cpuBackend.write(output, image.shape, image.dtype);
return { dataId, shape: image.shape, dtype: image.dtype };
}
};
function fusedConv2D(args) {
const { inputs, backend, attrs } = args;
const { x, filter, bias, preluActivationWeights } = inputs;
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
let result = conv2D({
inputs: { x, filter },
backend,
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
});
if (bias) {
const resultOld = result;
if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
bias.shape[0] !== 1) {
const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
result =
add({ inputs: { a: result, b: reshapedBias }, backend });
backend.disposeIntermediateTensorInfo(reshapedBias);
}
else {
result = add({ inputs: { a: result, b: bias }, backend });
}
backend.disposeIntermediateTensorInfo(resultOld);
}
if (activation) {
const resultOld = result;
if (dataFormat === 'NCHW' && activation === 'prelu' &&
preluActivationWeights.shape.length === 1 &&
preluActivationWeights.shape[0] !== 1) {
const reshapedAlpha = reshape({
inputs: { x: preluActivationWeights },
backend,
attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
});
result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
backend.disposeIntermediateTensorInfo(reshapedAlpha);
}
else {
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
}
backend.disposeIntermediateTensorInfo(resultOld);
}
return result;
}
const fusedConv2DConfig = {
kernelName: FusedConv2D,
backendName: 'cpu',
kernelFunc: fusedConv2D
};
function fusedDepthwiseConv2D(args) {
const { inputs, backend, attrs } = args;
const { x, filter, bias, preluActivationWeights } = inputs;
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
let result = depthwiseConv2dNative({
inputs: { x, filter },
backend,
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
});
if (bias) {
const oldResult = result;
result = add({ inputs: { a: result, b: bias }, backend });
backend.disposeIntermediateTensorInfo(oldResult);
}
if (activation) {
const oldResult = result;
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
backend.disposeIntermediateTensorInfo(oldResult);
}
return result;
}
const fusedDepthwiseConv2DConfig = {
kernelName: FusedDepthwiseConv2D,
backendName: 'cpu',
kernelFunc: fusedDepthwiseConv2D
};
function gatherNd(args) {
const { inputs, backend } = args;
const { params, indices } = inputs;
const paramsSize = sizeFromShape(params.shape);
const indicesShape = indices.shape;
const sliceRank = indicesShape[indicesShape.length - 1];
const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
if (numSlices === 0) {
return backend.makeTensorInfo(resultShape, params.dtype, []);
}
const indicesData = backend.data.get(indices.dataId).values;
const paramsBuf = backend.bufferSync(params);
const outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
}
const gatherNdConfig = {
kernelName: GatherNd,
backendName: 'cpu',
kernelFunc: gatherNd
};
function gatherV2(args) {
const { inputs, backend, attrs } = args;
const { x, indices } = inputs;
const { axis, batchDims } = attrs;
assertNotComplex([x, indices], 'gatherV2');
const parsedAxis = parseAxisParam(axis, x.shape)[0];
const indicesVals = backend.data.get(indices.dataId).values;
const axisDim = x.shape[parsedAxis];
for (let i = 0; i < indicesVals.length; ++i) {
const index = indicesVals[i];
assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
}
let $batchDims = batchDims;
if (batchDims == null) {
$batchDims = 0;
}
const indicesSize = sizeFromShape(indices.shape);
const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
const flattenX = reshape({
inputs: { x },
backend,
attrs: {
shape: [
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
shapeInfo.sliceSize
]
}
});
const flattenIndex = reshape({
inputs: { x: indices },
backend,
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
});
const flattenOutputShape = [
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
shapeInfo.sliceSize
];
const indicesBuf = backend.bufferSync(flattenIndex);
const xBuf = backend.bufferSync(flattenX);
const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
backend.disposeIntermediateTensorInfo(flattenX);
backend.disposeIntermediateTensorInfo(flattenIndex);
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
}
const gatherV2Config = {
kernelName: GatherV2,
backendName: 'cpu',
kernelFunc: gatherV2
};
function ifft(args) {
const { inputs, backend } = args;
const { input } = inputs;
const inputSize = sizeFromShape(input.shape);
const innerDimensionSize = input.shape[input.shape.length - 1];
const batch = inputSize / innerDimensionSize;
const input2D = reshape({
inputs: { x: input },
backend,
attrs: { shape: [batch, innerDimensionSize] }
});
const result = fftBatch(input2D, true, backend);
const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
backend.disposeIntermediateTensorInfo(input2D);
backend.disposeIntermediateTensorInfo(result);
return resultReshaped;
}
const ifftConfig = {
kernelName: IFFT,
backendName: 'cpu',
kernelFunc: ifft
};
const isFinite$1 = unaryKernelFunc$1(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool');
const isFiniteConfig = {
kernelName: IsFinite,
backendName: 'cpu',
kernelFunc: isFinite$1,
};
const isInf = unaryKernelFunc$1(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool');
const isInfConfig = {
kernelName: IsInf,
backendName: 'cpu',
kernelFunc: isInf,
};
const isNaN$1 = unaryKernelFunc$1(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool');
const isNaNConfig = {
kernelName: IsNan,
backendName: 'cpu',
kernelFunc: isNaN$1,
};
function linSpace(args) {
const { backend, attrs } = args;
const { start, stop, num } = attrs;
const outVals = linSpaceImpl(start, stop, num);
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
}
const linSpaceConfig = {
kernelName: LinSpace,
backendName: 'cpu',
kernelFunc: linSpace
};
const log1p = unaryKernelFunc$1(Log1p, (xi) => Math.log1p(xi));
const log1pConfig = {
kernelName: Log1p,
backendName: 'cpu',
kernelFunc: log1p,
};
const logicalAndImpl = createSimpleBinaryKernelImpl((a, b) => a && b);
const logicalAnd = binaryKernelFunc$1(LogicalAnd, logicalAndImpl, null , 'bool');
const logicalAndConfig = {
kernelName: LogicalAnd,
backendName: 'cpu',
kernelFunc: logicalAnd
};
const logicalNot = unaryKernelFunc$1(LogicalNot, (xi) => xi ? 0 : 1, 'bool');
const logicalNotConfig = {
kernelName: LogicalNot,
backendName: 'cpu',
kernelFunc: logicalNot,
};
const logicalOrImpl = createSimpleBinaryKernelImpl((a, b) => a || b);
const logicalOr = binaryKernelFunc$1(LogicalOr, logicalOrImpl, null , 'bool');
const logicalOrConfig = {
kernelName: LogicalOr,
backendName: 'cpu',
kernelFunc: logicalOr
};
function lRN(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { depthRadius, bias, alpha, beta } = attrs;
assertNotComplex(x, 'LRN');
const channels = x.shape[3];
const maxD = channels - 1;
const xValues = backend.data.get(x.dataId).values;
const size = sizeFromShape(x.shape);
const result = new Float32Array(size);
function sumAcrossChannels(offset) {
const currentChannel = offset % channels;
let beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
const endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
let sum = 0.0;
for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
const z = xValues[beginSumOffset];
sum += z * z;
}
return sum;
}
for (let offset = 0; offset < size; offset++) {
const sum = sumAcrossChannels(offset);
const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
result[offset] = val;
}
return backend.makeTensorInfo(x.shape, x.dtype, result);
}
const LRNConfig = {
kernelName: LRN,
backendName: 'cpu',
kernelFunc: lRN
};
function lRNGrad(args) {
const { inputs, backend, attrs } = args;
const { x, y, dy } = inputs;
const { depthRadius, bias, alpha, beta } = attrs;
assertNotComplex(dy, 'LRNGrad');
const dySize = sizeFromShape(dy.shape);
const channels = dy.shape[3];
const dyValues = backend.data.get(dy.dataId).values;
const xValues = backend.data.get(x.dataId).values;
const yValues = backend.data.get(y.dataId).values;
const result = new Float32Array(dySize);
const size = dySize;
for (let offset = 0; offset < size; offset++) {
const currentChannel = offset % channels;
const depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);
const depthEnd = (offset - currentChannel) +
Math.min(channels, currentChannel + depthRadius + 1);
let norm = 0;
for (let k = depthBegin; k < depthEnd; k++) {
norm += Math.pow(xValues[k], 2);
}
norm = alpha * norm + bias;
for (let k = depthBegin; k < depthEnd; k++) {
let dyi = -2 * alpha * beta * xValues[k] * yValues[offset] / norm;
if (offset === k) {
dyi += Math.pow(norm, -beta);
}
dyi *= dyValues[offset];
result[k] += dyi;
}
}
return backend.makeTensorInfo(dy.shape, x.dtype, result);
}
const LRNGradConfig = {
kernelName: LRNGrad,
backendName: 'cpu',
kernelFunc: lRNGrad
};
function max(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { reductionIndices, keepDims } = attrs;
const cpuBackend = backend;
let xShape = x.shape;
const xRank = xShape.length;
const origAxes = parseAxisParam(reductionIndices, xShape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, xRank);
let xVals = cpuBackend.data.get(x.dataId).values;
if (permutedAxes != null) {
const newShape = new Array(xRank);
for (let i = 0; i < newShape.length; i++) {
newShape[i] = xShape[permutedAxes[i]];
}
xVals = transposeImpl$1(xVals, xShape, x.dtype, permutedAxes, newShape);
axes = getInnerMostAxes(axes.length, xRank);
xShape = newShape;
}
assertNotComplex(x, 'max');
assertAxesAreInnerMostDims('max', axes, xRank);
const [maxOutShape, reduceShape] = computeOutAndReduceShapes(xShape, axes);
const reduceSize = sizeFromShape(reduceShape);
const result = maxImpl$1(xVals, reduceSize, maxOutShape, x.dtype);
const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
let outShape = maxOutShape;
if (keepDims) {
const newShape = expandShapeToKeepDim(maxOutShape, origAxes);
outShape = newShape;
}
return { dataId, shape: outShape, dtype: x.dtype };
}
const maxConfig = {
kernelName: Max,
backendName: 'cpu',
kernelFunc: max
};
function maxPool(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
assertNotComplex(x, 'maxPool');
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const dilations = 1;
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
`Got strides ${strides} and dilations '${dilations}'`);
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
let res;
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
arraysEqual(convInfo.inShape, convInfo.outShape)) {
res = identity$1({ inputs: { x }, backend });
}
else {
const xValues = backend.data.get(x.dataId).values;
const strides = computeStrides(x.shape);
const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'max');
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
}
return res;
}
const maxPoolConfig = {
kernelName: MaxPool,
backendName: 'cpu',
kernelFunc: maxPool
};
function maxPool3D(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
assertNotComplex(x, 'maxPool3d');
const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode, dataFormat);
const xValues = backend.data.get(x.dataId).values;
const outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'max');
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
}
const maxPool3DConfig = {
kernelName: MaxPool3D,
backendName: 'cpu',
kernelFunc: maxPool3D
};
function maxPool3DGrad(args) {
const { inputs, backend, attrs } = args;
const { dy, input } = inputs;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
assertNotComplex([dy, input], 'maxPool3DGrad');
const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 , pad, dimRoundingMode);
const inputBuf = backend.bufferSync(input);
const maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
const strideDepth = convInfo.strideDepth;
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationDepth = convInfo.dilationDepth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const dx = buffer(input.shape, 'float32');
const dyBuf = backend.bufferSync(dy);
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
const dyDepthCorner = dxDepth - padFront;
const dyRowCorner = dxRow - padTop;
const dyColCorner = dxCol - padLeft;
let dotProd = 0;
for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
Math.floor(dyDepth) !== dyDepth) {
continue;
}
for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
const dyRow = (dyRowCorner + wRow) / strideHeight;
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
Math.floor(dyRow) !== dyRow) {
continue;
}
for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
const dyCol = (dyColCorner + wCol) / strideWidth;
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
Math.floor(dyCol) !== dyCol) {
continue;
}
const maxPos = effectiveFilterDepth * effectiveFilterHeight *
effectiveFilterWidth -
1 -
maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
const curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth +
wRow * effectiveFilterWidth + wCol;
const mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}
const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
dotProd += pixel * mask;
}
}
}
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
}
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const maxPool3DGradConfig$1 = {
kernelName: MaxPool3DGrad,
backendName: 'cpu',
kernelFunc: maxPool3DGrad
};
function maxPoolGrad$1(args) {
const { inputs, backend, attrs } = args;
const { dy, input, output } = inputs;
const x = input;
assertNotComplex([input, output], 'maxPoolGrad');
const { filterSize, strides, pad, dimRoundingMode } = attrs;
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode);
const xValues = backend.data.get(x.dataId).values;
const maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
const strideHeight = convInfo.strideHeight;
const strideWidth = convInfo.strideWidth;
const dilationHeight = convInfo.dilationHeight;
const dilationWidth = convInfo.dilationWidth;
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
const dx = buffer(x.shape, 'float32');
const dyData = backend.data.get(dy.dataId).values;
const dyBuf = buffer(dy.shape, 'float32', dyData);
for (let b = 0; b < convInfo.batchSize; ++b) {
for (let d = 0; d < convInfo.inChannels; ++d) {
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
const dyRCorner = dxR - padTop;
const dyCCorner = dxC - padLeft;
let dotProd = 0;
for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
const dyR = (dyRCorner + wR) / strideHeight;
if (dyR < 0 || dyR >= convInfo.outHeight ||
Math.floor(dyR) !== dyR) {
continue;
}
for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
const dyC = (dyCCorner + wC) / strideWidth;
if (dyC < 0 || dyC >= convInfo.outWidth ||
Math.floor(dyC) !== dyC) {
continue;
}
const maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 -
maxPosBuf.get(b, dyR, dyC, d);
const curPos = wR * effectiveFilterWidth + wC;
const mask = maxPos === curPos ? 1 : 0;
if (mask === 0) {
continue;
}
const pixel = dyBuf.get(b, dyR, dyC, d);
dotProd += pixel * mask;
}
}
dx.set(dotProd, b, dxR, dxC, d);
}
}
}
}
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
}
const maxPoolGradConfig$1 = {
kernelName: MaxPoolGrad,
backendName: 'cpu',
kernelFunc: maxPoolGrad$1
};
function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
const strides = computeStrides(xShape);
const maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max');
const maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
return [maxPools.values, maxPositions.values];
}
const maxPoolWithArgmaxConfig = {
kernelName: MaxPoolWithArgmax,
backendName: 'cpu',
kernelFunc: ({ inputs, attrs, backend }) => {
const { x } = inputs;
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
const cpuBackend = backend;
assertNotComplex(x, 'MaxPoolWithArgmax');
const values = cpuBackend.data.get(x.dataId).values;
const convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
const [pooled, indexes] = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo);
const pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
const indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
return [
{ dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype },
{ dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' }
];
}
};
function mean(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
const axes = parseAxisParam(axis, x.shape);
const shapes = computeOutAndReduceShapes(x.shape, axes);
const reduceShape = shapes[1];
const reduceSize = sizeFromShape(reduceShape);
const toDispose = [];
const reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
toDispose.push(reduceSizeScalar);
const $x = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
toDispose.push($x);
const res = div({ inputs: { a: $x, b: reduceSizeScalar }, backend });
toDispose.push(res);
const result = sum({ inputs: { x: res }, backend, attrs: { axis, keepDims } });
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
}
const meanConfig = {
kernelName: Mean,
backendName: 'cpu',
kernelFunc: mean
};
function min(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { axis, keepDims } = attrs;
assertNotComplex(x, 'min');
const origAxes = parseAxisParam(axis, x.shape);
let axes = origAxes;
const permutedAxes = getAxesPermutation(axes, x.shape.length);
let $x = x;
if (permutedAxes != null) {
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
axes = getInnerMostAxes(axes.length, x.shape.length);
}
assertAxesAreInnerMostDims('min', axes, $x.shape.length);
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
const reduceSize = sizeFromShape(reduceShape);
const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
const aVals = backend.data.get($x.dataId).values;
for (let i = 0; i < vals.length; ++i) {
const offset = i * reduceSize;
let min = aVals[offset];
for (let j = 0; j < reduceSize; ++j) {
const value = aVals[offset + j];
if (Number.isNaN(value) ||
value < min) {
min = value;
}
}
vals[i] = min;
}
if (permutedAxes != null) {
backend.disposeIntermediateTensorInfo($x);
}
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
if (keepDims) {
const expandedShape = expandShapeToKeepDim(outShape, origAxes);
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
backend.disposeIntermediateTensorInfo(result);
return reshapedResult;
}
return result;
}
const minConfig = {
kernelName: Min,
backendName: 'cpu',
kernelFunc: min
};
function mirrorPad(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { paddings, mode } = attrs;
assertNotComplex(x, 'mirrorPad');
const outShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] );
const start = paddings.map(p => p[0]);
const end = paddings.map((p, i) => p[0] + x.shape[i]);
const offset = mode === 'reflect' ? 0 : 1;
const xVals = backend.data.get(x.dataId).values;
const xRank = x.shape.length;
const xStrides = computeStrides(x.shape);
const resultSize = sizeFromShape(outShape);
const resultRank = outShape.length;
const resultStrides = computeStrides(outShape);
const resVals = getTypedArrayFromDType(x.dtype, resultSize);
for (let i = 0; i < resultSize; i++) {
let coords = indexToLoc(i, resultRank, resultStrides);
for (let i = 0; i < resultRank; i++) {
if (coords[i] < start[i]) {
coords[i] = start[i] * 2 - coords[i] - offset;
}
else if (coords[i] >= end[i]) {
coords[i] = (end[i] - 1) * 2 - coords[i] + offset;
}
}
coords = coords.map((c, i) => c - start[i]);
const inIndex = locToIndex(coords, xRank, xStrides);
resVals[i] = xVals[inIndex];
}
const outId = backend.write(resVals, outShape, x.dtype);
return { dataId: outId, shape: outShape, dtype: x.dtype };
}
const mirrorPadConfig = {
kernelName: MirrorPad,
backendName: 'cpu',
kernelFunc: mirrorPad
};
const modImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => {
const rem = aValue % bValue;
if ((aValue < 0 && bValue < 0) || (aValue >= 0 && bValue >= 0)) {
return rem;
}
else {
return (rem + bValue) % bValue;
}
}));
const mod = binaryKernelFunc$1(Mod, modImpl);
const modConfig = {
kernelName: Mod,
backendName: 'cpu',
kernelFunc: mod
};
function softmax(args) {
const { inputs, backend, attrs } = args;
const { logits } = inputs;
const { dim } = attrs;
const logitsRank = logits.shape.length;
let $dim = dim;
if ($dim === -1) {
$dim = logitsRank - 1;
}
if ($dim !== logitsRank - 1) {
throw Error('Softmax along a non-last dimension is not yet supported. ' +
`Logits was rank ${logitsRank} and dim was ${$dim}`);
}
const axes = parseAxisParam([$dim], logits.shape);
const maxLogit = max({
inputs: { x: logits },
backend,
attrs: { reductionIndices: axes, keepDims: false }
});
const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
const maxLogitReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
const a = sub$1({ inputs: { a: logits, b: maxLogitReshaped }, backend });
const b = exp$1({ inputs: { x: a }, backend });
const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
const sumReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
const result = div({ inputs: { a: b, b: sumReshaped }, backend });
backend.disposeIntermediateTensorInfo(maxLogit);
backend.disposeIntermediateTensorInfo(maxLogitReshaped);
backend.disposeIntermediateTensorInfo(a);
backend.disposeIntermediateTensorInfo(b);
backend.disposeIntermediateTensorInfo(sumExp);
backend.disposeIntermediateTensorInfo(sumReshaped);
return result;
}
const softmaxConfig = {
kernelName: Softmax$1,
backendName: 'cpu',
kernelFunc: softmax
};
function multinomial(args) {
const { inputs, backend, attrs } = args;
const { logits } = inputs;
const { numSamples, seed, normalized } = attrs;
assertNotComplex(logits, 'multinomial');
const probabilities = normalized ?
logits :
softmax({ inputs: { logits }, backend, attrs: { dim: -1 } });
const batchSize = probabilities.shape[0];
const numEvents = probabilities.shape[1];
const probVals = backend.data.get(probabilities.dataId).values;
const resShape = [batchSize, numSamples];
const resVals = makeZerosTypedArray(sizeFromShape(resShape), 'int32');
for (let b = 0; b < batchSize; ++b) {
const offset = b * numEvents;
const cdf = new Float32Array(numEvents - 1);
cdf[0] = probVals[offset];
for (let event = 1; event < cdf.length; ++event) {
cdf[event] = cdf[event - 1] + probVals[offset + event];
}
const random = seedrandom.alea(seed.toString());
const outOffset = b * numSamples;
for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
const r = random();
resVals[outOffset + sampleId] = cdf.length;
for (let event = 0; event < cdf.length; event++) {
if (r < cdf[event]) {
resVals[outOffset + sampleId] = event;
break;
}
}
}
}
if (!normalized) {
backend.disposeIntermediateTensorInfo(probabilities);
}
return backend.makeTensorInfo(resShape, 'int32', resVals);
}
const multinomialConfig = {
kernelName: Multinomial,
backendName: 'cpu',
kernelFunc: multinomial
};
const nonMaxSuppressionV3Impl = nonMaxSuppressionV3Impl$2;
function nonMaxSuppressionV3(args) {
const { inputs, backend, attrs } = args;
const { boxes, scores } = inputs;
const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
assertNotComplex(boxes, 'NonMaxSuppression');
const boxesVals = backend.data.get(boxes.dataId).values;
const scoresVals = backend.data.get(scores.dataId).values;
const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
}
const nonMaxSuppressionV3Config = {
kernelName: NonMaxSuppressionV3,
backendName: 'cpu',
kernelFunc: nonMaxSuppressionV3
};
const nonMaxSuppressionV4Impl = nonMaxSuppressionV4Impl$2;
function nonMaxSuppressionV4(args) {
const { inputs, backend, attrs } = args;
const { boxes, scores } = inputs;
const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
assertNotComplex(boxes, 'NonMaxSuppressionPadded');
const boxesVals = backend.data.get(boxes.dataId).values;
const scoresVals = backend.data.get(scores.dataId).values;
const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
return [
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
];
}
const nonMaxSuppressionV4Config = {
kernelName: NonMaxSuppressionV4,
backendName: 'cpu',
kernelFunc: nonMaxSuppressionV4
};
const nonMaxSuppressionV5Impl = nonMaxSuppressionV5Impl$2;
function nonMaxSuppressionV5(args) {
const { inputs, backend, attrs } = args;
const { boxes, scores } = inputs;
const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
assertNotComplex(boxes, 'NonMaxSuppressionWithScore');
const boxesVals = backend.data.get(boxes.dataId).values;
const scoresVals = backend.data.get(scores.dataId).values;
const maxOutputSizeVal = maxOutputSize;
const iouThresholdVal = iouThreshold;
const scoreThresholdVal = scoreThreshold;
const softNmsSigmaVal = softNmsSigma;
const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
return [
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
];
}
const nonMaxSuppressionV5Config = {
kernelName: NonMaxSuppressionV5,
backendName: 'cpu',
kernelFunc: nonMaxSuppressionV5
};
function oneHot(args) {
const { inputs, backend, attrs } = args;
const { indices } = inputs;
const { dtype, depth, onValue, offValue } = attrs;
assertNotComplex(indices, 'oneHot');
const indicesSize = sizeFromShape(indices.shape);
const res = new Float32Array(indicesSize * depth);
res.fill(offValue);
const indicesVal = backend.data.get(indices.dataId).values;
for (let event = 0; event < indicesSize; ++event) {
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
res[event * depth + indicesVal[event]] = onValue;
}
}
return backend.makeTensorInfo([...indices.shape, depth], dtype, res);
}
const oneHotConfig = {
kernelName: OneHot,
backendName: 'cpu',
kernelFunc: oneHot
};
function zerosLike(args) {
const { inputs, backend } = args;
const { x } = inputs;
if (x.dtype === 'string') {
throw new Error('zerosLike is not supported for string tensors');
}
else if (x.dtype === 'complex64') {
const realPart = real$1({ inputs: { input: x }, backend });
const r = zerosLike({ inputs: { x: realPart }, backend });
const imagPart = imag({ inputs: { input: x }, backend });
const i = zerosLike({ inputs: { x: imagPart }, backend });
const result = complex$1({ inputs: { real: r, imag: i }, backend });
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
}
else {
return fill({ backend, attrs: { shape: x.shape, value: 0, dtype: x.dtype } });
}
}
const zerosLikeConfig = {
kernelName: ZerosLike,
backendName: 'cpu',
kernelFunc: zerosLike
};
function onesLike(args) {
const { inputs, backend } = args;
const { x } = inputs;
if (x.dtype === 'string') {
throw new Error('onesLike is not supported for string tensors');
}
else if (x.dtype === 'complex64') {
const realPart = real$1({ inputs: { input: x }, backend });
const r = onesLike({ inputs: { x: realPart }, backend });
const imagPart = imag({ inputs: { input: x }, backend });
const i = zerosLike({ inputs: { x: imagPart }, backend });
const result = complex$1({ inputs: { real: r, imag: i }, backend });
backend.disposeIntermediateTensorInfo(realPart);
backend.disposeIntermediateTensorInfo(r);
backend.disposeIntermediateTensorInfo(imagPart);
backend.disposeIntermediateTensorInfo(i);
return result;
}
else {
return fill({ backend, attrs: { shape: x.shape, value: 1, dtype: x.dtype } });
}
}
const onesLikeConfig = {
kernelName: OnesLike,
backendName: 'cpu',
kernelFunc: onesLike
};
function pack(args) {
const { inputs, backend, attrs } = args;
const { axis } = attrs;
if (inputs.length === 1) {
return expandDims$1({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
}
const shape = inputs[0].shape;
const dtype = inputs[0].dtype;
inputs.forEach(t => {
assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
});
const intermediateTensorInfos = [];
const expandedTensors = inputs.map(t => {
const expandedT = expandDims$1({ inputs: { input: t }, backend, attrs: { dim: axis } });
intermediateTensorInfos.push(expandedT);
return expandedT;
});
const result = concat({ inputs: expandedTensors, backend, attrs: { axis } });
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
}
const packConfig = {
kernelName: Pack,
backendName: 'cpu',
kernelFunc: pack
};
function padV2(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { paddings, constantValue } = attrs;
assertNotComplex(x, 'pad');
const outShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] );
const start = paddings.map(p => p[0]);
const xVals = backend.data.get(x.dataId).values;
const xSize = sizeFromShape(x.shape);
const xRank = x.shape.length;
const xStrides = computeStrides(x.shape);
const resultSize = sizeFromShape(outShape);
const resultRank = outShape.length;
const resultStrides = computeStrides(outShape);
const resVals = getTypedArrayFromDType(x.dtype, resultSize);
if (constantValue !== 0) {
resVals.fill(constantValue);
}
for (let i = 0; i < xSize; i++) {
const coords = indexToLoc(i, xRank, xStrides);
const outCoords = coords.map((c, i) => c + start[i]);
const outIndex = locToIndex(outCoords, resultRank, resultStrides);
resVals[outIndex] = xVals[i];
}
const outId = backend.write(resVals, outShape, x.dtype);
return { dataId: outId, shape: outShape, dtype: x.dtype };
}
const padV2Config = {
kernelName: PadV2,
backendName: 'cpu',
kernelFunc: padV2
};
const powImpl = createSimpleBinaryKernelImpl((a, b) => Math.pow(a, b));
const pow = binaryKernelFunc$1(Pow, powImpl);
const powConfig = {
kernelName: Pow,
backendName: 'cpu',
kernelFunc: pow
};
function raggedGather(args) {
const { inputs, backend} = args;
const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
const $paramsNestedSplits = paramsNestedSplits.map(t => backend.data.get(t.dataId).values);
const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
const $paramsDenseValues = backend.data.get(paramsDenseValues.dataId).values;
const $indices = backend.data.get(indices.dataId).values;
const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImpl($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape);
const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
}
const raggedGatherConfig = {
kernelName: RaggedGather,
backendName: 'cpu',
kernelFunc: raggedGather,
};
function raggedRange(args) {
const { inputs, backend } = args;
const { starts, limits, deltas } = inputs;
const $starts = backend.data.get(starts.dataId).values;
const $limits = backend.data.get(limits.dataId).values;
const $deltas = backend.data.get(deltas.dataId).values;
const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImpl($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
return [rtNestedSplits, rtDenseValues];
}
const raggedRangeConfig = {
kernelName: RaggedRange,
backendName: 'cpu',
kernelFunc: raggedRange,
};
function raggedTensorToTensor(args) {
const { inputs, backend, attrs } = args;
const { shape, values, defaultValue, rowPartitionTensors } = inputs;
const { rowPartitionTypes } = attrs;
const $shape = backend.data.get(shape.dataId).values;
const $values = backend.data.get(values.dataId).values;
const $defaultValue = backend.data.get(defaultValue.dataId).values;
const $rowPartitionValues = rowPartitionTensors.map(t => backend.data.get(t.dataId).values);
const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
const [outputShape, output] = raggedTensorToTensorImpl($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
return backend.makeTensorInfo(outputShape, values.dtype, output);
}
const raggedTensorToTensorConfig = {
kernelName: RaggedTensorToTensor,
backendName: 'cpu',
kernelFunc: raggedTensorToTensor,
};
function range$1(args) {
const { backend, attrs } = args;
const { start, stop, dtype, step } = attrs;
const values = rangeImpl(start, stop, step, dtype);
return backend.makeTensorInfo([values.length], dtype, values);
}
const rangeConfig = {
kernelName: Range,
backendName: 'cpu',
kernelFunc: range$1
};
const reciprocal = unaryKernelFunc$1(Reciprocal, (xi) => 1 / xi);
const reciprocalConfig = {
kernelName: Reciprocal,
backendName: 'cpu',
kernelFunc: reciprocal,
};
function resizeBilinear(args) {
const { inputs, backend, attrs } = args;
const { images } = inputs;
const { alignCorners, halfPixelCenters, size } = attrs;
assertNotComplex(images, 'resizeBilinear');
const imagesStrides = computeStrides(images.shape);
const [newHeight, newWidth] = size;
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
const xValues = backend.data.get(images.dataId).values;
const result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
const effectiveInputSize = [
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
];
const effectiveOutputSize = [
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
];
let outputIdx = 0;
const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
for (let b = 0; b < batch; b++) {
for (let r = 0; r < newHeight; r++) {
let sourceFracRow;
if (halfPixelCenters) {
sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
}
else {
sourceFracRow = effectiveRowSizeRatio * r;
}
const sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
const rowFrac = sourceFracRow - sourceRowFloor;
const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
const topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
const botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
for (let c = 0; c < newWidth; c++) {
let sourceFracCol;
if (halfPixelCenters) {
sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
}
else {
sourceFracCol = effectiveColSizeRatio * c;
}
const sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
const colFrac = sourceFracCol - sourceColFloor;
const sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
const topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
const botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
const topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
const botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
for (let d = 0; d < numChannels; d++) {
const topLeft = xValues[topLeftOffest + d];
const bottomLeft = xValues[botLeftOffset + d];
const topRight = xValues[topRightOffset + d];
const bottomRight = xValues[botRightOffest + d];
const top = topLeft + (topRight - topLeft) * colFrac;
const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
const newValue = top + (bottom - top) * rowFrac;
result[outputIdx++] = newValue;
}
}
}
}
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
}
const resizeBilinearConfig = {
kernelName: ResizeBilinear,
backendName: 'cpu',
kernelFunc: resizeBilinear
};
function resizeBilinearGrad(args) {
const { inputs, backend, attrs } = args;
const { images, dy } = inputs;
const { alignCorners } = attrs;
assertNotComplex([dy, images], 'resizeBilinearGrad');
const imagesStrides = computeStrides(images.shape);
const [batch, xHeight, xWidth, depth] = images.shape;
const [, yHeight, yWidth] = dy.shape;
const output = new Float32Array(batch * xHeight * xWidth * depth);
const effectiveXSize = [
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
];
const effectiveYSize = [
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
];
const heightScale = effectiveXSize[0] / effectiveYSize[0];
const widthScale = effectiveXSize[1] / effectiveYSize[1];
const dyValues = backend.data.get(dy.dataId).values;
let offset = 0;
for (let b = 0; b < batch; b++) {
const bOffset = b * imagesStrides[0];
for (let r = 0; r < yHeight; r++) {
const dxR = r * heightScale;
const topDxRIndex = Math.floor(dxR);
const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
const topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
const bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
const dxRLerp = dxR - topDxRIndex;
const inverseDxRLerp = 1.0 - dxRLerp;
for (let c = 0; c < yWidth; c++) {
const dxC = c * widthScale;
const leftDxCIndex = Math.floor(dxC);
const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
const dxCLerp = dxC - leftDxCIndex;
const inverseDxCLerp = 1.0 - dxCLerp;
const topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
const topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
const bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
const bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
const inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
for (let d = 0; d < depth; d++) {
const dyVal = dyValues[offset++];
output[topLeftRCOffset + d] +=
dyVal * inverseDxRLerpTimesInverseDxCLerp;
output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
}
}
}
}
return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
}
const resizeBilinearGradConfig$1 = {
kernelName: ResizeBilinearGrad,
backendName: 'cpu',
kernelFunc: resizeBilinearGrad
};
function resizeNearestNeighbor(args) {
const { inputs, backend, attrs } = args;
const { images } = inputs;
const { alignCorners, halfPixelCenters, size } = attrs;
assertNotComplex(images, 'resizeNearestNeighbor');
const imagesStrides = computeStrides(images.shape);
const [newHeight, newWidth] = size;
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
const xValues = backend.data.get(images.dataId).values;
const output = new Float32Array(batch * newHeight * newWidth * numChannels);
const effectiveInputSize = [
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
];
const effectiveOutputSize = [
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
];
const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
let outputOffset = 0;
for (let b = 0; b < batch; b++) {
const batchOffset = b * imagesStrides[0];
for (let r = 0; r < newHeight; r++) {
const sourceFracRow = halfPixelCenters ?
effectiveRowSizeRatio * (r + 0.5) :
effectiveRowSizeRatio * r;
let sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
if (halfPixelCenters) {
sourceNearestRow = Math.max(0, sourceNearestRow);
}
const rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
for (let c = 0; c < newWidth; c++) {
const sourceFracCol = halfPixelCenters ?
effectiveColSizeRatio * (c + 0.5) :
effectiveColSizeRatio * c;
let sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) :
Math.floor(sourceFracCol));
if (halfPixelCenters) {
sourceNearestCol = Math.max(0, sourceNearestCol);
}
const colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
for (let d = 0; d < numChannels; d++) {
const newVal = xValues[colOffset + d];
output[outputOffset++] = newVal;
}
}
}
}
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
}
const resizeNearestNeighborConfig = {
kernelName: ResizeNearestNeighbor,
backendName: 'cpu',
kernelFunc: resizeNearestNeighbor
};
function resizeNearestNeighborGrad(args) {
const { inputs, backend, attrs } = args;
const { images, dy } = inputs;
const { alignCorners } = attrs;
assertNotComplex([dy, images], 'resizeNearestNeighborGrad');
const imagesStrides = computeStrides(images.shape);
const dyStrides = computeStrides(dy.shape);
const [batch, xHeight, xWidth, depth] = images.shape;
const [, yHeight, yWidth] = dy.shape;
const output = new Float32Array(batch * xHeight * xWidth * depth);
const dyValues = backend.data.get(dy.dataId).values;
const effectiveXSize = [
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
];
const effectiveYSize = [
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
];
const heightScale = effectiveXSize[0] / effectiveYSize[0];
const widthScale = effectiveXSize[1] / effectiveYSize[1];
const invHeightScale = 1 / heightScale;
const invWidthScale = 1 / widthScale;
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
for (let b = 0; b < batch; b++) {
const batchOffset = b * imagesStrides[0];
for (let r = 0; r < xHeight; r++) {
const rowOffset = batchOffset + r * imagesStrides[1];
const startRLerp = Math.floor(r * invHeightScale);
const startDyR = Math.floor(startRLerp - (winHeight / 2));
for (let c = 0; c < xWidth; c++) {
const colOffset = rowOffset + c * imagesStrides[2];
const startCLerp = Math.floor(c * invWidthScale);
const startDyC = Math.floor(startCLerp - (winWidth / 2));
for (let d = 0; d < depth; d++) {
let accum = 0;
for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
const dyR = dyRIndex + startDyR;
if (dyR < 0 || dyR >= yHeight) {
continue;
}
const dyROffset = batchOffset + dyR * dyStrides[1];
const sourceFracRow = dyR * heightScale;
const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
Math.floor(sourceFracRow));
if (r !== sourceNearestRow) {
continue;
}
for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
const dyC = dyCIndex + startDyC;
if (dyC < 0 || dyC >= yWidth) {
continue;
}
const dyCOffset = dyROffset + dyC * dyStrides[2];
const sourceFracCol = dyC * widthScale;
const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
Math.floor(sourceFracCol));
if (c === sourceNearestCol) {
accum += dyValues[dyCOffset + d];
}
}
}
output[colOffset + d] = accum;
}
}
}
}
return backend.makeTensorInfo(images.shape, images.dtype, output);
}
const resizeNearestNeighborGradConfig$1 = {
kernelName: ResizeNearestNeighborGrad,
backendName: 'cpu',
kernelFunc: resizeNearestNeighborGrad
};
function reverse(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { dims } = attrs;
assertNotComplex(x, 'reverse');
const xRank = x.shape.length;
const $dims = parseAxisParam(dims, x.shape);
if (xRank === 0) {
return identity$1({ inputs: { x }, backend });
}
const outBuf = new TensorBuffer(x.shape, x.dtype);
const xBuf = backend.bufferSync(x);
for (let i = 0; i < outBuf.size; i++) {
const outLoc = outBuf.indexToLoc(i);
const inLoc = outLoc.slice();
$dims.forEach(d => inLoc[d] = x.shape[d] - 1 - inLoc[d]);
outBuf.set(xBuf.get(...inLoc), ...outLoc);
}
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
const reverseConfig = {
kernelName: Reverse,
backendName: 'cpu',
kernelFunc: reverse
};
const rotateWithOffsetConfig = {
kernelName: RotateWithOffset,
backendName: 'cpu',
kernelFunc: ({ inputs, attrs, backend }) => {
const { image } = inputs;
const { radians, fillValue, center } = attrs;
const cpuBackend = backend;
const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
const [centerX, centerY] = getImageCenter(center, imageHeight, imageWidth);
const fullOpacityValue = 255;
const sinFactor = Math.sin(radians);
const cosFactor = Math.cos(radians);
const imageVals = cpuBackend.data.get(image.dataId).values;
for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
for (let row = 0; row < imageHeight; row++) {
const rowOffset = row * (imageWidth * numChannels);
for (let col = 0; col < imageWidth; col++) {
const colOffset = col * numChannels;
for (let channel = 0; channel < numChannels; channel++) {
const coords = [batch, row, col, channel];
const x = coords[2];
const y = coords[1];
let coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
let coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
coordX = Math.round(coordX + centerX);
coordY = Math.round(coordY + centerY);
let outputValue = fillValue;
if (typeof fillValue !== 'number') {
if (channel === 3) {
outputValue = fullOpacityValue;
}
else {
outputValue = fillValue[channel];
}
}
if (coordX >= 0 && coordX < imageWidth && coordY >= 0 &&
coordY < imageHeight) {
const rotatedRowOffset = coordY * (imageWidth * numChannels);
const rotatedColOffset = coordX * numChannels;
const imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
outputValue = imageVals[imageIdx];
}
const outIdx = batchOffset + rowOffset + colOffset + channel;
output[outIdx] = outputValue;
}
}
}
}
const dataId = cpuBackend.write(output, image.shape, image.dtype);
return { dataId, shape: image.shape, dtype: image.dtype };
}
};
const round = unaryKernelFunc$1(Round, (xi) => {
const base = Math.floor(xi);
if (xi - base < 0.5) {
return Math.floor(xi);
}
else if (xi - base > 0.5) {
return Math.ceil(xi);
}
else {
if (base % 2.0 === 0.0) {
return base;
}
else {
return base + 1.0;
}
}
});
const roundConfig = {
kernelName: Round,
backendName: 'cpu',
kernelFunc: round,
};
function scatterNd(args) {
const { inputs, backend, attrs } = args;
const { indices, updates } = inputs;
const { shape } = attrs;
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
const sumDupeIndices = true;
const indicesBuf = backend.bufferSync(indices);
const updatesBuf = backend.bufferSync(updates);
const outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0 , sumDupeIndices);
return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
}
const scatterNdConfig = {
kernelName: ScatterNd,
backendName: 'cpu',
kernelFunc: scatterNd
};
function lowerBound(array, value) {
let left = 0;
let right = array.length;
let mid = 0;
while (left < right) {
mid = Math.floor((left + right) / 2);
if (array[mid] < value) {
left = mid + 1;
}
else {
right = mid;
}
}
return right;
}
function upperBound(array, value) {
let left = 0;
let right = array.length;
let mid = 0;
while (left < right) {
mid = Math.floor((left + right) / 2);
if (array[mid] <= value) {
left = mid + 1;
}
else {
right = mid;
}
}
return right;
}
function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
const output = getArrayFromDType('int32', batchSize * numValues);
for (let b = 0; b < batchSize; ++b) {
const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
const outputOffset = b * numValues;
for (let i = 0; i < numValues; ++i) {
output[outputOffset + i] = side === 'left' ?
lowerBound(sortedInputsSlice, values[i + outputOffset]) :
upperBound(sortedInputsSlice, values[i + outputOffset]);
}
}
return output;
}
function searchSorted(args) {
const { inputs, backend, attrs } = args;
const { sortedSequence, values } = inputs;
const { side } = attrs;
const $sortedSequence = backend.data.get(sortedSequence.dataId).values;
const $values = backend.data.get(values.dataId).values;
const output = searchSortedImpl($sortedSequence, $values, sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
return backend.makeTensorInfo(values.shape, 'int32', output);
}
const searchSortedConfig = {
kernelName: SearchSorted,
backendName: 'cpu',
kernelFunc: searchSorted,
};
function select(args) {
const { inputs, backend } = args;
const { condition, t, e } = inputs;
assertNotComplex([condition, t, e], 'select');
const conditionRank = condition.shape.length;
const values = backend.data.get(condition.dataId).values;
const tValues = backend.data.get(t.dataId).values;
const eValues = backend.data.get(e.dataId).values;
const resultDtype = upcastType(t.dtype, e.dtype);
const newValues = makeZerosTypedArray(sizeFromShape(t.shape), resultDtype);
let index = 0;
const offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ?
1 :
sizeFromShape(t.shape.slice(1));
for (let i = 0; i < values.length; i++) {
for (let j = 0; j < offset; j++) {
if (values[i] === 1) {
newValues[index++] = tValues[i];
}
else {
newValues[index++] = eValues[i];
}
}
}
return backend.makeTensorInfo(t.shape, resultDtype, newValues);
}
const selectConfig = {
kernelName: Select,
backendName: 'cpu',
kernelFunc: select
};
const scaleAlpha = SELU_SCALEALPHA;
const scale = SELU_SCALE;
const selu = unaryKernelFunc$1(Selu$1, (xi) => {
if (xi >= 0) {
return scale * xi;
}
else {
return scaleAlpha * (Math.exp(xi) - 1);
}
});
const seluConfig = {
kernelName: Selu$1,
backendName: 'cpu',
kernelFunc: selu,
};
const sign = unaryKernelFunc$1(Sign, (xi) => {
if (xi < 0) {
return -1;
}
else if (xi > 0) {
return 1;
}
else {
return 0;
}
});
const signConfig = {
kernelName: Sign,
backendName: 'cpu',
kernelFunc: sign,
};
const sin = unaryKernelFunc$1(Sin, (xi) => Math.sin(xi));
const sinConfig = {
kernelName: Sin,
backendName: 'cpu',
kernelFunc: sin,
};
const sinh = unaryKernelFunc$1(Sinh, (xi) => Math.sinh(xi));
const sinhConfig = {
kernelName: Sinh,
backendName: 'cpu',
kernelFunc: sinh,
};
const epsilon$1 = 1.1920928955078125e-7;
const threshold = Math.log(epsilon$1) + 2.0;
const softplus = unaryKernelFunc$1(Softplus$1, (xi) => {
const tooLarge = xi > -threshold;
const tooSmall = xi < threshold;
const expX = Math.exp(xi);
let result;
if (tooSmall) {
result = expX;
}
else if (tooLarge) {
result = xi;
}
else {
result = Math.log(1.0 + expX);
}
return result;
});
const softplusConfig = {
kernelName: Softplus$1,
backendName: 'cpu',
kernelFunc: softplus,
};
function spaceToBatchND(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { blockShape, paddings } = attrs;
assertNotComplex([x], 'spaceToBatchND');
const prod = sizeFromShape(blockShape);
const completePaddings = [[0, 0]];
completePaddings.push(...paddings);
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
completePaddings.push([0, 0]);
}
const paddedX = padV2Config.kernelFunc({
inputs: { x },
backend,
attrs: { paddings: completePaddings, constantValue: 0 }
});
const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
const reshapeInputs = { x: paddedX };
const reshapeAttrs = { shape: reshapedPaddedShape };
const paddedXReshaped = reshape({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
const transposeInputs = { x: paddedXReshaped };
const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
const paddedXT = transpose$1({ inputs: transposeInputs, backend, attrs: transposeAttrs });
const resultReshapeInputs = { x: paddedXT };
const resultReshapeAttrs = { shape: flattenShape };
const result = reshape({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
backend.disposeIntermediateTensorInfo(paddedX);
backend.disposeIntermediateTensorInfo(paddedXReshaped);
backend.disposeIntermediateTensorInfo(paddedXT);
return result;
}
const spaceToBatchNDConfig = {
kernelName: SpaceToBatchND,
backendName: 'cpu',
kernelFunc: spaceToBatchND
};
function sparseFillEmptyRows(args) {
const { inputs, backend } = args;
const { indices, values, denseShape, defaultValue } = inputs;
if (denseShape.shape.length !== 1) {
throw new Error(`Dense shape must be a vector, saw:
${denseShape.shape}`);
}
if (indices.shape.length !== 2) {
throw new Error(`Indices must be a matrix, saw:
${indices.shape}`);
}
if (values.shape.length !== 1) {
throw new Error(`Values must be a vector, saw:
${values.shape}`);
}
if (defaultValue.shape.length !== 0) {
throw new Error(`Default value must be a scalar, saw:
${defaultValue.shape}`);
}
const $indices = backend.data.get(indices.dataId).values;
const $values = backend.data.get(values.dataId).values;
const $denseShape = backend.data.get(denseShape.dataId).values;
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
return [
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
];
}
const sparseFillEmptyRowsConfig = {
kernelName: SparseFillEmptyRows,
backendName: 'cpu',
kernelFunc: sparseFillEmptyRows,
};
function sparseReshape(args) {
const { inputs, backend } = args;
const { inputIndices, inputShape, newShape } = inputs;
if (inputIndices.shape.length !== 2) {
throw new Error(`Input indices should be a matrix but received shape
${inputIndices.shape}`);
}
if (inputShape.shape.length !== 1) {
throw new Error(`Input shape should be a vector but received shape
${inputShape.shape}`);
}
if (newShape.shape.length !== 1) {
throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
}
const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
const $inputIndices = backend.data.get(inputIndices.dataId).values;
const targetShape = Array.from(backend.data.get(newShape.dataId).values);
const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
return [
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
];
}
const sparseReshapeConfig = {
kernelName: SparseReshape,
backendName: 'cpu',
kernelFunc: sparseReshape,
};
function sparseSegmentMean(args) {
const { inputs, backend } = args;
const { data, indices, segmentIds } = inputs;
if (data.shape.length < 1) {
throw new Error(`Data should be at least 1 dimensional but received scalar`);
}
if (indices.shape.length !== 1) {
throw new Error(`Indices should be a vector but received shape
${indices.shape}`);
}
if (segmentIds.shape.length !== 1) {
throw new Error(`Segment ids should be a vector but received shape
${segmentIds.shape}`);
}
if (indices.shape[0] !== segmentIds.shape[0]) {
throw new Error(`segmentIds and indices should have same size.`);
}
const $data = backend.data.get(data.dataId).values;
const $indices = backend.data.get(indices.dataId).values;
const $segmentIds = backend.data.get(segmentIds.dataId).values;
const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true);
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
const sparseSegmentMeanConfig = {
kernelName: SparseSegmentMean,
backendName: 'cpu',
kernelFunc: sparseSegmentMean,
};
function sparseSegmentSum(args) {
const { inputs, backend } = args;
const { data, indices, segmentIds } = inputs;
if (data.shape.length < 1) {
throw new Error(`Data should be at least 1 dimensional but received scalar`);
}
if (indices.shape.length !== 1) {
throw new Error(`Indices should be a vector but received shape
${indices.shape}`);
}
if (segmentIds.shape.length !== 1) {
throw new Error(`Segment ids should be a vector but received shape
${segmentIds.shape}`);
}
if (indices.shape[0] !== segmentIds.shape[0]) {
throw new Error(`segmentIds and indices should have same size.`);
}
const $data = backend.data.get(data.dataId).values;
const $indices = backend.data.get(indices.dataId).values;
const $segmentIds = backend.data.get(segmentIds.dataId).values;
const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds);
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
}
const sparseSegmentSumConfig = {
kernelName: SparseSegmentSum,
backendName: 'cpu',
kernelFunc: sparseSegmentSum,
};
function sparseToDense(args) {
const { inputs, backend, attrs } = args;
const { sparseIndices, sparseValues, defaultValue } = inputs;
const { outputShape } = attrs;
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
const sumDupeIndices = false;
const indicesBuf = backend.bufferSync(sparseIndices);
let outBuf;
switch (sparseValues.dtype) {
case 'bool': {
const updatesBuf = backend.bufferSync(sparseValues);
const $defaultValue = Boolean(backend.data.get(defaultValue.dataId).values[0]);
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
break;
}
case 'float32': {
const updatesBuf = backend.bufferSync(sparseValues);
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
break;
}
case 'int32': {
const updatesBuf = backend.bufferSync(sparseValues);
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
break;
}
case 'string': {
const updatesBuf = backend.bufferSync(sparseValues);
const $defaultValue = decodeString(backend.data.get(defaultValue.dataId).values[0]);
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
break;
}
default:
throw new Error(`Unsupported type ${sparseValues.dtype}`);
}
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
}
const sparseToDenseConfig = {
kernelName: SparseToDense,
backendName: 'cpu',
kernelFunc: sparseToDense
};
function splitV(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { numOrSizeSplits, axis } = attrs;
const $axis = parseAxisParam(axis, x.shape)[0];
const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
const begin = new Array(x.shape.length).fill(0);
const size = x.shape.slice();
return splitSizes.map(s => {
const sliceSize = [...size];
sliceSize[$axis] = s;
const sliceT = slice$1({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
begin[$axis] += s;
return sliceT;
});
}
const splitVConfig = {
kernelName: SplitV,
backendName: 'cpu',
kernelFunc: splitV
};
const squareConfig = {
kernelName: Square,
backendName: 'cpu',
kernelFunc: ({ inputs, backend }) => {
const { x } = inputs;
const cpuBackend = backend;
assertNotComplex(x, 'square');
const values = cpuBackend.data.get(x.dataId).values;
const newValues = new Float32Array(values.length);
for (let i = 0; i < values.length; ++i) {
const value = values[i];
newValues[i] = value * value;
}
const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
return { dataId, shape: x.shape, dtype: x.dtype };
}
};
const step = unaryKernelFunc$1(Step, (xi, attrs) => {
const stepAttrs = attrs;
if (isNaN(xi)) {
return NaN;
}
else {
return xi > 0 ? 1 : stepAttrs.alpha;
}
});
const stepConfig = {
kernelName: Step,
backendName: 'cpu',
kernelFunc: step,
};
function stridedSlice(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
assertNotComplex(x, 'stridedSlice');
const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
let result;
if (isIdentity) {
result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
}
else if (sliceDim0 || isSimpleSlice) {
assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
const size = computeOutShape$2($begin, $end, $strides);
const sliced = slice$1({ inputs: { x }, backend, attrs: { begin: $begin, size } });
result =
reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
backend.disposeIntermediateTensorInfo(sliced);
}
else {
const xBuf = backend.bufferSync(x);
const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
}
return result;
}
const stridedSliceConfig = {
kernelName: StridedSlice,
backendName: 'cpu',
kernelFunc: stridedSlice
};
function stringNGrams(args) {
const { inputs, backend, attrs } = args;
const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
const { data, dataSplits } = inputs;
const $data = backend.data.get(data.dataId).values;
const $dataSplits = backend.data.get(dataSplits.dataId).values;
const [nGrams, nGramsSplits] = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
return [
backend.makeTensorInfo([nGrams.length], 'string', nGrams),
backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
];
}
const stringNGramsConfig = {
kernelName: StringNGrams,
backendName: 'cpu',
kernelFunc: stringNGrams,
};
function stringSplit(args) {
const { inputs, backend, attrs } = args;
const { skipEmpty } = attrs;
const { input, delimiter } = inputs;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (input.shape.length !== 1) {
throw new Error(`Input must be a vector, got shape: ${input.shape}`);
}
if (delimiter.shape.length !== 0) {
throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
}
const $input = backend.data.get(input.dataId).values;
const $delimiter = backend.data.get(delimiter.dataId).values[0];
const [indices, values, shape] = stringSplitImpl($input, $delimiter, skipEmpty);
const outputSize = values.length;
return [
backend.makeTensorInfo([outputSize, 2], 'int32', indices),
backend.makeTensorInfo([outputSize], 'string', values),
backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
];
}
const stringSplitConfig = {
kernelName: StringSplit,
backendName: 'cpu',
kernelFunc: stringSplit,
};
function stringToHashBucketFast(args) {
const { inputs, backend, attrs } = args;
const { numBuckets } = attrs;
const { input } = inputs;
if (input.dtype !== 'string') {
throw new Error('Input must be of datatype string');
}
if (numBuckets <= 0) {
throw new Error(`Number of buckets must be at least 1`);
}
const $input = backend.data.get(input.dataId).values;
const output = stringToHashBucketFastImpl($input, numBuckets);
return backend.makeTensorInfo(input.shape, 'int32', output);
}
const stringToHashBucketFastConfig = {
kernelName: StringToHashBucketFast,
backendName: 'cpu',
kernelFunc: stringToHashBucketFast,
};
const tan = unaryKernelFunc$1(Tan, (xi) => Math.tan(xi));
const tanConfig = {
kernelName: Tan,
backendName: 'cpu',
kernelFunc: tan,
};
const tanh = unaryKernelFunc$1(Tanh$1, (xi) => Math.tanh(xi));
const tanhConfig = {
kernelName: Tanh$1,
backendName: 'cpu',
kernelFunc: tanh,
};
function tensorScatterUpdate(args) {
const { inputs, backend } = args;
const { tensor, indices, updates } = inputs;
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape);
const sumDupeIndices = false;
const indicesBuf = backend.bufferSync(indices);
const updatesBuf = backend.bufferSync(updates);
const tensorBuf = backend.bufferSync(tensor);
const outBuf = scatterImpl(indicesBuf, updatesBuf, tensor.shape, outputSize, sliceSize, numUpdates, sliceRank, strides, tensorBuf, sumDupeIndices);
return backend.makeTensorInfo(tensor.shape, outBuf.dtype, outBuf.values);
}
const tensorScatterUpdateConfig = {
kernelName: TensorScatterUpdate,
backendName: 'cpu',
kernelFunc: tensorScatterUpdate
};
function tile$1(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { reps } = attrs;
assertNotComplex(x, 'tile');
const outBuf = tileImpl(backend.bufferSync(x), reps);
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
}
const tileConfig = {
kernelName: Tile,
backendName: 'cpu',
kernelFunc: tile$1
};
function topK(args) {
const { inputs, backend, attrs } = args;
const { x } = inputs;
const { k, sorted } = attrs;
assertNotComplex(x, 'topk');
const xVals = backend.data.get(x.dataId).values;
const [allTopKVals, allTopKIndices] = topKImpl(xVals, x.shape, x.dtype, k, sorted);
return [
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
];
}
const topKConfig = {
kernelName: TopK,
backendName: 'cpu',
kernelFunc: topK
};
function transform(args) {
const { inputs, attrs, backend } = args;
const { image, transforms } = inputs;
const { interpolation, fillMode, fillValue, outputShape } = attrs;
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
const outShape = [batch, outHeight, outWidth, numChannels];
const inStrides = computeStrides(image.shape);
const batchInStride = inStrides[0];
const rowInStride = inStrides[1];
const colInStride = inStrides[2];
const outStrides = computeStrides(outShape);
const batchOutStride = outStrides[0];
const rowOutStride = outStrides[1];
const colOutStride = outStrides[2];
const outVals = getTypedArrayFromDType(image.dtype, sizeFromShape(outShape));
outVals.fill(fillValue);
const imageVals = backend.data.get(image.dataId).values;
const transformVals = backend.data.get(transforms.dataId).values;
for (let b = 0; b < batch; ++b) {
const transform = transforms.shape[0] === 1 ?
transformVals :
transformVals.subarray(b * 8, b * 8 + 8);
for (let outY = 0; outY < outHeight; ++outY) {
for (let outX = 0; outX < outWidth; ++outX) {
for (let channel = 0; channel < numChannels; ++channel) {
let val;
const projection = transform[6] * outX + transform[7] * outY + 1;
if (projection === 0) {
continue;
}
const inX = (transform[0] * outX + transform[1] * outY + transform[2]) /
projection;
const inY = (transform[3] * outX + transform[4] * outY + transform[5]) /
projection;
const x = mapCoord(inX, imageWidth, fillMode);
const y = mapCoord(inY, imageHeight, fillMode);
switch (interpolation) {
case 'nearest':
val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
break;
case 'bilinear':
val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
break;
default:
throw new Error(`Error in Transform: Expect 'nearest' or ` +
`'bilinear', but got ${interpolation}`);
}
const ind = b * batchOutStride + outY * rowOutStride +
outX * colOutStride + channel;
outVals[ind] = val;
}
}
}
return backend.makeTensorInfo(outShape, image.dtype, outVals);
}
const dataId = backend.write(outVals, outShape, image.dtype);
return { dataId, shape: image.shape, dtype: image.dtype };
}
const transformConfig = {
kernelName: Transform,
backendName: 'cpu',
kernelFunc: transform
};
function mapCoord(outCoord, len, mode) {
switch (mode) {
case 'reflect':
return mapCoordReflect(outCoord, len);
case 'wrap':
return mapCoordWrap(outCoord, len);
case 'nearest':
return mapCoordNearest(outCoord, len);
case 'constant':
default:
return mapCoordConstant(outCoord);
}
}
function mapCoordReflect(outCoord, len) {
let inCoord = outCoord;
if (inCoord < 0) {
if (len <= 1) {
inCoord = 0;
}
else {
const sz2 = 2 * len;
if (inCoord < sz2) {
inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
}
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
}
}
else if (inCoord > len - 1) {
if (len <= 1) {
inCoord = 0;
}
else {
const sz2 = 2 * len;
inCoord -= sz2 * Math.trunc(inCoord / sz2);
if (inCoord >= len) {
inCoord = sz2 - inCoord - 1;
}
}
}
return clamp(0, inCoord, len - 1);
}
function mapCoordWrap(outCoord, len) {
let inCoord = outCoord;
if (inCoord < 0) {
if (len <= 1) {
inCoord = 0;
}
else {
const sz = len - 1;
inCoord += len * (Math.trunc(-inCoord / sz) + 1);
}
}
else if (inCoord > len - 1) {
if (len <= 1) {
inCoord = 0;
}
else {
const sz = len - 1;
inCoord -= len * Math.trunc(inCoord / sz);
}
}
return clamp(0, inCoord, len - 1);
}
function mapCoordConstant(outCoord, len) {
return outCoord;
}
function mapCoordNearest(outCoord, len) {
return clamp(0, outCoord, len - 1);
}
function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
const ind = batch * batchStride + y * rowStride + x * colStride + channel;
if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
return imageVals[ind];
}
else {
return fillValue;
}
}
function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
const $y = Math.round(y);
const $x = Math.round(x);
return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
}
function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
const yFloor = Math.floor(y);
const xFloor = Math.floor(x);
const yCeil = yFloor + 1;
const xCeil = xFloor + 1;
const valueYFloor = (xCeil - x) *
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) +
(x - xFloor) *
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue);
const valueYCeil = (xCeil - x) *
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) +
(x - xFloor) *
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue);
return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
}
function unique$1(args) {
const { inputs, attrs, backend } = args;
const { axis } = attrs;
const { x } = inputs;
assertNotComplex(x, 'unique');
const values = backend.data.get(x.dataId).values;
const { outputValues, outputShape, indices } = uniqueImpl(values, axis, x.shape, x.dtype);
return [
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
backend.makeTensorInfo([indices.length], 'int32', indices),
];
}
const uniqueConfig = {
kernelName: Unique,
backendName: 'cpu',
kernelFunc: unique$1,
};
function unpack(args) {
const { inputs, backend, attrs } = args;
const { value } = inputs;
let { axis } = attrs;
if (axis < 0) {
axis += value.shape.length;
}
const valueRank = value.shape.length;
const num = value.shape[axis];
const outShape = new Array(valueRank - 1);
let outIndex = 0;
for (let i = 0; i < valueRank; i++) {
if (i !== axis) {
outShape[outIndex++] = value.shape[i];
}
}
const begin = new Array(valueRank).fill(0);
const size = value.shape.slice();
size[axis] = 1;
const res = new Array(num);
for (let i = 0; i < res.length; i++) {
begin[axis] = i;
const tempRes = slice$1({ inputs: { x: value }, backend, attrs: { begin, size } });
res[i] = reshape({ inputs: { x: tempRes }, backend, attrs: { shape: outShape } });
backend.disposeIntermediateTensorInfo(tempRes);
}
return res;
}
const unpackConfig = {
kernelName: Unpack,
backendName: 'cpu',
kernelFunc: unpack
};
function unsortedSegmentSum(args) {
const { inputs, backend, attrs } = args;
const { x, segmentIds } = inputs;
const { numSegments } = attrs;
assertNotComplex(x, 'unsortedSegmentSum');
const xRank = x.shape.length;
const segmentIdsRank = segmentIds.shape.length;
const res = [];
const intermediates = [];
const numIters = xRank - segmentIdsRank;
let $segmentIds = segmentIds;
for (let i = 0; i < numIters; ++i) {
const expanded = expandDims$1({ inputs: { input: $segmentIds }, backend, attrs: { dim: i + 1 } });
$segmentIds = expanded;
intermediates.push(expanded);
}
for (let i = 0; i < numSegments; ++i) {
const scalarValue = createScalarValue(i, 'int32');
const segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
const mask = equal$1({ inputs: { a: segmentId, b: $segmentIds }, backend });
const maskCasted = cast$2({ inputs: { x: mask }, backend, attrs: { dtype: 'float32' } });
const mul = multiply$1({ inputs: { a: maskCasted, b: x }, backend });
const sumTensorInfo = sum({ inputs: { x: mul }, backend, attrs: { axis: 0, keepDims: false } });
res.push(sumTensorInfo);
intermediates.push(segmentId);
intermediates.push(mask);
intermediates.push(maskCasted);
intermediates.push(mul);
intermediates.push(sumTensorInfo);
}
const result = pack({ inputs: res, backend, attrs: { axis: 0 } });
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
return result;
}
const unsortedSegmentSumConfig = {
kernelName: UnsortedSegmentSum,
backendName: 'cpu',
kernelFunc: unsortedSegmentSum
};
const kernelConfigs = [
_fusedMatMulConfig,
absConfig$1,
acosConfig,
acoshConfig,
addConfig$1,
addNConfig,
allConfig,
anyConfig,
argMaxConfig,
argMinConfig,
asinConfig,
asinhConfig,
atanConfig,
atan2Config,
atanhConfig,
avgPoolConfig,
avgPool3DConfig,
avgPool3DGradConfig$1,
avgPoolGradConfig$1,
batchMatMulConfig,
batchNormConfig,
batchToSpaceNDConfig,
bincountConfig,
bitwiseAndConfig$1,
broadcastArgsConfig,
castConfig$1,
ceilConfig$1,
clipByValueConfig,
complexConfig$1,
complexAbsConfig,
concatConfig,
conv2DConfig,
conv2DBackpropFilterConfig,
conv2DBackpropInputConfig,
conv3DConfig,
conv3DBackpropFilterV2Config,
conv3DBackpropInputV2Config,
cosConfig,
coshConfig,
cropAndResizeConfig,
cumprodConfig,
cumsumConfig,
denseBincountConfig,
depthToSpaceConfig,
depthwiseConv2dNativeConfig,
depthwiseConv2dNativeBackpropFilterConfig,
depthwiseConv2dNativeBackpropInputConfig,
diagConfig,
dilation2DConfig,
dilation2DBackpropFilterConfig,
dilation2DBackpropInputConfig,
drawConfig,
einsumConfig,
eluConfig,
eluGradConfig$1,
equalConfig$1,
erfConfig,
expConfig$1,
expandDimsConfig,
expm1Config$1,
fftConfig,
fillConfig,
flipLeftRightConfig,
floorConfig$1,
floorDivConfig$1,
fusedConv2DConfig,
fusedDepthwiseConv2DConfig,
gatherNdConfig,
gatherV2Config,
greaterConfig$1,
greaterEqualConfig$1,
identityConfig$1,
ifftConfig,
imagConfig,
isFiniteConfig,
isInfConfig,
isNaNConfig,
leakyReluConfig,
lessConfig$1,
lessEqualConfig$1,
linSpaceConfig,
logConfig$1,
log1pConfig,
logicalAndConfig,
logicalNotConfig,
logicalOrConfig,
LRNConfig,
LRNGradConfig,
maxConfig,
maximumConfig$1,
maxPoolConfig,
maxPool3DConfig,
maxPool3DGradConfig$1,
maxPoolGradConfig$1,
maxPoolWithArgmaxConfig,
meanConfig,
minConfig,
minimumConfig$1,
mirrorPadConfig,
modConfig,
multinomialConfig,
multiplyConfig$1,
negConfig$1,
nonMaxSuppressionV3Config,
nonMaxSuppressionV4Config,
nonMaxSuppressionV5Config,
notEqualConfig$1,
oneHotConfig,
onesLikeConfig,
packConfig,
padV2Config,
powConfig,
preluConfig,
prodConfig$1,
raggedGatherConfig,
raggedRangeConfig,
raggedTensorToTensorConfig,
rangeConfig,
realConfig$1,
realDivConfig,
reciprocalConfig,
reluConfig,
relu6Config,
reshapeConfig,
resizeBilinearConfig,
resizeBilinearGradConfig$1,
resizeNearestNeighborConfig,
resizeNearestNeighborGradConfig$1,
reverseConfig,
rotateWithOffsetConfig,
roundConfig,
rsqrtConfig$1,
scatterNdConfig,
searchSortedConfig,
selectConfig,
seluConfig,
sigmoidConfig$1,
signConfig,
sinConfig,
sinhConfig,
sliceConfig$1,
softmaxConfig,
softplusConfig,
spaceToBatchNDConfig,
sparseFillEmptyRowsConfig,
sparseReshapeConfig,
sparseSegmentMeanConfig,
sparseSegmentSumConfig,
sparseToDenseConfig,
splitVConfig,
sqrtConfig$1,
squareConfig,
squaredDifferenceConfig$1,
staticRegexReplaceConfig$1,
stepConfig,
stridedSliceConfig,
stringNGramsConfig,
stringSplitConfig,
stringToHashBucketFastConfig,
subConfig$1,
sumConfig,
tanConfig,
tanhConfig,
tensorScatterUpdateConfig,
tileConfig,
topKConfig,
transformConfig,
transposeConfig$1,
uniqueConfig,
unpackConfig,
unsortedSegmentSumConfig,
zerosLikeConfig
];
for (const kernelConfig of kernelConfigs) {
registerKernel(kernelConfig);
}
const absGradConfig = {
kernelName: Abs,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(dy, step$2(cast$3(x, 'float32'), -1)) };
}
};
const acosGradConfig = {
kernelName: Acos,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return {
x: () => {
const a = square$2(cast$3(x, 'float32'));
const b = sqrt$2(sub$2(scalar(1), a));
return neg$2(div$1(dy, b));
}
};
}
};
const acoshGradConfig = {
kernelName: Acosh,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return {
x: () => {
const a = sqrt$2(sub$2(square$2(cast$3(x, 'float32')), 1));
return div$1(dy, a);
}
};
}
};
const addGradConfig = {
kernelName: Add,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
let res = dy;
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, a.shape);
};
const derB = () => {
let res = dy;
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, b.shape);
};
return { a: derA, b: derB };
}
};
const addNGradConfig = {
kernelName: AddN,
saveAllInputs: true,
gradFunc: (dy, saved) => {
const ders = {};
saved.forEach((_, i) => {
ders[i] = () => dy.clone();
});
return ders;
}
};
const argMaxGradConfig = {
kernelName: ArgMax,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => zerosLike$2(x) };
}
};
const argMinGradConfig = {
kernelName: ArgMin,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => zerosLike$2(x) };
}
};
const asinGradConfig = {
kernelName: Asin,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, sqrt$2(sub$2(scalar(1), square$2(cast$3(x, 'float32'))))) };
}
};
const asinhGradConfig = {
kernelName: Asinh,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return {
x: () => {
const a = sqrt$2(add$1(scalar(1), square$2(cast$3(x, 'float32'))));
return div$1(dy, a);
}
};
}
};
const atan2GradConfig = {
kernelName: Atan2,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
const d = add$1(square$2(a), square$2(b));
let res = mul(dy, div$1(b, d));
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, a.shape);
};
const derB = () => {
const d = add$1(square$2(a), square$2(b));
let res = neg$2(mul(dy, div$1(a, d)));
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, b.shape);
};
return { a: derA, b: derB };
}
};
const atanGradConfig = {
kernelName: Atan,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, add$1(square$2(cast$3(x, 'float32')), 1)) };
}
};
const atanhGradConfig = {
kernelName: Atanh,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, sub$2(scalar(1), square$2(cast$3(x, 'float32')))) };
}
};
function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
const $input = convertToTensor(input, 'input', 'avgPool3dGrad');
let dy5D = $dy;
let input5D = $input;
let reshapedTo5D = false;
if ($input.rank === 4) {
reshapedTo5D = true;
dy5D = reshape$2($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
input5D = reshape$2($input, [
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
]);
}
assert$1(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +
`${dy5D.rank}.`);
assert$1(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
`${input5D.rank}.`);
checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
const inputs = { dy: dy5D, input: input5D };
const attrs = { filterSize, strides, pad, dimRoundingMode };
const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
if (reshapedTo5D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
const avgPool3dGrad = op({ avgPool3dGrad_ });
const avgPool3DGradConfig = {
kernelName: AvgPool3D,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
return {
x: () => avgPool3dGrad(dy, x, filterSize, strides, pad, dimRoundingMode)
};
}
};
function avgPoolGrad_(dy, input, filterSize, strides, pad) {
const $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
const $input = convertToTensor(input, 'input', 'avgPoolGrad');
assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`);
let input4D = $input;
let dy4D = $dy;
let reshapedTo4D = false;
if ($input.rank === 3) {
reshapedTo4D = true;
input4D =
reshape$2($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
dy4D = reshape$2($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
}
assert$1(dy4D.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ` +
`${dy4D.rank}.`);
assert$1(input4D.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ` +
`${input4D.rank}.`);
const inputs = { dy: dy4D, input: input4D };
const attrs = { filterSize, strides, pad };
const res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
if (reshapedTo4D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
}
return res;
}
const avgPoolGrad = op({ avgPoolGrad_ });
const avgPoolGradConfig = {
kernelName: AvgPool,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { filterSize, strides, pad } = attrs;
return { x: () => avgPoolGrad(dy, x, filterSize, strides, pad) };
}
};
const batchMatMulGradConfig = {
kernelName: BatchMatMul,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved, attrs) => {
const [a, b] = saved;
const { transposeA, transposeB } = attrs;
if (!transposeA && !transposeB) {
return {
a: () => matMul$1(dy, b, false, true),
b: () => matMul$1(a, dy, true, false)
};
}
else if (!transposeA && transposeB) {
return {
a: () => matMul$1(dy, b, false, false),
b: () => matMul$1(dy, a, true, false)
};
}
else if (transposeA && !transposeB) {
return {
a: () => matMul$1(b, dy, false, true),
b: () => matMul$1(a, dy, false, false)
};
}
else {
return {
a: () => matMul$1(b, dy, true, true),
b: () => matMul$1(dy, a, true, true)
};
}
}
};
const batchToSpaceNDGradConfig = {
kernelName: BatchToSpaceND,
gradFunc: (dy, saved, attrs) => {
const { blockShape, crops } = attrs;
return { x: () => spaceToBatchND$2(dy, blockShape, crops) };
}
};
const broadcastToGradConfig = {
kernelName: BroadcastTo,
gradFunc: (dy, saved, attrs) => {
const broadCastToAttrs = attrs;
const inputShape = broadCastToAttrs.inputShape;
const outputShape = broadCastToAttrs.shape;
const reps = Array.from(outputShape);
for (let i = inputShape.length - 1; i >= 0; i--) {
if (inputShape[i] === outputShape[i]) {
reps[i] = 1;
}
else if (inputShape[i] !== 1) {
throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
}
}
const axes = [];
for (let i = 0; i < reps.length; i++) {
if (reps[i] > 1) {
axes.push(i);
}
}
return { x: () => sum$2(dy, axes, true ) };
}
};
const castGradConfig = {
kernelName: Cast,
gradFunc: (dy) => {
return { x: () => dy.clone() };
}
};
const ceilGradConfig = {
kernelName: Ceil,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const clipByValueGradConfig = {
kernelName: ClipByValue,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { clipValueMin, clipValueMax } = attrs;
return {
x: () => where(logicalAnd$2(greaterEqual$2(x, clipValueMin), lessEqual$2(x, clipValueMax)), dy, zerosLike$2(dy)),
};
}
};
const complexAbsGradConfig = {
kernelName: ComplexAbs,
inputsToSave: ['x'],
gradFunc: absGradConfig.gradFunc,
};
const concatGradConfig = {
kernelName: Concat,
saveAllInputs: true,
gradFunc: (dy, saved, attrs) => {
const shapes = saved.map(t => t.shape);
const { axis } = attrs;
const $axis = parseAxisParam(axis, saved[0].shape)[0];
const sizeSplits = shapes.map(s => s[$axis]);
const derTensors = split$1(dy, sizeSplits, $axis);
return derTensors.map(t => () => t);
}
};
const conv2DGradConfig = {
kernelName: Conv2D,
inputsToSave: ['x', 'filter'],
gradFunc: (dy, saved, attrs) => {
const [x4D, $filter] = saved;
const { dilations, strides, pad, dataFormat } = attrs;
assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv2D: dilation rates greater than 1 ' +
`are not yet supported in gradients. Got dilations '${dilations}'`);
return {
x: () => conv2DBackpropInput$2(x4D.shape, dy, $filter, strides, pad, dataFormat),
filter: () => conv2DBackpropFilter$2(x4D, dy, $filter.shape, strides, pad, dataFormat)
};
}
};
const conv2DBackpropInputGradConfig = {
kernelName: Conv2DBackpropInput,
inputsToSave: ['dy', 'filter'],
gradFunc: (ddx, saved, attrs) => {
const [dy, filter] = saved;
const { strides, pad, dataFormat, dimRoundingMode } = attrs;
return {
dy: () => conv2d$1(ddx, filter, strides, pad, dataFormat, 1 , dimRoundingMode),
filter: () => conv2DBackpropFilter$2(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode)
};
}
};
function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
let x5D = x;
if (x.rank === 4) {
x5D = reshape$2(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
}
let dy5D = dy;
if (dy5D.rank === 4) {
dy5D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
}
assert$1(x5D.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` +
`${x5D.shape}.`);
assert$1(dy5D.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` +
`${dy5D.shape}.`);
assert$1(filterShape.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` +
`${filterShape}.`);
assert$1(x5D.shape[4] === filterShape[3], () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` +
`match input depth in filter (${filterShape[3]}.`);
assert$1(dy5D.shape[4] === filterShape[4], () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` +
`match output depth for filter (${filterShape[4]}).`);
const inputs = { x: x5D, dy: dy5D };
const attrs = { strides, pad, filterShape };
return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs);
}
const conv3DBackpropFilter = op({ conv3DBackpropFilter_ });
const conv3DGradConfig = {
kernelName: Conv3D,
inputsToSave: ['x', 'filter'],
gradFunc: (dy, saved, attrs) => {
const { dilations, strides, pad } = attrs;
assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv3D: dilation rates greater than 1 are ' +
`not yet supported in gradients. Got dilations '${dilations}'`);
const [x5D, $filter] = saved;
return {
x: () => conv3DBackpropInput$1(x5D.shape, dy, $filter, strides, pad),
filter: () => conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad)
};
}
};
const cosGradConfig = {
kernelName: Cos,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(neg$2(sin$2(cast$3(x, 'float32'))), dy) };
}
};
const coshGradConfig = {
kernelName: Cosh,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(sinh$2(cast$3(x, 'float32')), dy) };
}
};
const cumsumGradConfig = {
kernelName: Cumsum,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { axis, exclusive, reverse } = attrs;
return {
x: () => {
const permutation = getAxesPermutation([axis], x.rank);
let out = cumsum$2(dy, axis, exclusive, !reverse);
if (permutation != null) {
out = transpose$2(out, permutation);
}
return out;
}
};
}
};
const depthwiseConv2dNativeGradConfig = {
kernelName: DepthwiseConv2dNative,
inputsToSave: ['x', 'filter'],
gradFunc: (dy, saved, attrs) => {
const { dilations, strides, pad, dimRoundingMode } = attrs;
const $dilations = dilations == null ? [1, 1] : dilations;
assert$1(tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' +
`greater than 1 are not yet supported. Got dilations ` +
`'${$dilations}'`);
const [x, filter] = saved;
assert$1(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` +
`rank 4, but got rank ${x.rank}.`);
assert$1(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` +
`rank 4, but got rank ${filter.rank}.`);
assert$1(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` +
`channels (${x.shape[3]}) must match the inChannels dimension ` +
`in filter ${filter.shape[2]}.`);
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' +
`dilations must be 1. Got strides ${strides} and dilations ` +
`'${$dilations}'.`);
checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
return {
x: () => depthwiseConv2dNativeBackpropInput$2(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode),
filter: () => depthwiseConv2dNativeBackpropFilter$2(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode),
};
}
};
const dilation2dGradConfig = {
kernelName: Dilation2D,
inputsToSave: ['x', 'filter'],
gradFunc: (dy, saved, attrs) => {
const [x, filter] = saved;
const inputInputs = { x, filter, dy };
const filterInputs = { x, filter, dy };
return {
x: () => ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs),
filter: () => ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs)
};
}
};
const eluGradConfig = {
kernelName: Elu$1,
outputsToSave: [true],
gradFunc: (dy, saved) => {
const [y] = saved;
const inputs = { dy, y };
return { x: () => ENGINE.runKernel(EluGrad, inputs) };
}
};
const erfGradConfig = {
kernelName: Erf,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
const a = mul(exp$2(neg$2(square$2(x))), 2 / Math.sqrt(Math.PI));
return { x: () => mul(dy, a) };
}
};
const expGradConfig = {
kernelName: Exp,
outputsToSave: [true],
gradFunc: (dy, saved) => {
const [y] = saved;
return { x: () => mul(dy, y) };
}
};
const expandDimsGradConfig = {
kernelName: ExpandDims,
inputsToSave: ['input'],
gradFunc: (dy, saved) => {
const [input] = saved;
return { input: () => reshape$2(dy, input.shape) };
}
};
const expm1GradConfig = {
kernelName: Expm1,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(dy, exp$2(x)) };
}
};
const floorGradConfig = {
kernelName: Floor,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const floorDivGradConfig = {
kernelName: FloorDiv,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
const res = div$1(dy, cast$3(b, 'float32'));
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape$2(sum$2(res, reduceAxes), a.shape);
}
return res;
};
const derB = () => {
let res = mul(dy, cast$3(a, 'float32'));
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = reshape$2(sum$2(res, reduceAxes), b.shape);
}
const tmp = square$2(b);
return neg$2(div$1(res, cast$3(tmp, 'float32')));
};
return { a: derA, b: derB };
}
};
const fusedBatchNormGradConfig = {
kernelName: FusedBatchNorm,
inputsToSave: ['x', 'mean', 'variance', 'scale'],
gradFunc: (dy, saved, attrs) => {
const { varianceEpsilon } = attrs;
const [x, mean, variance, scale] = saved;
const scaleValue = scale == null ? scalar(1) : scale;
const reductionAxes = getReductionAxes(mean.shape, x.shape);
const tileShape = [];
if (mean.rank === 1) {
for (let i = 0; i < x.shape.length - 1; ++i) {
tileShape.push(x.shape[i]);
}
tileShape.push(1);
}
const xMinusMean = sub$2(x, mean);
const dyTimesScaleValue = mul(dy, scaleValue);
const oneOverSqrtVariance = rsqrt$2(add$1(variance, scalar(varianceEpsilon)));
const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
const derX = () => {
if (mean.rank === 1) {
return reshape$2(mul(mul(dy, tile$3(reshape$2(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
}
else {
return reshape$2(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
}
};
const derMean = () => {
let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
if (mean.rank === 1) {
meanDer = sum$2(meanDer, reductionAxes);
}
return reshape$2(meanDer, mean.shape);
};
const derVariance = () => {
let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
if (mean.rank === 1) {
varianceDer = sum$2(varianceDer, reductionAxes);
}
return reshape$2(varianceDer, mean.shape);
};
const derScale = () => {
const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
let scaleDer = mul(dy, xMinusMean2TimesRsqrt);
if (mean.rank === 1) {
scaleDer = sum$2(scaleDer, reductionAxes);
}
return reshape$2(scaleDer, mean.shape);
};
const derOffset = () => {
let offsetDer = dy;
if (mean.rank === 1) {
offsetDer = sum$2(offsetDer, reductionAxes);
}
return reshape$2(offsetDer, mean.shape);
};
return {
x: derX,
mean: derMean,
variance: derVariance,
scale: derScale,
offset: derOffset
};
}
};
const gatherGradConfig = {
kernelName: GatherV2,
inputsToSave: ['x', 'indices'],
gradFunc: (dy, saved, attrs) => {
const [x, indices] = saved;
const { axis, batchDims } = attrs;
const parsedAxis = parseAxisParam(axis, x.shape)[0];
const derXBatch = (x, indices, dy) => {
return () => {
const paramsShape = x.shape;
const indicesSize = indices.size;
const outerShape = paramsShape.slice(0, parsedAxis);
const outerDims = outerShape.length;
const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
const innerDims = innerShape.length;
const outerAxesIndices = arrayRange(0, outerDims);
const innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
const valuesShape = arrayConcat([outerShape, [indicesSize],
innerShape]);
const values = reshape$2(dy, valuesShape);
const reshapedIndices = reshape$2(indices, [indicesSize]);
const transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
const valuesTranspose = transpose$2(values, transposeDims);
let paramsGrad = unsortedSegmentSum$2(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
const invertTransposeDims = getUndoAxesPermutation(transposeDims);
paramsGrad = transpose$2(paramsGrad, invertTransposeDims);
return paramsGrad;
};
};
if (batchDims === 1) {
const batchSize = x.shape[0];
const xBatch = x.split(batchSize, 0);
const derXBatched = () => {
const stacked = stack(xBatch.map((x, i) => {
return derXBatch(x, indices.slice(i, 1), dy.slice(i, 1))();
}));
return stacked.reshape(x.shape);
};
return { x: derXBatched, indices: () => indices };
}
else {
return { x: derXBatch(x, indices, dy), indices: () => indices };
}
}
};
function arrayRange(start, stop) {
const result = [];
for (let i = start; i < stop; ++i) {
result.push(i);
}
return result;
}
function arrayConcat(arrays) {
const result = [];
for (let i = 0; i < arrays.length; ++i) {
for (let j = 0; j < arrays[i].length; ++j) {
result.push(arrays[i][j]);
}
}
return result;
}
const greaterEqualGradConfig = {
kernelName: GreaterEqual,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
return { a: () => zerosLike$2(a), b: () => zerosLike$2(b) };
}
};
const identityGradConfig = {
kernelName: Identity$1,
gradFunc: (dy) => {
return { x: () => cast$3(dy, 'float32') };
}
};
const isFiniteGradConfig = {
kernelName: IsFinite,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const isInfGradConfig = {
kernelName: IsInf,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const isNanGradConfig = {
kernelName: IsNan,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const leakyReluGradConfig = {
kernelName: LeakyRelu,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { alpha } = attrs;
const mask = greater$2(x, 0);
return { x: () => where(mask, dy, mul(dy, alpha)) };
}
};
const log1pGradConfig = {
kernelName: Log1p,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, add$1(x, 1)) };
}
};
const logGradConfig = {
kernelName: Log,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, cast$3(x, 'float32')) };
}
};
const logSoftmaxGradConfig = {
kernelName: LogSoftmax$1,
inputsToSave: [],
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const [value] = saved;
const { axis } = attrs;
return {
logits: () => {
const keepDims = true;
const softmax = exp$2(value);
return sub$2(dy, mul(sum$2(dy, axis, keepDims), softmax));
}
};
}
};
function localResponseNormalizationBackprop_(x, y, dy, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
const inputs = { x, y, dy };
const attrs = { depthRadius, bias, alpha, beta };
return ENGINE.runKernel(LRNGrad, inputs, attrs);
}
const localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_ });
const lrnGradConfig = {
kernelName: LRN,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const [x, y] = saved;
const { depthRadius, bias, alpha, beta } = attrs;
return {
x: () => localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta)
};
}
};
function gradForMinAndMax(dy, y, xOrig, origAxes) {
if (y.rank < xOrig.rank) {
y = reshape$2(y, expandShapeToKeepDim(y.shape, origAxes));
}
if (dy.rank < xOrig.rank) {
dy = reshape$2(dy, expandShapeToKeepDim(dy.shape, origAxes));
}
return {
x: () => {
const dx = mul(dy, cast$3(equal$2(xOrig, y), dy.dtype));
return dx;
}
};
}
const maxGradConfig = {
kernelName: Max,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const maxAttrs = attrs;
const { reductionIndices } = maxAttrs;
const x = saved[0];
const y = saved[1];
const origAxes = parseAxisParam(reductionIndices, x.shape);
const maxGrad = gradForMinAndMax(dy, y, x, origAxes);
return {
x: () => {
return maxGrad['x']();
}
};
}
};
const maximumGradConfig = {
kernelName: Maximum,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const derA = () => mul(dy, cast$3(greaterEqual$2(a, b), 'float32'));
const derB = () => mul(dy, cast$3(less$2(a, b), 'float32'));
return { a: derA, b: derB };
}
};
function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
const $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad');
const $input = convertToTensor(input, 'input', 'maxPool3dGrad');
const $output = convertToTensor(output, 'output', 'maxPool3dGrad');
let dy5D = $dy;
let input5D = $input;
let output5D = $output;
let reshapedTo5D = false;
if ($input.rank === 4) {
reshapedTo5D = true;
dy5D = reshape$2($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
input5D = reshape$2($input, [
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
]);
output5D = reshape$2($output, [
1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]
]);
}
assert$1(dy5D.rank === 5, () => `Error in maxPool3dGrad: dy must be rank 5 but got rank ` +
`${dy5D.rank}.`);
assert$1(input5D.rank === 5, () => `Error in maxPool3dGrad: input must be rank 5 but got rank ` +
`${input5D.rank}.`);
assert$1(output5D.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ` +
`${output5D.rank}.`);
checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode);
const inputs = { dy: dy5D, input: input5D, output: output5D };
const attrs = { filterSize, strides, pad, dimRoundingMode };
const res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs);
if (reshapedTo5D) {
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
}
return res;
}
const maxPool3dGrad = op({ maxPool3dGrad_ });
const maxPool3DGradConfig = {
kernelName: MaxPool3D,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const [x, y] = saved;
const { filterSize, strides, pad, dimRoundingMode } = attrs;
return {
x: () => maxPool3dGrad(dy, x, y, filterSize, strides, pad, dimRoundingMode)
};
}
};
function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
const $dy = convertToTensor(dy, 'dy', 'maxPoolGrad');
const $input = convertToTensor(input, 'input', 'maxPoolGrad');
const $output = convertToTensor(output, 'output', 'maxPoolGrad');
assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy ` +
`(${$dy.rank})`);
assert$1($dy.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ` +
`${$dy.rank}.`);
assert$1($input.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ` +
`${$input.rank}.`);
checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode);
const inputs = { dy: $dy, input: $input, output: $output };
const attrs = { filterSize, strides, pad, dimRoundingMode };
return ENGINE.runKernel(MaxPoolGrad, inputs, attrs);
}
const maxPoolGrad = op({ maxPoolGrad_ });
const maxPoolGradConfig = {
kernelName: MaxPool,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const [x, y] = saved;
const { filterSize, strides, pad } = attrs;
return {
x: () => maxPoolGrad(dy, x, y, filterSize, strides, pad)
};
}
};
const meanGradConfig = {
kernelName: Mean,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { axis } = attrs;
const axes = parseAxisParam(axis, x.shape);
const shapes = computeOutAndReduceShapes(x.shape, axes);
const reduceShape = shapes[1];
const reduceSize = sizeFromShape(reduceShape);
const derX = () => {
const expandedDyShape = x.shape.slice();
axes.forEach(axis => {
expandedDyShape[axis] = 1;
});
const expandedDy = reshape$2(dy, expandedDyShape);
const res = div$1(mul(expandedDy, ones(x.shape, 'float32')), reduceSize);
return res;
};
return { x: derX };
}
};
const minGradConfig = {
kernelName: Min,
inputsToSave: ['x'],
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const minAttrs = attrs;
const { axis } = minAttrs;
const [x, y] = saved;
const origAxes = parseAxisParam(axis, x.shape);
const minGrad = gradForMinAndMax(dy, y, x, origAxes);
return {
x: () => {
return minGrad['x']();
}
};
}
};
const minimumGradConfig = {
kernelName: Minimum,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const derA = () => mul(dy, cast$3(lessEqual$2(a, b), 'float32'));
const derB = () => mul(dy, cast$3(greater$2(a, b), 'float32'));
return { a: derA, b: derB };
}
};
const mirrorPadGradConfig = {
kernelName: MirrorPad,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const x = saved[0];
const { paddings } = attrs;
const begin = paddings.map(p => p[0]);
return { x: () => slice$2(dy, begin, x.shape) };
}
};
const modGradConfig = {
kernelName: Mod,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape$2(sum$2(dy, reduceAxes), a.shape);
}
return dy;
};
const derB = () => {
const res = mul(dy, neg$2(floor$2(div$1(a, b))));
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
return reshape$2(sum$2(res, reduceAxes), b.shape);
}
return res;
};
return { a: derA, b: derB };
}
};
const multiplyGradConfig = {
kernelName: Multiply,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
const res = mul(dy, cast$3(b, 'float32'));
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape$2(sum$2(res, reduceAxes), a.shape);
}
return res;
};
const derB = () => {
const res = mul(dy, cast$3(a, 'float32'));
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
return reshape$2(sum$2(res, reduceAxes), b.shape);
}
return res;
};
return { a: derA, b: derB };
}
};
const negGradConfig = {
kernelName: Neg,
gradFunc: (dy) => {
return { x: () => neg$2(dy) };
}
};
const oneHotGradConfig = {
kernelName: OneHot,
inputsToSave: ['indices'],
gradFunc: (dy, saved) => {
const indices = saved[0];
return { indices: () => zeros$1(indices.shape, 'float32') };
}
};
const onesLikeGradConfig = {
kernelName: OnesLike,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const packGradConfig = {
kernelName: Pack,
saveAllInputs: true,
gradFunc: (dy, saved, attrs) => {
const { axis } = attrs;
const derTensors = unstack(dy, axis);
return derTensors.map(t => () => t);
}
};
const padV2GradConfig = {
kernelName: PadV2,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const x = saved[0];
const { paddings } = attrs;
const begin = paddings.map(p => p[0]);
return { x: () => slice$2(dy, begin, x.shape) };
}
};
const powGradConfig = {
kernelName: Pow,
inputsToSave: ['a', 'b'],
outputsToSave: [true],
gradFunc: (dy, saved) => {
const [a, b, y] = saved;
const base = a;
const exp = b;
const outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
const derBase = () => {
const expFloat = cast$3(exp, 'float32');
let res = mul(dy, mul(expFloat, pow$2(base, sub$2(expFloat, scalar(1)))));
const reduceAxes = getReductionAxes(base.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, base.shape);
};
const derExp = () => {
const condition = greater$2(base, 0);
const logBase = where(condition, log$2(base), zerosLike$2(base));
let res = mul(dy, mul(y, logBase));
const reduceAxes = getReductionAxes(exp.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, exp.shape);
};
return { a: derBase, b: derExp };
}
};
const preluGradConfig = {
kernelName: Prelu,
inputsToSave: ['x', 'alpha'],
gradFunc: (dy, saved) => {
const [x, alpha] = saved;
const mask = greater$2(x, 0);
return {
x: () => where(mask, dy, mul(dy, alpha)),
alpha: () => {
let res = where(mask, zerosLike$2(dy), mul(dy, x));
const reduceAxes = getReductionAxes(alpha.shape, dy.shape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, alpha.shape);
}
};
}
};
function prodGradFn_(x, dy, axis) {
const expandedYShape = x.shape.slice();
expandedYShape[axis] = 1;
const expandedDy = reshape$2(dy, expandedYShape);
const xCumProd = cumprod$2(x, axis, true, false);
const xCumRevProd = cumprod$2(x, axis, true, true);
const dx = mul(xCumProd, xCumRevProd);
return mul(expandedDy, dx);
}
function prodsGradFn_(x, dy, axis) {
const xRank = x.shape.length;
const finalProdAxis = xRank - axis.length;
const xPermutation = getAxesPermutation(axis, xRank);
let permutedX = x;
if (xPermutation != null) {
permutedX = transpose$2(x, xPermutation);
}
const newShape = permutedX.shape.slice();
const removedShape = newShape.splice(xRank - axis.length, axis.length);
const endPartShape = removedShape.reduce((p, c) => p * c, 1);
newShape.push(endPartShape);
const reshapedPermutedX = permutedX.reshape(newShape);
let prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis);
prodGrad = prodGrad.reshape(permutedX.shape);
if (xPermutation != null) {
const undoPermutation = getUndoAxesPermutation(xPermutation);
prodGrad = transpose$2(prodGrad, undoPermutation);
}
return prodGrad;
}
const prodGradConfig = {
kernelName: Prod,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { axis } = attrs;
let axisArr = [];
if (axis === undefined || axis === null) {
axisArr = x.shape.map((_, i) => i);
}
else if (typeof axis === 'number') {
axisArr = [axis];
}
else {
axisArr = axis;
}
return { x: () => prodsGradFn_(x, dy, axisArr) };
}
};
const divGradConfig = {
kernelName: RealDiv,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
const res = div$1(dy, cast$3(b, 'float32'));
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
return reshape$2(sum$2(res, reduceAxes), a.shape);
}
return res;
};
const derB = () => {
let res = mul(dy, cast$3(a, 'float32'));
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = reshape$2(sum$2(res, reduceAxes), b.shape);
}
const tmp = square$2(b);
return neg$2(div$1(res, cast$3(tmp, 'float32')));
};
return { a: derA, b: derB };
}
};
const reciprocalGradConfig = {
kernelName: Reciprocal,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, neg$2(square$2(x))) };
}
};
const relu6GradConfig = {
kernelName: Relu6$1,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
const mask = mul(lessEqual$2(x, 6), step$2(x));
return { x: () => mul(dy, cast$3(mask, 'float32')) };
}
};
const reluGradConfig = {
kernelName: Relu$1,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(dy, cast$3(step$2(x), 'float32')) };
}
};
const reshapeGradConfig = {
kernelName: Reshape$1,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => reshape$2(dy, x.shape) };
}
};
const resizeBilinearGradConfig = {
kernelName: ResizeBilinear,
inputsToSave: ['images'],
gradFunc: (dy, saved, attrs) => {
const [images] = saved;
const inputs = { dy, images };
const imagesDer = () =>
ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs);
return { images: imagesDer };
}
};
const resizeNearestNeighborGradConfig = {
kernelName: ResizeNearestNeighbor,
inputsToSave: ['images'],
gradFunc: (dy, saved, attrs) => {
const [images] = saved;
const inputs = { dy, images };
const imagesDer = () =>
ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs);
return { images: imagesDer };
}
};
const reverseGradConfig = {
kernelName: Reverse,
gradFunc: (dy, saved, attrs) => {
const { dims } = attrs;
const axes = parseAxisParam(dims, dy.shape);
return { x: () => reverse$2(dy, axes) };
}
};
const roundGradConfig = {
kernelName: Round,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const rsqrtGradConfig = {
kernelName: Rsqrt,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => neg$2(div$1(dy, mul(pow$2(x, 1.5), 2))) };
}
};
const selectGradConfig = {
kernelName: Select,
inputsToSave: ['condition'],
gradFunc: (dy, saved) => {
const [condition] = saved;
return {
condition: () => cast$3(zerosLike$2(condition), 'float32'),
t: () => mul(dy, cast$3(condition, dy.dtype)),
e: () => mul(dy, cast$3(logicalNot$2(condition), dy.dtype))
};
}
};
const seluGradConfig = {
kernelName: Selu$1,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return {
x: () => {
const mask = greater$2(x, scalar(0));
const scaleAlpha = scalar(SELU_SCALEALPHA);
const scale = scalar(SELU_SCALE);
const greaterThanZeroDer = mul(dy, scale);
const lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp$2(cast$3(x, 'float32')));
return where(mask, greaterThanZeroDer, lessEqualZeroDer);
}
};
}
};
const sigmoidGradConfig = {
kernelName: Sigmoid$1,
outputsToSave: [true],
gradFunc: (dy, saved) => {
const [y] = saved;
return { x: () => mul(dy, mul(y, sub$2(scalar(1), y))) };
}
};
const signGradConfig = {
kernelName: Sign,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const sinGradConfig = {
kernelName: Sin,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(cos$2(cast$3(x, 'float32')), dy) };
}
};
const sinhGradConfig = {
kernelName: Sinh,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(cosh$2(cast$3(x, 'float32')), dy) };
}
};
const sliceGradConfig = {
kernelName: Slice,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { begin, size } = attrs;
const inputShape = x.shape;
const [begin_, size_] = parseSliceParams(x, begin, size);
const paddings = [];
for (let i = 0; i < dy.rank; i++) {
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
}
return { x: () => pad(dy, paddings) };
}
};
const softmaxGradConfig = {
kernelName: Softmax$1,
outputsToSave: [true],
gradFunc: (dy, saved, attrs) => {
const [y] = saved;
const { dim } = attrs;
const keepDims = true;
const dyTimesY = mul(dy, y);
return {
logits: () => sub$2(dyTimesY, mul(sum$2(dyTimesY, [dim], keepDims), y))
};
}
};
const softplusGradConfig = {
kernelName: Softplus$1,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(dy, sigmoid$2(x)) };
}
};
const spaceToBatchNDGradConfig = {
kernelName: SpaceToBatchND,
gradFunc: (dy, saved, attrs) => {
const { blockShape, paddings } = attrs;
return { x: () => batchToSpaceND$2(dy, blockShape, paddings) };
}
};
const splitVGradConfig = {
kernelName: SplitV,
gradFunc: (dy, saved, attrs) => {
const { axis } = attrs;
return { x: () => concat$2(dy, axis) };
}
};
const sqrtGradConfig = {
kernelName: Sqrt,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, mul(sqrt$2(cast$3(x, 'float32')), 2)) };
}
};
const squareGradConfig = {
kernelName: Square,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => mul(dy, mul(cast$3(x, 'float32'), 2)) };
}
};
const squaredDifferenceGradConfig = {
kernelName: SquaredDifference,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const two = scalar(2);
const derA = () => mul(dy, mul(two, sub$2(a, b)));
const derB = () => mul(dy, mul(two, sub$2(b, a)));
return { a: derA, b: derB };
}
};
const stepGradConfig = {
kernelName: Step,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const subGradConfig = {
kernelName: Sub,
inputsToSave: ['a', 'b'],
gradFunc: (dy, saved) => {
const [a, b] = saved;
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
const derA = () => {
let res = dy;
const reduceAxes = getReductionAxes(a.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(res, a.shape);
};
const derB = () => {
let res = dy;
const reduceAxes = getReductionAxes(b.shape, outShape);
if (reduceAxes.length > 0) {
res = sum$2(res, reduceAxes);
}
return reshape$2(neg$2(res), b.shape);
};
return { a: derA, b: derB };
}
};
const sumGradConfig = {
kernelName: Sum,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const expandedDyShape = x.shape.slice();
const { axis } = attrs;
const axes = parseAxisParam(axis, x.shape);
axes.forEach(axis => {
expandedDyShape[axis] = 1;
});
const expandedDy = reshape$2(dy, expandedDyShape);
const derX = mul(expandedDy, ones(x.shape, 'float32'));
return { x: () => derX };
}
};
const tanGradConfig = {
kernelName: Tan,
inputsToSave: ['x'],
gradFunc: (dy, saved) => {
const [x] = saved;
return { x: () => div$1(dy, square$2(cos$2(x))) };
}
};
const tanhGradConfig = {
kernelName: Tanh$1,
outputsToSave: [true],
gradFunc: (dy, saved) => {
const [y] = saved;
return { x: () => mul(sub$2(scalar(1), square$2(y)), dy) };
}
};
const tileGradConfig = {
kernelName: Tile,
inputsToSave: ['x'],
gradFunc: (dy, saved, attrs) => {
const [x] = saved;
const { reps } = attrs;
const derX = () => {
let xGrad = zerosLike$2(x);
if (x.rank === 1) {
for (let i = 0; i < reps[0]; ++i) {
xGrad = add$1(xGrad, slice$2(dy, [i * x.shape[0]], [x.shape[0]]));
}
}
else if (x.rank === 2) {
for (let i = 0; i < reps[0]; ++i) {
for (let j = 0; j < reps[1]; ++j) {
xGrad = add$1(xGrad, slice$2(dy, [i * x.shape[0], j * x.shape[1]], [
x.shape[0], x.shape[1]
]));
}
}
}
else if (x.rank === 3) {
for (let i = 0; i < reps[0]; ++i) {
for (let j = 0; j < reps[1]; ++j) {
for (let k = 0; k < reps[2]; ++k) {
xGrad =
add$1(xGrad, slice$2(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
}
}
}
}
else if (x.rank === 4) {
for (let i = 0; i < reps[0]; ++i) {
for (let j = 0; j < reps[1]; ++j) {
for (let k = 0; k < reps[2]; ++k) {
for (let l = 0; l < reps[3]; ++l) {
xGrad =
add$1(xGrad, slice$2(dy, [
i * x.shape[0], j * x.shape[1], k * x.shape[2],
l * x.shape[3]
], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
}
}
}
}
}
else {
throw new Error(`Gradient for tile operation is not implemented for rank-` +
`${x.rank} tensors yet.`);
}
return xGrad;
};
return { x: derX };
},
};
const transposeGradConfig = {
kernelName: Transpose,
gradFunc: (dy, saved, attrs) => {
const transposeAttrs = attrs;
const { perm } = transposeAttrs;
const undoPerm = getUndoAxesPermutation(perm);
return { x: () => transpose$2(dy, undoPerm) };
}
};
const unpackGradConfig = {
kernelName: Unpack,
gradFunc: (dy, saved, attrs) => {
const unpackAttrs = attrs;
const { axis } = unpackAttrs;
return { value: () => stack(dy, axis) };
}
};
const unsortedSegmentSumGradConfig = {
kernelName: UnsortedSegmentSum,
inputsToSave: ['segmentIds'],
gradFunc: (dy, saved) => {
const [segmentIds] = saved;
const derX = () => {
return gatherDropNegatives(dy, segmentIds);
};
return { x: derX };
}
};
function gatherDropNegatives(x, indices) {
const zeroClippedIndices = maximum$2(indices, zerosLike$2(indices));
const gathered = gather$1(x, zeroClippedIndices);
let isPositive = greaterEqual$2(indices, scalar(0, 'int32'));
const numIters = gathered.rank - isPositive.rank;
for (let i = 0; i < numIters; ++i) {
isPositive = expandDims$3(isPositive, i + 1);
}
isPositive = logicalAnd$2(isPositive, ones(gathered.shape, 'bool'));
const zeroSlice = zerosLike$2(gathered);
return where(isPositive, gathered, zeroSlice);
}
const zerosLikeGradConfig = {
kernelName: ZerosLike,
gradFunc: (dy) => {
return { x: () => zerosLike$2(dy) };
}
};
const gradConfigs = [
absGradConfig,
acosGradConfig,
acoshGradConfig,
addGradConfig,
addNGradConfig,
argMaxGradConfig,
argMinGradConfig,
asinGradConfig,
asinhGradConfig,
atan2GradConfig,
atanGradConfig,
atanhGradConfig,
avgPool3DGradConfig,
avgPoolGradConfig,
batchMatMulGradConfig,
batchToSpaceNDGradConfig,
broadcastToGradConfig,
castGradConfig,
ceilGradConfig,
clipByValueGradConfig,
complexAbsGradConfig,
concatGradConfig,
conv2DBackpropInputGradConfig,
conv2DGradConfig,
conv3DGradConfig,
cosGradConfig,
coshGradConfig,
cumsumGradConfig,
depthwiseConv2dNativeGradConfig,
dilation2dGradConfig,
divGradConfig,
eluGradConfig,
erfGradConfig,
expGradConfig,
expandDimsGradConfig,
expm1GradConfig,
floorDivGradConfig,
floorGradConfig,
fusedBatchNormGradConfig,
gatherGradConfig,
greaterEqualGradConfig,
identityGradConfig,
isFiniteGradConfig,
isInfGradConfig,
isNanGradConfig,
leakyReluGradConfig,
log1pGradConfig,
logGradConfig,
logSoftmaxGradConfig,
lrnGradConfig,
maxGradConfig,
maxGradConfig,
maximumGradConfig,
maxPool3DGradConfig,
maxPoolGradConfig,
meanGradConfig,
minGradConfig,
minimumGradConfig,
mirrorPadGradConfig,
modGradConfig,
multiplyGradConfig,
negGradConfig,
oneHotGradConfig,
onesLikeGradConfig,
packGradConfig,
padV2GradConfig,
padV2GradConfig,
powGradConfig,
preluGradConfig,
prodGradConfig,
reciprocalGradConfig,
relu6GradConfig,
reluGradConfig,
reshapeGradConfig,
resizeBilinearGradConfig,
resizeNearestNeighborGradConfig,
reverseGradConfig,
roundGradConfig,
rsqrtGradConfig,
selectGradConfig,
seluGradConfig,
sigmoidGradConfig,
signGradConfig,
sinGradConfig,
sinhGradConfig,
sliceGradConfig,
softmaxGradConfig,
softplusGradConfig,
spaceToBatchNDGradConfig,
spaceToBatchNDGradConfig,
splitVGradConfig,
splitVGradConfig,
sqrtGradConfig,
squaredDifferenceGradConfig,
squareGradConfig,
stepGradConfig,
subGradConfig,
sumGradConfig,
tanGradConfig,
tanhGradConfig,
tileGradConfig,
transposeGradConfig,
unpackGradConfig,
unsortedSegmentSumGradConfig,
zerosLikeGradConfig
];
for (const gradientConfig of gradConfigs) {
registerGradient(gradientConfig);
}
class AttributeError extends Error {
constructor(message) {
super(message);
Object.setPrototypeOf(this, AttributeError.prototype);
}
}
class RuntimeError extends Error {
constructor(message) {
super(message);
Object.setPrototypeOf(this, RuntimeError.prototype);
}
}
class ValueError extends Error {
constructor(message) {
super(message);
Object.setPrototypeOf(this, ValueError.prototype);
}
}
class NotImplementedError extends Error {
constructor(message) {
super(message);
Object.setPrototypeOf(this, NotImplementedError.prototype);
}
}
class AssertionError extends Error {
constructor(message) {
super(message);
Object.setPrototypeOf(this, AssertionError.prototype);
}
}
class LruCache {
constructor(maxEntries) {
this.maxEntries = maxEntries || 100;
this.cache = new Map();
}
get(key) {
let entry;
if (this.cache.has(key)) {
entry = this.cache.get(key);
this.cache.delete(key);
this.cache.set(key, entry);
}
return entry;
}
put(key, value) {
if (this.cache.has(key)) {
this.cache.delete(key);
}
else if (this.cache.size >= this.maxEntries) {
const keyToDelete = this.cache.keys().next().value;
this.cache.delete(keyToDelete);
}
this.cache.set(key, value);
}
getMaxEntries() {
return this.maxEntries;
}
setMaxEntries(maxEntries) {
if (maxEntries < 0) {
throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`);
}
if (this.maxEntries > maxEntries) {
for (let i = 0; i < this.maxEntries - maxEntries; i++) {
const keyToDelete = this.cache.keys().next().value;
this.cache.delete(keyToDelete);
}
}
this.maxEntries = maxEntries;
}
}
function pyListRepeat(value, numValues) {
if (Array.isArray(value)) {
let newArray = [];
for (let i = 0; i < numValues; i++) {
newArray = newArray.concat(value);
}
return newArray;
}
else {
const newArray = new Array(numValues);
newArray.fill(value);
return newArray;
}
}
function assert(val, message) {
if (!val) {
throw new AssertionError(message);
}
}
function count(array, refernce) {
let counter = 0;
for (const item of array) {
if (item === refernce) {
counter++;
}
}
return counter;
}
function singletonOrArray(xs) {
if (xs.length === 1) {
return xs[0];
}
return xs;
}
function toList(x) {
if (Array.isArray(x)) {
return x;
}
return [x];
}
function toSnakeCase(name) {
const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
if (insecure[0] !== '_') {
return insecure;
}
return 'private' + insecure;
}
function toCamelCase(identifier) {
if (identifier.length <= 1) {
return identifier;
}
if (identifier.indexOf('_') === -1) {
return identifier;
}
return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase());
}
let _GLOBAL_CUSTOM_OBJECTS = {};
function serializeKerasObject(instance) {
if (instance === null || instance === undefined) {
return null;
}
const dict = {};
dict['className'] = instance.getClassName();
dict['config'] = instance.getConfig();
return dict;
}
function convertNDArrayScalarsInConfig(config) {
if (config == null || typeof config !== 'object') {
return;
}
else if (Array.isArray(config)) {
config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));
}
else {
const fields = Object.keys(config);
for (const field of fields) {
const value = config[field];
if (value != null && typeof value === 'object') {
if (!Array.isArray(value) && value['type'] === 'ndarray' &&
typeof value['value'] === 'number') {
config[field] = value['value'];
}
else {
convertNDArrayScalarsInConfig(value);
}
}
}
}
}
function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) {
if (typeof identifier === 'string') {
const functionName = identifier;
let fn;
if (functionName in customObjects) {
fn = customObjects[functionName];
}
else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
}
else {
fn = moduleObjects[functionName];
if (fn == null) {
throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` +
`This may be due to one of the following reasons:\n` +
`1. The ${printableModuleName} is defined in Python, in which ` +
`case it needs to be ported to TensorFlow.js or your JavaScript ` +
`code.\n` +
`2. The custom ${printableModuleName} is defined in JavaScript, ` +
`but is not registered properly with ` +
`tf.serialization.registerClass().`);
}
}
return fn;
}
else {
const config = identifier;
if (config['className'] == null || config['config'] == null) {
throw new ValueError(`${printableModuleName}: Improper config format: ` +
`${JSON.stringify(config)}.\n` +
`'className' and 'config' must set.`);
}
const className = config['className'];
let cls, fromConfig;
if (className in customObjects) {
[cls, fromConfig] = customObjects[className];
}
else if (className in _GLOBAL_CUSTOM_OBJECTS) {
[cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];
}
else if (className in moduleObjects) {
[cls, fromConfig] = moduleObjects[className];
}
if (cls == null) {
throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` +
`This may be due to one of the following reasons:\n` +
`1. The ${printableModuleName} is defined in Python, in which ` +
`case it needs to be ported to TensorFlow.js or your JavaScript ` +
`code.\n` +
`2. The custom ${printableModuleName} is defined in JavaScript, ` +
`but is not registered properly with ` +
`tf.serialization.registerClass().`);
}
if (fromConfig != null) {
const customObjectsCombined = {};
for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {
customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
}
for (const key of Object.keys(customObjects)) {
customObjectsCombined[key] = customObjects[key];
}
const nestedConfig = config['config'];
nestedConfig['customObjects'] = customObjectsCombined;
const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (const key of Object.keys(customObjects)) {
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
convertNDArrayScalarsInConfig(config['config']);
const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
}
else {
const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
for (const key of Object.keys(customObjects)) {
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
}
const returnObj = new cls(config['config']);
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
return returnObj;
}
}
}
function numberCompare(a, b) {
return (a < b) ? -1 : ((a > b) ? 1 : 0);
}
function reverseNumberCompare(a, b) {
return -1 * numberCompare(a, b);
}
function unique(xs) {
if (xs == null) {
return xs;
}
const out = [];
for (const x of xs) {
if (out.indexOf(x) === -1) {
out.push(x);
}
}
return out;
}
function isObjectEmpty(obj) {
if (obj == null) {
throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);
}
for (const key in obj) {
if (obj.hasOwnProperty(key)) {
return false;
}
}
return true;
}
function checkStringTypeUnionValue(values, label, value) {
if (value == null) {
return;
}
if (values.indexOf(value) < 0) {
throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`);
}
}
function assertPositiveInteger(value, name) {
if (Array.isArray(value)) {
assert$1(value.length > 0, () => `${name} is unexpectedly an empty array.`);
value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));
}
else {
assert$1(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` +
`${formatAsFriendlyString(value)}.`);
}
}
function formatAsFriendlyString(value) {
if (value === null) {
return 'null';
}
else if (Array.isArray(value)) {
return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';
}
else if (typeof value === 'string') {
return `"${value}"`;
}
else {
return `${value}`;
}
}
function debounce(f, waitMs, nowFunc) {
let lastTime = nowFunc != null ? nowFunc() : now();
let lastResult;
const f2 = (...args) => {
const now$1 = nowFunc != null ? nowFunc() : now();
if (now$1 - lastTime < waitMs) {
return lastResult;
}
lastTime = now$1;
lastResult = f(...args);
return lastResult;
};
return f2;
}
function mapActivationToFusedKernel(activationName) {
if (activationName === 'relu') {
return 'relu';
}
if (activationName === 'linear') {
return 'linear';
}
if (activationName === 'elu') {
return 'elu';
}
return null;
}
let _nextUniqueTensorId = 0;
function getNextUniqueTensorId() {
return _nextUniqueTensorId++;
}
const _uidPrefixes = {};
function getUid(prefix = '') {
if (!(prefix in _uidPrefixes)) {
_uidPrefixes[prefix] = 0;
}
_uidPrefixes[prefix] += 1;
return prefix + _uidPrefixes[prefix].toString();
}
const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
const nameMap = new Map();
function checkDataFormat(value) {
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
}
const _nameScopeStack = [];
const _nameScopeDivider = '/';
function nameScope(name, fn) {
_nameScopeStack.push(name);
try {
const val = fn();
_nameScopeStack.pop();
return val;
}
catch (e) {
_nameScopeStack.pop();
throw e;
}
}
function currentNameScopePrefix() {
if (_nameScopeStack.length === 0) {
return '';
}
else {
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
}
}
function getScopedTensorName(tensorName) {
if (!isValidTensorName(tensorName)) {
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
}
return currentNameScopePrefix() + tensorName;
}
function getUniqueTensorName(scopedName) {
if (!isValidTensorName(scopedName)) {
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
}
if (!nameMap.has(scopedName)) {
nameMap.set(scopedName, 0);
}
const index = nameMap.get(scopedName);
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
if (index > 0) {
const result = `${scopedName}_${index}`;
nameMap.set(result, 1);
return result;
}
else {
return scopedName;
}
}
const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
function isValidTensorName(name) {
return !!name.match(tensorNameRegex);
}
function arrayProd(array, begin, end) {
if (begin == null) {
begin = 0;
}
if (end == null) {
end = array.length;
}
let prod = 1;
for (let i = begin; i < end; ++i) {
prod *= array[i];
}
return prod;
}
function range(begin, end) {
if (end < begin) {
throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
}
const out = [];
for (let i = begin; i < end; ++i) {
out.push(i);
}
return out;
}
let _epsilon;
function epsilon() {
if (_epsilon == null) {
_epsilon = backend().epsilon();
}
return _epsilon;
}
function imageDataFormat() {
return 'channelsLast';
}
function cast(x, dtype) {
return cast$3(x, dtype);
}
function expandDims(x, axis = -1) {
const outShape = x.shape.slice();
if (axis < 0) {
axis = outShape.length + axis + 1;
}
outShape.splice(axis, 0, 1);
return reshape$2(x, outShape);
}
function repeat(x, n) {
return tidy(() => {
if (x.shape.length !== 2) {
throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
`rank-${x.shape.length} tensor.`);
}
const y = expandDims(x, 1);
return tile(y, [1, n, 1]);
});
}
function flatten(x) {
const newShape = [arrayProd(x.shape)];
return reshape$2(x, newShape);
}
function batchFlatten(x) {
if (x.rank <= 1) {
throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
}
const newShape = [x.shape[0], arrayProd(x.shape, 1)];
return reshape$2(x, newShape);
}
function sliceAlongFirstAxis(array, start, size) {
return tidy(() => {
switch (array.rank) {
case 1:
return slice1d(array, start, size);
case 2:
return slice2d(array, [start, 0], [size, array.shape[1]]);
case 3:
return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
case 4:
return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
case 5:
return slice$2(array, [start, 0, 0, 0, 0], [
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
]);
case 6:
return slice$2(array, [start, 0, 0, 0, 0, 0], [
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
array.shape[5]
]);
default:
throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
`${array.rank}`);
}
});
}
function tile(x, n) {
if (!Array.isArray(n)) {
n = [n];
}
if (x.rank !== n.length) {
throw new ValueError(`The length of input n (${n.length}) does not match ` +
`the number of dimensions in input x (${x.rank})`);
}
return tile$3(x, n);
}
function randomNormal(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
return randomNormal$1(shape, mean, stddev, dtype, seed);
}
function dot(a, b, activation, bias) {
if ((a.rank < 2) || (b.rank < 2)) {
throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
` but got x shape = ${a.shape} and y shape = ${b.shape}`);
}
if (b.rank >= 3) {
const xLastDim = a.shape.slice(-1)[0];
const ySecondLastDim = b.shape.slice(-2)[0];
if (xLastDim !== ySecondLastDim) {
throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
` y shape = ${b.shape}`);
}
}
if ((a.rank === 2) && (b.rank === 2)) {
const transposeA = false;
const transposeB = false;
return matMul({
a,
b: b,
transposeA,
transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation
});
}
else {
const aFirstDims = a.shape.slice();
const aLastDim = aFirstDims.pop();
a = reshape$2(a, [-1, aLastDim]);
const bShape = b.shape.slice();
const bLastDim = bShape.pop();
const ySecondLastDim = bShape.pop();
const yOtherDims = [...bShape, bLastDim];
const perm = Array.from({ length: b.rank }, (_, i) => {
if (i === 0) {
return b.rank - 2;
}
else if (i <= b.rank - 2) {
return i - 1;
}
return i;
});
b = reshape$2(transpose$2(b, perm), [ySecondLastDim, -1]);
const outputShape = [...aFirstDims, ...yOtherDims];
const transposeA = false;
const transposeB = false;
return reshape$2(matMul({
a,
b,
transposeA,
transposeB,
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
activation
}), outputShape);
}
}
function gather(reference, indices, axis) {
return tidy(() => {
if (Array.isArray(indices)) {
indices = tensor1d(indices, 'int32');
}
else {
indices = cast$3(indices, 'int32');
}
return gather$1(reference, indices, axis);
});
}
function square(x) {
return mul(x, x);
}
function reshapeBias(xRank, bias, dataFormat) {
const biasShape = bias.shape;
if (bias.rank !== 1 && bias.rank !== xRank) {
throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
`; expected it to be 1 or ${xRank}`);
}
if (xRank === 5) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return reshape$2(bias, [1, biasShape[0], 1, 1, 1]);
}
else {
return reshape$2(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
}
}
else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return reshape$2(bias, [1, 1, 1, 1, biasShape[0]]);
}
else {
return reshape$2(bias, [1].concat(biasShape));
}
}
}
else if (xRank === 4) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return reshape$2(bias, [1, biasShape[0], 1, 1]);
}
else {
return reshape$2(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
}
}
else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return reshape$2(bias, [1, 1, 1, biasShape[0]]);
}
else {
return reshape$2(bias, [1].concat(biasShape));
}
}
}
else if (xRank === 3) {
if (dataFormat === 'channelsFirst') {
if (biasShape.length === 1) {
return reshape$2(bias, [1, biasShape[0], 1]);
}
else {
return reshape$2(bias, [1, biasShape[1], biasShape[0]]);
}
}
else if (dataFormat === 'channelsLast') {
if (biasShape.length === 1) {
return reshape$2(bias, [1, 1, biasShape[0]]);
}
else {
return reshape$2(bias, [1].concat(biasShape));
}
}
}
else if (xRank < 3) {
return bias;
}
throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
}
function biasAdd(x, bias, dataFormat) {
return tidy(() => {
if (dataFormat == null) {
dataFormat = imageDataFormat();
}
checkDataFormat(dataFormat);
return add$1(x, reshapeBias(x.rank, bias, dataFormat));
});
}
function elu(x, alpha = 1) {
if (alpha !== 1) {
throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
`yet.`);
}
return elu$3(x);
}
function softsign(x) {
return tidy(() => div$1(x, add$1(abs$2(x), 1)));
}
function dropout$1(x, level, noiseShape, seed) {
return tidy(() => dropout$2(x, level, noiseShape, seed));
}
function hardSigmoid(x) {
return tidy(() => {
const y = add$1(.5, mul(.2, x));
return clipByValue$2(y, 0, 1);
});
}
function inTrainPhase(x, alt, training = false) {
return training ? x() : alt();
}
const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
function checkFanMode(value) {
checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
}
function checkDistribution(value) {
checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
}
class Initializer extends Serializable {
fromConfigUsesCustomObjects() {
return false;
}
getConfig() {
return {};
}
}
class Zeros extends Initializer {
apply(shape, dtype) {
return zeros$1(shape, dtype);
}
}
Zeros.className = 'Zeros';
registerClass(Zeros);
class Ones extends Initializer {
apply(shape, dtype) {
return ones(shape, dtype);
}
}
Ones.className = 'Ones';
registerClass(Ones);
class Constant extends Initializer {
constructor(args) {
super();
if (typeof args !== 'object') {
throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
}
if (args.value === undefined) {
throw new ValueError(`config must have value set but got ${args}`);
}
this.value = args.value;
}
apply(shape, dtype) {
return tidy(() => mul(scalar(this.value), ones(shape, dtype)));
}
getConfig() {
return {
value: this.value,
};
}
}
Constant.className = 'Constant';
registerClass(Constant);
class RandomUniform extends Initializer {
constructor(args) {
super();
this.DEFAULT_MINVAL = -0.05;
this.DEFAULT_MAXVAL = 0.05;
this.minval = args.minval || this.DEFAULT_MINVAL;
this.maxval = args.maxval || this.DEFAULT_MAXVAL;
this.seed = args.seed;
}
apply(shape, dtype) {
return randomUniform(shape, this.minval, this.maxval, dtype, this.seed);
}
getConfig() {
return { minval: this.minval, maxval: this.maxval, seed: this.seed };
}
}
RandomUniform.className = 'RandomUniform';
registerClass(RandomUniform);
class RandomNormal extends Initializer {
constructor(args) {
super();
this.DEFAULT_MEAN = 0.;
this.DEFAULT_STDDEV = 0.05;
this.mean = args.mean || this.DEFAULT_MEAN;
this.stddev = args.stddev || this.DEFAULT_STDDEV;
this.seed = args.seed;
}
apply(shape, dtype) {
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
}
return randomNormal(shape, this.mean, this.stddev, dtype, this.seed);
}
getConfig() {
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
}
}
RandomNormal.className = 'RandomNormal';
registerClass(RandomNormal);
class TruncatedNormal extends Initializer {
constructor(args) {
super();
this.DEFAULT_MEAN = 0.;
this.DEFAULT_STDDEV = 0.05;
this.mean = args.mean || this.DEFAULT_MEAN;
this.stddev = args.stddev || this.DEFAULT_STDDEV;
this.seed = args.seed;
}
apply(shape, dtype) {
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`);
}
return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
}
getConfig() {
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
}
}
TruncatedNormal.className = 'TruncatedNormal';
registerClass(TruncatedNormal);
class Identity extends Initializer {
constructor(args) {
super();
this.gain = args.gain != null ? args.gain : 1.0;
}
apply(shape, dtype) {
return tidy(() => {
if (shape.length !== 2 || shape[0] !== shape[1]) {
throw new ValueError('Identity matrix initializer can only be used for' +
' 2D square matrices.');
}
else {
return mul(this.gain, eye(shape[0]));
}
});
}
getConfig() {
return { gain: this.gain };
}
}
Identity.className = 'Identity';
registerClass(Identity);
function computeFans(shape, dataFormat = 'channelsLast') {
let fanIn;
let fanOut;
checkDataFormat(dataFormat);
if (shape.length === 2) {
fanIn = shape[0];
fanOut = shape[1];
}
else if ([3, 4, 5].indexOf(shape.length) !== -1) {
if (dataFormat === 'channelsFirst') {
const receptiveFieldSize = arrayProd(shape, 2);
fanIn = shape[1] * receptiveFieldSize;
fanOut = shape[0] * receptiveFieldSize;
}
else if (dataFormat === 'channelsLast') {
const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
fanIn = shape[shape.length - 2] * receptiveFieldSize;
fanOut = shape[shape.length - 1] * receptiveFieldSize;
}
}
else {
const shapeProd = arrayProd(shape);
fanIn = Math.sqrt(shapeProd);
fanOut = Math.sqrt(shapeProd);
}
return [fanIn, fanOut];
}
class VarianceScaling extends Initializer {
constructor(args) {
super();
if (args.scale < 0.0) {
throw new ValueError(`scale must be a positive float. Got: ${args.scale}`);
}
this.scale = args.scale == null ? 1.0 : args.scale;
this.mode = args.mode == null ? 'fanIn' : args.mode;
checkFanMode(this.mode);
this.distribution =
args.distribution == null ? 'normal' : args.distribution;
checkDistribution(this.distribution);
this.seed = args.seed;
}
apply(shape, dtype) {
const fans = computeFans(shape);
const fanIn = fans[0];
const fanOut = fans[1];
let scale = this.scale;
if (this.mode === 'fanIn') {
scale /= Math.max(1, fanIn);
}
else if (this.mode === 'fanOut') {
scale /= Math.max(1, fanOut);
}
else {
scale /= Math.max(1, (fanIn + fanOut) / 2);
}
if (this.distribution === 'normal') {
const stddev = Math.sqrt(scale);
dtype = dtype || 'float32';
if (dtype !== 'float32' && dtype !== 'int32') {
throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`);
}
return truncatedNormal(shape, 0, stddev, dtype, this.seed);
}
else {
const limit = Math.sqrt(3 * scale);
return randomUniform(shape, -limit, limit, dtype, this.seed);
}
}
getConfig() {
return {
scale: this.scale,
mode: this.mode,
distribution: this.distribution,
seed: this.seed
};
}
}
VarianceScaling.className = 'VarianceScaling';
registerClass(VarianceScaling);
class GlorotUniform extends VarianceScaling {
constructor(args) {
super({
scale: 1.0,
mode: 'fanAvg',
distribution: 'uniform',
seed: args == null ? null : args.seed
});
}
getClassName() {
return VarianceScaling.className;
}
}
GlorotUniform.className = 'GlorotUniform';
registerClass(GlorotUniform);
class GlorotNormal extends VarianceScaling {
constructor(args) {
super({
scale: 1.0,
mode: 'fanAvg',
distribution: 'normal',
seed: args == null ? null : args.seed
});
}
getClassName() {
return VarianceScaling.className;
}
}
GlorotNormal.className = 'GlorotNormal';
registerClass(GlorotNormal);
class HeNormal extends VarianceScaling {
constructor(args) {
super({
scale: 2.0,
mode: 'fanIn',
distribution: 'normal',
seed: args == null ? null : args.seed
});
}
getClassName() {
return VarianceScaling.className;
}
}
HeNormal.className = 'HeNormal';
registerClass(HeNormal);
class HeUniform extends VarianceScaling {
constructor(args) {
super({
scale: 2.0,
mode: 'fanIn',
distribution: 'uniform',
seed: args == null ? null : args.seed
});
}
getClassName() {
return VarianceScaling.className;
}
}
HeUniform.className = 'HeUniform';
registerClass(HeUniform);
class LeCunNormal extends VarianceScaling {
constructor(args) {
super({
scale: 1.0,
mode: 'fanIn',
distribution: 'normal',
seed: args == null ? null : args.seed
});
}
getClassName() {
return VarianceScaling.className;
}
}
LeCunNormal.className = 'LeCunNormal';
registerClass(LeCunNormal);
class LeCunUniform extends VarianceScaling {
constructor(args) {
super({
scale: 1.0,
mode: 'fanIn',
distribution: 'uniform',
seed: args == null ? null : args.seed
});
}
getClassName() {
return VarianceScaling.className;
}
}
LeCunUniform.className = 'LeCunUniform';
registerClass(LeCunUniform);
class Orthogonal extends Initializer {
constructor(args) {
super();
this.DEFAULT_GAIN = 1;
this.ELEMENTS_WARN_SLOW = 2000;
this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
this.seed = args.seed;
}
apply(shape, dtype) {
return tidy(() => {
if (shape.length < 2) {
throw new NotImplementedError('Shape must be at least 2D.');
}
if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
throw new TypeError(`Unsupported data type ${dtype}.`);
}
dtype = dtype;
const numRows = sizeFromShape(shape.slice(0, -1));
const numCols = shape[shape.length - 1];
const numElements = numRows * numCols;
if (numElements > this.ELEMENTS_WARN_SLOW) {
console.warn(`Orthogonal initializer is being called on a matrix with more ` +
`than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
`Slowness may result.`);
}
const flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)];
const randNormalMat = randomNormal(flatShape, 0, 1, dtype, this.seed);
const qr = linalg.qr(randNormalMat, false);
let qMat = qr[0];
const rMat = qr[1];
const diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]);
qMat = mul(qMat, diag.sign());
if (numRows < numCols) {
qMat = qMat.transpose();
}
return mul(scalar(this.gain), qMat.reshape(shape));
});
}
getConfig() {
return {
gain: this.gain,
seed: this.seed,
};
}
}
Orthogonal.className = 'Orthogonal';
registerClass(Orthogonal);
const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'constant': 'Constant',
'glorotNormal': 'GlorotNormal',
'glorotUniform': 'GlorotUniform',
'heNormal': 'HeNormal',
'heUniform': 'HeUniform',
'identity': 'Identity',
'leCunNormal': 'LeCunNormal',
'leCunUniform': 'LeCunUniform',
'ones': 'Ones',
'orthogonal': 'Orthogonal',
'randomNormal': 'RandomNormal',
'randomUniform': 'RandomUniform',
'truncatedNormal': 'TruncatedNormal',
'varianceScaling': 'VarianceScaling',
'zeros': 'Zeros'
};
function deserializeInitializer(config, customObjects = {}) {
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
}
function serializeInitializer(initializer) {
return serializeKerasObject(initializer);
}
function getInitializer(identifier) {
if (typeof identifier === 'string') {
const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
identifier;
if (className === 'GlorotNormal') {
return new GlorotNormal();
}
else if (className === 'GlorotUniform') {
return new GlorotUniform();
}
else if (className === 'HeNormal') {
return new HeNormal();
}
else if (className === 'HeUniform') {
return new HeUniform();
}
else if (className === 'LeCunNormal') {
return new LeCunNormal();
}
else if (className === 'LeCunUniform') {
return new LeCunUniform();
}
else {
const config = {};
config['className'] = className;
config['config'] = {};
return deserializeInitializer(config);
}
}
else if (identifier instanceof Initializer) {
return identifier;
}
else {
return deserializeInitializer(identifier);
}
}
function normalizeShapeList(x) {
if (x.length === 0) {
return [];
}
if (!Array.isArray(x[0])) {
return [x];
}
return x;
}
function getExactlyOneTensor(xs) {
let x;
if (Array.isArray(xs)) {
if (xs.length !== 1) {
throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
}
x = xs[0];
}
else {
x = xs;
}
return x;
}
function getExactlyOneShape(shapes) {
if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
if (shapes.length === 1) {
shapes = shapes;
return shapes[0];
}
else {
throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
}
}
else {
return shapes;
}
}
function countParamsInWeights(weights) {
let count = 0;
for (const weight of weights) {
if (weight.shape.length === 0) {
count += 1;
}
else {
count += weight.shape.reduce((a, b) => a * b);
}
}
return count;
}
const DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
class LayerVariable {
constructor(val, dtype = 'float32', name = DEFAULT_VARIABLE_NAME_PREFIX, trainable = true, constraint = null) {
this.dtype = dtype == null ? 'float32' : dtype;
this.shape = val.shape;
this.id = getNextUniqueTensorId();
name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
this.originalName = getScopedTensorName(name);
this.name = getUniqueTensorName(this.originalName);
this.trainable_ = trainable;
this.constraint = constraint;
this.val = variable(val, this.trainable_, this.name, this.dtype);
}
read() {
this.assertNotDisposed();
return this.val;
}
write(newVal) {
this.assertNotDisposed();
checkShapesMatch(this.val, newVal);
if (this.val.id !== newVal.id) {
this.val.assign(newVal);
if (this.constraint != null) {
this.val.assign(this.constraint.apply(this.val));
}
}
return this;
}
dispose() {
this.assertNotDisposed();
this.val.dispose();
}
assertNotDisposed() {
if (this.val.isDisposed) {
throw new Error(`LayersVariable ${this.name} is already disposed.`);
}
}
get trainable() {
return this.trainable_;
}
set trainable(trainable) {
this.trainable_ = trainable;
this.val.trainable = trainable;
}
}
function checkShapesMatch(x, y) {
if (x.shape.toString() !== y.shape.toString()) {
throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' +
JSON.stringify(y.shape));
}
}
function batchGetValue(xs) {
return xs.map(x => x.read());
}
function batchSetValue(variablesAndValues) {
variablesAndValues.forEach(variableAndValue => {
const variable = variableAndValue[0];
variable.write(variableAndValue[1]);
});
}
class InputSpec {
constructor(args) {
this.dtype = args.dtype;
this.shape = args.shape;
if (args.shape != null) {
this.ndim = args.shape.length;
}
else {
this.ndim = args.ndim;
}
this.maxNDim = args.maxNDim;
this.minNDim = args.minNDim;
this.axes = args.axes || {};
}
}
class SymbolicTensor {
constructor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
this.dtype = dtype;
this.shape = shape;
this.sourceLayer = sourceLayer;
this.inputs = inputs;
this.callArgs = callArgs;
this.outputTensorIndex = outputTensorIndex;
this.id = getNextUniqueTensorId();
if (name != null) {
this.originalName = getScopedTensorName(name);
this.name = getUniqueTensorName(this.originalName);
}
this.rank = shape.length;
}
}
let _nextNodeID = 0;
class Node {
constructor(args,
callArgs) {
this.callArgs = callArgs;
this.id = _nextNodeID++;
this.outboundLayer = args.outboundLayer;
this.inboundLayers = args.inboundLayers;
this.nodeIndices = args.nodeIndices;
this.tensorIndices = args.tensorIndices;
this.inputTensors = args.inputTensors;
this.outputTensors = args.outputTensors;
this.inputMasks = args.inputMasks;
this.outputMasks = args.outputMasks;
this.inputShapes = args.inputShapes;
this.outputShapes = args.outputShapes;
for (const layer of args.inboundLayers) {
if (layer != null) {
layer.outboundNodes.push(this);
}
}
args.outboundLayer.inboundNodes.push(this);
}
getConfig() {
const inboundNames = [];
for (const layer of this.inboundLayers) {
if (layer != null) {
inboundNames.push(layer.name);
}
else {
inboundNames.push(null);
}
}
return {
outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
inboundLayers: inboundNames,
nodeIndices: this.nodeIndices,
tensorIndices: this.tensorIndices
};
}
}
let _nextLayerID = 0;
class Layer extends Serializable {
constructor(args = {}) {
super();
this._callHook = null;
this._addedWeightNames = [];
this._stateful = false;
this.id = _nextLayerID++;
this.activityRegularizer = null;
this.inputSpec = null;
this.supportsMasking = false;
this._trainableWeights = [];
this._nonTrainableWeights = [];
this._losses = [];
this._updates = [];
this._built = false;
this.inboundNodes = [];
this.outboundNodes = [];
let name = args.name;
if (!name) {
const prefix = this.getClassName();
name = toSnakeCase(prefix) + '_' + getUid(prefix);
}
this.name = name;
this.trainable_ = args.trainable == null ? true : args.trainable;
if (args.inputShape != null || args.batchInputShape != null) {
let batchInputShape;
if (args.batchInputShape != null) {
batchInputShape = args.batchInputShape;
}
else if (args.inputShape != null) {
let batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
batchInputShape = [batchSize].concat(args.inputShape);
}
this.batchInputShape = batchInputShape;
let dtype = args.dtype;
if (dtype == null) {
dtype = args.inputDType;
}
if (dtype == null) {
dtype = 'float32';
}
this.dtype = dtype;
}
if (args.weights != null) {
this.initialWeights = args.weights;
}
else {
this.initialWeights = null;
}
this._refCount = null;
this.fastWeightInitDuringBuild = false;
}
static nodeKey(layer, nodeIndex) {
return layer.name + '_ib-' + nodeIndex.toString();
}
getNodeAtIndex(nodeIndex, attrName) {
if (this.inboundNodes.length === 0) {
throw new RuntimeError('The layer has never been called ' +
`and thus has no defined ${attrName}.`);
}
if (this.inboundNodes.length <= nodeIndex) {
throw new ValueError(`Asked to get ${attrName} at node ${nodeIndex}, ` +
`but the layer has only ${this.inboundNodes.length} inbound nodes.`);
}
return this.inboundNodes[nodeIndex];
}
getInputAt(nodeIndex) {
return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
}
getOutputAt(nodeIndex) {
return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
}
get input() {
if (this.inboundNodes.length > 1) {
throw new AttributeError(`Layer ${this.name}` +
' has multiple inbound nodes, ' +
'hence the notion of "layer input" ' +
'is ill-defined. ' +
'Use `getInputAt(nodeIndex)` instead.');
}
else if (this.inboundNodes.length === 0) {
throw new AttributeError(`Layer ${this.name}` +
' is not connected, no input to return.');
}
return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
}
get output() {
if (this.inboundNodes.length === 0) {
throw new AttributeError(`Layer ${this.name}` +
' has no inbound nodes.');
}
if (this.inboundNodes.length > 1) {
throw new AttributeError(`Layer ${this.name}` +
' has multiple inbound nodes, ' +
'hence the notion of "layer output" ' +
'is ill-defined. ' +
'Use `getOutputAt(nodeIndex)` instead.');
}
return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
}
get losses() {
return this._losses;
}
calculateLosses() {
return this.losses.map(lossFn => lossFn());
}
get updates() {
return this._updates;
}
get built() {
return this._built;
}
set built(built) {
this._built = built;
}
get trainable() {
return this.trainable_;
}
set trainable(trainable) {
this._trainableWeights.forEach(w => w.trainable = trainable);
this.trainable_ = trainable;
}
get trainableWeights() {
if (this.trainable_) {
return this._trainableWeights.filter(w => w.trainable);
}
else {
return [];
}
}
set trainableWeights(weights) {
this._trainableWeights = weights;
}
get nonTrainableWeights() {
if (this.trainable) {
return this._trainableWeights.filter(w => !w.trainable)
.concat(this._nonTrainableWeights);
}
else {
return this._trainableWeights.concat(this._nonTrainableWeights);
}
}
set nonTrainableWeights(weights) {
this._nonTrainableWeights = weights;
}
get weights() {
return this.trainableWeights.concat(this.nonTrainableWeights);
}
get stateful() {
return this._stateful;
}
resetStates() {
if (!this.stateful) {
throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' +
'object.');
}
}
assertInputCompatibility(inputs) {
const inputsList = toList(inputs);
if (this.inputSpec == null || this.inputSpec.length === 0) {
return;
}
const inputSpec = toList(this.inputSpec);
if (inputsList.length !== inputSpec.length) {
throw new ValueError(`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
`but it received ${inputsList.length} input tensors. ` +
`Input received: ${inputs}`);
}
for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
const x = inputsList[inputIndex];
const spec = inputSpec[inputIndex];
if (spec == null) {
continue;
}
const ndim = x.rank;
if (spec.ndim != null) {
if (ndim !== spec.ndim) {
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}: ` +
`expected ndim=${spec.ndim}, found ndim=${ndim}`);
}
}
if (spec.maxNDim != null) {
if (ndim > spec.maxNDim) {
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
`: expected max_ndim=${spec.maxNDim}, found ndim=${ndim}`);
}
}
if (spec.minNDim != null) {
if (ndim < spec.minNDim) {
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
`: expected min_ndim=${spec.minNDim}, found ndim=${ndim}.`);
}
}
if (spec.dtype != null) {
if (x.dtype !== spec.dtype) {
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name} ` +
`: expected dtype=${spec.dtype}, found dtype=${x.dtype}.`);
}
}
if (spec.axes) {
const xShape = x.shape;
for (const key in spec.axes) {
const axis = Number(key);
const value = spec.axes[key];
const xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
`${this.name}: expected axis ${axis} of input shape to ` +
`have value ${value} but got shape ${xShape}.`);
}
}
}
if (spec.shape != null) {
for (let i = 0; i < spec.shape.length; ++i) {
const specDim = spec.shape[i];
const dim = x.shape[i];
if (specDim != null && dim != null) {
if (specDim !== dim) {
throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
`${this.name}: expected shape=${spec.shape}, ` +
`found shape=${x.shape}.`);
}
}
}
}
}
}
call(inputs, kwargs) {
return inputs;
}
invokeCallHook(inputs, kwargs) {
if (this._callHook != null) {
this._callHook(inputs, kwargs);
}
}
setCallHook(callHook) {
this._callHook = callHook;
}
clearCallHook() {
this._callHook = null;
}
apply(inputs, kwargs) {
kwargs = kwargs || {};
this.assertNotDisposed();
const inputsList = toList(inputs);
const allAreSymbolic = checkAllSymbolic(inputs);
const noneAreSymbolic = checkNoneSymbolic(inputs);
if (allAreSymbolic === noneAreSymbolic) {
throw new ValueError('Arguments to apply() must be all ' +
'SymbolicTensors or all Tensors');
}
return nameScope(this.name, () => {
if (!this.built) {
this.assertInputCompatibility(inputs);
const inputShapes = [];
for (const xElem of toList(inputs)) {
inputShapes.push(xElem.shape);
}
this.build(singletonOrArray(inputShapes));
this.built = true;
if (this.initialWeights) {
this.setWeights(this.initialWeights);
}
if (this._refCount === null && noneAreSymbolic) {
this._refCount = 1;
}
}
this.assertInputCompatibility(inputs);
if (noneAreSymbolic) {
let output = this.call(inputs, kwargs);
if (this.supportsMasking) {
this.setMaskMetadata(inputs, output);
}
const outputList = toList(output);
const outputListCopy = [];
for (let x of outputList) {
if (inputsList.indexOf(x) !== -1) {
x = x.clone();
}
outputListCopy.push(x);
}
output = singletonOrArray(outputListCopy);
if (this.activityRegularizer != null) {
throw new NotImplementedError('Layer invocation in the presence of activity ' +
'regularizer(s) is not supported yet.');
}
return output;
}
else {
const inputShape = collectInputShape(inputs);
const outputShape = this.computeOutputShape(inputShape);
let output;
const outputDType = guessOutputDType();
this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] :
inputShape);
if (outputShape != null && outputShape.length > 0 &&
Array.isArray(outputShape[0])) {
output = outputShape
.map((shape, index) => new SymbolicTensor(outputDType, shape, this, toList(inputs), kwargs, this.name, index));
}
else {
output = new SymbolicTensor(outputDType, outputShape, this, toList(inputs), kwargs, this.name);
}
this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs);
this._refCount++;
if (this.activityRegularizer != null) {
throw new NotImplementedError('Layer invocation in the presence of activity ' +
'regularizer(s) is not supported yet.');
}
return output;
}
});
}
warnOnIncompatibleInputShape(inputShape) {
if (this.batchInputShape == null) {
return;
}
else if (inputShape.length !== this.batchInputShape.length) {
console.warn(`The rank of the input tensor provided (shape: ` +
`${JSON.stringify(inputShape)}) does not match that of the ` +
`batchInputShape (${JSON.stringify(this.batchInputShape)}) ` +
`of the layer ${this.name}`);
}
else {
let dimMismatch = false;
this.batchInputShape.forEach((dimension, i) => {
if (dimension != null && inputShape[i] != null &&
inputShape[i] !== dimension) {
dimMismatch = true;
}
});
if (dimMismatch) {
console.warn(`The shape of the input tensor ` +
`(${JSON.stringify(inputShape)}) does not ` +
`match the expectation of layer ${this.name}: ` +
`${JSON.stringify(this.batchInputShape)}`);
}
}
}
get outputShape() {
if (this.inboundNodes == null || this.inboundNodes.length === 0) {
throw new AttributeError(`The layer ${this.name} has never been called and thus has no ` +
`defined output shape.`);
}
const allOutputShapes = [];
for (const node of this.inboundNodes) {
const shapeString = JSON.stringify(node.outputShapes);
if (allOutputShapes.indexOf(shapeString) === -1) {
allOutputShapes.push(shapeString);
}
}
if (allOutputShapes.length === 1) {
const outputShapes = this.inboundNodes[0].outputShapes;
if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) &&
outputShapes.length === 1) {
return outputShapes[0];
}
else {
return outputShapes;
}
}
else {
throw new AttributeError(`The layer ${this.name} has multiple inbound nodes with different ` +
`output shapes. Hence the notion of "output shape" is ill-defined ` +
`for the layer.`);
}
}
countParams() {
if (!this.built) {
throw new RuntimeError(`You tried to call countParams() on ${this.name}, ` +
`but the layer is not built yet. Build it first by calling ` +
`build(batchInputShape).`);
}
return countParamsInWeights(this.weights);
}
build(inputShape) {
this.built = true;
}
getWeights(trainableOnly = false) {
return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
}
setWeights(weights) {
tidy(() => {
const params = this.weights;
if (params.length !== weights.length) {
throw new ValueError(`You called setWeights(weights) on layer "${this.name}" ` +
`with a weight list of length ${weights.length}, ` +
`but the layer was expecting ${params.length} weights. ` +
`Provided weights: ${weights}...`);
}
if (params.length === 0) {
return;
}
const weightValueTuples = [];
const paramValues = batchGetValue(params);
for (let i = 0; i < paramValues.length; ++i) {
const pv = paramValues[i];
const p = params[i];
const w = weights[i];
if (!arraysEqual(pv.shape, w.shape)) {
throw new ValueError(`Layer weight shape ${pv.shape} ` +
`not compatible with provided weight shape ${w.shape}`);
}
weightValueTuples.push([p, w]);
}
batchSetValue(weightValueTuples);
});
}
addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) {
if (this._addedWeightNames.indexOf(name) !== -1) {
throw new ValueError(`Duplicate weight name ${name} for layer ${this.name}`);
}
this._addedWeightNames.push(name);
if (dtype == null) {
dtype = 'float32';
}
if (this.fastWeightInitDuringBuild) {
initializer = getInitializerFunc != null ? getInitializerFunc() :
getInitializer('zeros');
}
const initValue = initializer.apply(shape, dtype);
const weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
initValue.dispose();
if (regularizer != null) {
this.addLoss(() => regularizer.apply(weight.read()));
}
if (trainable == null) {
trainable = true;
}
if (trainable) {
this._trainableWeights.push(weight);
}
else {
this._nonTrainableWeights.push(weight);
}
return weight;
}
setFastWeightInitDuringBuild(value) {
this.fastWeightInitDuringBuild = value;
}
addLoss(losses) {
if (losses == null || Array.isArray(losses) && losses.length === 0) {
return;
}
losses = toList(losses);
if (this._losses !== undefined && this._losses !== null) {
this.losses.push(...losses);
}
}
computeOutputShape(inputShape) {
return inputShape;
}
computeMask(inputs, mask) {
if (!this.supportsMasking) {
if (mask != null) {
if (Array.isArray(mask)) {
mask.forEach(maskElement => {
if (maskElement != null) {
throw new TypeError(`Layer ${this.name} does not support masking, ` +
'but was passed an inputMask.');
}
});
}
else {
throw new TypeError(`Layer ${this.name} does not support masking, ` +
'but was passed an inputMask.');
}
}
return null;
}
return mask;
}
setMaskMetadata(inputs, outputs, previousMask) {
if (!this.supportsMasking) {
return;
}
const outputMasks = this.computeMask(inputs, previousMask);
const outputsList = toList(outputs);
const outputMasksList = toList(outputMasks);
if (outputsList.length !== outputMasksList.length) {
throw new Error(`${this.name} outputs ${outputsList.length} tensors ` +
`but ${outputsList.length} masks for those tensors`);
}
for (let i = 0; i < outputsList.length; i++) {
outputsList[i].kerasMask = outputMasksList[i];
}
}
addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs = null) {
const inputTensorList = toList(inputTensors);
outputTensors = toList(outputTensors);
inputMasks = toList(inputMasks);
outputMasks = toList(outputMasks);
inputShapes = normalizeShapeList(inputShapes);
outputShapes = normalizeShapeList(outputShapes);
const inboundLayers = [];
const nodeIndices = [];
const tensorIndices = [];
for (const x of inputTensorList) {
inboundLayers.push(x.sourceLayer);
nodeIndices.push(x.nodeIndex);
tensorIndices.push(x.tensorIndex);
}
new Node({
outboundLayer: this,
inboundLayers,
nodeIndices,
tensorIndices,
inputTensors: inputTensorList,
outputTensors,
inputMasks,
outputMasks,
inputShapes,
outputShapes
}, kwargs);
for (let i = 0; i < outputTensors.length; i++) {
outputTensors[i].sourceLayer = this;
outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
outputTensors[i].tensorIndex = i;
}
}
getConfig() {
const config = { name: this.name, trainable: this.trainable };
if (this.batchInputShape != null) {
config['batchInputShape'] = this.batchInputShape;
}
if (this.dtype != null) {
config['dtype'] = this.dtype;
}
return config;
}
disposeWeights() {
this.weights.forEach(weight => weight.dispose());
return this.weights.length;
}
assertNotDisposed() {
if (this._refCount === 0) {
throw new Error(`Layer '${this.name}' is already disposed.`);
}
}
dispose() {
if (!this.built) {
throw new Error(`Cannot dispose Layer ${this.name} because it has not been ` +
`built yet.`);
}
if (this._refCount === null) {
throw new Error(`Cannot dispose Layer ${this.name} because it has not been used ` +
`yet.`);
}
this.assertNotDisposed();
let numDisposedVariables = 0;
if (--this._refCount === 0) {
numDisposedVariables = this.disposeWeights();
}
return { refCountAfterDispose: this._refCount, numDisposedVariables };
}
}
function collectInputShape(inputTensors) {
inputTensors =
toList(inputTensors);
const shapes = [];
for (const x of inputTensors) {
shapes.push(x.shape);
}
return singletonOrArray(shapes);
}
function guessOutputDType(inputTensors) {
return 'float32';
}
function getSourceInputs(tensor, layer, nodeIndex) {
if (layer == null || (nodeIndex != null && nodeIndex > 0)) {
layer = tensor.sourceLayer;
nodeIndex = tensor.nodeIndex;
}
if (layer.inboundNodes.length === 0) {
return [tensor];
}
else {
const node = layer.inboundNodes[nodeIndex];
if (node.inboundLayers.length === 0) {
return node.inputTensors;
}
else {
const sourceTensors = [];
for (let i = 0; i < node.inboundLayers.length; i++) {
const x = node.inputTensors[i];
const layer = node.inboundLayers[i];
const nodeIndex = node.nodeIndices[i];
const previousSources = getSourceInputs(x, layer, nodeIndex);
for (const x of previousSources) {
if (sourceTensors.indexOf(x) === -1) {
sourceTensors.push(x);
}
}
}
return sourceTensors;
}
}
}
function checkAllSymbolic(tensors) {
let allAreSymbolic = true;
for (const tensor of toList(tensors)) {
if (!(tensor instanceof SymbolicTensor)) {
allAreSymbolic = false;
break;
}
}
return allAreSymbolic;
}
function checkNoneSymbolic(tensors) {
let noneAreSymbolic = true;
for (const tensor of toList(tensors)) {
if (tensor instanceof SymbolicTensor) {
noneAreSymbolic = false;
break;
}
}
return noneAreSymbolic;
}
class InputLayer extends Layer {
constructor(args) {
super({
dtype: args.dtype,
name: args.name != null ? args.name : getUid('input').toString()
});
if (args.batchSize == null) {
args.batchSize = null;
}
if (args.sparse == null) {
args.sparse = false;
}
this.trainable = false;
this.built = true;
this.sparse = args.sparse;
if (args.inputShape != null && args.batchInputShape != null) {
throw new ValueError('Only provide the inputShape OR ' +
'batchInputShape argument to inputLayer, not both at the same time.');
}
let batchInputShape = args.batchInputShape;
if (batchInputShape == null) {
if (args.inputShape == null) {
throw new ValueError('An InputLayer should be passed either a ' +
'`batchInputShape` or an `inputShape`.');
}
else {
batchInputShape = [args.batchSize].concat(args.inputShape);
}
}
else {
if (args.batchSize != null) {
throw new ValueError('Cannot specify batchSize if batchInputShape is ' +
'specified when creating an InputLayer.');
}
}
const dtype = args.dtype || 'float32';
this.batchInputShape = batchInputShape;
this.dtype = dtype;
this.inputSpec = [{ shape: batchInputShape }];
const inputTensor = new SymbolicTensor(this.dtype, this.batchInputShape, this, [], {}, this.name);
inputTensor.nodeIndex = 0;
inputTensor.tensorIndex = 0;
new Node({
outboundLayer: this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: [inputTensor],
outputTensors: [inputTensor],
inputMasks: [null],
outputMasks: [null],
inputShapes: [batchInputShape],
outputShapes: [batchInputShape]
});
}
apply(inputs, kwargs) {
throw new ValueError('Cannot pass any input to an ' +
`InputLayer's apply() method. InputLayer name: ${this.name}`);
}
dispose() {
return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 };
}
getConfig() {
return {
batchInputShape: this.batchInputShape,
dtype: this.dtype,
sparse: this.sparse,
name: this.name
};
}
}
InputLayer.className = 'InputLayer';
registerClass(InputLayer);
function Input(config) {
if (config.batchShape == null && config.shape == null) {
throw new Error('Please provide to Input either a `shape`' +
' or a `batchShape` argument. Note that ' +
'`shape` does not include the batch ' +
'dimension.');
}
if (config.batchShape != null && config.shape != null) {
throw new ValueError('Please provide either a `shape` or `batchShape` ' +
'argument to Input, but not both.');
}
let batchShape = config.batchShape;
if (config.shape != null && batchShape == null) {
batchShape = [null].concat(config.shape);
}
let dtype = config.dtype;
if (dtype == null) {
dtype = 'float32';
}
const inputLayer = new InputLayer({
batchInputShape: batchShape,
name: config.name,
dtype,
sparse: config.sparse
});
const outputs = inputLayer.inboundNodes[0].outputTensors;
return outputs[0];
}
function assertFeedCompatibility(key, val) {
if (key.dtype == null || key.dtype === val.dtype) {
return val;
}
try {
return cast$3(val, key.dtype);
}
catch (err) {
throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +
`of the key '${key.name}' (${key.dtype}).`);
}
}
class FeedDict {
constructor(feeds) {
this.id2Value = {};
this.id2Mask = {};
this.name2Id = {};
if (feeds instanceof FeedDict) {
for (const id in feeds.id2Value) {
this.id2Value[id] = feeds.id2Value[id];
if (id in feeds.id2Mask) {
this.id2Mask[id] = feeds.id2Mask[id];
}
}
}
else {
if (feeds == null) {
return;
}
for (const feed of feeds) {
this.add(feed.key, feed.value);
}
}
}
add(key, value, mask) {
if (this.id2Value[key.id] == null) {
this.id2Value[key.id] = assertFeedCompatibility(key, value);
this.name2Id[key.name] = key.id;
if (mask != null) {
this.id2Mask[key.id] = mask;
}
}
else {
throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);
}
return this;
}
addFeed(feed) {
this.add(feed.key, feed.value);
}
hasKey(key) {
return this.id2Value[key.id] != null;
}
names() {
return Object.keys(this.name2Id);
}
getValue(key) {
if (key instanceof SymbolicTensor) {
if (this.id2Value[key.id] == null) {
throw new ValueError(`Nonexistent key: ${key.name}`);
}
else {
return this.id2Value[key.id];
}
}
else {
const id = this.name2Id[key];
if (id == null) {
throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
}
return this.id2Value[id];
}
}
getMask(key) {
if (key instanceof SymbolicTensor) {
if (this.id2Value[key.id] == null) {
throw new ValueError(`Nonexistent key: ${key.name}`);
}
else {
return this.id2Mask[key.id];
}
}
else {
const id = this.name2Id[key];
if (id == null) {
throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
}
return this.id2Mask[id];
}
}
disposeMasks() {
if (this.id2Mask != null) {
dispose(this.id2Mask);
}
}
}
const cachedSorted = new LruCache();
const cachedRecipientCounts = new LruCache();
function execute(fetches, feedDict, kwargs, probe) {
const training = kwargs == null ? false : kwargs['training'];
const arrayFetches = Array.isArray(fetches);
const fetchArray = arrayFetches ? fetches : [fetches];
const outputNames = fetchArray.map(t => t.name);
const finalOutputs = [];
const feedNames = feedDict.names();
for (const outputName of outputNames) {
if (feedNames.indexOf(outputName) !== -1) {
finalOutputs.push(feedDict.getValue(outputName));
}
else {
finalOutputs.push(null);
}
}
const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(',');
let sorted = cachedSorted.get(fetchAndFeedKey);
let recipientCounts;
if (sorted == null) {
const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
sorted = out.sorted;
recipientCounts = out.recipientCounts;
cachedSorted.put(fetchAndFeedKey, sorted);
cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);
}
recipientCounts = {};
if (!training) {
Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));
}
const internalFeedDict = new FeedDict(feedDict);
for (let i = 0; i < sorted.length; ++i) {
const symbolic = sorted[i];
const srcLayer = symbolic.sourceLayer;
if (srcLayer instanceof InputLayer) {
continue;
}
const inputValues = [];
const inputMasks = [];
const tensorsToDispose = [];
let maskExists = false;
for (const input of symbolic.inputs) {
const value = internalFeedDict.getValue(input);
const mask = internalFeedDict.getMask(input);
inputValues.push(value);
inputMasks.push(mask);
if (mask != null) {
maskExists = true;
}
if (!training) {
recipientCounts[input.name]--;
if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
input.sourceLayer.stateful !== true) {
tensorsToDispose.push(value);
}
}
}
if (maskExists) {
kwargs = kwargs || {};
kwargs['mask'] = inputMasks[0];
}
const outputTensors = toList(srcLayer.apply(inputValues, kwargs));
let outputMask = null;
if (srcLayer.supportsMasking) {
outputMask = srcLayer.computeMask(inputValues, inputMasks);
}
const layerOutputs = getNodeOutputs(symbolic);
const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
for (let i = 0; i < outputSymbolicTensors.length; ++i) {
if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
}
const index = outputNames.indexOf(outputSymbolicTensors[i].name);
if (index !== -1) {
finalOutputs[index] = outputTensors[i];
}
}
if (!training) {
dispose(tensorsToDispose);
}
}
internalFeedDict.disposeMasks();
return arrayFetches ? finalOutputs : finalOutputs[0];
}
function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
assert$1(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`);
let finalSorted = [];
let finalRecipientMap = {};
if (fetches.length === 1) {
const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
finalSorted = out.sorted;
finalRecipientMap = out.recipientMap;
}
else {
const visited = new Set();
for (const fetch of fetches) {
const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);
for (const symbolicTensor of sorted) {
if (!visited.has(symbolicTensor.name)) {
finalSorted.push(symbolicTensor);
visited.add(symbolicTensor.name);
}
}
for (const name in recipientMap) {
if (finalRecipientMap[name] == null) {
finalRecipientMap[name] = new Set();
}
recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient));
}
}
}
return {
sorted: finalSorted,
recipientCounts: recipientMap2Counts(finalRecipientMap)
};
}
function recipientMap2Counts(recipientMap) {
const recipientCounts = {};
for (const name in recipientMap) {
recipientCounts[name] = recipientMap[name].size;
}
return recipientCounts;
}
function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
const visited = new Set();
const sorted = [];
const recipientMap = {};
for (const key of feedDict.names()) {
visited.add(key);
}
const stack = [];
const marks = [];
stack.push(fetch);
while (stack.length > 0) {
const top = stack[stack.length - 1];
if (visited.has(top.name)) {
stack.pop();
continue;
}
const topIsMarked = marks[marks.length - 1] === stack.length - 1;
if (top.inputs.length === 0 || topIsMarked) {
stack.pop();
sorted.push(top);
visited.add(top.name);
if (topIsMarked) {
marks.pop();
}
}
else {
marks.push(stack.length - 1);
for (const input of top.inputs) {
if (recipientMap[input.name] == null) {
recipientMap[input.name] = new Set();
}
recipientMap[input.name].add(top.name);
if (visited.has(input.name)) {
continue;
}
stack.push(input);
}
}
}
return { sorted, recipientMap };
}
function getNodeOutputs(fetch) {
let layerOutputs;
if (fetch.sourceLayer.inboundNodes.length === 1) {
layerOutputs = fetch.sourceLayer.output;
}
else {
let nodeIndex = null;
for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
for (const outputTensor of fetch.sourceLayer.inboundNodes[i]
.outputTensors) {
if (outputTensor.id === fetch.id) {
nodeIndex = i;
break;
}
}
}
layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
}
return layerOutputs;
}
function calcL2Norms(w, axis) {
return tidy(() => sqrt$2(sum$2(mul(w, w), axis, true)));
}
class Constraint extends Serializable {
getConfig() {
return {};
}
}
class MaxNorm extends Constraint {
constructor(args) {
super();
this.defaultMaxValue = 2;
this.defaultAxis = 0;
this.maxValue =
args.maxValue != null ? args.maxValue : this.defaultMaxValue;
this.axis = args.axis != null ? args.axis : this.defaultAxis;
}
apply(w) {
return tidy(() => {
const norms = calcL2Norms(w, this.axis);
const desired = clipByValue$2(norms, 0, this.maxValue);
return mul(w, div$1(desired, add$1(epsilon(), norms)));
});
}
getConfig() {
return { maxValue: this.maxValue, axis: this.axis };
}
}
MaxNorm.className = 'MaxNorm';
registerClass(MaxNorm);
class UnitNorm extends Constraint {
constructor(args) {
super();
this.defaultAxis = 0;
this.axis = args.axis != null ? args.axis : this.defaultAxis;
}
apply(w) {
return tidy(() => div$1(w, add$1(epsilon(), calcL2Norms(w, this.axis))));
}
getConfig() {
return { axis: this.axis };
}
}
UnitNorm.className = 'UnitNorm';
registerClass(UnitNorm);
class NonNeg extends Constraint {
apply(w) {
return relu$2(w);
}
}
NonNeg.className = 'NonNeg';
registerClass(NonNeg);
class MinMaxNorm extends Constraint {
constructor(args) {
super();
this.defaultMinValue = 0.0;
this.defaultMaxValue = 1.0;
this.defaultRate = 1.0;
this.defaultAxis = 0;
this.minValue =
args.minValue != null ? args.minValue : this.defaultMinValue;
this.maxValue =
args.maxValue != null ? args.maxValue : this.defaultMaxValue;
this.rate = args.rate != null ? args.rate : this.defaultRate;
this.axis = args.axis != null ? args.axis : this.defaultAxis;
}
apply(w) {
return tidy(() => {
const norms = calcL2Norms(w, this.axis);
const desired = add$1(mul(this.rate, clipByValue$2(norms, this.minValue, this.maxValue)), mul(1.0 - this.rate, norms));
return mul(w, div$1(desired, add$1(epsilon(), norms)));
});
}
getConfig() {
return {
minValue: this.minValue,
maxValue: this.maxValue,
rate: this.rate,
axis: this.axis
};
}
}
MinMaxNorm.className = 'MinMaxNorm';
registerClass(MinMaxNorm);
const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'maxNorm': 'MaxNorm',
'minMaxNorm': 'MinMaxNorm',
'nonNeg': 'NonNeg',
'unitNorm': 'UnitNorm'
};
function serializeConstraint(constraint) {
return serializeKerasObject(constraint);
}
function deserializeConstraint(config, customObjects = {}) {
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
}
function getConstraint(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === 'string') {
const className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
identifier;
const config = { className, config: {} };
return deserializeConstraint(config);
}
else if (identifier instanceof Constraint) {
return identifier;
}
else {
return deserializeConstraint(identifier);
}
}
function glorotUniform(args) {
return new GlorotUniform(args);
}
async function resolveScalarsInLogs(logs) {
if (logs == null) {
return;
}
const promises = [];
const keys = [];
const scalarsToDispose = [];
for (const key in logs) {
const value = logs[key];
if (typeof value !== 'number') {
const valueScalar = value;
promises.push(valueScalar.data());
keys.push(key);
scalarsToDispose.push(valueScalar);
}
}
if (promises.length > 0) {
const values = await Promise.all(promises);
for (let i = 0; i < values.length; ++i) {
logs[keys[i]] = values[i][0];
}
dispose(scalarsToDispose);
}
}
function disposeTensorsInLogs(logs) {
if (logs == null) {
return;
}
for (const key in logs) {
const value = logs[key];
if (typeof value !== 'number') {
value.dispose();
}
}
}
var ModelLoggingVerbosity;
(function (ModelLoggingVerbosity) {
ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
})(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
const DEFAULT_YIELD_EVERY_MS = 125;
class BaseCallback {
constructor() {
this.validationData = null;
}
setParams(params) {
this.params = params;
}
async onEpochBegin(epoch, logs) { }
async onEpochEnd(epoch, logs) { }
async onBatchBegin(batch, logs) { }
async onBatchEnd(batch, logs) { }
async onTrainBegin(logs) { }
async onTrainEnd(logs) { }
setModel(model) {
}
}
class CallbackList {
constructor(callbacks, queueLength = 10) {
if (callbacks == null) {
callbacks = [];
}
this.callbacks = callbacks;
this.queueLength = queueLength;
}
append(callback) {
this.callbacks.push(callback);
}
setParams(params) {
for (const callback of this.callbacks) {
callback.setParams(params);
}
}
setModel(model) {
for (const callback of this.callbacks) {
callback.setModel(model);
}
}
async onEpochBegin(epoch, logs) {
if (logs == null) {
logs = {};
}
for (const callback of this.callbacks) {
await callback.onEpochBegin(epoch, logs);
}
}
async onEpochEnd(epoch, logs) {
if (logs == null) {
logs = {};
}
for (const callback of this.callbacks) {
await callback.onEpochEnd(epoch, logs);
}
}
async onBatchBegin(batch, logs) {
if (logs == null) {
logs = {};
}
for (const callback of this.callbacks) {
await callback.onBatchBegin(batch, logs);
}
}
async onBatchEnd(batch, logs) {
if (logs == null) {
logs = {};
}
for (const callback of this.callbacks) {
await callback.onBatchEnd(batch, logs);
}
}
async onTrainBegin(logs) {
if (logs == null) {
logs = {};
}
for (const callback of this.callbacks) {
await callback.onTrainBegin(logs);
}
}
async onTrainEnd(logs) {
if (logs == null) {
logs = {};
}
for (const callback of this.callbacks) {
await callback.onTrainEnd(logs);
}
}
}
class BaseLogger extends BaseCallback {
constructor() {
super();
}
async onEpochBegin(epoch) {
this.seen = 0;
this.totals = {};
}
async onBatchEnd(batch, logs) {
if (logs == null) {
logs = {};
}
const batchSize = logs['size'] == null ? 0 : logs['size'];
this.seen += batchSize;
for (const key in logs) {
const value = logs[key];
if (typeof value === 'number') {
if (!this.totals.hasOwnProperty(key)) {
this.totals[key] = 0;
}
this.totals[key] = this.totals[key] + value * batchSize;
}
else {
let oldTotalsToDispose;
if (key in this.totals) {
oldTotalsToDispose = this.totals[key];
}
else {
this.totals[key] = 0;
}
const total = tidy(() => add$1((this.totals[key]), mul(value, batchSize)));
this.totals[key] = total;
if (oldTotalsToDispose != null) {
oldTotalsToDispose.dispose();
}
}
}
}
async onEpochEnd(epoch, logs) {
if (logs != null) {
for (const key of this.params['metrics']) {
if (this.totals[key] == null) {
continue;
}
if (typeof this.totals[key] === 'number') {
logs[key] = this.totals[key] / this.seen;
}
else {
tidy(() => {
const log = mul(div$1(1, this.seen), this.totals[key]);
logs[key] = log;
this.totals[key].dispose();
keep(logs[key]);
});
}
}
}
}
}
class History extends BaseCallback {
async onTrainBegin(logs) {
this.epoch = [];
this.history = {};
}
async onEpochEnd(epoch, logs) {
if (logs == null) {
logs = {};
}
this.epoch.push(epoch);
for (const key in logs) {
if (this.history[key] == null) {
this.history[key] = [];
}
this.history[key].push(logs[key]);
}
}
async syncData() {
const promises = [];
const keys = [];
const indices = [];
for (const key in this.history) {
const valueArray = this.history[key];
for (let i = 0; i < valueArray.length; ++i) {
if (typeof valueArray[i] !== 'number') {
const valueScalar = valueArray[i];
promises.push(valueScalar.data());
keys.push(key);
indices.push(i);
}
}
}
const values = await Promise.all(promises);
for (let n = 0; n < values.length; ++n) {
const tensorToDispose = this.history[keys[n]][indices[n]];
tensorToDispose.dispose();
this.history[keys[n]][indices[n]] = values[n][0];
}
}
}
class CustomCallback extends BaseCallback {
constructor(args, yieldEvery) {
super();
this.currentEpoch = 0;
this.nowFunc = args.nowFunc;
this.nextFrameFunc = args.nextFrameFunc || nextFrame;
this.yieldEvery = yieldEvery || 'auto';
if (this.yieldEvery === 'auto') {
this.yieldEvery = DEFAULT_YIELD_EVERY_MS;
}
if (this.yieldEvery === 'never' && args.onYield != null) {
throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' +
'Either change `yieldEvery` or remove the callback');
}
if (isNumber(this.yieldEvery)) {
this.maybeWait = debounce(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc);
}
this.trainBegin = args.onTrainBegin;
this.trainEnd = args.onTrainEnd;
this.epochBegin = args.onEpochBegin;
this.epochEnd = args.onEpochEnd;
this.batchBegin = args.onBatchBegin;
this.batchEnd = args.onBatchEnd;
this.yield = args.onYield;
}
async maybeWait(epoch, batch, logs) {
const ps = [];
if (this.yield != null) {
await resolveScalarsInLogs(logs);
ps.push(this.yield(epoch, batch, logs));
}
ps.push(this.nextFrameFunc());
await Promise.all(ps);
}
async onEpochBegin(epoch, logs) {
this.currentEpoch = epoch;
if (this.epochBegin != null) {
await resolveScalarsInLogs(logs);
await this.epochBegin(epoch, logs);
}
}
async onEpochEnd(epoch, logs) {
const ps = [];
if (this.epochEnd != null) {
await resolveScalarsInLogs(logs);
ps.push(this.epochEnd(epoch, logs));
}
if (this.yieldEvery === 'epoch') {
ps.push(this.nextFrameFunc());
}
await Promise.all(ps);
}
async onBatchBegin(batch, logs) {
if (this.batchBegin != null) {
await resolveScalarsInLogs(logs);
await this.batchBegin(batch, logs);
}
}
async onBatchEnd(batch, logs) {
const ps = [];
if (this.batchEnd != null) {
await resolveScalarsInLogs(logs);
ps.push(this.batchEnd(batch, logs));
}
if (this.yieldEvery === 'batch') {
ps.push(this.nextFrameFunc());
}
else if (isNumber(this.yieldEvery)) {
ps.push(this.maybeWait(this.currentEpoch, batch, logs));
}
await Promise.all(ps);
}
async onTrainBegin(logs) {
if (this.trainBegin != null) {
await resolveScalarsInLogs(logs);
await this.trainBegin(logs);
}
}
async onTrainEnd(logs) {
if (this.trainEnd != null) {
await resolveScalarsInLogs(logs);
await this.trainEnd(logs);
}
}
}
function standardizeCallbacks(callbacks, yieldEvery) {
if (callbacks == null) {
callbacks = {};
}
if (callbacks instanceof BaseCallback) {
return [callbacks];
}
if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
return callbacks;
}
const callbackConfigs = toList(callbacks);
return callbackConfigs.map(callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
}
class CallbackConstructorRegistry {
constructor() { }
static registerCallbackConstructor(verbosityLevel, callbackConstructor) {
assert$1(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), () => `Verbosity level is expected to be an integer >= 0, ` +
`but got ${verbosityLevel}`);
CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
CallbackConstructorRegistry.constructors[verbosityLevel] = [];
}
CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
}
static checkForDuplicate(callbackConstructor) {
for (const levelName in CallbackConstructorRegistry.constructors) {
const constructors = CallbackConstructorRegistry.constructors[+levelName];
constructors.forEach(ctor => {
if (ctor === callbackConstructor) {
throw new ValueError('Duplicate callback constructor.');
}
});
}
}
static clear() {
CallbackConstructorRegistry.constructors = {};
}
static createCallbacks(verbosityLevel) {
const constructors = [];
for (const levelName in CallbackConstructorRegistry.constructors) {
const level = +levelName;
if (verbosityLevel >= level) {
constructors.push(...CallbackConstructorRegistry.constructors[level]);
}
}
return constructors.map(ctor => new ctor());
}
}
CallbackConstructorRegistry.constructors = {};
function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
const history = new History();
const actualCallbacks = [
new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)
];
if (callbacks != null) {
actualCallbacks.push(...callbacks);
}
actualCallbacks.push(history);
const callbackList = new CallbackList(actualCallbacks);
callbackList.setParams({
epochs,
initialEpoch,
samples: numTrainSamples,
steps: stepsPerEpoch,
batchSize,
verbose,
doValidation,
metrics: callbackMetrics,
});
return { callbackList, history };
}
function deserialize(config, customObjects = {}, fastWeightInit = false) {
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
}
function l2Normalize(x, axis) {
return tidy(() => {
if (x.dtype !== 'float32') {
x = cast$3(x, 'float32');
}
const squareSum = sum$2(square(x), axis, true);
const epsilonTensor = fill$2(squareSum.shape, epsilon());
const norm = sqrt$2(maximum$2(squareSum, epsilonTensor));
return div$1(x, norm);
});
}
function meanSquaredError(yTrue, yPred) {
return tidy(() => mean$1(square(sub$2(yPred, yTrue)), -1));
}
function meanAbsoluteError(yTrue, yPred) {
return tidy(() => mean$1(abs$2(sub$2(yPred, yTrue)), -1));
}
function meanAbsolutePercentageError(yTrue, yPred) {
return tidy(() => {
const diff = sub$2(yTrue, yPred);
const clippedTrue = clipByValue$2(abs$2(yTrue), epsilon(), Number.MAX_VALUE);
const absResult = abs$2(div$1(diff, clippedTrue));
return mul(100, mean$1(absResult, -1));
});
}
function meanSquaredLogarithmicError(yTrue, yPred) {
return tidy(() => {
const clippedPred = clipByValue$2(yPred, epsilon(), Number.MAX_VALUE);
const firstLog = log$2(add$1(1, clippedPred));
const clippedTrue = clipByValue$2(yTrue, epsilon(), Number.MAX_VALUE);
const secondLog = log$2(add$1(1, clippedTrue));
return mean$1(square(sub$2(firstLog, secondLog)), -1);
});
}
function squaredHinge(yTrue, yPred) {
return tidy(() => {
const maxResult = maximum$2(0, sub$2(1, mul(yTrue, yPred)));
return mean$1(square(maxResult), -1);
});
}
function hinge(yTrue, yPred) {
return tidy(() => {
const maxResult = maximum$2(0, sub$2(1, mul(yTrue, yPred)));
return mean$1(maxResult, -1);
});
}
function categoricalHinge(yTrue, yPred) {
return tidy(() => {
const pos = sum$2(mul(yTrue, yPred), -1);
const neg = max$2(mul(sub$2(1, yTrue), yPred), -1);
return maximum$2(0, add$1(1, sub$2(neg, pos)));
});
}
function logcosh(yTrue, yPred) {
return tidy(() => {
const log2 = Math.log(2);
const predictionDiff = sub$2(yPred, yTrue);
const logcoshResult = sub$2(add$1(predictionDiff, softplus$2(mul(-2, predictionDiff))), log2);
return mean$1(logcoshResult, -1);
});
}
function categoricalCrossentropy$1(target, output, fromLogits = false) {
return tidy(() => {
if (fromLogits) {
output = softmax$2(output);
}
else {
const outputSum = sum$2(output, output.shape.length - 1, true);
output = div$1(output, outputSum);
}
output = clipByValue$2(output, epsilon(), 1 - epsilon());
return neg$2(sum$2(mul(cast$3(target, 'float32'), log$2(output)), output.shape.length - 1));
});
}
function sparseCategoricalCrossentropy$1(target, output, fromLogits = false) {
return tidy(() => {
const flatTarget = cast$3(floor$2(flatten(target)), 'int32');
output = clipByValue$2(output, epsilon(), 1 - epsilon());
const outputShape = output.shape;
const oneHotTarget = reshape$2(oneHot$2(flatTarget, outputShape[outputShape.length - 1]), outputShape);
return categoricalCrossentropy$1(oneHotTarget, output, fromLogits);
});
}
function sigmoidCrossEntropyWithLogits(labels, logits) {
if (!arraysEqual(labels.shape, logits.shape)) {
throw new ValueError(`logits and labels must have the same shape, but got shapes ` +
`${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`);
}
return tidy(() => {
const reluLogits = relu$2(logits);
const negAbsLogits = neg$2(abs$2(logits));
return add$1(sub$2(reluLogits, mul(logits, labels)), log1p$2(exp$2(negAbsLogits)));
});
}
function binaryCrossentropy$1(yTrue, yPred) {
return tidy(() => {
let y;
y = clipByValue$2(yPred, epsilon(), 1 - epsilon());
y = log$2(div$1(y, sub$2(1, y)));
return mean$1(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
});
}
function kullbackLeiblerDivergence(yTrue, yPred) {
return tidy(() => {
const clippedTrue = clipByValue$2(yTrue, epsilon(), 1);
const clippedPred = clipByValue$2(yPred, epsilon(), 1);
return sum$2(mul(yTrue, log$2(div$1(clippedTrue, clippedPred))), -1);
});
}
function poisson(yTrue, yPred) {
return tidy(() => {
const logPred = log$2(add$1(epsilon(), yPred));
return mean$1(sub$2(yPred, mul(yTrue, logPred)), -1);
});
}
function cosineProximity(yTrue, yPred) {
return tidy(() => {
const trueNormalized = l2Normalize(yTrue, -1);
const predNormalized = l2Normalize(yPred, -1);
const trueXPred = mul(trueNormalized, predNormalized);
return neg$2(sum$2(trueXPred, -1));
});
}
const lossesMap = {
meanSquaredError,
meanAbsoluteError,
meanAbsolutePercentageError,
meanSquaredLogarithmicError,
squaredHinge,
hinge,
categoricalHinge,
logcosh,
categoricalCrossentropy: categoricalCrossentropy$1,
sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
binaryCrossentropy: binaryCrossentropy$1,
kullbackLeiblerDivergence,
poisson,
cosineProximity
};
function get$1(identifierOrFn) {
if (typeof identifierOrFn === 'string') {
if (identifierOrFn in lossesMap) {
return lossesMap[identifierOrFn];
}
let errMsg = `Unknown loss ${identifierOrFn}`;
if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
errMsg = `Unknown loss ${identifierOrFn}. ` +
'Use "categoricalCrossentropy" as the string name for ' +
'tf.losses.softmaxCrossEntropy';
}
throw new ValueError(errMsg);
}
else {
return identifierOrFn;
}
}
function binaryAccuracy(yTrue, yPred) {
return tidy(() => {
const threshold = mul(.5, onesLike$2(yPred));
const yPredThresholded = cast(greater$2(yPred, threshold), yTrue.dtype);
return mean$1(equal$2(yTrue, yPredThresholded), -1);
});
}
function categoricalAccuracy(yTrue, yPred) {
return tidy(() => cast(equal$2(argMax$2(yTrue, -1), argMax$2(yPred, -1)), 'float32'));
}
function truePositives(yTrue, yPred) {
return tidy(() => {
return cast$3(sum$2(logicalAnd$2(equal$2(yTrue, 1), equal$2(yPred, 1))), 'float32');
});
}
function falsePositives(yTrue, yPred) {
return tidy(() => {
return cast$3(sum$2(logicalAnd$2(equal$2(yTrue, 0), equal$2(yPred, 1))), 'float32');
});
}
function precision(yTrue, yPred) {
return tidy(() => {
const tp = truePositives(yTrue, yPred);
const fp = falsePositives(yTrue, yPred);
const denominator = add$1(tp, fp);
return cast$3(where(greater$2(denominator, 0), div$1(tp, denominator), 0), 'float32');
});
}
function binaryCrossentropy(yTrue, yPred) {
return binaryCrossentropy$1(yTrue, yPred);
}
function sparseCategoricalAccuracy(yTrue, yPred) {
if (yTrue.rank === yPred.rank) {
yTrue = squeeze(yTrue, [yTrue.rank - 1]);
}
yPred = argMax$2(yPred, -1);
if (yPred.dtype !== yTrue.dtype) {
yPred = cast$3(yPred, yTrue.dtype);
}
return cast$3(equal$2(yTrue, yPred), 'float32');
}
const mse = meanSquaredError;
const MSE = meanSquaredError;
const mae = meanAbsoluteError;
const MAE = meanAbsoluteError;
const mape = meanAbsolutePercentageError;
const MAPE = meanAbsolutePercentageError;
const categoricalCrossentropy = categoricalCrossentropy$1;
const cosine = cosineProximity;
const sparseCategoricalCrossentropy = sparseCategoricalCrossentropy$1;
const metricsMap = {
binaryAccuracy,
categoricalAccuracy,
precision,
categoricalCrossentropy,
sparseCategoricalCrossentropy,
mse,
MSE,
mae,
MAE,
mape,
MAPE,
cosine
};
function get(identifier) {
if (typeof identifier === 'string' && identifier in metricsMap) {
return metricsMap[identifier];
}
else if (typeof identifier !== 'string' && identifier != null) {
return identifier;
}
else {
throw new ValueError(`Unknown metric ${identifier}`);
}
}
function getLossOrMetricName(fn) {
assert(fn !== null, `Unknown LossOrMetricFn ${fn}`);
if (typeof fn === 'string') {
return fn;
}
else {
let fnName;
for (const key of Object.keys(lossesMap)) {
if (lossesMap[key] === fn) {
fnName = key;
break;
}
}
if (fnName !== undefined) {
return fnName;
}
for (const key of Object.keys(metricsMap)) {
if (metricsMap[key] === fn) {
fnName = key;
break;
}
}
if (fnName !== undefined) {
return fnName;
}
return fn.name;
}
}
function getOptimizer(identifier) {
const optimizerMap = {
'Adagrad': () => train.adagrad(0.01),
'Adadelta': () => train.adadelta(1, 0.95, epsilon()),
'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()),
'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0),
'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()),
'SGD': () => train.sgd(0.01)
};
optimizerMap['adagrad'] = optimizerMap['Adagrad'];
optimizerMap['adadelta'] = optimizerMap['Adadelta'];
optimizerMap['adam'] = optimizerMap['Adam'];
optimizerMap['adamax'] = optimizerMap['Adamax'];
optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
optimizerMap['sgd'] = optimizerMap['SGD'];
if (identifier in optimizerMap) {
return optimizerMap[identifier]();
}
throw new ValueError(`Unknown Optimizer ${identifier}`);
}
const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize = false) {
if (userDefinedMetadata == null ||
typeof userDefinedMetadata !== 'object' ||
Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype ||
!plainObjectCheck(userDefinedMetadata)) {
throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
}
if (checkSize) {
const out = JSON.stringify(userDefinedMetadata);
if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
console.warn(`User-defined metadata of model "${modelName}" is too large in ` +
`size (length=${out.length} when serialized). It is not ` +
`recommended to store such large objects in user-defined metadata. ` +
`Please make sure its serialized length is <= ` +
`${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`);
}
}
}
function plainObjectCheck(x) {
if (x === null) {
return true;
}
else if (typeof x === 'object') {
if (Object.getPrototypeOf(x) === Object.prototype) {
const keys = Object.keys(x);
for (const key of keys) {
if (typeof key !== 'string') {
return false;
}
if (!plainObjectCheck(x[key])) {
return false;
}
}
return true;
}
else {
if (Array.isArray(x)) {
for (const item of x) {
if (!plainObjectCheck(item)) {
return false;
}
}
return true;
}
else {
return false;
}
}
}
else {
const xType = typeof x;
return xType === 'string' || xType === 'number' || xType === 'boolean';
}
}
function printSummary(model, lineLength, positions,
printFn = console.log) {
const sequentialLike = isModelSequentialLike(model);
const toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #'];
if (sequentialLike) {
lineLength = lineLength || 90;
positions = positions || [0.32, 0.61, 0.89, 1];
}
else {
lineLength = lineLength || 115;
positions = positions || [0.24, 0.48, 0.70, 0.80, 1];
}
if (positions[positions.length - 1] <= 1) {
positions = positions.map(p => Math.floor(lineLength * p));
}
let relevantNodes;
if (!sequentialLike) {
toDisplay.push('Receives inputs');
relevantNodes = [];
for (const depth in model.nodesByDepth) {
relevantNodes.push(...model.nodesByDepth[depth]);
}
}
printFn('_'.repeat(lineLength));
printRow(toDisplay, positions, printFn);
printFn('='.repeat(lineLength));
const layers = model.layers;
for (let i = 0; i < layers.length; ++i) {
if (sequentialLike) {
printLayerSummary(layers[i], positions, printFn);
}
else {
printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
}
printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
}
model.checkTrainableWeightsConsistency();
const trainableCount = countTrainableParams(model);
const nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
printFn(`Total params: ${trainableCount + nonTrainableCount}`);
printFn(`Trainable params: ${trainableCount}`);
printFn(`Non-trainable params: ${nonTrainableCount}`);
printFn('_'.repeat(lineLength));
}
function countTrainableParams(model) {
let trainableCount;
if (model.collectedTrainableWeights != null) {
trainableCount =
countParamsInWeights(model.collectedTrainableWeights);
}
else {
trainableCount = countParamsInWeights(model.trainableWeights);
}
return trainableCount;
}
function isModelSequentialLike(model) {
let sequentialLike = true;
const nodesByDepth = [];
const nodes = [];
for (const depth in model.nodesByDepth) {
nodesByDepth.push(model.nodesByDepth[depth]);
}
for (const depthNodes of nodesByDepth) {
if (depthNodes.length > 1 ||
depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
sequentialLike = false;
break;
}
nodes.push(...depthNodes);
}
if (sequentialLike) {
for (const layer of model.layers) {
let flag = false;
for (const node of layer.inboundNodes) {
if (nodes.indexOf(node) !== -1) {
if (flag) {
sequentialLike = false;
break;
}
else {
flag = true;
}
}
}
if (!sequentialLike) {
break;
}
}
}
return sequentialLike;
}
function printRow(fields, positions,
printFn = console.log) {
let line = '';
for (let i = 0; i < fields.length; ++i) {
if (i > 0) {
line = line.slice(0, line.length - 1) + ' ';
}
line += fields[i];
line = line.slice(0, positions[i]);
line += ' '.repeat(positions[i] - line.length);
}
printFn(line);
}
function printLayerSummary(layer, positions,
printFn) {
let outputShape;
let inputShape;
try {
inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
}
catch (err) {
inputShape = 'multiple';
}
try {
outputShape = JSON.stringify(layer.outputShape);
}
catch (err) {
outputShape = 'multiple';
}
const name = layer.name;
const className = layer.getClassName();
const fields = [`${name} (${className})`, inputShape,
outputShape, layer.countParams().toString()];
printRow(fields, positions, printFn);
}
function printLayerSummaryWithConnections(layer, positions, relevantNodes,
printFn) {
let outputShape;
let inputShape;
try {
inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
}
catch (err) {
inputShape = 'multiple';
}
try {
outputShape = JSON.stringify(layer.outputShape);
}
catch (err) {
outputShape = 'multiple';
}
const connections = [];
for (const node of layer.inboundNodes) {
if (relevantNodes != null && relevantNodes.length > 0 &&
relevantNodes.indexOf(node) === -1) {
continue;
}
for (let i = 0; i < node.inboundLayers.length; ++i) {
const inboundLayer = node.inboundLayers[i].name;
const inboundLayerIndex = node.nodeIndices[i];
const inboundTensorIndex = node.tensorIndices[i];
connections.push(`${inboundLayer}[${inboundLayerIndex}][${inboundTensorIndex}]`);
}
}
const name = layer.name;
const className = layer.getClassName();
const firstConnection = connections.length === 0 ? '' : connections[0];
const fields = [
`${name} (${className})`, inputShape,
outputShape, layer.countParams().toString(),
firstConnection
];
printRow(fields, positions, printFn);
for (let i = 1; i < connections.length; ++i) {
printRow(['', '', '', '', connections[i]], positions, printFn);
}
}
function isArrayItemInputOrOutputName(key, index, value) {
return (key === 'inboundNodes' || key === 'outputLayers' ||
key === 'inputLayers') &&
index === 0 && typeof value === 'string';
}
function convertPythonicToTs(pythonicConfig, key) {
if (pythonicConfig === null) {
return null;
}
else if (typeof pythonicConfig === 'string') {
return toCamelCase(pythonicConfig);
}
else if ((typeof pythonicConfig === 'number') ||
(typeof pythonicConfig === 'boolean')) {
return pythonicConfig;
}
else if (pythonicConfig instanceof Array) {
const tsArray = [];
const arrayLength = pythonicConfig.length;
for (let i = 0; i < arrayLength; ++i) {
const item = pythonicConfig[i];
if (isArrayItemInputOrOutputName(key, i, item)) {
tsArray.push(item);
}
else {
tsArray.push(convertPythonicToTs(item, key));
}
}
return tsArray;
}
else {
const tsDict = {};
for (const pythonicKey of Object.keys(pythonicConfig)) {
const pythonicValue = pythonicConfig[pythonicKey];
if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
tsDict[pythonicKey] = pythonicValue;
}
else {
const tsKey = toCamelCase(pythonicKey);
tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
}
}
return tsDict;
}
}
function convertTsToPythonic(tsConfig, key) {
if (tsConfig === null || tsConfig === undefined) {
return null;
}
else if (typeof tsConfig === 'string') {
return toSnakeCase(tsConfig);
}
else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) {
return tsConfig;
}
else if (tsConfig instanceof Array) {
const pyArray = [];
const arrayLength = tsConfig.length;
for (let i = 0; i < arrayLength; ++i) {
const item = tsConfig[i];
if (isArrayItemInputOrOutputName(key, i, item)) {
pyArray.push(item);
}
else {
pyArray.push(convertTsToPythonic(item, key));
}
}
return pyArray;
}
else {
const pyDict = {};
for (const tsKey of Object.keys(tsConfig)) {
const tsValue = tsConfig[tsKey];
const pyKey = toSnakeCase(tsKey);
if ((tsKey === 'name' || tsKey === 'className') &&
typeof tsValue === 'string') {
pyDict[pyKey] = tsValue;
}
else {
pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
}
}
return pyDict;
}
}
const version = '4.22.0';
const isKerasSavedModelFormat = (weights) => {
const keys = Object.keys(weights);
if (keys.length === 0) {
return false;
}
const key = keys[0].split('/');
return !isNaN(parseInt(key[key.length - 1], 10));
};
class Container extends Layer {
constructor(args) {
super({});
this.containerNodes = new Set();
this.name = args.name;
if (this.name == null) {
const prefix = this.getClassName().toLowerCase();
this.name = getUid(prefix);
}
this.supportsMasking = false;
this.trainable_ = true;
if (Array.isArray(args.inputs)) {
this.inputs = args.inputs.slice();
}
else {
this.inputs = [args.inputs];
}
if (Array.isArray(args.outputs)) {
this.outputs = args.outputs.slice();
}
else {
this.outputs = [args.outputs];
}
if (unique(this.inputs).length !== this.inputs.length) {
throw new ValueError('The list of inputs passed to the model is ' +
'redundant. All inputs should only appear once. Found: ' +
`${this.inputs.map(x => x.name)}`);
}
if (unique(this.outputs).length !== this.outputs.length) {
console.warn('The list of outputs passed to the model is redundant. ' +
'All outputs should only appear once. Found: ' +
`${this.outputs.map(x => x.name)}`);
}
this.inputLayers = [];
this.inputLayersNodeIndices = [];
this.inputLayersTensorIndices = [];
this.outputLayers = [];
this.outputLayersNodeIndices = [];
this.outputLayersTensorIndices = [];
this.layers = [];
this.internalContainerRefs = [];
for (const x of this.outputs) {
const layer = x.sourceLayer;
const nodeIndex = x.nodeIndex;
const tensorIndex = x.tensorIndex;
this.outputLayers.push(layer);
this.outputLayersNodeIndices.push(nodeIndex);
this.outputLayersTensorIndices.push(tensorIndex);
}
for (const x of this.inputs) {
const layer = x.sourceLayer;
const nodeIndex = x.nodeIndex;
const tensorIndex = x.tensorIndex;
assert(nodeIndex === 0, 'input layer has >1 nodes');
assert(tensorIndex === 0, 'input layer has >1 tensors');
this.inputLayers.push(layer);
this.inputLayersNodeIndices.push(nodeIndex);
this.inputLayersTensorIndices.push(tensorIndex);
}
this.inputNames = [];
this.outputNames = [];
this.feedInputShapes = [];
this.feedInputNames = [];
this.feedOutputNames = [];
for (let i = 0; i < this.inputLayers.length; i++) {
const layer = this.inputLayers[i];
if (!(layer instanceof InputLayer)) {
throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' +
`Received inputs: ${args.inputs}. ` +
`Input ${i} (0-based) originates ` +
`from layer type ${layer.getClassName()}.`);
}
this.inputNames.push(layer.name);
this.feedInputShapes.push(layer.batchInputShape);
this.feedInputNames.push(layer.name);
}
for (const layer of this.outputLayers) {
this.outputNames.push(layer.name);
}
this.internalInputShapes = this.inputs.map(x => x.shape);
this.internalOutputShapes = this.outputs.map(x => x.shape);
const nodesDepths = {};
const nodeIDToNode = {};
const layersDepths = {};
const layerIDToLayer = {};
const layerIndices = {};
const nodesInDecreasingDepth = [];
const buildMapOfGraph = (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) => {
if (layer == null || nodeIndex == null || tensorIndex == null) {
layer = tensor.sourceLayer;
nodeIndex = tensor.nodeIndex;
tensorIndex = tensor.tensorIndex;
}
const node = layer.inboundNodes[nodeIndex];
if (nodesInProgress.indexOf(node) !== -1) {
throw new RuntimeError(`The tensor ${tensor.name} at layer "${layer.name}" ` +
'is part of a cycle.');
}
if (finishedNodes.indexOf(node) !== -1) {
return;
}
this.containerNodes.add(Container.nodeKey(layer, nodeIndex));
if (!(layer.id in layerIndices)) {
layerIndices[layer.id] = Object.keys(layerIndices).length;
}
if (nodesInProgress.indexOf(node) === -1) {
nodesInProgress.push(node);
}
const numInboundLayers = node.inboundLayers.length;
for (let i = 0; i < numInboundLayers; i++) {
const x = node.inputTensors[i];
const layer = node.inboundLayers[i];
const nodeIndex = node.nodeIndices[i];
const tensorIndex = node.tensorIndices[i];
buildMapOfGraph(x, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex);
}
finishedNodes.push(node);
while (nodesInProgress.indexOf(node) >= 0) {
nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
}
nodesInDecreasingDepth.push(node);
};
const finishedNodes = [];
const nodesInProgress = [];
for (const x of this.outputs) {
buildMapOfGraph(x, finishedNodes, nodesInProgress);
}
const reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
for (const node of reversedNodesInDecreasingDepth) {
nodeIDToNode[node.id] = node;
if (!(node.id in nodesDepths)) {
nodesDepths[node.id] = 0;
}
let depth = nodesDepths[node.id];
const previousDepth = (layersDepths[node.outboundLayer.id] == null ?
0 :
layersDepths[node.outboundLayer.id]);
depth = Math.max(depth, previousDepth);
layersDepths[node.outboundLayer.id] = depth;
layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
nodesDepths[node.id] = depth;
for (let i = 0; i < node.inboundLayers.length; i++) {
const inboundLayer = node.inboundLayers[i];
const nodeIndex = node.nodeIndices[i];
const inboundNode = inboundLayer.inboundNodes[nodeIndex];
const previousDepth = (nodesDepths[inboundNode.id] == null ? 0 :
nodesDepths[inboundNode.id]);
nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth);
nodeIDToNode[inboundNode.id] = inboundNode;
}
}
const nodesByDepth = {};
for (const nodeID in nodesDepths) {
const depth = nodesDepths[nodeID];
if (!(depth in nodesByDepth)) {
nodesByDepth[depth] = [];
}
nodesByDepth[depth].push(nodeIDToNode[nodeID]);
}
const layersByDepth = {};
for (const layerID in layersDepths) {
const depth = layersDepths[layerID];
if (!(depth in layersByDepth)) {
layersByDepth[depth] = [];
}
layersByDepth[depth].push(layerIDToLayer[layerID]);
}
let depthKeys = Object.keys(layersByDepth)
.map(x => parseInt(x, 10))
.sort(reverseNumberCompare);
this.layers = [];
for (const depth of depthKeys) {
const layersForDepth = layersByDepth[depth];
layersForDepth.sort((a, b) => {
const aIndex = layerIndices[a.id];
const bIndex = layerIndices[b.id];
if (aIndex < bIndex) {
return -1;
}
if (aIndex > bIndex) {
return 1;
}
return 0;
});
for (const layer of layersForDepth) {
if (layer instanceof Container) {
this.internalContainerRefs.push(layer);
}
this.layers.push(layer);
}
}
this.layersByDepth = layersByDepth;
depthKeys = Object.keys(nodesByDepth)
.map(x => parseInt(x, 10))
.sort(reverseNumberCompare);
const computableTensors = this.inputs.slice();
const layersWithCompleteInput = [];
for (const depth of depthKeys) {
for (const node of nodesByDepth[depth]) {
const layer = node.outboundLayer;
if (layer != null) {
for (const x of node.inputTensors) {
if (computableTensors.indexOf(x) === -1) {
throw new RuntimeError(`Graph disconnected: cannot obtain value for tensor ${x}` +
` at layer "${layer.name}". ` +
'The following previous layers were accessed without ' +
`issue: ${layersWithCompleteInput}`);
}
}
for (const x of node.outputTensors) {
computableTensors.push(x);
}
layersWithCompleteInput.push(layer.name);
}
}
}
this.nodesByDepth = nodesByDepth;
const allNames = this.layers.map(x => x.name);
for (const name of allNames) {
const numOccurrences = allNames.filter(x => x === name).length;
if (numOccurrences !== 1) {
throw new RuntimeError(`The name "${name}" is used ${numOccurrences} times ` +
'in the model. All layer names should be unique. Layer names: ' +
JSON.stringify(allNames));
}
}
this.outboundNodes = [];
this.inboundNodes = [];
new Node({
outboundLayer: this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: this.inputs,
outputTensors: this.outputs,
inputMasks: this.inputs.map(x => null),
outputMasks: this.outputs.map(x => null),
inputShapes: this.inputs.map(x => x.shape),
outputShapes: this.outputs.map(x => x.shape)
});
this.built = true;
this._refCount = 1;
}
assertNotDisposed() {
if (this._refCount === 0) {
throw new Error(`Container '${this.name}' is already disposed.`);
}
}
dispose() {
this.assertNotDisposed();
const result = { refCountAfterDispose: null, numDisposedVariables: 0 };
if (--this._refCount === 0) {
for (const layer of this.layers) {
result.numDisposedVariables += layer.dispose().numDisposedVariables;
}
for (const container of this.internalContainerRefs) {
result.numDisposedVariables += container.dispose().numDisposedVariables;
}
}
result.refCountAfterDispose = this._refCount;
return result;
}
get trainable() {
return this.trainable_;
}
set trainable(trainable) {
this.layers.forEach(layer => {
layer._trainableWeights
.forEach(w => w.trainable = trainable);
});
this.trainable_ = trainable;
}
get trainableWeights() {
if (this._trainableWeights.length > 0) {
throw new ValueError('Container instance unexpectedly contains _trainableWeights.' +
'The trainable weights of a Container are a union of the ' +
'trainable weights of its consituent Layers. Its own ' +
'_trainableWeights must remain an empty Array.');
}
if (!this.trainable) {
return [];
}
let weights = [];
for (const layer of this.layers) {
weights = weights.concat(layer.trainableWeights);
}
return weights;
}
get nonTrainableWeights() {
const weights = [];
for (const layer of this.layers) {
weights.push(...layer.nonTrainableWeights);
}
if (!this.trainable) {
const trainableWeights = [];
for (const layer of this.layers) {
trainableWeights.push(...layer.trainableWeights);
}
return trainableWeights.concat(weights);
}
return weights;
}
get weights() {
return this.trainableWeights.concat(this.nonTrainableWeights);
}
loadWeights(weights, strict = true) {
const nameToWeight = {};
let totalWeightsCount = 0;
const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights);
if (modelIsKerasSavedModelFormat) {
this.parseWeights(weights);
}
for (const layer of this.layers) {
for (const [index, weight] of layer.weights.entries()) {
const parsedName = modelIsKerasSavedModelFormat ?
`${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` :
weight.originalName;
if (nameToWeight[parsedName] != null) {
throw new ValueError(`Duplicate weight name: ${parsedName}`);
}
nameToWeight[parsedName] = weight;
totalWeightsCount++;
}
}
const weightValueTuples = [];
for (const name in weights) {
let validatedName = name;
if (nameToWeight[name] == null) {
const tokens = name.split('/');
const shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
validatedName = shortenNameArray.join('/');
}
if (nameToWeight[validatedName] != null) {
weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
}
else if (strict) {
throw new ValueError(`Provided weight data has no target variable: ${name}`);
}
delete nameToWeight[validatedName];
}
if (strict) {
const unsetNames = [];
for (const name in nameToWeight) {
unsetNames.push(name);
}
if (unsetNames.length > 0) {
throw new ValueError(`${unsetNames.length} of ${totalWeightsCount} weights are not set: ` +
`${unsetNames}`);
}
}
batchSetValue(weightValueTuples);
}
parseWeights(weights) {
for (const key in Object.keys(weights)) {
const listParts = key.split('/');
const list = ['vars', 'layer_checkpoint_dependencies'];
const newKey = listParts
.map(str => {
if (str.startsWith('_')) {
return str.slice(1);
}
return str;
})
.filter(str => !list.includes(str))
.join('/');
if (newKey !== key) {
weights[newKey] = weights[key];
delete weights[key];
}
}
}
updatedConfig() {
const theConfig = this.getConfig();
const modelConfig = {};
modelConfig['className'] = this.getClassName();
modelConfig['config'] = theConfig;
modelConfig['kerasVersion'] = `tfjs-layers ${version}`;
modelConfig['backend'] = 'TensorFlow.js';
return modelConfig;
}
toJSON(unused, returnString = true) {
const modelConfig = convertTsToPythonic(this.updatedConfig());
return returnString ? JSON.stringify(modelConfig) : modelConfig;
}
call(inputs, kwargs) {
return tidy(() => {
inputs = toList(inputs);
const feedDict = new FeedDict();
for (let i = 0; i < this.inputs.length; ++i) {
feedDict.add(this.inputs[i], inputs[i]);
}
return execute(this.outputs, feedDict, kwargs);
});
}
computeMask(inputs, mask) {
return tidy(() => {
inputs = toList(inputs);
let masks;
if (mask == null) {
masks = pyListRepeat(null, inputs.length);
}
else {
masks = toList(mask);
}
return this.runInternalGraph(inputs, masks)[1];
});
}
computeOutputShape(inputShape) {
const inputShapes = normalizeShapeList(inputShape);
if (inputShapes.length !== this.inputLayers.length) {
throw new ValueError(`Invalid inputShape argument ${inputShape}: ` +
`model has ${this.inputLayers.length} tensor inputs.`);
}
const layersToOutputShapes = {};
for (let i = 0; i < inputShapes.length; i++) {
const layer = this.inputLayers[i];
const inputShape = inputShapes[i];
const shapeKey = layer.name + '_0_0';
layersToOutputShapes[shapeKey] = inputShape;
}
const depthKeys = Object.keys(this.nodesByDepth)
.map(x => parseInt(x, 10))
.sort(reverseNumberCompare);
if (depthKeys.length > 1) {
for (const depth of depthKeys) {
const nodes = this.nodesByDepth[depth];
for (const node of nodes) {
const layer = node.outboundLayer;
if (this.inputLayers.map(x => x.id).indexOf(layer.id) !== -1) {
continue;
}
const inputShapes = [];
for (let j = 0; j < node.inboundLayers.length; j++) {
const inboundLayer = node.inboundLayers[j];
const nodeIndex = node.nodeIndices[j];
const tensorIndex = node.tensorIndices[j];
const shapeKey = `${inboundLayer.name}_${nodeIndex}_${tensorIndex}`;
const inputShape = layersToOutputShapes[shapeKey];
inputShapes.push(inputShape);
}
const outputShape = layer.computeOutputShape(singletonOrArray(inputShapes));
const outputShapes = normalizeShapeList(outputShape);
const nodeIndex = layer.inboundNodes.indexOf(node);
for (let j = 0; j < outputShapes.length; j++) {
const shapeKey = `${layer.name}_${nodeIndex}_${j}`;
layersToOutputShapes[shapeKey] = outputShapes[j];
}
}
}
}
const outputShapes = [];
const outputShapeKeys = [];
for (let i = 0; i < this.outputLayers.length; i++) {
const layer = this.outputLayers[i];
const nodeIndex = this.outputLayersNodeIndices[i];
const tensorIndex = this.outputLayersTensorIndices[i];
const shapeKey = `${layer.name}_${nodeIndex}_${tensorIndex}`;
outputShapeKeys.push(shapeKey);
}
for (let i = 0; i < outputShapeKeys.length; i++) {
const key = outputShapeKeys[i];
assert(key in layersToOutputShapes);
outputShapes.push(layersToOutputShapes[key]);
}
return singletonOrArray(outputShapes);
}
runInternalGraph(inputs, masks) {
if (masks == null) {
masks = pyListRepeat(null, inputs.length);
}
const tensorMap = {};
for (let i = 0; i < this.inputs.length; ++i) {
const x = this.inputs[i];
const y = inputs[i];
const mask = masks[i];
tensorMap[x.id] = [y, mask];
}
const depthKeys = Object.keys(this.nodesByDepth)
.map(x => parseInt(x, 10))
.sort(reverseNumberCompare);
for (const depth of depthKeys) {
const nodes = this.nodesByDepth[depth];
for (const node of nodes) {
const layer = node.outboundLayer;
const referenceInputTensors = node.inputTensors;
const referenceOutputTensors = node.outputTensors;
const computedData = new Array();
for (const x of referenceInputTensors) {
if (x.id in tensorMap) {
computedData.push(tensorMap[x.id]);
}
}
if (computedData.length === referenceInputTensors.length) {
let kwargs = {};
let computedTensors;
let computedMasks;
let outputTensors;
let outputMasks;
if (node.callArgs != null) {
kwargs = node.callArgs;
}
if (computedData.length === 1) {
const [computedTensor, computedMask] = computedData[0];
if (kwargs['mask'] == null) {
kwargs['mask'] = computedMask;
}
outputTensors =
toList(layer.call(computedTensor, kwargs));
outputMasks = toList(layer.computeMask(computedTensor, computedMask));
computedTensors = [computedTensor];
computedMasks = [computedMask];
}
else {
computedTensors = computedData.map(x => x[0]);
computedMasks = computedData.map(x => x[1]);
if (kwargs['mask'] == null) {
kwargs['mask'] = computedMasks;
}
outputTensors =
toList(layer.call(computedTensors, kwargs));
outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
}
if (layer.activityRegularizer) {
throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' +
'presence of activity regularizer(s) is not supported yet.');
}
for (let i = 0; i < referenceOutputTensors.length; ++i) {
const x = referenceOutputTensors[i];
const y = outputTensors[i];
const mask = outputMasks[i];
tensorMap[x.id] = [y, mask];
}
}
}
}
const outputTensors = [];
const outputMasks = [];
const outputShapes = [];
for (const x of this.outputs) {
assert(x.id in tensorMap, `Could not compute output ${x.name} : ${x.id}`);
const [tensor, mask] = tensorMap[x.id];
outputShapes.push(tensor.shape);
outputTensors.push(tensor);
outputMasks.push(mask);
}
return [outputTensors, outputMasks, outputShapes];
}
buildNodeConversionMap(layers) {
const nodeConversionMap = {};
let keptNodes;
for (const layer of this.layers) {
keptNodes = layer instanceof Container ? 1 : 0;
for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
const nodeKey = Container.nodeKey(layer, originalNodeIndex);
if (this.containerNodes.has(nodeKey)) {
nodeConversionMap[nodeKey] = keptNodes;
keptNodes += 1;
}
}
}
return nodeConversionMap;
}
getLayer(nameOrIndex, index) {
if (index != null) {
return this.findLayer(index);
}
else {
if (nameOrIndex == null) {
throw new ValueError('Provide either a layer name or layer index');
}
if (typeof nameOrIndex === 'number') {
return this.findLayer(nameOrIndex);
}
}
for (const layer of this.layers) {
if (layer.name === nameOrIndex) {
return layer;
}
}
throw new ValueError(`No such layer: ${nameOrIndex}`);
}
findLayer(index) {
if (this.layers.length <= index) {
throw new ValueError(`Was asked to retrieve layer at index ${index}, but model only ` +
`has ${this.layers.length} layer(s).`);
}
else {
return this.layers[index];
}
}
calculateLosses() {
return tidy(() => {
const losses = [];
for (const layer of this.layers) {
for (let nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
const nodeKey = Container.nodeKey(layer, nodeIndex);
if (this.containerNodes.has(nodeKey)) {
losses.push(...layer.calculateLosses());
}
}
}
return losses;
});
}
getConfig() {
const config = { name: this.name };
const nodeConversionMap = this.buildNodeConversionMap(this.layers);
const layerConfigs = [];
for (const layer of this.layers) {
const layerClassName = layer.getClassName();
const layerConfig = layer.getConfig();
const filteredInboundNodes = [];
for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
const node = layer.inboundNodes[originalNodeIndex];
const nodeKey = Container.nodeKey(layer, originalNodeIndex);
let kwargs = {};
if (this.containerNodes.has(nodeKey)) {
if (node.callArgs) {
try {
JSON.stringify(node.callArgs);
kwargs = node.callArgs;
}
catch (err) {
console.warn(`Layer ${layer.name} was passed ` +
`non-serializable keyword arguments: ` +
`${node.callArgs}. They will not be included ` +
`in the serialized model (and thus will be ` +
`missing at deserialization time).`);
kwargs = {};
}
}
if (node.inboundLayers.length > 0) {
const nodeData = [];
for (let i = 0; i < node.inboundLayers.length; i++) {
const inboundLayer = node.inboundLayers[i];
const nodeIndex = node.nodeIndices[i];
const tensorIndex = node.tensorIndices[i];
const nodeKey = Container.nodeKey(inboundLayer, nodeIndex);
let newNodeIndex = nodeConversionMap[nodeKey];
if (newNodeIndex == null) {
newNodeIndex = 0;
}
nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]);
}
filteredInboundNodes.push(nodeData);
}
}
}
const dict = {};
dict['name'] = layer.name;
dict['className'] = layerClassName;
dict['config'] = layerConfig;
dict['inboundNodes'] = filteredInboundNodes;
layerConfigs.push(dict);
}
config['layers'] = layerConfigs;
const modelInputs = [];
for (let i = 0; i < this.inputLayers.length; i++) {
const layer = this.inputLayers[i];
const nodeIndex = this.inputLayersNodeIndices[i];
const nodeKey = Container.nodeKey(layer, nodeIndex);
if (!this.containerNodes.has(nodeKey)) {
continue;
}
let newNodeIndex = nodeConversionMap[nodeKey];
if (newNodeIndex === null || newNodeIndex === undefined) {
newNodeIndex = 0;
}
const tensorIndex = this.inputLayersTensorIndices[i];
modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
}
config['inputLayers'] = modelInputs;
const modelOutputs = [];
for (let i = 0; i < this.outputLayers.length; i++) {
const layer = this.outputLayers[i];
const nodeIndex = this.outputLayersNodeIndices[i];
const nodeKey = Container.nodeKey(layer, nodeIndex);
if (!this.containerNodes.has(nodeKey)) {
continue;
}
let newNodeIndex = nodeConversionMap[nodeKey];
if (newNodeIndex === null || newNodeIndex === undefined) {
newNodeIndex = 0;
}
const tensorIndex = this.outputLayersTensorIndices[i];
modelOutputs.push([layer.name, newNodeIndex, tensorIndex]);
}
config['outputLayers'] = modelOutputs;
return config;
}
static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
const createdLayers = {};
const unprocessedNodes = {};
function addUnprocessedNode(layer, nodeData) {
if (!(layer.name in unprocessedNodes)) {
unprocessedNodes[layer.name] = [nodeData];
}
else {
unprocessedNodes[layer.name].push(nodeData);
}
}
function processNode(layer, nodeData) {
const inputTensors = [];
let kwargs;
for (const inputData of nodeData) {
const inboundLayerName = inputData[0];
const inboundNodeIndex = inputData[1];
const inboundTensorIndex = inputData[2];
kwargs = inputData[3] == null ?
{} :
inputData[3];
if (!(inboundLayerName in createdLayers)) {
addUnprocessedNode(layer, nodeData);
return;
}
const inboundLayer = createdLayers[inboundLayerName];
if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
addUnprocessedNode(layer, nodeData);
return;
}
const inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
}
if (inputTensors.length > 0) {
layer.apply(singletonOrArray(inputTensors), kwargs);
}
}
function processLayer(layerData) {
const layerName = layerData['name'];
const layer = deserialize(layerData, config['customObjects'] != null ?
config['customObjects'] :
{});
layer.setFastWeightInitDuringBuild(fastWeightInit);
createdLayers[layerName] = layer;
const inboundNodesData = layerData['inboundNodes'];
inboundNodesData.forEach(nodeData => {
if (!(nodeData instanceof Array)) {
throw new ValueError(`Corrupted configuration, expected array for nodeData: ${nodeData}`);
}
addUnprocessedNode(layer, nodeData);
});
}
const name = config['name'];
const layersFromConfig = config['layers'];
for (const layerData of layersFromConfig) {
processLayer(layerData);
}
while (!isObjectEmpty(unprocessedNodes)) {
for (const layerData of layersFromConfig) {
const layer = createdLayers[layerData['name']];
if (layer.name in unprocessedNodes) {
const currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
delete unprocessedNodes[layer.name];
for (const nodeData of currentUnprocessedNodesForLayer) {
processNode(layer, nodeData);
}
}
}
}
const inputTensors = [];
const outputTensors = [];
const inputLayersFromConfig = config['inputLayers'];
for (const layerData of inputLayersFromConfig) {
const layerName = layerData[0];
const nodeIndex = layerData[1];
const tensorIndex = layerData[2];
assert(layerName in createdLayers);
const layer = createdLayers[layerName];
const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
inputTensors.push(layerOutputTensors[tensorIndex]);
}
const outputLayersFromConfig = config['outputLayers'];
for (const layerData of outputLayersFromConfig) {
const layerName = layerData[0];
const nodeIndex = layerData[1];
const tensorIndex = layerData[2];
assert(layerName in createdLayers);
const layer = createdLayers[layerName];
const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
outputTensors.push(layerOutputTensors[tensorIndex]);
}
return new cls({ inputs: inputTensors, outputs: outputTensors, name });
}
get stateful() {
if (this._stateful) {
throw new ValueError('Container instance unexpectedly has _stateful = true. The ' +
'statefulness of a Container is determined by the Layers it ' +
'contains. Its _stateful property must remain the default false.');
}
for (const layer of this.layers) {
if (layer.stateful) {
return true;
}
}
return false;
}
resetStates() {
tidy(() => {
this.layers.forEach(layer => {
if (layer.stateful) {
layer.resetStates();
}
});
});
}
}
function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
const numOutputs = outputNames.length;
if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {
return outputNames.map(name => null);
}
if (numOutputs === 1) {
if (Array.isArray(xWeight) && xWeight.length === 1) {
return xWeight;
}
else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
return [xWeight[outputNames[0]]];
}
else {
return [xWeight];
}
}
if (Array.isArray(xWeight)) {
if (xWeight.length !== numOutputs) {
throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` +
`element(s), but the model has ${numOutputs} outputs. ` +
`Make sure a set of weights is provided for each model output.`);
}
return xWeight;
}
else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&
typeof xWeight[Object.keys(xWeight)[0]] ===
'object') {
const output = [];
outputNames.forEach(outputName => {
if (outputName in xWeight) {
output.push(xWeight[outputName]);
}
else {
output.push(null);
}
});
return output;
}
else {
throw new Error(`The model has multiple (${numOutputs}) outputs, ` +
`so ${weightType} must be either an array with ` +
`${numOutputs} elements or an object with ${outputNames} keys. ` +
`Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);
}
}
function standardizeClassWeights(classWeight, outputNames) {
return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
}
async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
if (classWeight != null) {
const yClasses = tidy(() => {
if (y.shape.length === 1) {
return clone(y);
}
else if (y.shape.length === 2) {
if (y.shape[1] > 1) {
const axis = 1;
return argMax$2(y, axis);
}
else if (y.shape[1] === 1) {
return reshape$2(y, [y.shape[0]]);
}
else {
throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` +
`during handling of class weights. The size is expected to be ` +
`>= 1.`);
}
}
else {
throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` +
`handling of class weights. The rank is expected to be 1 or 2.`);
}
});
const yClassIndices = Array.from(await yClasses.data());
dispose(yClasses);
const classSampleWeight = [];
yClassIndices.forEach(classIndex => {
if (classWeight[classIndex] == null) {
throw new Error(`classWeight must contain all classes in the training data. ` +
`The class ${classIndex} exists in the data but not in ` +
`classWeight`);
}
else {
classSampleWeight.push(classWeight[classIndex]);
}
});
return tensor1d(classSampleWeight, 'float32');
}
else {
return null;
}
}
function computeWeightedLoss(losses, sampleWeights) {
return mul(losses, sampleWeights);
}
const DEFAULT_VALIDATION_BATCH_SIZE = 32;
function standardizeDataIteratorOutput(
model, iteratorOut) {
let xs;
let ys;
const iteratorOutObj = iteratorOut;
xs = iteratorOutObj['xs'];
ys = iteratorOutObj['ys'];
assert$1(xs != null && ys != null, () => 'A Dataset iterator for fitDataset() is expected to generate ' +
'objects of the form `{xs: xVal, ys: yVal}`, where the two ' +
'values may be `tf.Tensor`, an array of Tensors, or a map of ' +
'string to Tensor. The provided Dataset instead generates ' +
`${iteratorOut}`);
const flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
const flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
const batchSize = flattenedXs[0].shape[0];
assert$1(flattenedXs.length === model.inputs.length, () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` +
`provides ${flattenedXs.length} inputs. (Expected input keys: ` +
`${JSON.stringify(model.inputNames)})`);
assert$1(flattenedYs.length === model.outputs.length, () => `LayersModel has ${model.outputs.length} outputs, but the dataset ` +
`provides ${flattenedYs.length} outputs. (Expected output keys: ` +
`${JSON.stringify(model.outputNames)})`);
for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
assert$1(flattenedXs[xIndex].shape[0] === batchSize, () => `Batch size mismatch: input ` +
`${model.inputNames[xIndex]} has ${flattenedXs[xIndex].shape[0]}; ` +
`expected ${batchSize} based on input ${model.inputNames[0]}.`);
}
for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
assert$1(flattenedYs[yIndex].shape[0] === batchSize, () => `Batch size mismatch: output ` +
`${model.outputNames[yIndex]} has ${flattenedYs[yIndex].shape[0]}; ` +
`expected ${batchSize} based on input ${model.inputNames[0]}.`);
}
return { xs: flattenedXs, ys: flattenedYs };
}
function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
if (values instanceof Tensor) {
return [values];
}
else if (Array.isArray(values)) {
assert$1(values.length === names.length, () => `Received an array of ${values.length} Tensors, but expected ${names.length} to match the ${inputOrOutput} keys ${names}.`);
return values;
}
else {
const result = [];
for (const name of names) {
if (values[name] == null) {
throw new ValueError(`The feature data generated by the dataset lacks the required ` +
`${inputOrOutput} key '${name}'.`);
}
result.push(values[name]);
}
return result;
}
}
function standardizeTensorValidationData(data) {
if (data.length === 3) {
throw new NotImplementedError('Validation with sample weights is not implemented yet.');
}
return { xs: data[0], ys: data[1] };
}
async function fitDataset(
model, dataset, args) {
const hasBatchesPerEpoch = args.batchesPerEpoch != null;
assert$1(model.optimizer != null, () => 'You must compile a model before training/testing. Use ' +
'LayersModel.compile(modelCompileConfig).');
assert$1(args != null, () => `For fitDataset(), the 2nd argument (config) is required, ` +
`but it is not provided in this call.`);
assert$1(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), () => `For fitDataset(), config.epochs is expected to be a positive ` +
`integer, but got ${args.epochs}`);
assert$1(!hasBatchesPerEpoch ||
(args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
`positive integer if specified, but got ${args.batchesPerEpoch}`);
assert$1(
args['validationSplit'] == null, () => '`validationSplit` is not supported by `fitDataset()`. ' +
'Use validationData instead.');
if (model.isTraining) {
throw new Error('Cannot start training because another fit() call is ongoing.');
}
model.isTraining = true;
try {
const doValidation = args.validationData != null;
let valXs;
let valYs;
if (doValidation) {
if (isDatasetObject(args.validationData)) {
assert$1(args.validationBatches == null ||
(args.validationBatches > 0 &&
Number.isInteger(args.validationBatches)), () => `For fitDataset() with dataset-based validation, ` +
`config.validationBatches is expected not to be provided, ` +
`or to be a positive integer, ` +
`but got ${args.validationBatches}`);
}
else {
const validationData = standardizeTensorValidationData(args.validationData);
valXs = validationData.xs;
valYs = validationData.ys;
}
}
const trainFunction = model.makeTrainFunction();
const outLabels = model.getDedupedMetricsNames();
let callbackMetrics;
if (doValidation) {
callbackMetrics =
outLabels.slice().concat(outLabels.map(n => 'val_' + n));
}
else {
callbackMetrics = outLabels.slice();
}
const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
const verbose = args.verbose == null ? 1 : args.verbose;
const { callbackList, history } = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null,
doValidation, callbackMetrics);
callbackList.setModel(model);
model.history = history;
await callbackList.onTrainBegin();
model.stopTraining_ = false;
let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
let dataIterator = await dataset.iterator();
while (epoch < args.epochs) {
const epochLogs = {};
await callbackList.onEpochBegin(epoch);
let stepsDone = 0;
let batchIndex = 0;
if (!hasBatchesPerEpoch) {
dataIterator = await dataset.iterator();
}
while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
const iteratorOut = await dataIterator.next();
if (hasBatchesPerEpoch && iteratorOut.done) {
console.warn('You provided `batchesPerEpoch` as ' +
`${args.batchesPerEpoch}, ` +
'but your dataset iterator ran out of data after ' +
`${stepsDone} batches; ` +
'interrupting training. Make sure that your ' +
'dataset can generate at least `batchesPerEpoch * epochs` ' +
'batches (in this case, ' +
`${args.batchesPerEpoch * args.epochs} batches). ` +
'You may need to use the repeat() function when building ' +
'your dataset.');
break;
}
if (iteratorOut.value != null) {
const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
const batchLogs = {};
batchLogs['batch'] = batchIndex;
batchLogs['size'] = xs[0].shape[0];
await callbackList.onBatchBegin(batchIndex, batchLogs);
const sampleWeights = [];
if (args.classWeight != null) {
const standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
for (let i = 0; i < standardClassWeights.length; ++i) {
sampleWeights.push(await standardizeWeights(ys[i], null, standardClassWeights[i]));
}
}
const ins = xs.concat(ys).concat(sampleWeights);
const outs = trainFunction(ins);
dispose(ins);
for (let i = 0; i < outLabels.length; ++i) {
const label = outLabels[i];
const out = outs[i];
batchLogs[label] = out;
keep(out);
}
await callbackList.onBatchEnd(batchIndex, batchLogs);
disposeTensorsInLogs(batchLogs);
batchIndex++;
stepsDone++;
}
if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
iteratorOut.done) {
if (doValidation) {
let valOuts;
if (isDatasetObject(args.validationData)) {
valOuts = toList(await model.evaluateDataset(args.validationData, { batches: args.validationBatches }));
}
else {
valOuts = toList(model.evaluate(valXs, valYs, {
batchSize: args.validationBatchSize == null ?
DEFAULT_VALIDATION_BATCH_SIZE :
args.validationBatchSize,
verbose: 0
}));
}
for (let i = 0; i < model.metricsNames.length; ++i) {
epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
}
}
break;
}
if (model.stopTraining_) {
break;
}
}
await callbackList.onEpochEnd(epoch, epochLogs);
epoch++;
if (model.stopTraining_) {
break;
}
}
await callbackList.onTrainEnd();
await model.history.syncData();
return model.history;
}
finally {
model.isTraining = false;
}
}
function getStepsPerEpoch(dataset, args) {
let stepsPerEpoch = null;
if (args.batchesPerEpoch != null) {
stepsPerEpoch = args.batchesPerEpoch;
}
else if (Number.isFinite(dataset.size)) {
stepsPerEpoch = dataset.size;
}
return stepsPerEpoch;
}
function isDatasetObject(dataset) {
return (typeof dataset.iterator === 'function');
}
function isLazyIteratorObject(iterator) {
return (typeof iterator.next === 'function');
}
async function evaluateDataset(
model, dataset, args) {
args = args || {};
const hasBatches = args.batches != null;
const f = model.testFunction;
let outs = [];
if (args.verbose > 0) {
throw new NotImplementedError('Verbose mode is not implemented yet.');
}
assert$1(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), () => 'Test loop expects `batches` to be a positive integer, but ' +
`received ${JSON.stringify(args.batches)}`);
const dataIterator = isLazyIteratorObject(dataset) ?
dataset :
await dataset.iterator();
let numExamples = 0;
let batch = 0;
while (hasBatches ? batch < args.batches : true) {
const iteratorOut = await dataIterator.next();
outs = tidy(() => {
if (iteratorOut.value) {
const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
const xsAndYs = xs.concat(ys);
const batchOuts = tidy(() => f(xsAndYs));
dispose(xsAndYs);
if (batch === 0) {
for (let i = 0; i < batchOuts.length; ++i) {
outs.push(scalar(0));
}
}
const batchSize = xsAndYs[0].shape[0];
for (let i = 0; i < batchOuts.length; ++i) {
const batchOut = batchOuts[i];
const oldScalar = outs[i];
outs[i] =
tidy(() => add$1(outs[i], mul(batchSize, batchOut)));
if (batch > 0) {
dispose(oldScalar);
}
}
dispose(batchOuts);
numExamples += batchSize;
++batch;
}
return outs;
});
if (iteratorOut.done) {
if (hasBatches) {
console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' +
'Interrupting evalution. Make sure that your ' +
'dataset can generate at least `batches` ' +
`batches (in this case, ${args.batches} batches). ` +
'You may need to use the repeat() function when building ' +
'your dataset.');
}
break;
}
}
for (let i = 0; i < outs.length; ++i) {
const oldScalar = outs[i];
outs[i] = div$1(outs[i], numExamples);
dispose(oldScalar);
}
return singletonOrArray(outs);
}
function checkBatchSize(batchSize) {
assert$1(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`);
}
function sliceArrays(arrays, start, stop) {
if (arrays == null) {
return [null];
}
else if (Array.isArray(arrays)) {
return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start));
}
else {
return sliceAlongFirstAxis(arrays, start, stop - start);
}
}
function sliceArraysByIndices(arrays, indices) {
return tidy(() => {
if (arrays == null) {
return null;
}
else if (Array.isArray(arrays)) {
return arrays.map(array => sliceArraysByIndices(array, indices));
}
else {
return gather(arrays, indices.dtype === 'int32' ? indices : cast$3(indices, 'int32'));
}
});
}
function makeBatches(size, batchSize) {
const output = [];
let batchStart = 0;
let batchEnd = null;
while (batchStart < size) {
batchEnd = batchStart + batchSize;
if (batchEnd >= size) {
batchEnd = size;
}
output.push([batchStart, batchEnd]);
batchStart = batchEnd;
}
return output;
}
function ensureTensorsRank2OrHigher(tensors) {
const outs = [];
if (tensors instanceof Tensor) {
tensors = [tensors];
}
for (let i = 0; i < tensors.length; ++i) {
const tensor = tensors[i];
if (tensor.rank === 1) {
outs.push(expandDims(tensor, 1));
}
else if (tensor.rank === 0) {
throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' +
'(scalar).');
}
else {
outs.push(tensor);
}
}
return outs;
}
function disposeNewTensors(tensors, refTensors) {
if (tensors == null) {
return;
}
const oldTensorIds = [];
if (refTensors instanceof Tensor) {
oldTensorIds.push(refTensors.id);
}
else if (Array.isArray(refTensors)) {
refTensors.forEach(t => oldTensorIds.push(t.id));
}
else if (refTensors != null) {
for (const name in refTensors) {
const oldTensor = refTensors[name];
oldTensorIds.push(oldTensor.id);
}
}
const tensorsToDispose = [];
if (tensors instanceof Tensor) {
if (oldTensorIds.indexOf(tensors.id) === -1) {
tensorsToDispose.push(tensors);
}
}
else if (Array.isArray(tensors)) {
tensors.forEach(t => {
if (oldTensorIds.indexOf(t.id) === -1) {
tensorsToDispose.push(t);
}
});
}
else if (tensors != null) {
for (const name in tensors) {
const tensor = tensors[name];
if (oldTensorIds.indexOf(tensor.id) === -1) {
tensorsToDispose.push(tensor);
}
}
}
tensorsToDispose.forEach(t => {
if (!t.isDisposed) {
t.dispose();
}
});
}
function isDataTensor(x) {
return x instanceof Tensor;
}
function isDataArray(x) {
return Array.isArray(x);
}
function isDataDict(x) {
return !isDataTensor(x) && !isDataArray(x);
}
function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
if (names == null || names.length === 0) {
if (data != null) {
let gotUnexpectedData = false;
if (isDataArray(data) && data.length > 0) {
gotUnexpectedData = true;
}
else if (isDataDict(data)) {
for (const key in data) {
if (data.hasOwnProperty(key)) {
gotUnexpectedData = true;
break;
}
}
}
else {
gotUnexpectedData = true;
}
if (gotUnexpectedData) {
throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
`but got ${data}`);
}
}
return [];
}
if (data == null) {
return names.map(name => null);
}
let arrays;
if (isDataDict(data)) {
data = data;
arrays = [];
for (const name of names) {
if (data[name] == null) {
throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
`${names}`);
}
arrays.push(data[name]);
}
}
else if (isDataArray(data)) {
data = data;
if (data.length !== names.length) {
throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
`Tensors that you are passing to your model is not the size the ` +
`model expected. Expected to see ${names.length} Tensor(s), but ` +
`instead got the following list of Tensor(s): ${data}`);
}
arrays = data;
}
else {
data = data;
if (names.length > 1) {
throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
`but only received one Tensor. Found: Tensor with shape ${data.shape}`);
}
arrays = [data];
}
arrays = ensureTensorsRank2OrHigher(arrays);
if (shapes != null) {
for (let i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
const array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
`to have ${shapes[i].length} dimension(s). but got array with ` +
`shape ${array.shape}`);
}
for (let j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
const dim = array.shape[j];
const refDim = shapes[i][j];
if (refDim != null && refDim >= 0 && dim !== refDim) {
throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` +
`example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
`(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` +
` but the ${exceptionPrefix} received an input with ${array.shape[0]}` +
` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` +
` (tensor shape [${array.shape}])`);
}
}
}
}
return arrays;
}
function checkArrayLengths(inputs, targets, weights) {
const setX = unique(inputs.map(input => input.shape[0]));
setX.sort();
const setY = unique(targets.map(target => target.shape[0]));
setY.sort();
if (setX.length > 1) {
throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
`Got array shapes: ` +
`${JSON.stringify(inputs.map(input => input.shape))}`);
}
if (setY.length > 1) {
throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
`Got array shapes: ` +
`${JSON.stringify(targets.map(target => target.shape))}`);
}
if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
throw new ValueError(`Input Tensors should have the same number of samples as target ` +
`Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
`sample(s).`);
}
}
function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
const keyLosses = [
meanSquaredError, binaryCrossentropy$1,
categoricalCrossentropy$1
];
for (let i = 0; i < targets.length; ++i) {
const y = targets[i];
const loss = lossFns[i];
const shape = outputShapes[i];
if (loss == null) {
continue;
}
if (loss === categoricalCrossentropy$1) {
if (y.shape[y.shape.length - 1] === 1) {
throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
`a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
`expects targets to be binary matrices (1s and 0s) of shape ` +
`[samples, classes].`);
}
}
if (keyLosses.indexOf(loss) !== -1) {
const slicedYShape = y.shape.slice(1);
const slicedShape = shape.slice(1);
for (let j = 0; j < slicedYShape.length; ++j) {
const targetDim = slicedYShape[j];
const outDim = slicedShape[j];
if (outDim != null && targetDim !== outDim) {
throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
`output of shape ${shape}, while using a loss function that ` +
`expects targets to have the same shape as the output.`);
}
}
}
}
}
function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
let arrays;
if (Array.isArray(data)) {
if (data.length !== names.length) {
throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
`Tensors that you are passing to your model is not the size the ` +
`the model expected. Expected to see ${names.length} Tensor(s),` +
` but instead got ${data.length} Tensors(s).`);
}
arrays = data;
}
else {
if (names.length > 1) {
throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
`but only received one Tensor. Found: array with shape ` +
`${JSON.stringify(data.shape)}.`);
}
arrays = [data];
}
if (shapes != null) {
for (let i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
const array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
`to have ${shapes[i].length} dimension(s), but got array with ` +
`shape ${JSON.stringify(array.shape)}`);
}
for (let j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
const dim = array.shape[j];
const refDim = shapes[i][j];
if (refDim != null) {
if (refDim !== dim) {
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
`${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
`got array with shape ${JSON.stringify(array.shape)}.`);
}
}
}
}
}
}
function collectMetrics(metrics, outputNames) {
if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
return outputNames.map(name => []);
}
let wrappedMetrics;
if (typeof metrics === 'string' || typeof metrics === 'function') {
wrappedMetrics = [metrics];
}
else if (Array.isArray(metrics) || typeof metrics === 'object') {
wrappedMetrics = metrics;
}
else {
throw new TypeError('Type of metrics argument not understood. Expected an string,' +
`function, Array, or Object, found: ${metrics}`);
}
if (Array.isArray(wrappedMetrics)) {
return outputNames.map(name => wrappedMetrics);
}
else {
const nestedMetrics = [];
for (const name of outputNames) {
let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
if (!Array.isArray(outputMetrics)) {
outputMetrics = [outputMetrics];
}
nestedMetrics.push(outputMetrics);
}
return nestedMetrics;
}
}
const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
class LayersModel extends Container {
constructor(args) {
super(args);
this.isTraining = false;
}
summary(lineLength, positions, printFn = console.log) {
if (!this.built) {
throw new ValueError(`This model has never been called, thus its weights have not been ` +
`created yet. So no summary can be displayed. Build the model ` +
`first (e.g., by calling it on some test data).`);
}
printSummary(this, lineLength, positions, printFn);
}
compile(args) {
if (args.loss == null) {
args.loss = [];
}
this.loss = args.loss;
if (typeof args.optimizer === 'string') {
this.optimizer_ = getOptimizer(args.optimizer);
this.isOptimizerOwned = true;
}
else {
if (!(args.optimizer instanceof Optimizer)) {
throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
}
this.optimizer_ = args.optimizer;
this.isOptimizerOwned = false;
}
let lossFunctions = [];
if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
typeof args.loss !== 'function') {
args.loss = args.loss;
for (const name in args.loss) {
if (this.outputNames.indexOf(name) === -1) {
throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
`Only expected the following keys: ${this.outputNames}`);
}
}
for (const name of this.outputNames) {
if (args.loss[name] == null) {
console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
`this was done on purpose, and we will not be expecting data ` +
`to be passed to ${name} during training`);
}
lossFunctions.push(get$1(args.loss[name]));
}
}
else if (Array.isArray(args.loss)) {
if (args.loss.length !== this.outputs.length) {
throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
`model output. The model has ${this.outputs.length} output(s), ` +
`but you passed loss=${args.loss}.`);
}
const theLosses = args.loss;
lossFunctions = theLosses.map(l => get$1(l));
}
else {
const lossFunction = get$1(args.loss);
this.outputs.forEach(_ => {
lossFunctions.push(lossFunction);
});
}
this.lossFunctions = lossFunctions;
this.feedOutputNames = [];
this.feedOutputShapes = [];
this.feedLossFns = [];
for (let i = 0; i < this.outputs.length; ++i) {
const shape = this.internalOutputShapes[i];
const name = this.outputNames[i];
this.feedOutputNames.push(name);
this.feedOutputShapes.push(shape);
this.feedLossFns.push(this.lossFunctions[i]);
}
const skipTargetIndices = [];
this.metrics = args.metrics;
this.metricsNames = ['loss'];
this.metricsTensors = [];
nameScope('loss', () => {
for (let i = 0; i < this.outputs.length; ++i) {
if (skipTargetIndices.indexOf(i) !== -1) {
continue;
}
const weightedLoss = this.lossFunctions[i];
if (this.outputs.length > 1) {
this.metricsTensors.push([weightedLoss, i]);
this.metricsNames.push(this.outputNames[i] + '_loss');
}
}
});
const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
const appendMetric = (outputIndex, metricName, metricTensor) => {
if (this.outputNames.length > 1) {
metricName = this.outputNames[outputIndex] + '_' + metricName;
}
this.metricsNames.push(metricName);
this.metricsTensors.push([metricTensor, outputIndex]);
};
nameScope('metric', () => {
for (let i = 0; i < this.outputs.length; ++i) {
if (skipTargetIndices.indexOf(i) !== -1) {
continue;
}
const outputMetrics = nestedMetrics[i];
const handleMetrics = (metrics) => {
const metricNamePrefix = '';
let metricName;
let accFn;
let weightedMetricFn;
for (const metric of metrics) {
if (typeof metric === 'string' &&
['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
-1) {
const outputShape = this.internalOutputShapes[i];
if (outputShape[outputShape.length - 1] === 1 ||
this.lossFunctions[i] === binaryCrossentropy$1) {
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = binaryAccuracy;
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = binaryCrossentropy;
}
}
else if (this.lossFunctions[i] ===
sparseCategoricalCrossentropy$1) {
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = sparseCategoricalAccuracy;
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = sparseCategoricalCrossentropy;
}
}
else {
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = categoricalAccuracy;
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = categoricalCrossentropy;
}
}
let suffix;
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
suffix = 'acc';
}
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
suffix = 'ce';
}
weightedMetricFn = accFn;
metricName = metricNamePrefix + suffix;
}
else {
const metricFn = get(metric);
weightedMetricFn = metricFn;
metricName =
metricNamePrefix + getLossOrMetricName(metric);
}
let metricResult;
nameScope(metricName, () => {
metricResult = weightedMetricFn;
});
appendMetric(i, metricName, metricResult);
}
};
handleMetrics(outputMetrics);
}
});
this.collectedTrainableWeights = this.trainableWeights;
}
checkTrainableWeightsConsistency() {
if (this.collectedTrainableWeights == null) {
return;
}
if (this.trainableWeights.length !==
this.collectedTrainableWeights.length) {
console.warn('Discrepancy between trainableweights and collected trainable ' +
'weights. Did you set `model.trainable` without calling ' +
'`model.compile()` afterwards?');
}
}
evaluate(x, y, args = {}) {
const batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
const checkBatchAxis = true;
const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
try {
const ins = standardizedOuts[0].concat(standardizedOuts[1]);
this.makeTestFunction();
const f = this.testFunction;
const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
return singletonOrArray(testOuts);
}
finally {
disposeNewTensors(standardizedOuts[0], x);
disposeNewTensors(standardizedOuts[1], y);
}
}
async evaluateDataset(dataset, args) {
this.makeTestFunction();
return evaluateDataset(this, dataset, args);
}
checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
let numSamples;
if (steps != null) {
numSamples = null;
if (batchSize != null) {
throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
`Got batchSize = ${batchSize}`);
}
}
else if (ins != null) {
if (Array.isArray(ins)) {
numSamples = ins[0].shape[0];
}
else {
numSamples = ins.shape[0];
}
}
else {
throw new ValueError(`Either the input data should have a defined shape, or ` +
`${stepsName} shoud be specified.`);
}
return numSamples;
}
execute(inputs, outputs) {
if (Array.isArray(outputs) && outputs.length === 0) {
throw new ValueError('`outputs` is an empty Array, which is not allowed.');
}
const outputsIsArray = Array.isArray(outputs);
const outputNames = (outputsIsArray ? outputs : [outputs]);
const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
const feedDict = new FeedDict();
if (inputs instanceof Tensor) {
inputs = [inputs];
}
if (Array.isArray(inputs)) {
if (inputs.length !== this.inputs.length) {
throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
`does not match the number of inputs of this model ` +
`(${this.inputs.length}).`);
}
for (let i = 0; i < this.inputs.length; ++i) {
feedDict.add(this.inputs[i], inputs[i]);
}
}
else {
for (const input of this.inputs) {
const tensorValue = inputs[input.name];
if (tensorValue == null) {
throw new ValueError(`No value is provided for the model's input ${input.name}`);
}
feedDict.add(input, tensorValue);
}
}
const executeOutputs = execute(outputSymbolicTensors, feedDict);
return outputsIsArray ? executeOutputs : executeOutputs[0];
}
retrieveSymbolicTensors(symbolicTensorNames) {
const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
let outputsRemaining = symbolicTensorNames.length;
for (const layer of this.layers) {
const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
const layerOutputNames = layerOutputs.map(output => output.name);
for (let i = 0; i < symbolicTensorNames.length; ++i) {
const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
if (index !== -1) {
outputSymbolicTensors[i] = layerOutputs[index];
outputsRemaining--;
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining === 0) {
break;
}
}
if (outputsRemaining > 0) {
const remainingNames = [];
outputSymbolicTensors.forEach((tensor, i) => {
if (tensor == null) {
remainingNames.push(symbolicTensorNames[i]);
}
});
throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
`${JSON.stringify(remainingNames)}`);
}
return outputSymbolicTensors;
}
predictLoop(ins, batchSize = 32, verbose = false) {
return tidy(() => {
const numSamples = this.checkNumSamples(ins);
if (verbose) {
throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
}
const batches = makeBatches(numSamples, batchSize);
const outsBatches = this.outputs.map(output => []);
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
const batchOuts = tidy(() => {
const batchStart = batches[batchIndex][0];
const batchEnd = batches[batchIndex][1];
const insBatch = sliceArrays(ins, batchStart, batchEnd);
const feeds = [];
if (Array.isArray(insBatch)) {
for (let i = 0; i < insBatch.length; ++i) {
feeds.push({ key: this.inputs[i], value: insBatch[i] });
}
}
else {
feeds.push({ key: this.inputs[0], value: insBatch });
}
const feedDict = new FeedDict(feeds);
return execute(this.outputs, feedDict);
});
batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
}
return singletonOrArray(outsBatches.map(batches => concat$2(batches, 0)));
});
}
predict(x, args = {}) {
const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
try {
const batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
return this.predictLoop(xsRank2OrHigher, batchSize);
}
finally {
disposeNewTensors(xsRank2OrHigher, x);
}
}
predictOnBatch(x) {
checkInputData(x, this.inputNames, this.feedInputShapes, true);
const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
return this.predictLoop(x, batchSize);
}
standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
if (this.optimizer_ == null) {
throw new RuntimeError('You must compile a model before training/testing. Use ' +
'LayersModel.compile(modelCompileArgs).');
}
const outputShapes = [];
for (let i = 0; i < this.feedOutputShapes.length; ++i) {
const outputShape = this.feedOutputShapes[i];
const lossFn = this.feedLossFns[i];
if (lossFn === sparseCategoricalCrossentropy$1) {
outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
}
else {
outputShapes.push(outputShape);
}
}
x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
checkArrayLengths(x, y);
checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
if (this.stateful && batchSize != null && batchSize > 0) {
if (x[0].shape[0] % batchSize !== 0) {
throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
`number of samples that is divisible by the batch size ` +
`${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
}
}
return [x, y];
}
async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
if (sampleWeight != null) {
throw new Error('sample weight is not supported yet.');
}
let standardSampleWeights = null;
if (classWeight != null) {
const classWeights = standardizeClassWeights(classWeight, this.outputNames);
standardSampleWeights = [];
for (let i = 0; i < classWeights.length; ++i) {
standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
}
}
return [standardXs, standardYs, standardSampleWeights];
}
testLoop(f, ins, batchSize, verbose = 0, steps) {
return tidy(() => {
const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
const outs = [];
if (verbose > 0) {
throw new NotImplementedError('Verbose mode is not implemented yet.');
}
if (steps != null) {
throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
}
else {
const batches = makeBatches(numSamples, batchSize);
const indexArray = tensor1d(range(0, numSamples));
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
const batchStart = batches[batchIndex][0];
const batchEnd = batches[batchIndex][1];
const batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
const insBatch = sliceArraysByIndices(ins, batchIds);
const batchOuts = f(insBatch);
if (batchIndex === 0) {
for (let i = 0; i < batchOuts.length; ++i) {
outs.push(scalar(0));
}
}
for (let i = 0; i < batchOuts.length; ++i) {
const batchOut = batchOuts[i];
outs[i] =
add$1(outs[i], mul(batchEnd - batchStart, batchOut));
}
}
for (let i = 0; i < outs.length; ++i) {
outs[i] = div$1(outs[i], numSamples);
}
}
return outs;
});
}
getDedupedMetricsNames() {
const outLabels = this.metricsNames;
const dedupedOutLabels = [];
for (let i = 0; i < outLabels.length; ++i) {
const label = outLabels[i];
let newLabel = label;
if (count(outLabels, label) > 1) {
const dupIndex = count(outLabels.slice(0, i), label);
newLabel += `_${dupIndex}`;
}
dedupedOutLabels.push(newLabel);
}
return dedupedOutLabels;
}
makeTrainFunction() {
return (data) => {
const lossValues = [];
const inputs = data.slice(0, this.inputs.length);
const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
const metricsValues = [];
const totalLossFunction = () => {
const feeds = [];
for (let i = 0; i < this.inputs.length; ++i) {
feeds.push({ key: this.inputs[i], value: inputs[i] });
}
const feedDict = new FeedDict(feeds);
const outputs = execute(this.outputs, feedDict, { 'training': true });
let totalLoss;
for (let i = 0; i < this.lossFunctions.length; ++i) {
const lossFunction = this.lossFunctions[i];
let loss = lossFunction(targets[i], outputs[i]);
if (sampleWeights[i] != null) {
loss = computeWeightedLoss(loss, sampleWeights[i]);
}
const meanLoss = mean$1(loss);
lossValues.push(meanLoss);
if (i === 0) {
totalLoss = loss;
}
else {
totalLoss = add$1(totalLoss, loss);
}
}
for (let i = 0; i < this.metricsTensors.length; ++i) {
let weightedMetric;
if (this.outputs.length > 1 && i < this.outputs.length) {
weightedMetric = lossValues[i];
}
else {
const metric = this.metricsTensors[i][0];
const outputIndex = this.metricsTensors[i][1];
weightedMetric =
mean$1(metric(targets[outputIndex], outputs[outputIndex]));
}
keep(weightedMetric);
metricsValues.push(weightedMetric);
}
totalLoss = mean$1(totalLoss);
this.calculateLosses().forEach(regularizerLoss => {
totalLoss = add$1(totalLoss, regularizerLoss);
});
return totalLoss;
};
const variables = this.collectedTrainableWeights.map(param => param.read());
const returnCost = true;
const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables);
return [totalLossValue].concat(metricsValues);
};
}
makeTestFunction() {
this.testFunction = (data) => {
return tidy(() => {
const valOutputs = [];
let totalLoss;
const inputs = data.slice(0, this.inputs.length);
const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
const feeds = [];
for (let i = 0; i < this.inputs.length; ++i) {
feeds.push({ key: this.inputs[i], value: inputs[i] });
}
const feedDict = new FeedDict(feeds);
const outputs = execute(this.outputs, feedDict);
for (let i = 0; i < this.lossFunctions.length; ++i) {
const lossFunction = this.lossFunctions[i];
const loss = mean$1(lossFunction(targets[i], outputs[i]));
if (i === 0) {
totalLoss = loss;
}
else {
totalLoss = add$1(totalLoss, loss);
}
valOutputs.push(totalLoss);
}
for (let i = 0; i < this.metricsTensors.length; ++i) {
const metric = this.metricsTensors[i][0];
const outputIndex = this.metricsTensors[i][1];
const meanMetric = mean$1(metric(targets[outputIndex], outputs[outputIndex]));
valOutputs.push(meanMetric);
}
return valOutputs;
});
};
}
async fit(x, y, args = {}) {
if (this.isTraining) {
throw new Error('Cannot start training because another fit() call is ongoing.');
}
this.isTraining = true;
let inputs;
let targets;
let originalInputs;
let originalTargets;
let inputValX;
let inputValY;
let valX;
let valY;
let sampleWeights;
try {
const batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
const checkBatchAxis = false;
const standardizedOuts = await this.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
inputs = standardizedOuts[0];
targets = standardizedOuts[1];
sampleWeights = standardizedOuts[2];
let doValidation = false;
let valIns;
if (args.validationData != null && args.validationData.length > 0) {
doValidation = true;
if (args.validationData.length === 2) {
inputValX = args.validationData[0];
inputValY = args.validationData[1];
}
else if (args.validationData.length === 3) {
throw new NotImplementedError('validationData including sample weights is not supported yet.');
}
else {
throw new ValueError(`When passing validation data, it must contain 2 (valX, valY) ` +
`or 3 (valX, valY, valSampleWeight) items; ` +
`${args.validationData} is invalid.`);
}
const checkBatchAxis = true;
const valStandardized = await this.standardizeUserData(inputValX, inputValY, null, null, checkBatchAxis, batchSize);
valX = valStandardized[0];
valY = valStandardized[1];
valIns = valX.concat(valY);
}
else if (args.validationSplit != null && args.validationSplit > 0 &&
args.validationSplit < 1) {
doValidation = true;
const splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
const originalBatchSize = inputs[0].shape[0];
valX = sliceArrays(inputs, splitAt, originalBatchSize);
originalInputs = inputs;
inputs = sliceArrays(inputs, 0, splitAt);
valY = sliceArrays(targets, splitAt, originalBatchSize);
originalTargets = targets;
targets = sliceArrays(targets, 0, splitAt);
valIns = valX.concat(valY);
}
else if (args.validationSteps != null) {
doValidation = true;
}
const ins = inputs.concat(targets).concat(sampleWeights);
this.checkTrainableWeightsConsistency();
const trainFunction = this.makeTrainFunction();
const outLabels = this.getDedupedMetricsNames();
let valFunction;
let callbackMetrics;
if (doValidation) {
this.makeTestFunction();
valFunction = this.testFunction;
callbackMetrics =
outLabels.slice().concat(outLabels.map(n => 'val_' + n));
}
else {
valFunction = null;
valIns = [];
callbackMetrics = outLabels.slice();
}
const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
const out = await this.fitLoop(trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
return out;
}
finally {
this.isTraining = false;
disposeNewTensors(inputs, x);
disposeNewTensors(targets, y);
disposeNewTensors(originalInputs, x);
disposeNewTensors(originalTargets, y);
disposeNewTensors(valX, inputValX);
disposeNewTensors(valY, inputValY);
if (sampleWeights != null) {
dispose(sampleWeights);
}
}
}
async fitLoop(f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
if (batchSize == null) {
batchSize = 32;
}
if (epochs == null) {
epochs = 1;
}
if (shuffle$1 == null) {
shuffle$1 = true;
}
if (initialEpoch == null) {
initialEpoch = 0;
}
let doValidation = false;
if (valF != null && valIns != null) {
doValidation = true;
}
if (validationSteps != null) {
doValidation = true;
if (stepsPerEpoch == null) {
throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' +
'i.e., `stepsPerEpoch` must be set.');
}
}
const numTrainSamples = this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
let indexArray;
if (numTrainSamples != null) {
indexArray = range(0, numTrainSamples);
}
if (verbose == null) {
verbose = 1;
}
const { callbackList, history } = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics);
callbackList.setModel(this);
this.history = history;
await callbackList.onTrainBegin();
this.stopTraining_ = false;
for (let epoch = initialEpoch; epoch < epochs; ++epoch) {
await callbackList.onEpochBegin(epoch);
const epochLogs = {};
if (stepsPerEpoch != null) {
throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
}
else {
if (shuffle$1 === 'batch') {
throw new NotImplementedError('batch shuffling is not implemneted'
+ ' yet');
}
else if (shuffle$1) {
shuffle(indexArray);
}
const epochIndexArray1D = tensor1d(indexArray);
const batches = makeBatches(numTrainSamples, batchSize);
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
const batchLogs = {};
await callbackList.onBatchBegin(batchIndex, batchLogs);
tidy(() => {
const batchStart = batches[batchIndex][0];
const batchEnd = batches[batchIndex][1];
const batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
batchLogs['batch'] = batchIndex;
batchLogs['size'] = batchEnd - batchStart;
const insBatch = sliceArraysByIndices(ins, batchIds);
const outs = f(insBatch);
for (let i = 0; i < outLabels.length; ++i) {
const label = outLabels[i];
const out = outs[i];
batchLogs[label] = out;
keep(out);
}
if (batchIndex === batches.length - 1) {
if (doValidation) {
const valOuts = this.testLoop(valF, valIns, batchSize);
for (let i = 0; i < outLabels.length; ++i) {
const label = outLabels[i];
const out = valOuts[i];
keep(out);
epochLogs['val_' + label] = out;
}
}
}
});
await callbackList.onBatchEnd(batchIndex, batchLogs);
disposeTensorsInLogs(batchLogs);
if (this.stopTraining_) {
break;
}
}
epochIndexArray1D.dispose();
}
await callbackList.onEpochEnd(epoch, epochLogs);
if (this.stopTraining_) {
break;
}
}
await callbackList.onTrainEnd();
await this.history.syncData();
return this.history;
}
async fitDataset(dataset, args) {
return fitDataset(this, dataset, args);
}
async trainOnBatch(x, y) {
const standardizeOut = await this.standardizeUserData(x, y);
const inputs = standardizeOut[0];
const targets = standardizeOut[1];
const trainFunction = this.makeTrainFunction();
const losses = trainFunction(inputs.concat(targets));
const lossValues = [];
for (const loss of losses) {
const v = await loss.data();
lossValues.push(v[0]);
}
dispose(losses);
disposeNewTensors(standardizeOut[0], x);
disposeNewTensors(standardizeOut[1], y);
return singletonOrArray(lossValues);
}
getNamedWeights(config) {
const namedWeights = [];
const trainableOnly = config != null && config.trainableOnly;
const weights = trainableOnly ? this.trainableWeights : this.weights;
const weightValues = this.getWeights(trainableOnly);
for (let i = 0; i < weights.length; ++i) {
if (trainableOnly && !weights[i].trainable) {
continue;
}
namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
}
return namedWeights;
}
set stopTraining(stop) {
this.stopTraining_ = stop;
}
get stopTraining() {
return this.stopTraining_;
}
get optimizer() {
return this.optimizer_;
}
set optimizer(optimizer) {
if (this.optimizer_ !== optimizer) {
this.optimizer_ = optimizer;
this.isOptimizerOwned = false;
}
}
dispose() {
const result = super.dispose();
if (result.refCountAfterDispose === 0 && this.optimizer != null &&
this.isOptimizerOwned) {
const numTensorsBeforeOptmizerDisposal = memory().numTensors;
this.optimizer_.dispose();
result.numDisposedVariables +=
numTensorsBeforeOptmizerDisposal - memory().numTensors;
}
return result;
}
getLossIdentifiers() {
let lossNames;
if (typeof this.loss === 'string') {
lossNames = toSnakeCase(this.loss);
}
else if (Array.isArray(this.loss)) {
for (const loss of this.loss) {
if (typeof loss !== 'string') {
throw new Error('Serialization of non-string loss is not supported.');
}
}
lossNames = this.loss.map(name => toSnakeCase(name));
}
else {
const outputNames = Object.keys(this.loss);
lossNames = {};
const losses = this.loss;
for (const outputName of outputNames) {
if (typeof losses[outputName] === 'string') {
lossNames[outputName] =
toSnakeCase(losses[outputName]);
}
else {
throw new Error('Serialization of non-string loss is not supported.');
}
}
}
return lossNames;
}
getMetricIdentifiers() {
if (typeof this.metrics === 'string' ||
typeof this.metrics === 'function') {
return [toSnakeCase(getLossOrMetricName(this.metrics))];
}
else if (Array.isArray(this.metrics)) {
return this.metrics.map(metric => toSnakeCase(getLossOrMetricName(metric)));
}
else {
const metricsIdentifiers = {};
for (const key in this.metrics) {
metricsIdentifiers[key] =
toSnakeCase(getLossOrMetricName(this.metrics[key]));
}
return metricsIdentifiers;
}
}
getTrainingConfig() {
return {
loss: this.getLossIdentifiers(),
metrics: this.getMetricIdentifiers(),
optimizer_config: {
class_name: this.optimizer.getClassName(),
config: this.optimizer.getConfig()
}
};
}
loadTrainingConfig(trainingConfig) {
if (trainingConfig.weighted_metrics != null) {
throw new Error('Loading weight_metrics is not supported yet.');
}
if (trainingConfig.loss_weights != null) {
throw new Error('Loading loss_weights is not supported yet.');
}
if (trainingConfig.sample_weight_mode != null) {
throw new Error('Loading sample_weight_mode is not supported yet.');
}
const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
const optimizer = deserialize(tsConfig);
let loss;
if (typeof trainingConfig.loss === 'string') {
loss = toCamelCase(trainingConfig.loss);
}
else if (Array.isArray(trainingConfig.loss)) {
loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry));
}
else if (trainingConfig.loss != null) {
loss = {};
for (const key in trainingConfig.loss) {
loss[key] = toCamelCase(trainingConfig.loss[key]);
}
}
let metrics;
if (Array.isArray(trainingConfig.metrics)) {
metrics = trainingConfig.metrics.map(metric => toCamelCase(metric));
}
else if (trainingConfig.metrics != null) {
metrics = {};
for (const key in trainingConfig.metrics) {
metrics[key] = toCamelCase(trainingConfig.metrics[key]);
}
}
this.compile({ loss, metrics, optimizer });
}
async save(handlerOrURL, config) {
if (typeof handlerOrURL === 'string') {
const handlers = getSaveHandlers(handlerOrURL);
if (handlers.length === 0) {
throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`);
}
else if (handlers.length > 1) {
throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` +
`URL '${handlerOrURL}'`);
}
handlerOrURL = handlers[0];
}
if (handlerOrURL.save == null) {
throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' +
'provided does not have the `save` attribute defined.');
}
const weightDataAndSpecs = await encodeWeights(this.getNamedWeights(config));
const returnString = false;
const unusedArg = null;
const modelConfig = this.toJSON(unusedArg, returnString);
const modelArtifacts = {
modelTopology: modelConfig,
format: LAYERS_MODEL_FORMAT_NAME,
generatedBy: `TensorFlow.js tfjs-layers v${version}`,
convertedBy: null,
};
const includeOptimizer = config == null ? false : config.includeOptimizer;
if (includeOptimizer && this.optimizer != null) {
modelArtifacts.trainingConfig = this.getTrainingConfig();
const weightType = 'optimizer';
const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await encodeWeights(await this.optimizer.getWeights(), weightType);
weightDataAndSpecs.specs.push(...optimizerWeightSpecs);
weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
}
if (this.userDefinedMetadata != null) {
const checkSize = true;
checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
}
modelArtifacts.weightData = weightDataAndSpecs.data;
modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
return handlerOrURL.save(modelArtifacts);
}
setUserDefinedMetadata(userDefinedMetadata) {
checkUserDefinedMetadata(userDefinedMetadata, this.name);
this.userDefinedMetadata = userDefinedMetadata;
}
getUserDefinedMetadata() {
return this.userDefinedMetadata;
}
}
LayersModel.className = 'Model';
registerClass(LayersModel);
class Functional extends LayersModel {
}
Functional.className = 'Functional';
registerClass(Functional);
async function loadLayersModelFromIOHandler(handler, customObjects, options) {
if (options == null) {
options = {};
}
if (handler.load == null) {
throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' +
'does not have the `load` method implemented.');
}
const artifacts = await handler.load();
let modelTopology = artifacts.modelTopology;
if (modelTopology['model_config'] != null) {
modelTopology = modelTopology['model_config'];
}
const strict = options.strict == null ? true : options.strict;
const fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
const model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
const trainingConfig = artifacts.trainingConfig;
if (trainingConfig != null) {
model.loadTrainingConfig(trainingConfig);
}
if (artifacts.userDefinedMetadata != null) {
model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
}
if (artifacts.weightData != null) {
if (artifacts.weightSpecs == null) {
throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' +
'Therefore loading of weights cannot proceed.');
}
const { modelWeights, optimizerWeights } = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs);
model.loadWeights(modelWeights, strict);
if (model.optimizer != null && optimizerWeights.length > 0) {
await model.optimizer.setWeights(optimizerWeights);
}
dispose(modelWeights);
dispose(optimizerWeights.map(w => w.tensor));
}
return model;
}
function decodeModelAndOptimizerWeights(weightData, specs) {
const name2Tensor = decodeWeights(weightData, specs);
const modelWeights = {};
const optimizerWeights = [];
specs.forEach(spec => {
if (spec.group === 'optimizer') {
optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] });
}
else {
modelWeights[spec.name] = name2Tensor[spec.name];
}
});
return { modelWeights, optimizerWeights };
}
class Sequential extends LayersModel {
constructor(args) {
super({ inputs: [], outputs: [] });
args = args || {};
this.trainable = true;
this.built = false;
this.name = (args.name != null) ? args.name : getUid('sequential_');
if (args.layers != null) {
for (const layer of args.layers) {
this.add(layer);
}
}
}
checkShape(layer) {
const shape = layer.inboundNodes[0].outputTensors[0].shape;
if (shape.some(x => x < 0)) {
throw new ValueError('Negative dimension size caused by adding layer ' +
`${layer.name} with input shape [` +
`${layer.inboundNodes[0].inputTensors[0].shape}]`);
}
}
add(layer) {
const isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
let modelLayer;
if (isLayerModelInstance) {
modelLayer = layer;
if (modelLayer.outputs.length !== 1) {
throw new ValueError('All layers in a Sequential model ' +
'should have a single output tensor. ' +
'For multi-output layers, ' +
'use the functional API.');
}
if (modelLayer.inputs.length !== 1) {
throw new ValueError('All layers in a Sequential model ' +
'should have a single input tensor. ' +
'For multi-input layers, ' +
'use the functional API.');
}
}
if (this.outputs.length === 0) {
if (layer.inboundNodes.length === 0) {
if (layer.batchInputShape == null) {
throw new ValueError('The first layer in a Sequential model must ' +
'get an `inputShape` or `batchInputShape` argument.');
}
const x = Input({
batchShape: layer.batchInputShape,
dtype: layer.dtype,
name: layer.name + '_input'
});
layer.apply(x);
}
if (isLayerModelInstance) {
this.outputs = modelLayer.outputs;
this.inputs = modelLayer.inputs;
}
else {
if (layer.inboundNodes.length !== 1) {
throw new ValueError('A layer added to a Sequential model must not already be ' +
`connected somewhere else. LayersModel received layer ${layer.name} ` +
`which has ${layer.inboundNodes.length} pre-existing inbound ` +
'connections.');
}
if (layer.inboundNodes[0].outputTensors.length !== 1) {
throw new ValueError('All layers in a Sequential model ' +
'should have a single output tensor. ' +
'For multi-output layers, ' +
'use the functional API.');
}
this.checkShape(layer);
this.outputs = [layer.inboundNodes[0].outputTensors[0]];
this.inputs = getSourceInputs(this.outputs[0]);
}
this.inboundNodes = [];
new Node({
outboundLayer: this,
inboundLayers: [],
nodeIndices: [],
tensorIndices: [],
inputTensors: this.inputs,
outputTensors: this.outputs,
inputMasks: pyListRepeat(null, this.inputs.length),
outputMasks: [null],
inputShapes: this.inputs.map(x => x.shape),
outputShapes: this.outputs[0].shape
});
}
else {
const outputTensor = layer.apply(this.outputs[0]);
if (Array.isArray(outputTensor)) {
throw new TypeError('All layers in a Sequential model ' +
'should have a single output tensor. ' +
'For multi-output layers, ' +
'use the functional API.');
}
this.checkShape(layer);
this.outputs = [outputTensor];
this.inboundNodes[0].outputTensors = this.outputs;
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
}
this.layers.push(layer);
this.built = false;
}
pop() {
if (this.layers.length === 0) {
throw new TypeError('There are no layers in the model.');
}
this.layers.pop();
if (this.layers.length === 0) {
this.outputs = [];
this.inboundNodes = [];
this.outboundNodes = [];
}
else {
const lastLayerIndex = this.layers.length - 1;
this.layers[lastLayerIndex].outboundNodes = [];
this.outputs = [this.layers[lastLayerIndex].output];
this.inboundNodes[0].outputTensors = this.outputs;
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
}
}
call(inputs, kwargs) {
if (this.model == null) {
this.build();
}
return this.model.call(inputs, kwargs);
}
build(inputShape) {
getExactlyOneShape(inputShape);
if (this.inputs.length === 0 || this.outputs.length === 0) {
throw new TypeError('Sequential model cannot be built: model is empty.' +
' Add some layers first.');
}
this.model = new LayersModel({
inputs: this.inputs,
outputs: this.outputs[0],
name: this.name + '_model'
});
this.model.trainable = this.trainable;
this.supportsMasking = this.model.supportsMasking;
this.inputLayers = this.model.inputLayers;
this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
this.outputLayers = this.model.outputLayers;
this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
this.nodesByDepth = this.model.nodesByDepth;
this.containerNodes = this.model.containerNodes;
this.outputNames = this.model.outputNames;
this.inputNames = this.model.inputNames;
this.built = true;
}
countParams() {
if (!this.built) {
this.build();
}
return super.countParams();
}
summary(lineLength, positions, printFn = console.log) {
if (!this.built) {
this.build();
}
super.summary(lineLength, positions, printFn);
}
setWeights(weights) {
if (this.model == null) {
this.build();
}
this.model.setWeights(weights);
}
evaluate(x, y, args = {}) {
if (!this.built) {
throw new RuntimeError('The model needs to be compiled before being used.');
}
return this.model.evaluate(x, y, args);
}
async evaluateDataset(dataset, args) {
if (!this.built) {
throw new RuntimeError('The model needs to be compiled before being used.');
}
return this.model.evaluateDataset(dataset, args);
}
predict(x, args = {}) {
if (this.model == null) {
this.build();
}
return this.model.predict(x, args);
}
predictOnBatch(x) {
if (this.model == null) {
this.build();
}
return this.model.predictOnBatch(x);
}
compile(args) {
this.build();
this.model.compile(args);
this.optimizer_ = this.model.optimizer;
this.isOptimizerOwned = this.model.isOptimizerOwned;
this.loss = this.model.loss;
this.metrics = this.model.metrics;
this.metricsTensors = this.model.metricsTensors;
this.metricsNames = this.model.metricsNames;
}
get optimizer() {
return this.model == null ? undefined : this.model.optimizer;
}
set optimizer(optimizer) {
this.model.optimizer = optimizer;
}
async fit(x, y, args = {}) {
if (!this.built) {
throw new RuntimeError('The model needs to be compiled before ' +
'being used.');
}
return this.model.fit(x, y, args);
}
async fitDataset(dataset, args) {
if (!this.built) {
throw new RuntimeError('The model needs to be compiled before ' +
'being used.');
}
return this.model.fitDataset(dataset, args);
}
async trainOnBatch(x, y) {
return this.model.trainOnBatch(x, y);
}
static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
let configArray;
let extraModelConfig = {};
if (config instanceof Array) {
if (!(config[0].className != null) ||
config[0]['className'] === 'Merge') {
throw new ValueError('Legacy serialization format not supported yet.');
}
configArray = config;
}
else {
assert$1(config['layers'] != null, () => `When the config data for a Sequential model is not an Array, ` +
`it must be an Object that contains the 'layers' field.`);
configArray = config['layers'];
delete config['layers'];
extraModelConfig = config;
}
const model = new cls(extraModelConfig);
if (!(model instanceof Sequential)) {
throw new NotImplementedError(`Sequential.fromConfig called on non-Sequential input: ${model}`);
}
for (const conf of configArray) {
const customObjects = undefined;
const layer = deserialize(conf, customObjects, fastWeightInit);
if (fastWeightInit) {
layer.setFastWeightInitDuringBuild(true);
}
model.add(layer);
}
return model;
}
set stopTraining(stop) {
if (this.model == null) {
throw new ValueError('Cannot set the stopTraining property of a sequential model before ' +
'it is compiled.');
}
this.model.stopTraining = stop;
}
get stopTraining() {
if (this.model == null) {
throw new ValueError('Cannot get the stopTraining property of a sequential model before ' +
'it is compiled.');
}
return this.model.stopTraining;
}
getConfig() {
const layers = [];
for (const layer of this.layers) {
const dict = {};
dict['className'] = layer.getClassName();
dict['config'] = layer.getConfig();
layers.push(dict);
}
return { name: this.name, layers };
}
}
Sequential.className = 'Sequential';
registerClass(Sequential);
function sequential(config) {
return new Sequential(config);
}
let Activation$1 = class Activation extends Serializable {
getConfig() {
return {};
}
};
class Elu extends Activation$1 {
apply(x, alpha = 1) {
return elu(x, alpha);
}
}
Elu.className = 'elu';
registerClass(Elu);
class Selu extends Activation$1 {
apply(x) {
return selu$2(x);
}
}
Selu.className = 'selu';
registerClass(Selu);
class Relu extends Activation$1 {
apply(x) {
return relu$2(x);
}
}
Relu.className = 'relu';
registerClass(Relu);
class Relu6 extends Activation$1 {
apply(x) {
return tidy(() => minimum$2(6.0, relu$2(x)));
}
}
Relu6.className = 'relu6';
registerClass(Relu6);
class Linear extends Activation$1 {
apply(x) {
return x;
}
}
Linear.className = 'linear';
registerClass(Linear);
class Sigmoid extends Activation$1 {
apply(x) {
return sigmoid$2(x);
}
}
Sigmoid.className = 'sigmoid';
registerClass(Sigmoid);
class HardSigmoid extends Activation$1 {
apply(x) {
return hardSigmoid(x);
}
}
HardSigmoid.className = 'hardSigmoid';
registerClass(HardSigmoid);
class Softplus extends Activation$1 {
apply(x) {
return softplus$2(x);
}
}
Softplus.className = 'softplus';
registerClass(Softplus);
class Softsign extends Activation$1 {
apply(x) {
return softsign(x);
}
}
Softsign.className = 'softsign';
registerClass(Softsign);
class Tanh extends Activation$1 {
apply(x) {
return tanh$2(x);
}
}
Tanh.className = 'tanh';
registerClass(Tanh);
class Softmax extends Activation$1 {
apply(x, axis = (-1)) {
return softmax$2(x, axis);
}
}
Softmax.className = 'softmax';
registerClass(Softmax);
class LogSoftmax extends Activation$1 {
apply(x, axis = (-1)) {
return logSoftmax(x, axis);
}
}
LogSoftmax.className = 'logSoftmax';
registerClass(LogSoftmax);
class Gelu extends Activation$1 {
apply(x) {
return tidy(() => {
return tidy(() => {
const sqrtTwo = Math.sqrt(2);
const cdf = mul(0.5, add$1(1, erf$2(div$1(x, sqrtTwo))));
return mul(x, cdf);
});
});
}
}
Gelu.className = 'gelu';
registerClass(Gelu);
class GeluNew extends Activation$1 {
apply(x) {
return tidy(() => {
return mul(0.5, mul(x, add$1(1, tanh$2(mul(sqrt$2(div$1(2, Math.PI)), add$1(x, mul(0.044715, pow$2(x, 3))))))));
});
}
}
GeluNew.className = 'gelu_new';
registerClass(GeluNew);
class Mish extends Activation$1 {
apply(x) {
return tidy(() => mul(x, tanh$2(softplus$2(x))));
}
}
Mish.className = 'mish';
registerClass(Mish);
class Swish extends Activation$1 {
apply(x, alpha = 1) {
return tidy(() => mul(sigmoid$2(mul(x, alpha)), x));
}
}
Swish.className = 'swish';
registerClass(Swish);
function serializeActivation(activation) {
return activation.getClassName();
}
function deserializeActivation(config, customObjects = {}) {
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
}
function getActivation(identifier) {
if (identifier == null) {
const config = {};
config['className'] = 'linear';
config['config'] = {};
return deserializeActivation(config);
}
if (typeof identifier === 'string') {
const config = {};
config['className'] = identifier;
config['config'] = {};
return deserializeActivation(config);
}
else if (identifier instanceof Activation$1) {
return identifier;
}
else {
return deserializeActivation(identifier);
}
}
function assertObjectArgs(args) {
if (args != null && typeof args !== 'object') {
throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
`object, but received: ${args}`);
}
}
class Regularizer extends Serializable {
}
class L1L2 extends Regularizer {
constructor(args) {
super();
assertObjectArgs(args);
this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
this.hasL1 = this.l1 !== 0;
this.hasL2 = this.l2 !== 0;
}
apply(x) {
return tidy(() => {
let regularization = zeros$1([1]);
if (this.hasL1) {
regularization = add$1(regularization, sum$2(mul(this.l1, abs$2(x))));
}
if (this.hasL2) {
regularization =
add$1(regularization, sum$2(mul(this.l2, square(x))));
}
return reshape$2(regularization, []);
});
}
getConfig() {
return { 'l1': this.l1, 'l2': this.l2 };
}
static fromConfig(cls, config) {
return new cls({ l1: config['l1'], l2: config['l2'] });
}
}
L1L2.className = 'L1L2';
registerClass(L1L2);
const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
'l1l2': 'L1L2'
};
function serializeRegularizer(constraint) {
return serializeKerasObject(constraint);
}
function deserializeRegularizer(config, customObjects = {}) {
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
}
function getRegularizer(identifier) {
if (identifier == null) {
return null;
}
if (typeof identifier === 'string') {
const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
identifier;
const config = { className, config: {} };
return deserializeRegularizer(config);
}
else if (identifier instanceof Regularizer) {
return identifier;
}
else {
return deserializeRegularizer(identifier);
}
}
class Dropout extends Layer {
constructor(args) {
super(args);
this.rate = Math.max(Math.min(args.rate, 1), 0);
this.noiseShape = args.noiseShape;
this.seed = args.seed;
this.supportsMasking = true;
}
getNoiseShape(input) {
if (this.noiseShape == null) {
return this.noiseShape;
}
const inputShape = input.shape;
const noiseShape = [];
for (let i = 0; i < this.noiseShape.length; ++i) {
noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
}
return noiseShape;
}
call(inputs, kwargs) {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
const input = getExactlyOneTensor(inputs);
if (0 < this.rate && this.rate < 1) {
const training = kwargs['training'] == null ? false : kwargs['training'];
const noiseShape = this.getNoiseShape(input);
const output = inTrainPhase(() => dropout$1(input, this.rate, noiseShape, this.seed), () => input, training);
return output;
}
return inputs;
});
}
getConfig() {
const config = {
rate: this.rate,
noiseShape: this.noiseShape,
seed: this.seed,
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
dispose() {
return super.dispose();
}
}
Dropout.className = 'Dropout';
registerClass(Dropout);
class SpatialDropout1D extends Dropout {
constructor(args) {
super(args);
this.inputSpec = [{ ndim: 3 }];
}
getNoiseShape(input) {
const inputShape = input.shape;
return [inputShape[0], 1, inputShape[2]];
}
}
SpatialDropout1D.className = 'SpatialDropout1D';
registerClass(SpatialDropout1D);
class Dense extends Layer {
constructor(args) {
super(args);
this.activation = null;
this.useBias = true;
this.kernel = null;
this.bias = null;
this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
this.DEFAULT_BIAS_INITIALIZER = 'zeros';
if (args.batchInputShape == null && args.inputShape == null &&
args.inputDim != null) {
let batchSize = null;
if (args.batchSize != null) {
batchSize = args.batchSize;
}
this.batchInputShape = [batchSize, args.inputDim];
}
this.units = args.units;
assertPositiveInteger(this.units, 'units');
this.activation = getActivation(args.activation);
if (args.useBias != null) {
this.useBias = args.useBias;
}
this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
this.biasInitializer =
getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
this.kernelConstraint = getConstraint(args.kernelConstraint);
this.biasConstraint = getConstraint(args.biasConstraint);
this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
this.biasRegularizer = getRegularizer(args.biasRegularizer);
this.activityRegularizer = getRegularizer(args.activityRegularizer);
this.supportsMasking = true;
this.inputSpec = [{ minNDim: 2 }];
}
build(inputShape) {
inputShape = getExactlyOneShape(inputShape);
const inputLastDim = inputShape[inputShape.length - 1];
if (this.kernel == null) {
this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
if (this.useBias) {
this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
}
}
this.inputSpec = [{ minNDim: 2, axes: { [-1]: inputLastDim } }];
this.built = true;
}
computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
const outputShape = inputShape.slice();
outputShape[outputShape.length - 1] = this.units;
return outputShape;
}
call(inputs, kwargs) {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
const input = getExactlyOneTensor(inputs);
const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
let output;
if (fusedActivationName != null) {
output = dot(input, this.kernel.read(), fusedActivationName, this.bias ? this.bias.read() : null);
}
else {
output = dot(input, this.kernel.read());
if (this.bias != null) {
output = biasAdd(output, this.bias.read());
}
if (this.activation != null) {
output = this.activation.apply(output);
}
}
return output;
});
}
getConfig() {
const config = {
units: this.units,
activation: serializeActivation(this.activation),
useBias: this.useBias,
kernelInitializer: serializeInitializer(this.kernelInitializer),
biasInitializer: serializeInitializer(this.biasInitializer),
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
biasRegularizer: serializeRegularizer(this.biasRegularizer),
activityRegularizer: serializeRegularizer(this.activityRegularizer),
kernelConstraint: serializeConstraint(this.kernelConstraint),
biasConstraint: serializeConstraint(this.biasConstraint)
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
Dense.className = 'Dense';
registerClass(Dense);
class Flatten extends Layer {
constructor(args) {
args = args || {};
super(args);
this.inputSpec = [{ minNDim: 3 }];
this.dataFormat = args.dataFormat;
}
computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
for (const dim of inputShape.slice(1)) {
if (dim == null) {
throw new ValueError(`The shape of the input to "Flatten" is not fully defined ` +
`(got ${inputShape.slice(1)}). Make sure to pass a complete ` +
`"input_shape" or "batch_input_shape" argument to the first ` +
`layer in your model.`);
}
}
return [inputShape[0], arrayProd(inputShape, 1)];
}
call(inputs, kwargs) {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
let input = getExactlyOneTensor(inputs);
if (this.dataFormat === 'channelsFirst' && input.rank > 1) {
const permutation = [0];
for (let i = 2; i < input.rank; ++i) {
permutation.push(i);
}
permutation.push(1);
input = transpose$2(input, permutation);
}
return batchFlatten(input);
});
}
getConfig() {
const config = {};
if (this.dataFormat != null) {
config['dataFormat'] = this.dataFormat;
}
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
Flatten.className = 'Flatten';
registerClass(Flatten);
class Activation extends Layer {
constructor(args) {
super(args);
this.supportsMasking = true;
this.activation = getActivation(args.activation);
}
call(inputs, kwargs) {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
const input = getExactlyOneTensor(inputs);
return this.activation.apply(input);
});
}
getConfig() {
const config = { activation: serializeActivation(this.activation) };
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
Activation.className = 'Activation';
registerClass(Activation);
class RepeatVector extends Layer {
constructor(args) {
super(args);
this.n = args.n;
this.inputSpec = [{ ndim: 2 }];
}
computeOutputShape(inputShape) {
return [inputShape[0], this.n, inputShape[1]];
}
call(inputs, kwargs) {
return tidy(() => {
inputs = getExactlyOneTensor(inputs);
return repeat(inputs, this.n);
});
}
getConfig() {
const config = {
n: this.n,
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
RepeatVector.className = 'RepeatVector';
registerClass(RepeatVector);
class Reshape extends Layer {
constructor(args) {
super(args);
this.targetShape = args.targetShape;
for (let i = 0; i < this.targetShape.length; ++i) {
if (this.isUnknown(this.targetShape[i])) {
this.targetShape[i] = null;
}
}
}
isUnknown(dim) {
return dim < 0 || dim == null;
}
fixUnknownDimension(inputShape, outputShape) {
const errorMsg = 'Total size of new array must be unchanged.';
const finalShape = outputShape.slice();
let known = 1;
let unknown = null;
for (let i = 0; i < finalShape.length; ++i) {
const dim = finalShape[i];
if (this.isUnknown(dim)) {
if (unknown === null) {
unknown = i;
}
else {
throw new ValueError('Can only specifiy one unknown dimension.');
}
}
else {
known *= dim;
}
}
const originalSize = arrayProd(inputShape);
if (unknown !== null) {
if (known === 0 || originalSize % known !== 0) {
throw new ValueError(errorMsg);
}
finalShape[unknown] = originalSize / known;
}
else if (originalSize !== known) {
throw new ValueError(errorMsg);
}
return finalShape;
}
computeOutputShape(inputShape) {
let anyUnknownDims = false;
for (let i = 0; i < inputShape.length; ++i) {
if (this.isUnknown(inputShape[i])) {
anyUnknownDims = true;
break;
}
}
if (anyUnknownDims) {
return inputShape.slice(0, 1).concat(this.targetShape);
}
else {
return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
}
}
call(inputs, kwargs) {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
const input = getExactlyOneTensor(inputs);
const inputShape = input.shape;
const outputShape = inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
return reshape$2(input, outputShape);
});
}
getConfig() {
const config = {
targetShape: this.targetShape,
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
Reshape.className = 'Reshape';
registerClass(Reshape);
class Permute extends Layer {
constructor(args) {
super(args);
if (args.dims == null) {
throw new Error('Required configuration field `dims` is missing during Permute ' +
'constructor call.');
}
if (!Array.isArray(args.dims)) {
throw new Error('Permute constructor requires `dims` to be an Array, but received ' +
`${args.dims} instead.`);
}
const expectedSortedIndices = range(1, args.dims.length + 1);
if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) +
' `dims` must contain consecutive integers starting from 1.');
}
this.dims = args.dims;
this.dimsIncludingBatch = [0].concat(this.dims);
this.inputSpec = [new InputSpec({ ndim: this.dims.length + 1 })];
}
computeOutputShape(inputShape) {
inputShape = getExactlyOneShape(inputShape);
const outputShape = inputShape.slice();
this.dims.forEach((dim, i) => {
outputShape[i + 1] = inputShape[dim];
});
return outputShape;
}
call(inputs, kwargs) {
return transpose$2(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
}
getConfig() {
const config = {
dims: this.dims,
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
Permute.className = 'Permute';
registerClass(Permute);
class Masking extends Layer {
constructor(args) {
super(args == null ? {} : args);
this.supportsMasking = true;
if (args != null) {
this.maskValue = args.maskValue == null ? 0 : args.maskValue;
}
else {
this.maskValue = 0;
}
}
computeOutputShape(inputShape) {
return inputShape;
}
getConfig() {
const baseConfig = super.getConfig();
const config = { maskValue: this.maskValue };
Object.assign(config, baseConfig);
return config;
}
computeMask(inputs, mask) {
const input = getExactlyOneTensor(inputs);
const axis = -1;
return any$2(notEqual$2(input, this.maskValue), axis);
}
call(inputs, kwargs) {
return tidy(() => {
this.invokeCallHook(inputs, kwargs);
const input = getExactlyOneTensor(inputs);
const axis = -1;
const keepDims = true;
const booleanMask = any$2(notEqual$2(input, this.maskValue), axis, keepDims);
const output = mul(input, cast$3(booleanMask, input.dtype));
return output;
});
}
}
Masking.className = 'Masking';
registerClass(Masking);
function dense(args) {
return new Dense(args);
}
function dropout(args) {
return new Dropout(args);
}
export { LayersModel, PlatformStub, dense, dropout, enableProdMode, env, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$2 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler };