function _mergeNamespaces(n, m) { m.forEach(function (e) { e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) { if (k !== 'default' && !(k in n)) { var d = Object.getOwnPropertyDescriptor(e, k); Object.defineProperty(n, k, d.get ? d : { enumerable: true, get: function () { return e[k]; } }); } }); }); return Object.freeze(n); } const EPSILON_FLOAT32$1 = 1e-7; const EPSILON_FLOAT16$1 = 1e-4; class DataStorage { constructor(backend, dataMover) { this.backend = backend; this.dataMover = dataMover; this.data = new WeakMap(); this.dataIdsCount = 0; } get(dataId) { if (!this.data.has(dataId)) { this.dataMover.moveData(this.backend, dataId); } return this.data.get(dataId); } set(dataId, value) { this.dataIdsCount++; this.data.set(dataId, value); } has(dataId) { return this.data.has(dataId); } delete(dataId) { this.dataIdsCount--; return this.data.delete(dataId); } numDataIds() { return this.dataIdsCount; } } class KernelBackend { refCount(dataId) { return notYetImplemented('refCount'); } incRef(dataId) { return notYetImplemented('incRef'); } timerAvailable() { return true; } time(f) { return notYetImplemented('time'); } read(dataId) { return notYetImplemented('read'); } readSync(dataId) { return notYetImplemented('readSync'); } readToGPU(dataId, options) { return notYetImplemented('readToGPU'); } numDataIds() { return notYetImplemented('numDataIds'); } disposeData(dataId, force) { return notYetImplemented('disposeData'); } write(values, shape, dtype) { return notYetImplemented('write'); } move(dataId, values, shape, dtype, refCount) { return notYetImplemented('move'); } createTensorFromGPUData(values, shape, dtype) { return notYetImplemented('createTensorFromGPUData'); } memory() { return notYetImplemented('memory'); } floatPrecision() { return notYetImplemented('floatPrecision'); } epsilon() { return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1; } dispose() { return notYetImplemented('dispose'); } } function notYetImplemented(kernelName) { throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` + `This kernel may not be supported by the tfjs backend you have chosen`); } function shuffle(array) { let counter = array.length; let index = 0; while (counter > 0) { index = (Math.random() * counter) | 0; counter--; swap(array, counter, index); } } function clamp(min, x, max) { return Math.max(min, Math.min(x, max)); } function nearestLargerEven(val) { return val % 2 === 0 ? val : val + 1; } function swap(object, left, right) { const temp = object[left]; object[left] = object[right]; object[right] = temp; } function sum$2(arr) { let sum = 0; for (let i = 0; i < arr.length; i++) { sum += arr[i]; } return sum; } function assert$1(expr, msg) { if (!expr) { throw new Error(typeof msg === 'string' ? msg : msg()); } } function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') { assert$1(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`); } function assertNonNull(a) { assert$1(a != null, () => `The input to the tensor constructor must be a non-null value.`); } function sizeFromShape(shape) { if (shape.length === 0) { return 1; } let size = shape[0]; for (let i = 1; i < shape.length; i++) { size *= shape[i]; } return size; } function arraysEqual(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (let i = 0; i < n1.length; i++) { if (n1[i] !== n2[i]) { return false; } } return true; } function isInt(a) { return a % 1 === 0; } function sizeToSquarishShape(size) { const width = Math.ceil(Math.sqrt(size)); return [width, Math.ceil(size / width)]; } function rightPad(a, size) { if (size <= a.length) { return a; } return a + ' '.repeat(size - a.length); } function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) { return new Promise((resolve, reject) => { let tryCount = 0; const tryFn = () => { if (checkFn()) { resolve(); return; } tryCount++; const nextBackoff = delayFn(tryCount); if (maxCounter != null && tryCount >= maxCounter) { reject(); return; } if (scheduleFn != null) { scheduleFn(tryFn, nextBackoff); } else { setTimeout(tryFn, nextBackoff); } }; tryFn(); }); } function inferFromImplicitShape(shape, size) { let shapeProd = 1; let implicitIdx = -1; for (let i = 0; i < shape.length; ++i) { if (shape[i] >= 0) { shapeProd *= shape[i]; } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error(`Shapes can only have 1 implicit size. ` + `Found -1 at dim ${implicitIdx} and dim ${i}`); } implicitIdx = i; } else if (shape[i] < 0) { throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`); } } if (implicitIdx === -1) { if (size > 0 && size !== shapeProd) { throw Error(`Size(${size}) must match the product of shape ${shape}`); } return shape; } if (shapeProd === 0) { throw Error(`Cannot infer the missing size in [${shape}] when ` + `there are 0 elements`); } if (size % shapeProd !== 0) { throw Error(`The implicit shape can't be a fractional number. ` + `Got ${size} / ${shapeProd}`); } const newShape = shape.slice(); newShape[implicitIdx] = size / shapeProd; return newShape; } function parseAxisParam(axis, shape) { const rank = shape.length; axis = axis == null ? shape.map((s, i) => i) : [].concat(axis); assert$1(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` + `got axis ${axis}`); assert$1(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` + `got axis ${axis}`); return axis.map(a => a < 0 ? rank + a : a); } function squeezeShape(shape, axis) { const newShape = []; const keptDims = []; const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; const axes = (axis == null || isEmptyArray) ? null : parseAxisParam(axis, shape).sort(); let j = 0; for (let i = 0; i < shape.length; ++i) { if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); keptDims.push(i); } if (axes[j] <= i) { j++; } } if (shape[i] !== 1) { newShape.push(shape[i]); keptDims.push(i); } } return { newShape, keptDims }; } function getTypedArrayFromDType(dtype, size) { return getArrayFromDType(dtype, size); } function getArrayFromDType(dtype, size) { let values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); } else if (dtype === 'int32') { values = new Int32Array(size); } else if (dtype === 'bool') { values = new Uint8Array(size); } else if (dtype === 'string') { values = new Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } return values; } function checkConversionForErrors(vals, dtype) { for (let i = 0; i < vals.length; i++) { const num = vals[i]; if (isNaN(num) || !isFinite(num)) { throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`); } } } function isValidDtype(dtype) { return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string'; } function hasEncodingLoss(oldType, newType) { if (newType === 'complex64') { return false; } if (newType === 'float32' && oldType !== 'complex64') { return false; } if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { return false; } if (newType === 'bool' && oldType === 'bool') { return false; } return true; } function bytesPerElement(dtype) { if (dtype === 'float32' || dtype === 'int32') { return 4; } else if (dtype === 'complex64') { return 8; } else if (dtype === 'bool') { return 1; } else { throw new Error(`Unknown dtype ${dtype}`); } } function bytesFromStringArray(arr) { if (arr == null) { return 0; } let bytes = 0; arr.forEach(x => bytes += x.length); return bytes; } function isString(value) { return typeof value === 'string' || value instanceof String; } function isBoolean(value) { return typeof value === 'boolean'; } function isNumber(value) { return typeof value === 'number'; } function inferDtype(values) { if (Array.isArray(values)) { return inferDtype(values[0]); } if (values instanceof Float32Array) { return 'float32'; } else if (values instanceof Int32Array || values instanceof Uint8Array || values instanceof Uint8ClampedArray) { return 'int32'; } else if (isNumber(values)) { return 'float32'; } else if (isString(values)) { return 'string'; } else if (isBoolean(values)) { return 'bool'; } return 'float32'; } function isFunction(f) { return !!(f && f.constructor && f.call && f.apply); } function nearestDivisor(size, start) { for (let i = start; i < size; ++i) { if (size % i === 0) { return i; } } return size; } function computeStrides(shape) { const rank = shape.length; if (rank < 2) { return []; } const strides = new Array(rank - 1); strides[rank - 2] = shape[rank - 1]; for (let i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } function createNestedArray(offset, shape, a, isComplex = false) { const ret = new Array(); if (shape.length === 1) { const d = shape[0] * (isComplex ? 2 : 1); for (let i = 0; i < d; i++) { ret[i] = a[offset + i]; } } else { const d = shape[0]; const rest = shape.slice(1); const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1); for (let i = 0; i < d; i++) { ret[i] = createNestedArray(offset + i * len, rest, a, isComplex); } } return ret; } function toNestedArray(shape, a, isComplex = false) { if (shape.length === 0) { return a[0]; } const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1); if (size === 0) { return []; } if (size !== a.length) { throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`); } return createNestedArray(0, shape, a, isComplex); } function makeOnesTypedArray(size, dtype) { const array = makeZerosTypedArray(size, dtype); for (let i = 0; i < array.length; i++) { array[i] = 1; } return array; } function makeZerosTypedArray(size, dtype) { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { return new Int32Array(size); } else if (dtype === 'bool') { return new Uint8Array(size); } else { throw new Error(`Unknown data type ${dtype}`); } } function assertNonNegativeIntegerDimensions(shape) { shape.forEach(dimSize => { assert$1(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` + `shape [${shape}].`); }); } function locToIndex(locs, rank, strides) { if (rank === 0) { return 0; } else if (rank === 1) { return locs[0]; } let index = locs[locs.length - 1]; for (let i = 0; i < locs.length - 1; ++i) { index += strides[i] * locs[i]; } return index; } function indexToLoc(index, rank, strides) { if (rank === 0) { return []; } else if (rank === 1) { return [index]; } const locs = new Array(rank); for (let i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / strides[i]); index -= locs[i] * strides[i]; } locs[locs.length - 1] = index; return locs; } function isPromise(object) { return object && object.then && typeof object.then === 'function'; } const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; class Environment { constructor(global) { this.global = global; this.flags = {}; this.flagRegistry = {}; this.urlFlags = {}; this.getQueryParams = getQueryParams; this.populateURLFlags(); } setPlatform(platformName, platform) { if (this.platform != null) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn(`Platform ${this.platformName} has already been set. ` + `Overwriting the platform with ${platformName}.`); } } this.platformName = platformName; this.platform = platform; } registerFlag(flagName, evaluationFn, setHook) { this.flagRegistry[flagName] = { evaluationFn, setHook }; if (this.urlFlags[flagName] != null) { const flagValue = this.urlFlags[flagName]; if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`); } this.set(flagName, flagValue); } } async getAsync(flagName) { if (flagName in this.flags) { return this.flags[flagName]; } this.flags[flagName] = await this.evaluateFlag(flagName); return this.flags[flagName]; } get(flagName) { if (flagName in this.flags) { return this.flags[flagName]; } const flagValue = this.evaluateFlag(flagName); if (isPromise(flagValue)) { throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` + `Please use getAsync() instead.`); } this.flags[flagName] = flagValue; return this.flags[flagName]; } getNumber(flagName) { return this.get(flagName); } getBool(flagName) { return this.get(flagName); } getString(flagName) { return this.get(flagName); } getFlags() { return this.flags; } get features() { return this.flags; } set(flagName, value) { if (this.flagRegistry[flagName] == null) { throw new Error(`Cannot set flag ${flagName} as it has not been registered.`); } this.flags[flagName] = value; if (this.flagRegistry[flagName].setHook != null) { this.flagRegistry[flagName].setHook(value); } } evaluateFlag(flagName) { if (this.flagRegistry[flagName] == null) { throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`); } return this.flagRegistry[flagName].evaluationFn(); } setFlags(flags) { this.flags = Object.assign({}, flags); } reset() { this.flags = {}; this.urlFlags = {}; this.populateURLFlags(); } populateURLFlags() { if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') { return; } const urlParams = this.getQueryParams(this.global.location.search); if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); keyValues.forEach(keyValue => { const [key, value] = keyValue.split(':'); this.urlFlags[key] = parseValue(key, value); }); } } } function getQueryParams(queryString) { const params = {}; queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => { decodeParam(params, t[0], t[1]); return t.join('='); }); return params; } function decodeParam(params, name, value) { params[decodeURIComponent(name)] = decodeURIComponent(value || ''); } function parseValue(flagName, value) { const lowerCaseValue = value.toLowerCase(); if (lowerCaseValue === 'true' || lowerCaseValue === 'false') { return lowerCaseValue === 'true'; } else if (`${+lowerCaseValue}` === lowerCaseValue) { return +lowerCaseValue; } else { return value; } } function env() { return ENV$2; } let ENV$2 = null; function setEnvironmentGlobal(environment) { ENV$2 = environment; } let globalNameSpace; function getGlobalNamespace() { if (globalNameSpace == null) { let ns; if (typeof (window) !== 'undefined') { ns = window; } else if (typeof (global) !== 'undefined') { ns = global; } else if (typeof (process) !== 'undefined') { ns = process; } else if (typeof (self) !== 'undefined') { ns = self; } else { throw new Error('Could not find a global object'); } globalNameSpace = ns; } return globalNameSpace; } function getGlobalMap() { const ns = getGlobalNamespace(); if (ns._tfGlobals == null) { ns._tfGlobals = new Map(); } return ns._tfGlobals; } function getGlobal(key, init) { const globalMap = getGlobalMap(); if (globalMap.has(key)) { return globalMap.get(key); } else { const singleton = init(); globalMap.set(key, singleton); return globalMap.get(key); } } const Abs = 'Abs'; const Acos = 'Acos'; const Acosh = 'Acosh'; const Add = 'Add'; const AddN = 'AddN'; const All = 'All'; const Any = 'Any'; const ArgMax = 'ArgMax'; const ArgMin = 'ArgMin'; const Asin = 'Asin'; const Asinh = 'Asinh'; const Atan = 'Atan'; const Atanh = 'Atanh'; const Atan2 = 'Atan2'; const AvgPool = 'AvgPool'; const AvgPoolGrad = 'AvgPoolGrad'; const AvgPool3D = 'AvgPool3D'; const AvgPool3DGrad = 'AvgPool3DGrad'; const BatchMatMul = 'BatchMatMul'; const BatchToSpaceND = 'BatchToSpaceND'; const Bincount = 'Bincount'; const BitwiseAnd = 'BitwiseAnd'; const BroadcastTo = 'BroadcastTo'; const BroadcastArgs = 'BroadcastArgs'; const Cast = 'Cast'; const Ceil = 'Ceil'; const ClipByValue = 'ClipByValue'; const Complex = 'Complex'; const ComplexAbs = 'ComplexAbs'; const Concat = 'Concat'; const Conv2D = 'Conv2D'; const Conv2DBackpropFilter = 'Conv2DBackpropFilter'; const Conv2DBackpropInput = 'Conv2DBackpropInput'; const Conv3D = 'Conv3D'; const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; const Cos = 'Cos'; const Cosh = 'Cosh'; const Cumprod = 'Cumprod'; const Cumsum = 'Cumsum'; const CropAndResize = 'CropAndResize'; const DenseBincount = 'DenseBincount'; const DepthToSpace = 'DepthToSpace'; const DepthwiseConv2dNative = 'DepthwiseConv2dNative'; const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; const Diag = 'Diag'; const Dilation2D = 'Dilation2D'; const Dilation2DBackpropInput = 'Dilation2DBackpropInput'; const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; const RealDiv = 'RealDiv'; const Einsum = 'Einsum'; const Elu$1 = 'Elu'; const EluGrad = 'EluGrad'; const Erf = 'Erf'; const Equal = 'Equal'; const Exp = 'Exp'; const ExpandDims = 'ExpandDims'; const Expm1 = 'Expm1'; const FFT = 'FFT'; const Fill = 'Fill'; const FlipLeftRight = 'FlipLeftRight'; const Floor = 'Floor'; const FloorDiv = 'FloorDiv'; const FusedBatchNorm = 'FusedBatchNorm'; const GatherV2 = 'GatherV2'; const GatherNd = 'GatherNd'; const Greater = 'Greater'; const GreaterEqual = 'GreaterEqual'; const Identity$1 = 'Identity'; const IFFT = 'IFFT'; const Imag = 'Imag'; const IsFinite = 'IsFinite'; const IsInf = 'IsInf'; const IsNan = 'IsNan'; const LeakyRelu = 'LeakyRelu'; const Less = 'Less'; const LessEqual = 'LessEqual'; const LinSpace = 'LinSpace'; const Log = 'Log'; const Log1p = 'Log1p'; const LogicalAnd = 'LogicalAnd'; const LogicalNot = 'LogicalNot'; const LogicalOr = 'LogicalOr'; const LogSoftmax$1 = 'LogSoftmax'; const LRN = 'LRN'; const LRNGrad = 'LRNGrad'; const Max = 'Max'; const Maximum = 'Maximum'; const MaxPool = 'MaxPool'; const MaxPoolGrad = 'MaxPoolGrad'; const MaxPool3D = 'MaxPool3D'; const MaxPool3DGrad = 'MaxPool3DGrad'; const MaxPoolWithArgmax = 'MaxPoolWithArgmax'; const Mean = 'Mean'; const Min = 'Min'; const Minimum = 'Minimum'; const MirrorPad = 'MirrorPad'; const Mod = 'Mod'; const Multinomial = 'Multinomial'; const Multiply = 'Multiply'; const Neg = 'Neg'; const NotEqual = 'NotEqual'; const NonMaxSuppressionV3 = 'NonMaxSuppressionV3'; const NonMaxSuppressionV4 = 'NonMaxSuppressionV4'; const NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; const OnesLike = 'OnesLike'; const OneHot = 'OneHot'; const Pack = 'Pack'; const PadV2 = 'PadV2'; const Pow = 'Pow'; const Prelu = 'Prelu'; const Prod = 'Prod'; const RaggedGather = 'RaggedGather'; const RaggedRange = 'RaggedRange'; const RaggedTensorToTensor = 'RaggedTensorToTensor'; const Range = 'Range'; const Real = 'Real'; const Reciprocal = 'Reciprocal'; const Relu$1 = 'Relu'; const Reshape$1 = 'Reshape'; const ResizeNearestNeighbor = 'ResizeNearestNeighbor'; const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; const ResizeBilinear = 'ResizeBilinear'; const ResizeBilinearGrad = 'ResizeBilinearGrad'; const Relu6$1 = 'Relu6'; const Reverse = 'Reverse'; const Round = 'Round'; const Rsqrt = 'Rsqrt'; const ScatterNd = 'ScatterNd'; const TensorScatterUpdate = 'TensorScatterUpdate'; const SearchSorted = 'SearchSorted'; const Select = 'Select'; const Selu$1 = 'Selu'; const Slice = 'Slice'; const Sin = 'Sin'; const Sinh = 'Sinh'; const Sign = 'Sign'; const Sigmoid$1 = 'Sigmoid'; const Softplus$1 = 'Softplus'; const Sqrt = 'Sqrt'; const Sum = 'Sum'; const SpaceToBatchND = 'SpaceToBatchND'; const SplitV = 'SplitV'; const Softmax$1 = 'Softmax'; const SparseFillEmptyRows = 'SparseFillEmptyRows'; const SparseReshape = 'SparseReshape'; const SparseSegmentMean = 'SparseSegmentMean'; const SparseSegmentSum = 'SparseSegmentSum'; const SparseToDense = 'SparseToDense'; const SquaredDifference = 'SquaredDifference'; const Square = 'Square'; const StaticRegexReplace = 'StaticRegexReplace'; const StridedSlice = 'StridedSlice'; const StringNGrams = 'StringNGrams'; const StringSplit = 'StringSplit'; const StringToHashBucketFast = 'StringToHashBucketFast'; const Sub = 'Sub'; const Tan = 'Tan'; const Tanh$1 = 'Tanh'; const Tile = 'Tile'; const TopK = 'TopK'; const Transform = 'Transform'; const Transpose = 'Transpose'; const Unique = 'Unique'; const Unpack = 'Unpack'; const UnsortedSegmentSum = 'UnsortedSegmentSum'; const ZerosLike = 'ZerosLike'; const Step = 'Step'; const FromPixels = 'FromPixels'; const RotateWithOffset = 'RotateWithOffset'; const _FusedMatMul = '_FusedMatMul'; const FusedConv2D = 'FusedConv2D'; const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D'; function warn(...msg) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.warn(...msg); } } function log$2(...msg) { if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) { console.log(...msg); } } const kernelRegistry = getGlobal('kernelRegistry', () => new Map()); const gradRegistry = getGlobal('gradRegistry', () => new Map()); function getKernel(kernelName, backendName) { const key = makeKey(kernelName, backendName); return kernelRegistry.get(key); } function getGradient(kernelName) { return gradRegistry.get(kernelName); } function getKernelsForBackend(backendName) { const it = kernelRegistry.entries(); const result = []; while (true) { const { done, value } = it.next(); if (done) { break; } const [key, config] = value; const [backend,] = key.split('_'); if (backend === backendName) { result.push(config); } } return result; } function registerKernel(config) { const { kernelName, backendName } = config; const key = makeKey(kernelName, backendName); if (kernelRegistry.has(key)) { warn(`The kernel '${kernelName}' for backend ` + `'${backendName}' is already registered`); } kernelRegistry.set(key, config); } function registerGradient(config) { const { kernelName } = config; if (gradRegistry.has(kernelName)) { if (env().getBool('DEBUG')) { warn(`Overriding the gradient for '${kernelName}'`); } } gradRegistry.set(kernelName, config); } function makeKey(kernelName, backendName) { return `${backendName}_${kernelName}`; } setTimeout(() => env().setPlatform('browser', new PlatformStub())); function isTypedArrayBrowser(a) { return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array || a instanceof Uint8ClampedArray; } class PlatformStub { constructor() { } fetch(path, init) { throw new Error("fetch is not supported in this build."); } now() { return performance.now(); } encode(text, encoding) { if (encoding !== 'utf-8' && encoding !== 'utf8') { throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`); } if (this.textEncoder == null) { this.textEncoder = new TextEncoder(); } return this.textEncoder.encode(text); } decode(bytes, encoding) { return new TextDecoder(encoding).decode(bytes); } setTimeoutCustom(functionRef, delay) { if (typeof window === 'undefined' || !env().getBool('USE_SETTIMEOUTCUSTOM')) { setTimeout(functionRef, delay); return; } this.functionRefs.push(functionRef); setTimeout(() => { window.postMessage({name: this.messageName, index: this.functionRefs.length - 1}, location.origin); }, delay); if (!this.hasEventListener) { this.hasEventListener = true; window.addEventListener('message', (event) => { if (event.source === window && event.data.name === this.messageName) { event.stopPropagation(); const functionRef = this.functionRefs[event.data.index]; functionRef(); this.handledMessageCount++; if (this.handledMessageCount === this.functionRefs.length) { this.functionRefs = []; this.handledMessageCount = 0; } } }, true); } } isTypedArray(a) { return isTypedArrayBrowser(a) } } var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {}; function getDefaultExportFromCjs (x) { return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x; } var long = Long$1; var wasm = null; try { wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([ 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11 ])), {}).exports; } catch (e) { } function Long$1(low, high, unsigned) { this.low = low | 0; this.high = high | 0; this.unsigned = !!unsigned; } Object.defineProperty(Long$1.prototype, "__isLong__", { value: true }); function isLong(obj) { return (obj && obj["__isLong__"]) === true; } Long$1.isLong = isLong; var INT_CACHE = {}; var UINT_CACHE = {}; function fromInt(value, unsigned) { var obj, cachedObj, cache; if (unsigned) { value >>>= 0; if (cache = (0 <= value && value < 256)) { cachedObj = UINT_CACHE[value]; if (cachedObj) return cachedObj; } obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true); if (cache) UINT_CACHE[value] = obj; return obj; } else { value |= 0; if (cache = (-128 <= value && value < 128)) { cachedObj = INT_CACHE[value]; if (cachedObj) return cachedObj; } obj = fromBits(value, value < 0 ? -1 : 0, false); if (cache) INT_CACHE[value] = obj; return obj; } } Long$1.fromInt = fromInt; function fromNumber(value, unsigned) { if (isNaN(value)) return unsigned ? UZERO : ZERO; if (unsigned) { if (value < 0) return UZERO; if (value >= TWO_PWR_64_DBL) return MAX_UNSIGNED_VALUE; } else { if (value <= -TWO_PWR_63_DBL) return MIN_VALUE; if (value + 1 >= TWO_PWR_63_DBL) return MAX_VALUE; } if (value < 0) return fromNumber(-value, unsigned).neg(); return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned); } Long$1.fromNumber = fromNumber; function fromBits(lowBits, highBits, unsigned) { return new Long$1(lowBits, highBits, unsigned); } Long$1.fromBits = fromBits; var pow_dbl = Math.pow; function fromString(str, unsigned, radix) { if (str.length === 0) throw Error('empty string'); if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity") return ZERO; if (typeof unsigned === 'number') { radix = unsigned, unsigned = false; } else { unsigned = !! unsigned; } radix = radix || 10; if (radix < 2 || 36 < radix) throw RangeError('radix'); var p; if ((p = str.indexOf('-')) > 0) throw Error('interior hyphen'); else if (p === 0) { return fromString(str.substring(1), unsigned, radix).neg(); } var radixToPower = fromNumber(pow_dbl(radix, 8)); var result = ZERO; for (var i = 0; i < str.length; i += 8) { var size = Math.min(8, str.length - i), value = parseInt(str.substring(i, i + size), radix); if (size < 8) { var power = fromNumber(pow_dbl(radix, size)); result = result.mul(power).add(fromNumber(value)); } else { result = result.mul(radixToPower); result = result.add(fromNumber(value)); } } result.unsigned = unsigned; return result; } Long$1.fromString = fromString; function fromValue(val, unsigned) { if (typeof val === 'number') return fromNumber(val, unsigned); if (typeof val === 'string') return fromString(val, unsigned); return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned); } Long$1.fromValue = fromValue; var TWO_PWR_16_DBL = 1 << 16; var TWO_PWR_24_DBL = 1 << 24; var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL; var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL; var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2; var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL); var ZERO = fromInt(0); Long$1.ZERO = ZERO; var UZERO = fromInt(0, true); Long$1.UZERO = UZERO; var ONE = fromInt(1); Long$1.ONE = ONE; var UONE = fromInt(1, true); Long$1.UONE = UONE; var NEG_ONE = fromInt(-1); Long$1.NEG_ONE = NEG_ONE; var MAX_VALUE = fromBits(0xFFFFFFFF|0, 0x7FFFFFFF|0, false); Long$1.MAX_VALUE = MAX_VALUE; var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF|0, 0xFFFFFFFF|0, true); Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE; var MIN_VALUE = fromBits(0, 0x80000000|0, false); Long$1.MIN_VALUE = MIN_VALUE; var LongPrototype = Long$1.prototype; LongPrototype.toInt = function toInt() { return this.unsigned ? this.low >>> 0 : this.low; }; LongPrototype.toNumber = function toNumber() { if (this.unsigned) return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0); return this.high * TWO_PWR_32_DBL + (this.low >>> 0); }; LongPrototype.toString = function toString(radix) { radix = radix || 10; if (radix < 2 || 36 < radix) throw RangeError('radix'); if (this.isZero()) return '0'; if (this.isNegative()) { if (this.eq(MIN_VALUE)) { var radixLong = fromNumber(radix), div = this.div(radixLong), rem1 = div.mul(radixLong).sub(this); return div.toString(radix) + rem1.toInt().toString(radix); } else return '-' + this.neg().toString(radix); } var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned), rem = this; var result = ''; while (true) { var remDiv = rem.div(radixToPower), intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0, digits = intval.toString(radix); rem = remDiv; if (rem.isZero()) return digits + result; else { while (digits.length < 6) digits = '0' + digits; result = '' + digits + result; } } }; LongPrototype.getHighBits = function getHighBits() { return this.high; }; LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() { return this.high >>> 0; }; LongPrototype.getLowBits = function getLowBits() { return this.low; }; LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() { return this.low >>> 0; }; LongPrototype.getNumBitsAbs = function getNumBitsAbs() { if (this.isNegative()) return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs(); var val = this.high != 0 ? this.high : this.low; for (var bit = 31; bit > 0; bit--) if ((val & (1 << bit)) != 0) break; return this.high != 0 ? bit + 33 : bit + 1; }; LongPrototype.isZero = function isZero() { return this.high === 0 && this.low === 0; }; LongPrototype.eqz = LongPrototype.isZero; LongPrototype.isNegative = function isNegative() { return !this.unsigned && this.high < 0; }; LongPrototype.isPositive = function isPositive() { return this.unsigned || this.high >= 0; }; LongPrototype.isOdd = function isOdd() { return (this.low & 1) === 1; }; LongPrototype.isEven = function isEven() { return (this.low & 1) === 0; }; LongPrototype.equals = function equals(other) { if (!isLong(other)) other = fromValue(other); if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1) return false; return this.high === other.high && this.low === other.low; }; LongPrototype.eq = LongPrototype.equals; LongPrototype.notEquals = function notEquals(other) { return !this.eq( other); }; LongPrototype.neq = LongPrototype.notEquals; LongPrototype.ne = LongPrototype.notEquals; LongPrototype.lessThan = function lessThan(other) { return this.comp( other) < 0; }; LongPrototype.lt = LongPrototype.lessThan; LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) { return this.comp( other) <= 0; }; LongPrototype.lte = LongPrototype.lessThanOrEqual; LongPrototype.le = LongPrototype.lessThanOrEqual; LongPrototype.greaterThan = function greaterThan(other) { return this.comp( other) > 0; }; LongPrototype.gt = LongPrototype.greaterThan; LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) { return this.comp( other) >= 0; }; LongPrototype.gte = LongPrototype.greaterThanOrEqual; LongPrototype.ge = LongPrototype.greaterThanOrEqual; LongPrototype.compare = function compare(other) { if (!isLong(other)) other = fromValue(other); if (this.eq(other)) return 0; var thisNeg = this.isNegative(), otherNeg = other.isNegative(); if (thisNeg && !otherNeg) return -1; if (!thisNeg && otherNeg) return 1; if (!this.unsigned) return this.sub(other).isNegative() ? -1 : 1; return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1; }; LongPrototype.comp = LongPrototype.compare; LongPrototype.negate = function negate() { if (!this.unsigned && this.eq(MIN_VALUE)) return MIN_VALUE; return this.not().add(ONE); }; LongPrototype.neg = LongPrototype.negate; LongPrototype.add = function add(addend) { if (!isLong(addend)) addend = fromValue(addend); var a48 = this.high >>> 16; var a32 = this.high & 0xFFFF; var a16 = this.low >>> 16; var a00 = this.low & 0xFFFF; var b48 = addend.high >>> 16; var b32 = addend.high & 0xFFFF; var b16 = addend.low >>> 16; var b00 = addend.low & 0xFFFF; var c48 = 0, c32 = 0, c16 = 0, c00 = 0; c00 += a00 + b00; c16 += c00 >>> 16; c00 &= 0xFFFF; c16 += a16 + b16; c32 += c16 >>> 16; c16 &= 0xFFFF; c32 += a32 + b32; c48 += c32 >>> 16; c32 &= 0xFFFF; c48 += a48 + b48; c48 &= 0xFFFF; return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned); }; LongPrototype.subtract = function subtract(subtrahend) { if (!isLong(subtrahend)) subtrahend = fromValue(subtrahend); return this.add(subtrahend.neg()); }; LongPrototype.sub = LongPrototype.subtract; LongPrototype.multiply = function multiply(multiplier) { if (this.isZero()) return ZERO; if (!isLong(multiplier)) multiplier = fromValue(multiplier); if (wasm) { var low = wasm.mul(this.low, this.high, multiplier.low, multiplier.high); return fromBits(low, wasm.get_high(), this.unsigned); } if (multiplier.isZero()) return ZERO; if (this.eq(MIN_VALUE)) return multiplier.isOdd() ? MIN_VALUE : ZERO; if (multiplier.eq(MIN_VALUE)) return this.isOdd() ? MIN_VALUE : ZERO; if (this.isNegative()) { if (multiplier.isNegative()) return this.neg().mul(multiplier.neg()); else return this.neg().mul(multiplier).neg(); } else if (multiplier.isNegative()) return this.mul(multiplier.neg()).neg(); if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24)) return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned); var a48 = this.high >>> 16; var a32 = this.high & 0xFFFF; var a16 = this.low >>> 16; var a00 = this.low & 0xFFFF; var b48 = multiplier.high >>> 16; var b32 = multiplier.high & 0xFFFF; var b16 = multiplier.low >>> 16; var b00 = multiplier.low & 0xFFFF; var c48 = 0, c32 = 0, c16 = 0, c00 = 0; c00 += a00 * b00; c16 += c00 >>> 16; c00 &= 0xFFFF; c16 += a16 * b00; c32 += c16 >>> 16; c16 &= 0xFFFF; c16 += a00 * b16; c32 += c16 >>> 16; c16 &= 0xFFFF; c32 += a32 * b00; c48 += c32 >>> 16; c32 &= 0xFFFF; c32 += a16 * b16; c48 += c32 >>> 16; c32 &= 0xFFFF; c32 += a00 * b32; c48 += c32 >>> 16; c32 &= 0xFFFF; c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48; c48 &= 0xFFFF; return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned); }; LongPrototype.mul = LongPrototype.multiply; LongPrototype.divide = function divide(divisor) { if (!isLong(divisor)) divisor = fromValue(divisor); if (divisor.isZero()) throw Error('division by zero'); if (wasm) { if (!this.unsigned && this.high === -2147483648 && divisor.low === -1 && divisor.high === -1) { return this; } var low = (this.unsigned ? wasm.div_u : wasm.div_s)( this.low, this.high, divisor.low, divisor.high ); return fromBits(low, wasm.get_high(), this.unsigned); } if (this.isZero()) return this.unsigned ? UZERO : ZERO; var approx, rem, res; if (!this.unsigned) { if (this.eq(MIN_VALUE)) { if (divisor.eq(ONE) || divisor.eq(NEG_ONE)) return MIN_VALUE; else if (divisor.eq(MIN_VALUE)) return ONE; else { var halfThis = this.shr(1); approx = halfThis.div(divisor).shl(1); if (approx.eq(ZERO)) { return divisor.isNegative() ? ONE : NEG_ONE; } else { rem = this.sub(divisor.mul(approx)); res = approx.add(rem.div(divisor)); return res; } } } else if (divisor.eq(MIN_VALUE)) return this.unsigned ? UZERO : ZERO; if (this.isNegative()) { if (divisor.isNegative()) return this.neg().div(divisor.neg()); return this.neg().div(divisor).neg(); } else if (divisor.isNegative()) return this.div(divisor.neg()).neg(); res = ZERO; } else { if (!divisor.unsigned) divisor = divisor.toUnsigned(); if (divisor.gt(this)) return UZERO; if (divisor.gt(this.shru(1))) return UONE; res = UZERO; } rem = this; while (rem.gte(divisor)) { approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber())); var log2 = Math.ceil(Math.log(approx) / Math.LN2), delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48), approxRes = fromNumber(approx), approxRem = approxRes.mul(divisor); while (approxRem.isNegative() || approxRem.gt(rem)) { approx -= delta; approxRes = fromNumber(approx, this.unsigned); approxRem = approxRes.mul(divisor); } if (approxRes.isZero()) approxRes = ONE; res = res.add(approxRes); rem = rem.sub(approxRem); } return res; }; LongPrototype.div = LongPrototype.divide; LongPrototype.modulo = function modulo(divisor) { if (!isLong(divisor)) divisor = fromValue(divisor); if (wasm) { var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)( this.low, this.high, divisor.low, divisor.high ); return fromBits(low, wasm.get_high(), this.unsigned); } return this.sub(this.div(divisor).mul(divisor)); }; LongPrototype.mod = LongPrototype.modulo; LongPrototype.rem = LongPrototype.modulo; LongPrototype.not = function not() { return fromBits(~this.low, ~this.high, this.unsigned); }; LongPrototype.and = function and(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low & other.low, this.high & other.high, this.unsigned); }; LongPrototype.or = function or(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low | other.low, this.high | other.high, this.unsigned); }; LongPrototype.xor = function xor(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned); }; LongPrototype.shiftLeft = function shiftLeft(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); if ((numBits &= 63) === 0) return this; else if (numBits < 32) return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned); else return fromBits(0, this.low << (numBits - 32), this.unsigned); }; LongPrototype.shl = LongPrototype.shiftLeft; LongPrototype.shiftRight = function shiftRight(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); if ((numBits &= 63) === 0) return this; else if (numBits < 32) return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned); else return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned); }; LongPrototype.shr = LongPrototype.shiftRight; LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); numBits &= 63; if (numBits === 0) return this; else { var high = this.high; if (numBits < 32) { var low = this.low; return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned); } else if (numBits === 32) return fromBits(high, 0, this.unsigned); else return fromBits(high >>> (numBits - 32), 0, this.unsigned); } }; LongPrototype.shru = LongPrototype.shiftRightUnsigned; LongPrototype.shr_u = LongPrototype.shiftRightUnsigned; LongPrototype.toSigned = function toSigned() { if (!this.unsigned) return this; return fromBits(this.low, this.high, false); }; LongPrototype.toUnsigned = function toUnsigned() { if (this.unsigned) return this; return fromBits(this.low, this.high, true); }; LongPrototype.toBytes = function toBytes(le) { return le ? this.toBytesLE() : this.toBytesBE(); }; LongPrototype.toBytesLE = function toBytesLE() { var hi = this.high, lo = this.low; return [ lo & 0xff, lo >>> 8 & 0xff, lo >>> 16 & 0xff, lo >>> 24 , hi & 0xff, hi >>> 8 & 0xff, hi >>> 16 & 0xff, hi >>> 24 ]; }; LongPrototype.toBytesBE = function toBytesBE() { var hi = this.high, lo = this.low; return [ hi >>> 24 , hi >>> 16 & 0xff, hi >>> 8 & 0xff, hi & 0xff, lo >>> 24 , lo >>> 16 & 0xff, lo >>> 8 & 0xff, lo & 0xff ]; }; Long$1.fromBytes = function fromBytes(bytes, unsigned, le) { return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned); }; Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) { return new Long$1( bytes[0] | bytes[1] << 8 | bytes[2] << 16 | bytes[3] << 24, bytes[4] | bytes[5] << 8 | bytes[6] << 16 | bytes[7] << 24, unsigned ); }; Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) { return new Long$1( bytes[4] << 24 | bytes[5] << 16 | bytes[6] << 8 | bytes[7], bytes[0] << 24 | bytes[1] << 16 | bytes[2] << 8 | bytes[3], unsigned ); }; var long$1 = getDefaultExportFromCjs(long); var LongExports = _mergeNamespaces({ __proto__: null, default: long$1 }, [long]); const Long = long$1 || LongExports; function hexToLong(hex) { return Long.fromString(hex, true, 16); } const k0 = hexToLong('c3a5c85c97cb3127'); const k1 = hexToLong('b492b66fbe98f273'); const k2 = hexToLong('9ae16a3b2f90404f'); function shiftMix(val) { return val.xor(val.shru(47)); } function fetch(s, offset, numBytes) { const bytes = s.slice(offset, offset + numBytes); return Long.fromBytes(Array.from(bytes), true, true); } function fetch64(s, offset) { return fetch(s, offset, 8); } function fetch32(s, offset) { return fetch(s, offset, 4); } function rotate64(val, shift) { return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift)); } function hashLen16(u, v, mul = hexToLong('9ddfea08eb382d69')) { let a = u.xor(v).mul(mul); a = a.xor(a.shru(47)); let b = v.xor(a).mul(mul); b = b.xor(b.shru(47)); b = b.mul(mul); return b; } function weakHashLen32WithSeeds(w, x, y, z, a, b) { a = a.add(w); b = rotate64(b.add(a).add(z), 21); const c = a; a = a.add(x); a = a.add(y); b = b.add(rotate64(a, 44)); return [a.add(z), b.add(c)]; } function weakHashLen32WithSeedsStr(s, offset, a, b) { return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b); } function hashLen0to16(s, len = s.length) { if (len >= 8) { const mul = k2.add(len * 2); const a = fetch64(s, 0).add(k2); const b = fetch64(s, len - 8); const c = rotate64(b, 37).mul(mul).add(a); const d = rotate64(a, 25).add(b).mul(mul); return hashLen16(c, d, mul); } if (len >= 4) { const mul = k2.add(len * 2); const a = fetch32(s, 0); return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul); } if (len > 0) { const a = s[0]; const b = s[len >> 1]; const c = s[len - 1]; const y = a + (b << 8); const z = len + (c << 2); return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2); } return k2; } function hashLen17to32(s, len = s.length) { const mul = k2.add(len * 2); const a = fetch64(s, 0).mul(k1); const b = fetch64(s, 8); const c = fetch64(s, len - 8).mul(mul); const d = fetch64(s, len - 16).mul(k2); return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul); } function hashLen33to64(s, len = s.length) { const mul = k2.add(len * 2); const a = fetch64(s, 0).mul(k2); const b = fetch64(s, 8); const c = fetch64(s, len - 8).mul(mul); const d = fetch64(s, len - 16).mul(k2); const y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d); const z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul); const e = fetch64(s, 16).mul(mul); const f = fetch64(s, 24); const g = y.add(fetch64(s, len - 32)).mul(mul); const h = z.add(fetch64(s, len - 24)).mul(mul); return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul); } function fingerPrint64(s, len = s.length) { const seed = Long.fromNumber(81, true); if (len <= 32) { if (len <= 16) { return hashLen0to16(s, len); } else { return hashLen17to32(s, len); } } else if (len <= 64) { return hashLen33to64(s, len); } let x = seed; let y = seed.mul(k1).add(113); let z = shiftMix(y.mul(k2).add(113)).mul(k2); let v = [Long.UZERO, Long.UZERO]; let w = [Long.UZERO, Long.UZERO]; x = x.mul(k2).add(fetch64(s, 0)); let offset = 0; const end = ((len - 1) >> 6) * 64; const last64 = end + ((len - 1) & 63) - 63; do { x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1); y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1); x = x.xor(w[1]); y = y.add(v[0]).add(fetch64(s, offset + 40)); z = rotate64(z.add(w[0]), 33).mul(k1); v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0])); w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16))); [z, x] = [x, z]; offset += 64; } while (offset !== end); const mul = k1.add(z.and(0xff).shl(1)); offset = last64; w[0] = w[0].add((len - 1) & 63); v[0] = v[0].add(w[0]); w[0] = w[0].add(v[0]); x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul); y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul); x = x.xor(w[1].mul(9)); y = y.add(v[0].mul(9).add(fetch64(s, offset + 40))); z = rotate64(z.add(w[0]), 33).mul(mul); v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0])); w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16))); [z, x] = [x, z]; return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul); } function createScalarValue(value, dtype) { if (dtype === 'string') { return encodeString(value); } return toTypedArray([value], dtype); } function noConversionNeeded(a, dtype) { return (a instanceof Float32Array && dtype === 'float32') || (a instanceof Int32Array && dtype === 'int32') || (a instanceof Uint8Array && dtype === 'bool'); } function toTypedArray(a, dtype) { if (dtype === 'string') { throw new Error('Cannot convert a string[] to a TypedArray'); } if (Array.isArray(a)) { a = flatten$1(a); } if (env().getBool('DEBUG')) { checkConversionForErrors(a, dtype); } if (noConversionNeeded(a, dtype)) { return a; } if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(a); } else if (dtype === 'int32') { return new Int32Array(a); } else if (dtype === 'bool') { const bool = new Uint8Array(a.length); for (let i = 0; i < bool.length; ++i) { if (Math.round(a[i]) !== 0) { bool[i] = 1; } } return bool; } else { throw new Error(`Unknown data type ${dtype}`); } } function now() { return env().platform.now(); } function encodeString(s, encoding = 'utf-8') { encoding = encoding || 'utf-8'; return env().platform.encode(s, encoding); } function decodeString(bytes, encoding = 'utf-8') { encoding = encoding || 'utf-8'; return env().platform.decode(bytes, encoding); } function isTypedArray(a) { if (env().platform.isTypedArray != null) { return env().platform.isTypedArray(a); } else { return isTypedArrayBrowser(a); } } function flatten$1(arr, result = [], skipTypedArray = false) { if (result == null) { result = []; } if (typeof arr === 'boolean' || typeof arr === 'number' || typeof arr === 'string' || isPromise(arr) || arr == null || isTypedArray(arr) && skipTypedArray) { result.push(arr); } else if (Array.isArray(arr) || isTypedArray(arr)) { for (let i = 0; i < arr.length; ++i) { flatten$1(arr[i], result, skipTypedArray); } } else { let maxIndex = -1; for (const key of Object.keys(arr)) { if (/^([1-9]+[0-9]*|0)$/.test(key)) { maxIndex = Math.max(maxIndex, Number(key)); } } for (let i = 0; i <= maxIndex; i++) { flatten$1(arr[i], result, skipTypedArray); } } return result; } class Profiler { constructor(backendTimer, logger) { this.backendTimer = backendTimer; this.logger = logger; if (logger == null) { this.logger = new Logger(); } } profileKernel(kernelName, inputs, f) { let outputs; const holdResultWrapperFn = () => { outputs = f(); }; let timer; const start = now(); if (this.backendTimer.timerAvailable()) { timer = this.backendTimer.time(holdResultWrapperFn); } else { holdResultWrapperFn(); for (const output of outputs) { output.dataSync(); } timer = Promise.resolve({ kernelMs: now() - start }); } if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) { for (let i = 0; i < outputs.length; i++) { const output = outputs[i]; output.data().then(tensorVals => { checkComputationForErrors(tensorVals, output.dtype, kernelName); }); } } const kernelProfile = { kernelName, outputs, inputs, timeMs: timer.then(timing => timing.kernelMs), extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : '') }; return kernelProfile; } logKernelProfile(kernelProfile) { const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile; outputs.forEach(result => { Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => { this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]); }); }); } } function checkComputationForErrors(vals, dtype, kernelName) { if (dtype !== 'float32') { return false; } for (let i = 0; i < vals.length; i++) { const num = vals[i]; if (isNaN(num) || !isFinite(num)) { console.warn(`Found ${num} in the result of '${kernelName}'`); return true; } } return false; } class Logger { logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) { const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) : timeMs['error']; const paddedName = rightPad(name, 25); const rank = result.rank; const size = result.size; const shape = rightPad(result.shape.toString(), 14); let inputShapesDescription = ''; for (const name in inputs) { const input = inputs[name]; if (input != null) { const inputShape = input.shape || result.shape; const inputRank = inputShape.length; inputShapesDescription += `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `; } } console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); } } function getFilteredNodesXToY(tape, xs, y) { const tensorsFromX = {}; const nodesFromX = {}; for (let i = 0; i < xs.length; i++) { tensorsFromX[xs[i].id] = true; } for (let i = 0; i < tape.length; i++) { const node = tape[i]; const nodeInputs = node.inputs; for (const inputName in nodeInputs) { const input = nodeInputs[inputName]; let anyInputFromX = false; for (let j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { node.outputs.forEach(output => tensorsFromX[output.id] = true); anyInputFromX = true; nodesFromX[node.id] = true; break; } } if (anyInputFromX) { break; } } } const tensorsLeadToY = {}; tensorsLeadToY[y.id] = true; const nodesToY = {}; for (let i = tape.length - 1; i >= 0; i--) { const node = tape[i]; const nodeInputs = node.inputs; for (let j = 0; j < node.outputs.length; j++) { if (tensorsLeadToY[node.outputs[j].id]) { for (const inputName in nodeInputs) { tensorsLeadToY[nodeInputs[inputName].id] = true; nodesToY[node.id] = true; } break; } } } const filteredTape = []; for (let i = 0; i < tape.length; i++) { const node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { const prunedInputs = {}; for (const inputName in node.inputs) { const nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } const prunedNode = Object.assign({}, node); prunedNode.inputs = prunedInputs; prunedNode.outputs = node.outputs; filteredTape.push(prunedNode); } } return filteredTape; } function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) { for (let i = filteredTape.length - 1; i >= 0; i--) { const node = filteredTape[i]; const dys = []; node.outputs.forEach(o => { const gradTensor = tensorAccumulatedGradientMap[o.id]; if (gradTensor != null) { dys.push(gradTensor); } else { dys.push(null); } }); if (node.gradient == null) { throw new Error(`Cannot compute gradient: gradient function not found ` + `for ${node.kernelName}.`); } const inputGradients = node.gradient(dys); for (const inputName in node.inputs) { if (!(inputName in inputGradients)) { throw new Error(`Cannot backprop through input ${inputName}. ` + `Available gradients found: ${Object.keys(inputGradients)}.`); } const dx = tidy(() => inputGradients[inputName]()); if (dx.dtype !== 'float32') { throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` + `${inputName} must have 'float32' dtype, but has '${dx.dtype}'`); } const x = node.inputs[inputName]; if (!arraysEqual(dx.shape, x.shape)) { throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` + `'${inputName}' has shape '${dx.shape}', which does not match ` + `the shape of the input '${x.shape}'`); } if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; } else { const curGradient = tensorAccumulatedGradientMap[x.id]; tensorAccumulatedGradientMap[x.id] = add(curGradient, dx); curGradient.dispose(); } } } } const FORMAT_LIMIT_NUM_VALS = 20; const FORMAT_NUM_FIRST_LAST_VALS = 3; const FORMAT_NUM_SIG_DIGITS = 7; function tensorToString(vals, shape, dtype, verbose) { const strides = computeStrides(shape); const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); const rank = shape.length; const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); const lines = ['Tensor']; if (verbose) { lines.push(` dtype: ${dtype}`); lines.push(` rank: ${rank}`); lines.push(` shape: [${shape}]`); lines.push(` values:`); } lines.push(valsLines.map(l => ' ' + l).join('\n')); return lines.join('\n'); } function computeMaxSizePerColumn(vals, shape, dtype, strides) { const n = sizeFromShape(shape); const numCols = strides[strides.length - 1]; const padPerCol = new Array(numCols).fill(0); const rank = shape.length; const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; if (rank > 1) { for (let row = 0; row < n / numCols; row++) { const offset = row * numCols; for (let j = 0; j < numCols; j++) { padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); } } } return padPerCol; } function valToString(val, pad, dtype) { let valStr; if (Array.isArray(val)) { valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` + `${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`; } else if (isString(val)) { valStr = `'${val}'`; } else if (dtype === 'bool') { valStr = boolNumToString(val); } else { valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); } return rightPad(valStr, pad); } function boolNumToString(v) { return v === 0 ? 'false' : 'true'; } function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) { const storagePerElement = dtype === 'complex64' ? 2 : 1; const size = shape[0]; const rank = shape.length; if (rank === 0) { if (dtype === 'complex64') { const complexTuple = createComplexTuples(vals); return [valToString(complexTuple[0], 0, dtype)]; } if (dtype === 'bool') { return [boolNumToString(vals[0])]; } return [vals[0].toString()]; } if (rank === 1) { if (size > FORMAT_LIMIT_NUM_VALS) { const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; let firstVals = Array.from(vals.slice(0, firstValsSize)); let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); if (dtype === 'complex64') { firstVals = createComplexTuples(firstVals); lastVals = createComplexTuples(lastVals); } return [ '[' + firstVals.map((x, i) => valToString(x, padPerCol[i], dtype)) .join(', ') + ', ..., ' + lastVals .map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype)) .join(', ') + ']' ]; } const displayVals = dtype === 'complex64' ? createComplexTuples(vals) : Array.from(vals); return [ '[' + displayVals.map((x, i) => valToString(x, padPerCol[i], dtype)) .join(', ') + ']' ]; } const subshape = shape.slice(1); const substrides = strides.slice(1); const stride = strides[0] * storagePerElement; const lines = []; if (size > FORMAT_LIMIT_NUM_VALS) { for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { const start = i * stride; const end = start + stride; lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false )); } lines.push('...'); for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { const start = i * stride; const end = start + stride; lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 )); } } else { for (let i = 0; i < size; i++) { const start = i * stride; const end = start + stride; lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 )); } } const sep = rank === 2 ? ',' : ''; lines[0] = '[' + (size > 0 ? lines[0] + sep : ''); for (let i = 1; i < lines.length - 1; i++) { lines[i] = ' ' + lines[i] + sep; } let newLineSep = ',\n'; for (let i = 2; i < rank; i++) { newLineSep += '\n'; } lines[lines.length - 1] = ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); return lines; } function createComplexTuples(vals) { const complexTuples = []; for (let i = 0; i < vals.length; i += 2) { complexTuples.push([vals[i], vals[i + 1]]); } return complexTuples; } class TensorBuffer { constructor(shape, dtype, values) { this.dtype = dtype; this.shape = shape.slice(); this.size = sizeFromShape(shape); if (values != null) { const n = values.length; assert$1(n === this.size, () => `Length of values '${n}' does not match the size ` + `inferred by the shape '${this.size}'.`); } if (dtype === 'complex64') { throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` + `a TensorBuffer for the real and imaginary parts separately and ` + `call tf.complex(real, imag).`); } this.values = values || getArrayFromDType(dtype, this.size); this.strides = computeStrides(shape); } set(value, ...locs) { if (locs.length === 0) { locs = [0]; } assert$1(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` + `match the rank (${this.rank})`); const index = this.locToIndex(locs); this.values[index] = value; } get(...locs) { if (locs.length === 0) { locs = [0]; } let i = 0; for (const loc of locs) { if (loc < 0 || loc >= this.shape[i]) { const msg = `Requested out of range element at ${locs}. ` + ` Buffer shape=${this.shape}`; throw new Error(msg); } i++; } let index = locs[locs.length - 1]; for (let i = 0; i < locs.length - 1; ++i) { index += this.strides[i] * locs[i]; } return this.values[index]; } locToIndex(locs) { if (this.rank === 0) { return 0; } else if (this.rank === 1) { return locs[0]; } let index = locs[locs.length - 1]; for (let i = 0; i < locs.length - 1; ++i) { index += this.strides[i] * locs[i]; } return index; } indexToLoc(index) { if (this.rank === 0) { return []; } else if (this.rank === 1) { return [index]; } const locs = new Array(this.shape.length); for (let i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / this.strides[i]); index -= locs[i] * this.strides[i]; } locs[locs.length - 1] = index; return locs; } get rank() { return this.shape.length; } toTensor() { return trackerFn().makeTensor(this.values, this.shape, this.dtype); } } let trackerFn = null; let opHandler$1 = null; function setTensorTracker(fn) { trackerFn = fn; } function setOpHandler(handler) { opHandler$1 = handler; } class Tensor { constructor(shape, dtype, dataId, id) { this.kept = false; this.isDisposedInternal = false; this.shape = shape.slice(); this.dtype = dtype || 'float32'; this.size = sizeFromShape(shape); this.strides = computeStrides(shape); this.dataId = dataId; this.id = id; this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); } get rank() { return this.shape.length; } async buffer() { const vals = await this.data(); return opHandler$1.buffer(this.shape, this.dtype, vals); } bufferSync() { return opHandler$1.buffer(this.shape, this.dtype, this.dataSync()); } async array() { const vals = await this.data(); return toNestedArray(this.shape, vals, this.dtype === 'complex64'); } arraySync() { return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64'); } async data() { this.throwIfDisposed(); const data = trackerFn().read(this.dataId); if (this.dtype === 'string') { const bytes = await data; try { return bytes.map(b => decodeString(b)); } catch (_a) { throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().'); } } return data; } dataToGPU(options) { this.throwIfDisposed(); return trackerFn().readToGPU(this.dataId, options); } dataSync() { this.throwIfDisposed(); const data = trackerFn().readSync(this.dataId); if (this.dtype === 'string') { try { return data.map(b => decodeString(b)); } catch (_a) { throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().'); } } return data; } async bytes() { this.throwIfDisposed(); const data = await trackerFn().read(this.dataId); if (this.dtype === 'string') { return data; } else { return new Uint8Array(data.buffer); } } dispose() { if (this.isDisposed) { return; } if (this.kerasMask) { this.kerasMask.dispose(); } trackerFn().disposeTensor(this); this.isDisposedInternal = true; } get isDisposed() { return this.isDisposedInternal; } throwIfDisposed() { if (this.isDisposed) { throw new Error(`Tensor is disposed.`); } } print(verbose = false) { return opHandler$1.print(this, verbose); } clone() { this.throwIfDisposed(); return opHandler$1.clone(this); } toString(verbose = false) { const vals = this.dataSync(); return tensorToString(vals, this.shape, this.dtype, verbose); } cast(dtype) { this.throwIfDisposed(); return opHandler$1.cast(this, dtype); } variable(trainable = true, name, dtype) { this.throwIfDisposed(); return trackerFn().makeVariable(this, trainable, name, dtype); } } Object.defineProperty(Tensor, Symbol.hasInstance, { value: (instance) => { return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null; } }); function getGlobalTensorClass() { return getGlobal('Tensor', () => { return Tensor; }); } getGlobalTensorClass(); class Variable extends Tensor { constructor(initialValue, trainable, name, tensorId) { super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId); this.trainable = trainable; this.name = name; } assign(newValue) { if (newValue.dtype !== this.dtype) { throw new Error(`dtype of the new value (${newValue.dtype}) and ` + `previous value (${this.dtype}) must match`); } if (!arraysEqual(newValue.shape, this.shape)) { throw new Error(`shape of the new value (${newValue.shape}) and ` + `previous value (${this.shape}) must match`); } trackerFn().disposeTensor(this); this.dataId = newValue.dataId; trackerFn().incRef(this, null ); } dispose() { trackerFn().disposeVariable(this); this.isDisposedInternal = true; } } Object.defineProperty(Variable, Symbol.hasInstance, { value: (instance) => { return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function; } }); var Rank; (function (Rank) { Rank["R0"] = "R0"; Rank["R1"] = "R1"; Rank["R2"] = "R2"; Rank["R3"] = "R3"; Rank["R4"] = "R4"; Rank["R5"] = "R5"; Rank["R6"] = "R6"; })(Rank || (Rank = {})); var UpcastInt32AndMap; (function (UpcastInt32AndMap) { UpcastInt32AndMap["float32"] = "float32"; UpcastInt32AndMap["int32"] = "int32"; UpcastInt32AndMap["bool"] = "int32"; UpcastInt32AndMap["complex64"] = "complex64"; })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); var UpcastBoolAndMap; (function (UpcastBoolAndMap) { UpcastBoolAndMap["float32"] = "float32"; UpcastBoolAndMap["int32"] = "int32"; UpcastBoolAndMap["bool"] = "bool"; UpcastBoolAndMap["complex64"] = "complex64"; })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); var UpcastFloat32AndMap; (function (UpcastFloat32AndMap) { UpcastFloat32AndMap["float32"] = "float32"; UpcastFloat32AndMap["int32"] = "float32"; UpcastFloat32AndMap["bool"] = "float32"; UpcastFloat32AndMap["complex64"] = "complex64"; })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); var UpcastComplex64AndMap; (function (UpcastComplex64AndMap) { UpcastComplex64AndMap["float32"] = "complex64"; UpcastComplex64AndMap["int32"] = "complex64"; UpcastComplex64AndMap["bool"] = "complex64"; UpcastComplex64AndMap["complex64"] = "complex64"; })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); const upcastTypeMap = { 'float32': UpcastFloat32AndMap, 'int32': UpcastInt32AndMap, 'bool': UpcastBoolAndMap, 'complex64': UpcastComplex64AndMap }; function upcastType(typeA, typeB) { if (typeA === 'string' || typeB === 'string') { if (typeA === 'string' && typeB === 'string') { return 'string'; } throw new Error(`Can not upcast ${typeA} with ${typeB}`); } return upcastTypeMap[typeA][typeB]; } function sumOutType(type) { return upcastType(type, 'int32'); } function isWebGLData(values) { return values != null && typeof values === 'object' && 'texture' in values && values.texture instanceof WebGLTexture; } function isWebGPUData(values) { return typeof GPUBuffer !== 'undefined' && values != null && typeof values === 'object' && 'buffer' in values && values.buffer instanceof GPUBuffer; } function makeTypesMatch(a, b) { if (a.dtype === b.dtype) { return [a, b]; } const dtype = upcastType(a.dtype, b.dtype); return [a.cast(dtype), b.cast(dtype)]; } function getTensorsInContainer(result) { const list = []; const seen = new Set(); walkTensorContainer(result, list, seen); return list; } function walkTensorContainer(container, list, seen) { if (container == null) { return; } if (container instanceof Tensor) { list.push(container); return; } if (!isIterable(container)) { return; } const iterable = container; for (const k in iterable) { const val = iterable[k]; if (!seen.has(val)) { seen.add(val); walkTensorContainer(val, list, seen); } } } function isIterable(obj) { return Array.isArray(obj) || typeof obj === 'object'; } function isRegisteredKernelInvocation(kernelInvocation) { return kernelInvocation.kernelName != null; } class EngineState { constructor() { this.registeredVariables = {}; this.nextTapeNodeId = 0; this.numBytes = 0; this.numTensors = 0; this.numStringTensors = 0; this.numDataBuffers = 0; this.gradientDepth = 0; this.kernelDepth = 0; this.scopeStack = []; this.numDataMovesStack = []; this.nextScopeId = 0; this.tensorInfo = new WeakMap(); this.profiling = false; this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null, get kernelNames() { return Array.from(new Set(this.kernels.map(k => k.name))); } }; } dispose() { for (const variableName in this.registeredVariables) { this.registeredVariables[variableName].dispose(); } } } class Engine { constructor(ENV) { this.ENV = ENV; this.registry = {}; this.registryFactory = {}; this.pendingBackendInitId = 0; this.state = new EngineState(); } async ready() { if (this.pendingBackendInit != null) { return this.pendingBackendInit.then(() => { }); } if (this.backendInstance != null) { return; } const sortedBackends = this.getSortedBackends(); for (let i = 0; i < sortedBackends.length; i++) { const backendName = sortedBackends[i]; const success = await this.initializeBackend(backendName).success; if (success) { await this.setBackend(backendName); return; } } throw new Error(`Could not initialize any backends, all backend initializations ` + `failed.`); } get backend() { if (this.pendingBackendInit != null) { throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` + `sure to await tf.ready() or await tf.setBackend() before calling ` + `other methods`); } if (this.backendInstance == null) { const { name, asyncInit } = this.initializeBackendsAndReturnBest(); if (asyncInit) { throw new Error(`The highest priority backend '${name}' has not yet been ` + `initialized. Make sure to await tf.ready() or ` + `await tf.setBackend() before calling other methods`); } this.setBackend(name); } return this.backendInstance; } backendNames() { return Object.keys(this.registryFactory); } findBackend(backendName) { if (!(backendName in this.registry)) { if (backendName in this.registryFactory) { const { asyncInit } = this.initializeBackend(backendName); if (asyncInit) { return null; } } else { return null; } } return this.registry[backendName]; } findBackendFactory(backendName) { if (!(backendName in this.registryFactory)) { return null; } return this.registryFactory[backendName].factory; } registerBackend(backendName, factory, priority = 1) { if (backendName in this.registryFactory) { warn(`${backendName} backend was already registered. ` + `Reusing existing backend factory.`); return false; } this.registryFactory[backendName] = { factory, priority }; return true; } async setBackend(backendName) { if (this.registryFactory[backendName] == null) { throw new Error(`Backend name '${backendName}' not found in registry`); } this.backendName = backendName; if (this.registry[backendName] == null) { this.backendInstance = null; const { success, asyncInit } = this.initializeBackend(backendName); const result = asyncInit ? await success : success; if (!result) { return false; } } this.backendInstance = this.registry[backendName]; this.setupRegisteredKernels(); this.profiler = new Profiler(this.backendInstance); return true; } setupRegisteredKernels() { const kernels = getKernelsForBackend(this.backendName); kernels.forEach(kernel => { if (kernel.setupFunc != null) { kernel.setupFunc(this.backendInstance); } }); } disposeRegisteredKernels(backendName) { const kernels = getKernelsForBackend(backendName); kernels.forEach(kernel => { if (kernel.disposeFunc != null) { kernel.disposeFunc(this.registry[backendName]); } }); } initializeBackend(backendName) { const registryFactoryEntry = this.registryFactory[backendName]; if (registryFactoryEntry == null) { throw new Error(`Cannot initialize backend ${backendName}, no registration found.`); } try { const backend = registryFactoryEntry.factory(); if (backend && !(backend instanceof KernelBackend) && typeof backend.then === 'function') { const promiseId = ++this.pendingBackendInitId; const success = backend .then(backendInstance => { if (promiseId < this.pendingBackendInitId) { return false; } this.registry[backendName] = backendInstance; this.pendingBackendInit = null; return true; }) .catch(err => { if (promiseId < this.pendingBackendInitId) { return false; } this.pendingBackendInit = null; warn(`Initialization of backend ${backendName} failed`); warn(err.stack || err.message); return false; }); this.pendingBackendInit = success; return { success, asyncInit: true }; } else { this.registry[backendName] = backend; return { success: true, asyncInit: false }; } } catch (err) { warn(`Initialization of backend ${backendName} failed`); warn(err.stack || err.message); return { success: false, asyncInit: false }; } } removeBackend(backendName) { if (!(backendName in this.registryFactory)) { throw new Error(`${backendName} backend not found in registry`); } if (this.backendName === backendName && this.pendingBackendInit != null) { this.pendingBackendInitId++; } if (backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } delete this.registryFactory[backendName]; if (this.backendName === backendName) { this.pendingBackendInit = null; this.backendName = null; this.backendInstance = null; } } getSortedBackends() { if (Object.keys(this.registryFactory).length === 0) { throw new Error('No backend found in registry.'); } return Object.keys(this.registryFactory).sort((a, b) => { return this.registryFactory[b].priority - this.registryFactory[a].priority; }); } initializeBackendsAndReturnBest() { const sortedBackends = this.getSortedBackends(); for (let i = 0; i < sortedBackends.length; i++) { const backendName = sortedBackends[i]; const { success, asyncInit } = this.initializeBackend(backendName); if (asyncInit || success) { return { name: backendName, asyncInit }; } } throw new Error(`Could not initialize any backends, all backend initializations ` + `failed.`); } moveData(backend, dataId) { const info = this.state.tensorInfo.get(dataId); const srcBackend = info.backend; const values = this.readSync(dataId); const refCount = srcBackend.refCount(dataId); srcBackend.disposeData(dataId, true); info.backend = backend; backend.move(dataId, values, info.shape, info.dtype, refCount); if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; } } tidy(nameOrFn, fn) { let name = null; if (fn == null) { if (typeof nameOrFn !== 'function') { throw new Error('Please provide a function to tidy()'); } fn = nameOrFn; } else { if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { throw new Error('When calling with two arguments, the first argument ' + 'to tidy() must be a string'); } if (typeof fn !== 'function') { throw new Error('When calling with two arguments, the 2nd argument ' + 'to tidy() must be a function'); } name = nameOrFn; } let result; return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => { result = fn(); if (result instanceof Promise) { console.error('Cannot return a Promise inside of tidy.'); } return result; }); } scopedRun(start, end, f) { start(); try { const res = f(); end(); return res; } catch (ex) { end(); throw ex; } } nextTensorId() { return Engine.nextTensorId++; } nextVariableId() { return Engine.nextVariableId++; } clone(x) { const y = ENGINE.runKernel(Identity$1, { x }); const inputs = { x }; const grad = (dy) => ({ x: () => { const dtype = 'float32'; const gradInputs = { x: dy }; const attrs = { dtype }; return ENGINE.runKernel(Cast, gradInputs, attrs); } }); const saved = []; this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; } runKernel(kernelName, inputs, attrs) { const hasKernel = getKernel(kernelName, this.backendName) != null; if (!hasKernel) { throw new Error(`Kernel '${kernelName}' not registered for backend '${this.backendName}'`); } return this.runKernelFunc({ kernelName, inputs, attrs }); } shouldCheckForMemLeaks() { return this.ENV.getBool('IS_TEST'); } checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) { const numDataIdsAfter = this.backend.numDataIds(); let numOutputDataIds = 0; outInfos.forEach(info => { numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); }); const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; if (dataIdsLeaked > 0) { throw new Error(`Backend '${this.backendName}' has an internal memory leak ` + `(${dataIdsLeaked} data ids) after running '${kernelName}'`); } } runKernelFunc(kernelParams) { let outputs; let saved = []; const isTapeOn = this.isTapeOn(); const startingBytecount = this.state.numBytes; const startingNumTensors = this.state.numTensors; if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack.push(0); } let kernelFunc; let out; const kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ? kernelParams.kernelName : this.state.activeScope != null ? this.state.activeScope.name : ''; if (isRegisteredKernelInvocation(kernelParams)) { const { kernelName, inputs, attrs } = kernelParams; const kernel = getKernel(kernelName, this.backendName); assert$1(kernel != null, () => `Cannot find registered kernel '${kernelName}' for backend '${this.backendName}'`); kernelFunc = () => { const numDataIdsBefore = this.backend.numDataIds(); out = kernel.kernelFunc({ inputs, attrs, backend: this.backend }); const outInfos = Array.isArray(out) ? out : [out]; if (this.shouldCheckForMemLeaks()) { this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos); } const outTensors = outInfos.map((outInfo) => { if (outInfo.rank != null) { return outInfo; } return this.makeTensorFromTensorInfo(outInfo); }); if (isTapeOn) { const tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors); saved = this.saveTensorsForBackwardMode(tensorsToSave); } return outTensors; }; } else { const { forwardFunc } = kernelParams; const saveFunc = (tensors) => { if (!isTapeOn) { return; } saved = tensors.map(tensor => this.keep(this.clone(tensor))); }; kernelFunc = () => { const numDataIdsBefore = this.backend.numDataIds(); out = this.tidy(() => forwardFunc(this.backend, saveFunc)); const outs = (Array.isArray(out) ? out : [out]); if (this.shouldCheckForMemLeaks()) { this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs); } return outs; }; } const { inputs, attrs } = kernelParams; const backwardsFunc = isRegisteredKernelInvocation(kernelParams) ? null : kernelParams.backwardsFunc; let kernelProfile; this.scopedRun( () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => { if (!this.ENV.getBool('DEBUG') && !this.state.profiling) { outputs = kernelFunc(); } else { kernelProfile = this.profiler.profileKernel(kernelOrScopeName, inputs, () => kernelFunc()); if (this.ENV.getBool('DEBUG')) { this.profiler.logKernelProfile(kernelProfile); } outputs = kernelProfile.outputs; } }); if (isTapeOn) { this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs); } if (this.state.profiling) { this.state.activeProfile.kernels.push({ name: kernelOrScopeName, bytesAdded: this.state.numBytes - startingBytecount, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null), outputShapes: outputs.map(item => item.shape), kernelTimeMs: kernelProfile.timeMs, extraInfo: kernelProfile.extraInfo }); } return (Array.isArray(out) ? outputs : outputs[0]); } saveTensorsForBackwardMode(tensors) { const saved = tensors.map(tensor => this.keep(this.clone(tensor))); return saved; } getTensorsForGradient(kernelName, inputs, outputs) { const gradConfig = getGradient(kernelName); if (gradConfig != null) { const inputsToSave = gradConfig.inputsToSave || []; const outputsToSave = gradConfig.outputsToSave || []; let inputTensorsToSave; if (gradConfig.saveAllInputs) { assert$1(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.'); inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]); } else { inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]); } const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]); return inputTensorsToSave.concat(outputTensorsToSave); } return []; } makeTensor(values, shape, dtype, backend) { if (values == null) { throw new Error('Values passed to engine.makeTensor() are null'); } dtype = dtype || 'float32'; backend = backend || this.backend; let backendVals = values; if (dtype === 'string' && isString(values[0])) { backendVals = values.map(d => encodeString(d)); } const dataId = backend.write(backendVals, shape, dtype); const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.trackTensor(t, backend); if (dtype === 'string') { const info = this.state.tensorInfo.get(dataId); const newBytes = bytesFromStringArray(backendVals); this.state.numBytes += newBytes - info.bytes; info.bytes = newBytes; } return t; } makeTensorFromDataId(dataId, shape, dtype, backend) { dtype = dtype || 'float32'; const tensorInfo = { dataId, shape, dtype }; return this.makeTensorFromTensorInfo(tensorInfo, backend); } makeTensorFromTensorInfo(tensorInfo, backend) { const { dataId, shape, dtype } = tensorInfo; const t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.trackTensor(t, backend); return t; } makeVariable(initialValue, trainable = true, name, dtype) { name = name || this.nextVariableId().toString(); if (dtype != null && dtype !== initialValue.dtype) { initialValue = initialValue.cast(dtype); } const v = new Variable(initialValue, trainable, name, this.nextTensorId()); if (this.state.registeredVariables[v.name] != null) { throw new Error(`Variable with name ${v.name} was already registered`); } this.state.registeredVariables[v.name] = v; this.incRef(v, this.backend); return v; } trackTensor(a, backend) { this.state.numTensors++; if (a.dtype === 'string') { this.state.numStringTensors++; } let bytes = 0; if (a.dtype !== 'complex64' && a.dtype !== 'string') { bytes = a.size * bytesPerElement(a.dtype); } this.state.numBytes += bytes; if (!this.state.tensorInfo.has(a.dataId)) { this.state.numDataBuffers++; this.state.tensorInfo.set(a.dataId, { backend: backend || this.backend, dtype: a.dtype, shape: a.shape, bytes }); } if (!(a instanceof Variable)) { this.track(a); } } incRef(a, backend) { this.trackTensor(a, backend); this.backend.incRef(a.dataId); } removeDataId(dataId, backend) { if (this.state.tensorInfo.has(dataId) && this.state.tensorInfo.get(dataId).backend === backend) { this.state.tensorInfo.delete(dataId); this.state.numDataBuffers--; } } disposeTensor(a) { if (!this.state.tensorInfo.has(a.dataId)) { return; } const info = this.state.tensorInfo.get(a.dataId); this.state.numTensors--; if (a.dtype === 'string') { this.state.numStringTensors--; this.state.numBytes -= info.bytes; } if (a.dtype !== 'complex64' && a.dtype !== 'string') { const bytes = a.size * bytesPerElement(a.dtype); this.state.numBytes -= bytes; } if (info.backend.disposeData(a.dataId)) { this.removeDataId(a.dataId, info.backend); } } disposeVariables() { for (const varName in this.state.registeredVariables) { const v = this.state.registeredVariables[varName]; this.disposeVariable(v); } } disposeVariable(v) { this.disposeTensor(v); if (this.state.registeredVariables[v.name] != null) { delete this.state.registeredVariables[v.name]; } } memory() { const info = this.backend.memory(); info.numTensors = this.state.numTensors; info.numDataBuffers = this.state.numDataBuffers; info.numBytes = this.state.numBytes; if (this.state.numStringTensors > 0) { info.unreliable = true; if (info.reasons == null) { info.reasons = []; } info.reasons.push('Memory usage by string tensors is approximate ' + '(2 bytes per character)'); } return info; } async profile(query) { this.state.profiling = true; const startBytes = this.state.numBytes; const startNumTensors = this.state.numTensors; this.state.activeProfile.kernels = []; this.state.activeProfile.result = await query(); this.state.profiling = false; this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot)); this.state.activeProfile.newBytes = this.state.numBytes - startBytes; this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; for (const kernel of this.state.activeProfile.kernels) { kernel.kernelTimeMs = await kernel.kernelTimeMs; kernel.extraInfo = await kernel.extraInfo; } return this.state.activeProfile; } isTapeOn() { return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; } addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) { const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved }; const gradConfig = getGradient(kernelName); if (gradConfig != null) { gradientsFunc = gradConfig.gradFunc; } if (gradientsFunc != null) { tapeNode.gradient = (dys) => { dys = dys.map((dy, i) => { if (dy == null) { const output = outputs[i]; const vals = makeZerosTypedArray(output.size, output.dtype); return this.makeTensor(vals, output.shape, output.dtype); } return dy; }); return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } this.state.activeTape.push(tapeNode); } keep(result) { result.kept = true; return result; } startTape() { if (this.state.gradientDepth === 0) { this.state.activeTape = []; } this.state.gradientDepth++; } endTape() { this.state.gradientDepth--; } startScope(name) { const scopeInfo = { track: [], name: 'unnamed scope', id: this.state.nextScopeId++ }; if (name) { scopeInfo.name = name; } this.state.scopeStack.push(scopeInfo); this.state.activeScope = scopeInfo; } endScope(result) { const tensorsToTrackInParent = getTensorsInContainer(result); const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id)); for (let i = 0; i < this.state.activeScope.track.length; i++) { const tensor = this.state.activeScope.track[i]; if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { tensor.dispose(); } } const oldScope = this.state.scopeStack.pop(); this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; tensorsToTrackInParent.forEach(tensor => { if (!tensor.kept && tensor.scopeId === oldScope.id) { this.track(tensor); } }); } gradients(f, xs, dy, allowNoGradients = false) { assert$1(xs.length > 0, () => 'gradients() received an empty list of xs.'); if (dy != null && dy.dtype !== 'float32') { throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`); } const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f)); assert$1(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.'); const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + 'that the f you passed encloses all operations that lead from x ' + 'to y.'); } return this.tidy('backward', () => { const accumulatedGradientMap = {}; accumulatedGradientMap[y.id] = (dy == null) ? ones$1(y.shape) : dy; backpropagateGradients(accumulatedGradientMap, filteredTape, f => this.tidy(f), add$1); const grads = xs.map(x => accumulatedGradientMap[x.id]); if (this.state.gradientDepth === 0) { this.state.activeTape.forEach(node => { for (const tensor of node.saved) { tensor.dispose(); } }); this.state.activeTape = null; } return { value: y, grads }; }); } customGrad(f) { assert$1(isFunction(f), () => 'The f passed in customGrad(f) must be a function.'); return (...inputs) => { assert$1(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' + 'tensors'); let res; const inputMap = {}; inputs.forEach((input, i) => { inputMap[i] = input; }); const forwardFunc = (_, save) => { res = f(...[...inputs, save]); assert$1(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.value` is a tensor'); assert$1(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function.'); return res.value; }; const backwardsFunc = (dy, saved) => { const gradRes = res.gradFunc(dy, saved); const grads = Array.isArray(gradRes) ? gradRes : [gradRes]; assert$1(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'the same number of tensors as inputs passed to f(...).'); assert$1(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'a list of only tensors.'); const gradMap = {}; grads.forEach((grad, i) => { gradMap[i] = () => grad; }); return gradMap; }; return this.runKernelFunc({ forwardFunc, backwardsFunc, inputs: inputMap, }); }; } readSync(dataId) { const info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); } read(dataId) { const info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); } readToGPU(dataId, options) { const info = this.state.tensorInfo.get(dataId); return info.backend.readToGPU(dataId, options); } async time(query) { const start = now(); const timingInfo = await this.backend.time(query); timingInfo.wallMs = now() - start; return timingInfo; } track(result) { if (this.state.activeScope != null) { result.scopeId = this.state.activeScope.id; this.state.activeScope.track.push(result); } return result; } get registeredVariables() { return this.state.registeredVariables; } reset() { this.pendingBackendInitId++; this.state.dispose(); this.ENV.reset(); this.state = new EngineState(); for (const backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } this.backendName = null; this.backendInstance = null; this.pendingBackendInit = null; } } Engine.nextTensorId = 0; Engine.nextVariableId = 0; function ones$1(shape) { const values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); return ENGINE.makeTensor(values, shape, 'float32'); } function getOrMakeEngine() { const ns = getGlobalNamespace(); if (ns._tfengine == null) { const environment = new Environment(ns); ns._tfengine = new Engine(environment); } setEnvironmentGlobal(ns._tfengine.ENV); setTensorTracker(() => ns._tfengine); return ns._tfengine; } const ENGINE = getOrMakeEngine(); function add$1(a, b) { const inputs = { a, b }; return ENGINE.runKernel(Add, inputs); } function _isNavigatorDefined() { return typeof navigator !== 'undefined' && navigator != null; } function isMobile(nav) { if (nav || _isNavigatorDefined()) { if (!nav) { nav = navigator; } if (nav.product === 'ReactNative') { return true; } const a = nav.userAgent || nav.vendor || (typeof window !== 'undefined' ? window.opera : ''); if (!a) { const navAny = nav; return navAny.userAgentData && navAny.userAgentData.mobile; } return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i .test(a) || /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i .test(a.substr(0, 4)); } return false; } function isBrowser() { return (typeof window !== 'undefined' && window.document != null) || (typeof WorkerGlobalScope !== 'undefined'); } const ENV$1 = env(); ENV$1.registerFlag('DEBUG', () => false, debugValue => { if (debugValue) { console.warn('Debugging mode is ON. The output of every math call will ' + 'be downloaded to CPU and checked for NaNs. ' + 'This significantly impacts performance.'); } }); ENV$1.registerFlag('IS_BROWSER', () => isBrowser()); ENV$1.registerFlag('IS_NODE', () => (typeof process !== 'undefined') && (typeof process.versions !== 'undefined') && (typeof process.versions.node !== 'undefined')); ENV$1.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor)); ENV$1.registerFlag('IS_SAFARI', () => typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Safari/.test(navigator.userAgent) && /Apple/.test(navigator.vendor)); ENV$1.registerFlag('PROD', () => false); ENV$1.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV$1.getBool('DEBUG')); ENV$1.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true); ENV$1.registerFlag('IS_TEST', () => false); ENV$1.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', () => ENV$1.getBool('DEBUG')); ENV$1.registerFlag('WRAP_TO_IMAGEBITMAP', () => false); ENV$1.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false); ENV$1.registerFlag('USE_SETTIMEOUTCUSTOM', () => false); function buffer(shape, dtype = 'float32', values) { dtype = dtype || 'float32'; assertNonNegativeIntegerDimensions(shape); return new TensorBuffer(shape, dtype, values); } function inferShape(val, dtype) { let firstElem = val; if (isTypedArray(val)) { return dtype === 'string' ? [] : [val.length]; } if (isWebGLData(val)) { const usedChannels = val.channels || 'RGBA'; return [val.height, val.width * usedChannels.length]; } else if (isWebGPUData(val)) { return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))]; } if (!Array.isArray(val)) { return []; } const shape = []; while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== 'string') { shape.push(firstElem.length); firstElem = firstElem[0]; } if (Array.isArray(val) && env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { deepAssertShapeConsistency(val, shape, []); } return shape; } function deepAssertShapeConsistency(val, shape, indices) { indices = indices || []; if (!(Array.isArray(val)) && !isTypedArray(val)) { assert$1(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` + `but should be an array/TypedArray of ${shape[0]} elements`); return; } assert$1(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` + `but is an array of ${val.length} elements`); assert$1(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` + `elements, but has ${val.length} elements`); const subShape = shape.slice(1); for (let i = 0; i < val.length; ++i) { deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); } } function assertDtype(expectedDtype, actualDType, argName, functionName) { if (expectedDtype === 'string_or_numeric') { return; } if (expectedDtype == null) { throw new Error(`Expected dtype cannot be null.`); } if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || expectedDtype === 'numeric' && actualDType === 'string') { throw new Error(`Argument '${argName}' passed to '${functionName}' must ` + `be ${expectedDtype} tensor, but got ${actualDType} tensor`); } } function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') { if (x instanceof getGlobalTensorClass()) { assertDtype(parseAsDtype, x.dtype, argName, functionName); return x; } let inferredDtype = inferDtype(x); if (inferredDtype !== 'string' && ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { inferredDtype = parseAsDtype; } assertDtype(parseAsDtype, inferredDtype, argName, functionName); if ((x == null) || (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && typeof x !== 'boolean' && typeof x !== 'string')) { const type = x == null ? 'null' : x.constructor.name; throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` + `Tensor or TensorLike, but got '${type}'`); } const inferredShape = inferShape(x, inferredDtype); if (!isTypedArray(x) && !Array.isArray(x)) { x = [x]; } const skipTypedArray = true; const values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype) : flatten$1(x, [], skipTypedArray); return ENGINE.makeTensor(values, inferredShape, inferredDtype); } function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') { if (!Array.isArray(arg)) { throw new Error(`Argument ${argName} passed to ${functionName} must be a ` + '`Tensor[]` or `TensorLike[]`'); } const tensors = arg; return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype)); } const OP_SCOPE_SUFFIX = '__op'; function op(f) { const keys = Object.keys(f); if (keys.length !== 1) { throw new Error(`Please provide an object with a single key ` + `(operation name) mapping to a function. Got an object with ` + `${keys.length} keys.`); } let opName = keys[0]; const fn = f[opName]; if (opName.endsWith('_')) { opName = opName.substring(0, opName.length - 1); } opName = opName + OP_SCOPE_SUFFIX; const f2 = (...args) => { ENGINE.startScope(opName); try { const result = fn(...args); if (isPromise(result)) { console.error('Cannot return a Promise inside of tidy.'); } ENGINE.endScope(result); return result; } catch (ex) { ENGINE.endScope(null); throw ex; } }; Object.defineProperty(f2, 'name', { value: opName, configurable: true }); return f2; } function cast_(x, dtype) { const $x = convertToTensor(x, 'x', 'cast'); if (!isValidDtype(dtype)) { throw new Error(`Failed to cast to unknown dtype ${dtype}`); } if (dtype === 'string' && $x.dtype !== 'string' || dtype !== 'string' && $x.dtype === 'string') { throw new Error('Only strings can be casted to strings'); } const inputs = { x: $x }; const attrs = { dtype }; return ENGINE.runKernel(Cast, inputs, attrs); } const cast$2 = op({ cast_ }); function clone_(x) { const $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric'); const inputs = { x: $x }; return ENGINE.runKernel(Identity$1, inputs); } const clone = op({ clone_ }); function print(x, verbose = false) { console.log(x.toString(verbose)); } getOrMakeEngine(); const opHandler = { buffer, cast: cast$2, clone, print }; setOpHandler(opHandler); function enableProdMode() { env().set('PROD', true); } function engine() { return ENGINE; } function memory() { return ENGINE.memory(); } function tidy(nameOrFn, fn) { return ENGINE.tidy(nameOrFn, fn); } function dispose(container) { const tensors = getTensorsInContainer(container); tensors.forEach(tensor => tensor.dispose()); } function keep(result) { return ENGINE.keep(result); } function registerBackend(name, factory, priority = 1) { return ENGINE.registerBackend(name, factory, priority); } function backend() { return ENGINE.backend; } function add_(a, b) { let $a = convertToTensor(a, 'a', 'add'); let $b = convertToTensor(b, 'b', 'add'); [$a, $b] = makeTypesMatch($a, $b); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Add, inputs); } const add = op({ add_ }); function floorDiv_(a, b) { let $a = convertToTensor(a, 'a', 'floorDiv'); let $b = convertToTensor(b, 'b', 'floorDiv'); [$a, $b] = makeTypesMatch($a, $b); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(FloorDiv, inputs); } const floorDiv$1 = op({ floorDiv_ }); function div_(a, b) { let $a = convertToTensor(a, 'a', 'div'); let $b = convertToTensor(b, 'b', 'div'); [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'int32' && $b.dtype === 'int32') { return floorDiv$1($a, $b); } const inputs = { a: $a, b: $b }; const attrs = {}; return ENGINE.runKernel(RealDiv, inputs, attrs); } const div = op({ div_ }); function mul_(a, b) { let $a = convertToTensor(a, 'a', 'mul'); let $b = convertToTensor(b, 'b', 'mul'); [$a, $b] = makeTypesMatch($a, $b); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Multiply, inputs); } const mul = op({ mul_ }); function abs_(x) { const $x = convertToTensor(x, 'x', 'abs'); if ($x.dtype === 'complex64') { const inputs = { x: $x }; return ENGINE.runKernel(ComplexAbs, inputs); } else { const inputs = { x: $x }; return ENGINE.runKernel(Abs, inputs); } } const abs$1 = op({ abs_ }); function any_(x, axis = null, keepDims = false) { const $x = convertToTensor(x, 'x', 'any', 'bool'); const inputs = { x: $x }; const attrs = { axis, keepDims }; return ENGINE.runKernel(Any, inputs, attrs); } const any$1 = op({ any_ }); function argMax_(x, axis = 0) { const $x = convertToTensor(x, 'x', 'argMax'); const inputs = { x: $x }; const attrs = { axis }; return ENGINE.runKernel(ArgMax, inputs, attrs); } const argMax$1 = op({ argMax_ }); function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat = 'NHWC', dilations) { const inputChannels = inputShape[3]; const $filterShape = [...filterShape, inputChannels]; const $dataFormat = convertConv2DDataFormat(dataFormat); return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null , null , $dataFormat); } function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'channelsLast') { const [filterHeight, filterWidth] = parseTupleParam(filterSize); let filterShape; if (dataFormat === 'channelsLast') { filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; } else if (dataFormat === 'channelsFirst') { filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); } function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'NDHWC') { const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize); let filterShape; let $dataFormat; if (dataFormat === 'NDHWC') { $dataFormat = 'channelsLast'; filterShape = [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]]; } else if (dataFormat === 'NCDHW') { $dataFormat = 'channelsFirst'; filterShape = [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]]; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode); } function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise = false, dataFormat = 'channelsLast') { let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1]; if (dataFormat === 'channelsLast') { [batchSize, inHeight, inWidth, inChannels] = inShape; } else if (dataFormat === 'channelsFirst') { [batchSize, inChannels, inHeight, inWidth] = inShape; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } const [filterHeight, filterWidth, , filterChannels] = filterShape; const [strideHeight, strideWidth] = parseTupleParam(strides); const [dilationHeight, dilationWidth] = parseTupleParam(dilations); const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); const { padInfo, outHeight, outWidth } = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat); const outChannels = depthwise ? filterChannels * inChannels : filterChannels; let outShape; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outHeight, outWidth, outChannels]; } return { batchSize, dataFormat, inHeight, inWidth, inChannels, outHeight, outWidth, outChannels, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, effectiveFilterHeight, effectiveFilterWidth, dilationHeight, dilationWidth, inShape, outShape, filterShape }; } function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise = false, dataFormat = 'channelsLast', roundingMode) { let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1]; if (dataFormat === 'channelsLast') { [batchSize, inDepth, inHeight, inWidth, inChannels] = inShape; } else if (dataFormat === 'channelsFirst') { [batchSize, inChannels, inDepth, inHeight, inWidth] = inShape; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape; const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides); const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations); const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode); const outChannels = depthwise ? filterChannels * inChannels : filterChannels; let outShape; if (dataFormat === 'channelsFirst') { outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; } else if (dataFormat === 'channelsLast') { outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; } return { batchSize, dataFormat, inDepth, inHeight, inWidth, inChannels, outDepth, outHeight, outWidth, outChannels, padInfo, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, dilationDepth, dilationHeight, dilationWidth, inShape, outShape, filterShape }; } function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) { if (zeroPad == null) { zeroPad = computeDefaultPad(inShape, fieldSize, stride); } const inputRows = inShape[0]; const inputCols = inShape[1]; const outputRows = round$1((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); const outputCols = round$1((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); return [outputRows, outputCols]; } function computeOutputShape4D(inShape, filterShape, outChannels, strides, zeroPad, roundingMode) { if (zeroPad == null) { zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]); } const outShape = [0, 0, 0, outChannels]; for (let index = 0; index < 3; index++) { if (inShape[index] + 2 * zeroPad >= filterShape[index]) { outShape[index] = round$1((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] + 1, roundingMode); } } return outShape; } function computeDefaultPad(inputShape, fieldSize, stride, dilation = 1) { const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); } function parseTupleParam(param) { if (typeof param === 'number') { return [param, param, param]; } if (param.length === 2) { return [param[0], param[1], 1]; } return param; } function parse3TupleParam(param) { return typeof param === 'number' ? [param, param, param] : param; } function getEffectiveFilterSize(filterSize, dilation) { if (dilation <= 1) { return filterSize; } return filterSize + (filterSize - 1) * (dilation - 1); } function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) { let padInfo; let outHeight; let outWidth; if (typeof pad === 'number') { const padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType }; const outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode); outHeight = outShape[0]; outWidth = outShape[1]; } else if (pad === 'same') { outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight); const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth); const top = Math.floor(padAlongHeight / 2); const bottom = padAlongHeight - top; const left = Math.floor(padAlongWidth / 2); const right = padAlongWidth - left; padInfo = { top, bottom, left, right, type: 'SAME' }; } else if (pad === 'valid') { padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' }; outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); } else if (typeof pad === 'object') { const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0]; const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1]; const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0]; const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1]; const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ? 'VALID' : 'EXPLICIT'; padInfo = { top, bottom, left, right, type: padType }; outHeight = round$1((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode); outWidth = round$1((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode); } else { throw Error(`Unknown padding parameter: ${pad}`); } return { padInfo, outHeight, outWidth }; } function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) { let padInfo; let outDepth; let outHeight; let outWidth; if (pad === 'valid') { pad = 0; } if (typeof pad === 'number') { const padType = (pad === 0) ? 'VALID' : 'NUMBER'; padInfo = { top: pad, bottom: pad, left: pad, right: pad, front: pad, back: pad, type: padType }; const outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, [strideDepth, strideHeight, strideWidth], pad, roundingMode); outDepth = outShape[0]; outHeight = outShape[1]; outWidth = outShape[2]; } else if (pad === 'same') { outDepth = Math.ceil(inDepth / strideDepth); outHeight = Math.ceil(inHeight / strideHeight); outWidth = Math.ceil(inWidth / strideWidth); const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; const front = Math.floor(padAlongDepth / 2); const back = padAlongDepth - front; const top = Math.floor(padAlongHeight / 2); const bottom = padAlongHeight - top; const left = Math.floor(padAlongWidth / 2); const right = padAlongWidth - left; padInfo = { top, bottom, left, right, front, back, type: 'SAME' }; } else { throw Error(`Unknown padding parameter: ${pad}`); } return { padInfo, outDepth, outHeight, outWidth }; } function round$1(value, roundingMode) { if (!roundingMode) { return Math.trunc(value); } switch (roundingMode) { case 'round': return Math.round(value); case 'ceil': return Math.ceil(value); case 'floor': return Math.floor(value); default: throw new Error(`Unknown roundingMode ${roundingMode}`); } } function tupleValuesAreOne(param) { const [dimA, dimB, dimC] = parseTupleParam(param); return dimA === 1 && dimB === 1 && dimC === 1; } function eitherStridesOrDilationsAreOne(strides, dilations) { return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); } function stridesOrDilationsArePositive(values) { return parseTupleParam(values).every(value => value > 0); } function convertConv2DDataFormat(dataFormat) { if (dataFormat === 'NHWC') { return 'channelsLast'; } else if (dataFormat === 'NCHW') { return 'channelsFirst'; } else { throw new Error(`Unknown dataFormat ${dataFormat}`); } } function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) { if (dimRoundingMode != null) { if (typeof pad === 'string') { throw Error(`Error in ${opDesc}: pad must be an integer when using ` + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); } else if (typeof pad === 'number') { assert$1(isInt(pad), () => `Error in ${opDesc}: pad must be an integer when using ` + `dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`); } else if (typeof pad === 'object') { pad.forEach(p => { p.forEach(v => { assert$1(isInt(v), () => `Error in ${opDesc}: pad must be an integer when using ` + `dimRoundingMode ${dimRoundingMode} but got pad ${v}.`); }); }); } else { throw Error(`Error in ${opDesc}: Unknown padding parameter: ${pad}`); } } } function reshape_(x, shape) { const $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric'); const inputs = { x: $x }; const attrs = { shape }; return ENGINE.runKernel(Reshape$1, inputs, attrs); } const reshape$1 = op({ reshape_ }); function concat_(tensors, axis = 0) { assert$1(tensors.length >= 1, () => 'Pass at least one tensor to concat'); const $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric'); if ($tensors[0].dtype === 'complex64') { $tensors.forEach(tensor => { if (tensor.dtype !== 'complex64') { throw new Error(`Cannot concatenate complex64 tensors with a tensor with dtype ${tensor.dtype}. `); } }); } if ($tensors.length === 1) { return clone($tensors[0]); } const inputs = $tensors; const attr = { axis }; return ENGINE.runKernel(Concat, inputs, attr); } const concat$1 = op({ concat_ }); function matMul_(a, b, transposeA = false, transposeB = false) { let $a = convertToTensor(a, 'a', 'matMul'); let $b = convertToTensor(b, 'b', 'matMul'); [$a, $b] = makeTypesMatch($a, $b); const inputs = { a: $a, b: $b }; const attrs = { transposeA, transposeB }; return ENGINE.runKernel(BatchMatMul, inputs, attrs); } const matMul$1 = op({ matMul_ }); function sigmoid_(x) { const $x = convertToTensor(x, 'x', 'sigmoid', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Sigmoid$1, inputs); } const sigmoid$1 = op({ sigmoid_ }); function slice_(x, begin, size) { const $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric'); if ($x.rank === 0) { throw new Error('Slicing scalar is not possible'); } const inputs = { x: $x }; const attrs = { begin, size }; return ENGINE.runKernel(Slice, inputs, attrs); } const slice$1 = op({ slice_ }); function tanh_(x) { const $x = convertToTensor(x, 'x', 'tanh', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Tanh$1, inputs); } const tanh$1 = op({ tanh_ }); function batchToSpaceND_(x, blockShape, crops) { const $x = convertToTensor(x, 'x', 'batchToSpaceND'); const prod = blockShape.reduce((a, b) => a * b); assert$1($x.rank >= 1 + blockShape.length, () => `input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`); assert$1(crops.length === blockShape.length, () => `crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`); assert$1($x.shape[0] % prod === 0, () => `input tensor batch is ${$x.shape[0]} but is not divisible by the product of ` + `the elements of blockShape ${blockShape.join(' * ')} === ${prod}`); const inputs = { x: $x }; const attrs = { blockShape, crops }; return ENGINE.runKernel(BatchToSpaceND, inputs, attrs); } const batchToSpaceND$1 = op({ batchToSpaceND_ }); function broadcastTo_(x, shape) { let input = convertToTensor(x, 'broadcastTo', 'x'); const xShape = input.shape; assertNonNegativeIntegerDimensions(shape); if (shape.length < input.rank) { throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`); } if (shape.length > input.rank) { const newShape = input.shape.slice(); while (newShape.length < shape.length) { newShape.unshift(1); } input = reshape$1(input, newShape); } const inputShape = input.shape; const reps = Array.from(shape); for (let i = shape.length - 1; i >= 0; i--) { if (inputShape[i] === shape[i]) { reps[i] = 1; } else if (input.shape[i] !== 1) { throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`); } } const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); if (axes.length === 0) { return clone(input); } const inputs = { x: input }; const attrs = { reps }; return ENGINE.runKernel(Tile, inputs, attrs); } const broadcastTo = op({ broadcastTo_ }); function fill$1(shape, value, dtype) { assertNonNegativeIntegerDimensions(shape); dtype = dtype || inferDtype(value); const attrs = { shape, value, dtype }; return ENGINE.runKernel(Fill, {}, attrs); } function clipByValue_(x, clipValueMin, clipValueMax) { const $x = convertToTensor(x, 'x', 'clipByValue'); assert$1((clipValueMin <= clipValueMax), () => `Error in clip: min (${clipValueMin}) must be ` + `less than or equal to max (${clipValueMax}).`); if (clipValueMin === clipValueMax) { return fill$1($x.shape, clipValueMin, $x.dtype); } const inputs = { x: $x }; const attrs = { clipValueMin, clipValueMax }; return ENGINE.runKernel(ClipByValue, inputs, attrs); } const clipByValue$1 = op({ clipByValue_ }); function complex_(real, imag) { const $real = convertToTensor(real, 'real', 'complex'); const $imag = convertToTensor(imag, 'imag', 'complex'); assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` + `must match in call to tf.complex().`); const inputs = { real: $real, imag: $imag }; return ENGINE.runKernel(Complex, inputs); } const complex$1 = op({ complex_ }); function conv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) { const $x = convertToTensor(x, 'x', 'conv2d', 'float32'); const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32'); let x4D = $x; let reshapedTo4D = false; if ($x.rank === 3) { reshapedTo4D = true; x4D = reshape$1($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); } assert$1(x4D.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`); assert$1($filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` + `${$filter.rank}.`); checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode); const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; assert$1(inDepth === $filter.shape[2], () => `Error in conv2d: depth of input (${inDepth}) must match ` + `input depth for filter ${$filter.shape[2]}.`); assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); assert$1(stridesOrDilationsArePositive(dilations), () => 'Error in conv2D: Dilated rates should be larger than 0.'); assert$1(stridesOrDilationsArePositive(strides), () => 'Error in conv2D: Strides should be larger than 0.'); const inputs = { x: x4D, filter: $filter }; const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode }; const res = ENGINE.runKernel(Conv2D, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } const conv2d$1 = op({ conv2d_ }); function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat = 'NHWC', dimRoundingMode) { assert$1(xShape.length === dy.rank, () => `Length of inShape ` + `(${xShape.length}) and rank of dy (${dy.rank}) must match`); let xShape4D = xShape; let dy4D = dy; let reshapedTo4D = false; if (dy.rank === 3) { reshapedTo4D = true; dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); xShape4D = [1, xShape[0], xShape[1], xShape[2]]; } assert$1(xShape4D.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ` + `${xShape4D.length}.`); assert$1(dy4D.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got ` + `rank ${dy4D.rank}`); assert$1(filter.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got ` + `rank ${filter.rank}`); const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1]; const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; assert$1(inDepth === filter.shape[2], () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` + `match input depth for filter ${filter.shape[2]}.`); assert$1(outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` + `match output depth for filter ${filter.shape[3]}.`); checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode); const inputs = { dy: dy4D, filter }; const attrs = { strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D }; const res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } const conv2DBackpropInput$1 = op({ conv2DBackpropInput_ }); function conv3DBackpropInput_(xShape, dy, filter, strides, pad) { assert$1(xShape.length === dy.rank, () => `Length of inShape ` + `(${xShape.length}) and rank of dy (${dy.rank}) must match`); let xShape5D = xShape; let dy5D = dy; let reshapedTo5D = false; if (dy.rank === 4) { reshapedTo5D = true; dy5D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]); xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]]; } const inDepth = xShape5D[4]; const outDepth = dy5D.shape[4]; assert$1(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` + `${xShape5D.length}.`); assert$1(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` + `rank ${dy5D.rank}`); assert$1(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` + `rank ${filter.rank}`); assert$1(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` + `match input depth for filter ${filter.shape[3]}.`); assert$1(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` + `match output depth for filter ${filter.shape[4]}.`); const inputs = { dy: dy5D, filter }; const attrs = { pad, strides, inputShape: xShape5D }; const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs); if (reshapedTo5D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } const conv3DBackpropInput$1 = op({ conv3DBackpropInput_ }); function cos_(x) { const $x = convertToTensor(x, 'x', 'cos', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Cos, inputs); } const cos$1 = op({ cos_ }); function cosh_(x) { const $x = convertToTensor(x, 'x', 'cosh', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Cosh, inputs); } const cosh$1 = op({ cosh_ }); function cumprod_(x, axis = 0, exclusive = false, reverse = false) { const $x = convertToTensor(x, 'x', 'cumprod'); const inputs = { x: $x }; const attrs = { axis, exclusive, reverse }; return ENGINE.runKernel(Cumprod, inputs, attrs); } const cumprod$1 = op({ cumprod_ }); function cumsum_(x, axis = 0, exclusive = false, reverse = false) { const $x = convertToTensor(x, 'x', 'cumsum'); const inputs = { x: $x }; const attrs = { axis, exclusive, reverse }; return ENGINE.runKernel(Cumsum, inputs, attrs); } const cumsum$1 = op({ cumsum_ }); function getBroadcastDims$1(inShape, outShape) { const inRank = inShape.length; const dims = []; for (let i = 0; i < inRank; i++) { const dim = inRank - 1 - i; const a = inShape[dim] || 1; const b = outShape[outShape.length - 1 - i] || 1; if (b > 1 && a === 1) { dims.unshift(dim); } } return dims; } function getReductionAxes(inShape, outShape) { const result = []; for (let i = 0; i < outShape.length; i++) { const inDim = inShape[inShape.length - i - 1]; const outAxis = outShape.length - i - 1; const outDim = outShape[outAxis]; if (inDim == null || (inDim === 1 && outDim > 1)) { result.unshift(outAxis); } } return result; } function assertAndGetBroadcastShape(shapeA, shapeB) { const l = Math.max(shapeA.length, shapeB.length); const result = new Array(l); for (let i = 0; i < l; i++) { let a = shapeA[shapeA.length - i - 1]; if (a == null) { a = 1; } let b = shapeB[shapeB.length - i - 1]; if (b == null) { b = 1; } if (a === 1) { result[l - i - 1] = b; } else if (b === 1) { result[l - i - 1] = a; } else if (a !== b) { const errMsg = `Operands could not be broadcast together with shapes ` + `${shapeA} and ${shapeB}.`; throw Error(errMsg); } else { result[l - i - 1] = a; } } return result; } function equal_(a, b) { let $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric'); let $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Equal, inputs); } const equal$1 = op({ equal_ }); function where_(condition, a, b) { const $a = convertToTensor(a, 'a', 'where'); const $b = convertToTensor(b, 'b', 'where'); const $condition = convertToTensor(condition, 'condition', 'where', 'bool'); const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape); const $broadcastedCondition = broadcastTo($condition, broadcastShape); const $broadcastedA = broadcastTo($a, broadcastShape); const $broadcastedB = broadcastTo($b, broadcastShape); const inputs = { condition: $broadcastedCondition, t: $broadcastedA, e: $broadcastedB }; return ENGINE.runKernel(Select, inputs); } const where = op({ where_ }); function zerosLike_(x) { const $x = convertToTensor(x, 'x', 'zerosLike'); const inputs = { x: $x }; return ENGINE.runKernel(ZerosLike, inputs); } const zerosLike$1 = op({ zerosLike_ }); function elu_(x) { const $x = convertToTensor(x, 'x', 'elu', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Elu$1, inputs); } const elu$2 = op({ elu_ }); function erf_(x) { let $x = convertToTensor(x, 'x', 'erf'); assert$1($x.dtype === 'int32' || $x.dtype === 'float32', () => 'Input dtype must be `int32` or `float32`.'); if ($x.dtype === 'int32') { $x = cast$2($x, 'float32'); } const inputs = { x: $x }; return ENGINE.runKernel(Erf, inputs); } const erf$1 = op({ erf_ }); function axesAreInnerMostDims(axes, rank) { for (let i = 0; i < axes.length; ++i) { if (axes[axes.length - i - 1] !== rank - 1 - i) { return false; } } return true; } function combineLocations(outputLoc, reduceLoc, axes) { const rank = outputLoc.length + reduceLoc.length; const loc = []; let outIdx = 0; let reduceIdx = 0; for (let dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { loc.push(outputLoc[outIdx++]); } else { loc.push(reduceLoc[reduceIdx++]); } } return loc; } function computeOutAndReduceShapes(aShape, axes) { const outShape = []; const rank = aShape.length; for (let dim = 0; dim < rank; dim++) { if (axes.indexOf(dim) === -1) { outShape.push(aShape[dim]); } } const reduceShape = axes.map(dim => aShape[dim]); return [outShape, reduceShape]; } function expandShapeToKeepDim(shape, axes) { const reduceSubShape = axes.map(x => 1); return combineLocations(shape, reduceSubShape, axes); } function assertAxesAreInnerMostDims(msg, axes, rank) { assert$1(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` + `Got axes ${axes} and rank-${rank} input.`); } function getAxesPermutation(axes, rank) { if (axesAreInnerMostDims(axes, rank)) { return null; } const result = []; for (let i = 0; i < rank; ++i) { if (axes.indexOf(i) === -1) { result.push(i); } } axes.forEach(axis => result.push(axis)); return result; } function getUndoAxesPermutation(axes) { return axes.map((axis, i) => [i, axis]) .sort((a, b) => a[1] - b[1]) .map(x => x[0]); } function getInnerMostAxes(numAxes, rank) { const res = []; for (let i = rank - numAxes; i < rank; ++i) { res.push(i); } return res; } function max_(x, axis = null, keepDims = false) { const $x = convertToTensor(x, 'x', 'max'); const inputs = { x: $x }; const attrs = { reductionIndices: axis, keepDims }; return ENGINE.runKernel(Max, inputs, attrs); } const max$1 = op({ max_ }); function min_(x, axis = null, keepDims = false) { const $x = convertToTensor(x, 'x', 'min'); const inputs = { x: $x }; const attrs = { axis, keepDims }; return ENGINE.runKernel(Min, inputs, attrs); } const min$1 = op({ min_ }); function pow_(base, exp) { let $base = convertToTensor(base, 'base', 'pow'); let $exp = convertToTensor(exp, 'exp', 'pow'); [$base, $exp] = makeTypesMatch($base, $exp); const inputs = { a: $base, b: $exp }; return ENGINE.runKernel(Pow, inputs); } const pow$1 = op({ pow_ }); function makeTensor(values, shape, inferredShape, dtype) { if (dtype == null) { dtype = inferDtype(values); } else if (dtype === 'complex64') { throw new Error(`Cannot construct a complex64 tensor directly. ` + `Please use tf.complex(real, imag).`); } if (isWebGPUData(values) || isWebGLData(values)) { if (dtype !== 'float32' && dtype !== 'int32') { throw new Error(`Creating tensor from GPU data only supports ` + `'float32'|'int32' dtype, while the dtype is ${dtype}.`); } return ENGINE.backend.createTensorFromGPUData(values, shape || inferredShape, dtype); } if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== 'number' && typeof values !== 'boolean' && typeof values !== 'string') { throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + 'an array of numbers/booleans/strings, or a TypedArray'); } if (shape != null) { assertNonNegativeIntegerDimensions(shape); const providedSize = sizeFromShape(shape); const inferredSize = sizeFromShape(inferredShape); assert$1(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` + `${providedSize} values but has ${inferredSize}`); for (let i = 0; i < inferredShape.length; ++i) { const inferred = inferredShape[i]; const flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true; assert$1(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` + `(${inferredShape}) does not match the provided ` + `shape (${shape}). `); } } if (!isTypedArray(values) && !Array.isArray(values)) { values = [values]; } shape = shape || inferredShape; values = dtype !== 'string' ? toTypedArray(values, dtype) : flatten$1(values, [], true); return ENGINE.makeTensor(values, shape, dtype); } function scalar(value, dtype) { if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) && dtype !== 'complex64') { throw new Error('Error creating a new Scalar: value must be a primitive ' + '(number|boolean|string)'); } if (dtype === 'string' && isTypedArray(value) && !(value instanceof Uint8Array)) { throw new Error('When making a scalar from encoded string, ' + 'the value must be `Uint8Array`.'); } const shape = []; const inferredShape = []; return makeTensor(value, shape, inferredShape, dtype); } function sqrt_(x) { const $x = convertToTensor(x, 'x', 'sqrt', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Sqrt, inputs); } const sqrt$1 = op({ sqrt_ }); function square_(x) { const $x = convertToTensor(x, 'x', 'square'); const attrs = {}; return ENGINE.runKernel('Square', { x: $x }, attrs); } const square$2 = op({ square_ }); function sum_(x, axis = null, keepDims = false) { let $x = convertToTensor(x, 'x', 'sum'); if ($x.dtype === 'bool') { $x = cast$2($x, 'int32'); } const inputs = { x: $x }; const attrs = { axis, keepDims }; return ENGINE.runKernel(Sum, inputs, attrs); } const sum$1 = op({ sum_ }); function norm_(x, ord = 'euclidean', axis = null, keepDims = false) { x = convertToTensor(x, 'x', 'norm'); const norm = normImpl(x, ord, axis); let keepDimsShape = norm.shape; if (keepDims) { const axes = parseAxisParam(axis, x.shape); keepDimsShape = expandShapeToKeepDim(norm.shape, axes); } return reshape$1(norm, keepDimsShape); } function normImpl(x, p, axis = null) { if (x.rank === 0) { return abs$1(x); } if (x.rank !== 1 && axis === null) { return normImpl(reshape$1(x, [-1]), p, axis); } if (x.rank === 1 || typeof axis === 'number' || Array.isArray(axis) && axis.length === 1) { if (p === 1) { return sum$1(abs$1(x), axis); } if (p === Infinity) { return max$1(abs$1(x), axis); } if (p === -Infinity) { return min$1(abs$1(x), axis); } if (p === 'euclidean' || p === 2) { return sqrt$1(sum$1(pow$1(abs$1(x), scalar(2, 'int32')), axis)); } throw new Error(`Error in norm: invalid ord value: ${p}`); } if (Array.isArray(axis) && axis.length === 2) { if (p === 1) { return max$1(sum$1(abs$1(x), axis[0]), axis[1] - 1); } if (p === Infinity) { return max$1(sum$1(abs$1(x), axis[1]), axis[0]); } if (p === -Infinity) { return min$1(sum$1(abs$1(x), axis[1]), axis[0]); } if (p === 'fro' || p === 'euclidean') { return sqrt$1(sum$1(square$2(x), axis)); } throw new Error(`Error in norm: invalid ord value: ${p}`); } throw new Error(`Error in norm: invalid axis: ${axis}`); } const norm = op({ norm_ }); function exp_(x) { const $x = convertToTensor(x, 'x', 'exp'); const inputs = { x: $x }; return ENGINE.runKernel(Exp, inputs); } const exp$1 = op({ exp_ }); function expandDims_(x, axis = 0) { const $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric'); assert$1(axis <= $x.rank, () => 'Axis must be <= rank of the tensor'); const inputs = { input: $x }; const attrs = { dim: axis }; return ENGINE.runKernel(ExpandDims, inputs, attrs); } const expandDims$2 = op({ expandDims_ }); function tile_(x, reps) { const $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric'); assert$1($x.rank === reps.length, () => `Error in transpose: rank of input ${$x.rank} ` + `must match length of reps ${reps}.`); const inputs = { x: $x }; const attrs = { reps }; return ENGINE.runKernel(Tile, inputs, attrs); } const tile$2 = op({ tile_ }); function eye_(numRows, numColumns, batchShape, dtype = 'float32') { if (numColumns == null) { numColumns = numRows; } const buff = buffer([numRows, numColumns], dtype); const n = numRows <= numColumns ? numRows : numColumns; for (let i = 0; i < n; ++i) { buff.set(1, i, i); } const out = reshape$1(buff.toTensor(), [numRows, numColumns]); if (batchShape == null) { return out; } else { if (batchShape.length === 1) { return tile$2(expandDims$2(out, 0), [batchShape[0], 1, 1]); } else if (batchShape.length === 2) { return tile$2(expandDims$2(expandDims$2(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]); } else if (batchShape.length === 3) { return tile$2(expandDims$2(expandDims$2(expandDims$2(out, 0), 0), 0), [ batchShape[0], batchShape[1], batchShape[2], 1, 1 ]); } else { throw new Error(`eye() currently supports only 1D and 2D ` + `batchShapes, but received ${batchShape.length}D.`); } } } const eye = op({ eye_ }); function floor_(x) { const $x = convertToTensor(x, 'x', 'floor', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Floor, inputs); } const floor$1 = op({ floor_ }); function gather_(x, indices, axis = 0, batchDims = 0) { const $x = convertToTensor(x, 'x', 'gather'); const $indices = convertToTensor(indices, 'indices', 'gather', 'int32'); const inputs = { x: $x, indices: $indices }; const attrs = { axis, batchDims }; return ENGINE.runKernel(GatherV2, inputs, attrs); } const gather$1 = op({ gather_ }); function greater_(a, b) { let $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric'); let $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Greater, inputs); } const greater$1 = op({ greater_ }); function greaterEqual_(a, b) { let $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric'); let $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(GreaterEqual, inputs); } const greaterEqual$1 = op({ greaterEqual_ }); function imag_(input) { const $input = convertToTensor(input, 'input', 'imag'); const inputs = { input: $input }; return ENGINE.runKernel(Imag, inputs); } const imag$1 = op({ imag_ }); function leakyRelu_(x, alpha = 0.2) { const $x = convertToTensor(x, 'x', 'leakyRelu'); const inputs = { x: $x }; const attrs = { alpha }; return ENGINE.runKernel(LeakyRelu, inputs, attrs); } const leakyRelu$1 = op({ leakyRelu_ }); function less_(a, b) { let $a = convertToTensor(a, 'a', 'less', 'string_or_numeric'); let $b = convertToTensor(b, 'b', 'less', 'string_or_numeric'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Less, inputs); } const less$1 = op({ less_ }); function lessEqual_(a, b) { let $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric'); let $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(LessEqual, inputs); } const lessEqual$1 = op({ lessEqual_ }); function log_(x) { const $x = convertToTensor(x, 'x', 'log', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Log, inputs); } const log$1 = op({ log_ }); function log1p_(x) { const $x = convertToTensor(x, 'x', 'log1p'); const inputs = { x: $x }; return ENGINE.runKernel(Log1p, inputs); } const log1p$1 = op({ log1p_ }); function variableGrads(f, varList) { assert$1(isFunction(f), () => 'The f passed in variableGrads(f) must be a function'); assert$1(varList == null || Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' + 'of variables'); const specifiedVarList = varList != null; if (!specifiedVarList) { varList = []; for (const varName in ENGINE.registeredVariables) { varList.push(ENGINE.registeredVariables[varName]); } } const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null; const originalVarCount = varList.length; varList = varList.filter(variable => variable.trainable); assert$1(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` + `be trainable, but none of the ${originalVarCount} variables is ` + `trainable.`); const allowNoGradients = true; const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients); assert$1(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' + 'the loss function y=f(x). Please make sure the operations that ' + 'use variables are inside the function f passed to minimize().'); assert$1(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` + `returned a rank-${value.rank} tensor`); const namedGrads = {}; varList.forEach((v, i) => { if (grads[i] != null) { namedGrads[v.name] = grads[i]; } }); if (specifiedNonTrainable != null) { specifiedNonTrainable.forEach(v => namedGrads[v.name] = null); } return { value, grads: namedGrads }; } function customGrad(f) { return ENGINE.customGrad(f); } function neg_(x) { const $x = convertToTensor(x, 'x', 'neg'); const inputs = { x: $x }; return ENGINE.runKernel(Neg, inputs); } const neg$1 = op({ neg_ }); function softplus_(x) { const $x = convertToTensor(x, 'x', 'softplus'); const inputs = { x: $x }; return ENGINE.runKernel(Softplus$1, inputs); } const softplus$1 = op({ softplus_ }); function sub_(a, b) { let $a = convertToTensor(a, 'a', 'sub'); let $b = convertToTensor(b, 'b', 'sub'); [$a, $b] = makeTypesMatch($a, $b); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Sub, inputs); } const sub$1 = op({ sub_ }); function logSoftmax_(logits, axis = -1) { const $logits = convertToTensor(logits, 'logits', 'logSoftmax'); if (axis === -1) { axis = $logits.rank - 1; } if (axis !== $logits.rank - 1) { throw Error('Log Softmax along a non-last dimension is not yet supported. ' + `Logits was rank ${$logits.rank} and axis was ${axis}`); } const customOp = customGrad((logits, save) => { const keepDims = true; const xMax = max$1(logits, axis, true); const shifted = sub$1(logits, xMax); const value = sub$1(cast$2(shifted, 'float32'), log$1(sum$1(exp$1(shifted), axis, keepDims))); save([value]); const gradFunc = (dy, saved) => { const [value] = saved; const keepDims = true; const softmax = exp$1(value); return sub$1(dy, mul(sum$1(dy, axis, keepDims), softmax)); }; return { value, gradFunc }; }); return customOp($logits); } const logSoftmax = op({ logSoftmax_ }); function logicalAnd_(a, b) { const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool'); const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool'); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(LogicalAnd, inputs); } const logicalAnd$1 = op({ logicalAnd_ }); function logicalNot_(x) { const $x = convertToTensor(x, 'x', 'logicalNot', 'bool'); const inputs = { x: $x }; return ENGINE.runKernel(LogicalNot, inputs); } const logicalNot$1 = op({ logicalNot_ }); function maximum_(a, b) { let $a = convertToTensor(a, 'a', 'maximum'); let $b = convertToTensor(b, 'b', 'maximum'); [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'bool') { $a = cast$2($a, 'int32'); $b = cast$2($b, 'int32'); } assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Maximum, inputs); } const maximum$1 = op({ maximum_ }); function mean_(x, axis = null, keepDims = false) { const $x = convertToTensor(x, 'x', 'mean'); const inputs = { x: $x }; const attrs = { axis, keepDims }; return ENGINE.runKernel(Mean, inputs, attrs); } const mean = op({ mean_ }); function zeros(shape, dtype = 'float32') { assertNonNegativeIntegerDimensions(shape); if (dtype === 'complex64') { const real = zeros(shape, 'float32'); const imag = zeros(shape, 'float32'); return complex$1(real, imag); } const values = makeZerosTypedArray(sizeFromShape(shape), dtype); return ENGINE.makeTensor(values, shape, dtype); } function ones(shape, dtype = 'float32') { assertNonNegativeIntegerDimensions(shape); if (dtype === 'complex64') { const real = ones(shape, 'float32'); const imag = zeros(shape, 'float32'); return complex$1(real, imag); } const values = makeOnesTypedArray(sizeFromShape(shape), dtype); return ENGINE.makeTensor(values, shape, dtype); } function minimum_(a, b) { let $a = convertToTensor(a, 'a', 'minimum'); let $b = convertToTensor(b, 'b', 'minimum'); [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'bool') { $a = cast$2($a, 'int32'); $b = cast$2($b, 'int32'); } assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(Minimum, inputs); } const minimum$1 = op({ minimum_ }); function notEqual_(a, b) { let $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric'); let $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric'); [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const inputs = { a: $a, b: $b }; return ENGINE.runKernel(NotEqual, inputs); } const notEqual$1 = op({ notEqual_ }); function oneHot_(indices, depth, onValue = 1, offValue = 0, dtype = 'int32') { if (depth < 2) { throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`); } const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); const inputs = { indices: $indices }; const attrs = { dtype, depth, onValue, offValue }; return ENGINE.runKernel(OneHot, inputs, attrs); } const oneHot$1 = op({ oneHot_ }); function onesLike_(x) { const $x = convertToTensor(x, 'x', 'onesLike'); const inputs = { x: $x }; return ENGINE.runKernel(OnesLike, inputs); } const onesLike$1 = op({ onesLike_ }); function pad_(x, paddings, constantValue = 0) { const $x = convertToTensor(x, 'x', 'pad'); if ($x.rank === 0) { throw new Error('pad(scalar) is not defined. Pass non-scalar to pad'); } const attrs = { paddings, constantValue }; const inputs = { x: $x }; return ENGINE.runKernel(PadV2, inputs, attrs); } const pad = op({ pad_ }); function spaceToBatchND_(x, blockShape, paddings) { const $x = convertToTensor(x, 'x', 'spaceToBatchND'); assert$1($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`); assert$1(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`); assert$1($x.shape.reduce((a, b, i) => { if (i > 0 && i <= blockShape.length) { return a && ((b + paddings[i - 1][0] + paddings[i - 1][1]) % blockShape[i - 1] === 0); } return a; }, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`); const inputs = { x: $x }; const attrs = { blockShape, paddings }; return ENGINE.runKernel(SpaceToBatchND, inputs, attrs); } const spaceToBatchND$1 = op({ spaceToBatchND_ }); function prelu_(x, alpha) { const $x = convertToTensor(x, 'x', 'prelu'); const $alpha = convertToTensor(alpha, 'alpha', 'prelu'); const inputs = { x: $x, alpha: $alpha }; return ENGINE.runKernel(Prelu, inputs); } const prelu$1 = op({ prelu_ }); var alea$1 = {exports: {}}; (function (module) { (function(global, module, define) { function Alea(seed) { var me = this, mash = Mash(); me.next = function() { var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; me.s0 = me.s1; me.s1 = me.s2; return me.s2 = t - (me.c = t | 0); }; me.c = 1; me.s0 = mash(' '); me.s1 = mash(' '); me.s2 = mash(' '); me.s0 -= mash(seed); if (me.s0 < 0) { me.s0 += 1; } me.s1 -= mash(seed); if (me.s1 < 0) { me.s1 += 1; } me.s2 -= mash(seed); if (me.s2 < 0) { me.s2 += 1; } mash = null; } function copy(f, t) { t.c = f.c; t.s0 = f.s0; t.s1 = f.s1; t.s2 = f.s2; return t; } function impl(seed, opts) { var xg = new Alea(seed), state = opts && opts.state, prng = xg.next; prng.int32 = function() { return (xg.next() * 0x100000000) | 0; }; prng.double = function() { return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; }; prng.quick = prng; if (state) { if (typeof(state) == 'object') copy(state, xg); prng.state = function() { return copy(xg, {}); }; } return prng; } function Mash() { var n = 0xefc8249d; var mash = function(data) { data = String(data); for (var i = 0; i < data.length; i++) { n += data.charCodeAt(i); var h = 0.02519603282416938 * n; n = h >>> 0; h -= n; h *= n; n = h >>> 0; h -= n; n += h * 0x100000000; } return (n >>> 0) * 2.3283064365386963e-10; }; return mash; } if (module && module.exports) { module.exports = impl; } else { this.alea = impl; } })( commonjsGlobal, module); } (alea$1)); var aleaExports = alea$1.exports; var xor128$1 = {exports: {}}; (function (module) { (function(global, module, define) { function XorGen(seed) { var me = this, strseed = ''; me.x = 0; me.y = 0; me.z = 0; me.w = 0; me.next = function() { var t = me.x ^ (me.x << 11); me.x = me.y; me.y = me.z; me.z = me.w; return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8); }; if (seed === (seed | 0)) { me.x = seed; } else { strseed += seed; } for (var k = 0; k < strseed.length + 64; k++) { me.x ^= strseed.charCodeAt(k) | 0; me.next(); } } function copy(f, t) { t.x = f.x; t.y = f.y; t.z = f.z; t.w = f.w; return t; } function impl(seed, opts) { var xg = new XorGen(seed), state = opts && opts.state, prng = function() { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function() { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (typeof(state) == 'object') copy(state, xg); prng.state = function() { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else { this.xor128 = impl; } })( commonjsGlobal, module); } (xor128$1)); var xor128Exports = xor128$1.exports; var xorwow$1 = {exports: {}}; (function (module) { (function(global, module, define) { function XorGen(seed) { var me = this, strseed = ''; me.next = function() { var t = (me.x ^ (me.x >>> 2)); me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v; return (me.d = (me.d + 362437 | 0)) + (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0; }; me.x = 0; me.y = 0; me.z = 0; me.w = 0; me.v = 0; if (seed === (seed | 0)) { me.x = seed; } else { strseed += seed; } for (var k = 0; k < strseed.length + 64; k++) { me.x ^= strseed.charCodeAt(k) | 0; if (k == strseed.length) { me.d = me.x << 10 ^ me.x >>> 4; } me.next(); } } function copy(f, t) { t.x = f.x; t.y = f.y; t.z = f.z; t.w = f.w; t.v = f.v; t.d = f.d; return t; } function impl(seed, opts) { var xg = new XorGen(seed), state = opts && opts.state, prng = function() { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function() { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (typeof(state) == 'object') copy(state, xg); prng.state = function() { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else { this.xorwow = impl; } })( commonjsGlobal, module); } (xorwow$1)); var xorwowExports = xorwow$1.exports; var xorshift7$1 = {exports: {}}; (function (module) { (function(global, module, define) { function XorGen(seed) { var me = this; me.next = function() { var X = me.x, i = me.i, t, v; t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24); t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10); t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3); t = X[(i + 4) & 7]; v ^= t ^ (t << 7); t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9); X[i] = v; me.i = (i + 1) & 7; return v; }; function init(me, seed) { var j, X = []; if (seed === (seed | 0)) { X[0] = seed; } else { seed = '' + seed; for (j = 0; j < seed.length; ++j) { X[j & 7] = (X[j & 7] << 15) ^ (seed.charCodeAt(j) + X[(j + 1) & 7] << 13); } } while (X.length < 8) X.push(0); for (j = 0; j < 8 && X[j] === 0; ++j); if (j == 8) X[7] = -1; me.x = X; me.i = 0; for (j = 256; j > 0; --j) { me.next(); } } init(me, seed); } function copy(f, t) { t.x = f.x.slice(); t.i = f.i; return t; } function impl(seed, opts) { if (seed == null) seed = +(new Date); var xg = new XorGen(seed), state = opts && opts.state, prng = function() { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function() { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (state.x) copy(state, xg); prng.state = function() { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else { this.xorshift7 = impl; } })( commonjsGlobal, module); } (xorshift7$1)); var xorshift7Exports = xorshift7$1.exports; var xor4096$1 = {exports: {}}; (function (module) { (function(global, module, define) { function XorGen(seed) { var me = this; me.next = function() { var w = me.w, X = me.X, i = me.i, t, v; me.w = w = (w + 0x61c88647) | 0; v = X[(i + 34) & 127]; t = X[i = ((i + 1) & 127)]; v ^= v << 13; t ^= t << 17; v ^= v >>> 15; t ^= t >>> 12; v = X[i] = v ^ t; me.i = i; return (v + (w ^ (w >>> 16))) | 0; }; function init(me, seed) { var t, v, i, j, w, X = [], limit = 128; if (seed === (seed | 0)) { v = seed; seed = null; } else { seed = seed + '\0'; v = 0; limit = Math.max(limit, seed.length); } for (i = 0, j = -32; j < limit; ++j) { if (seed) v ^= seed.charCodeAt((j + 32) % seed.length); if (j === 0) w = v; v ^= v << 10; v ^= v >>> 15; v ^= v << 4; v ^= v >>> 13; if (j >= 0) { w = (w + 0x61c88647) | 0; t = (X[j & 127] ^= (v + w)); i = (0 == t) ? i + 1 : 0; } } if (i >= 128) { X[(seed && seed.length || 0) & 127] = -1; } i = 127; for (j = 4 * 128; j > 0; --j) { v = X[(i + 34) & 127]; t = X[i = ((i + 1) & 127)]; v ^= v << 13; t ^= t << 17; v ^= v >>> 15; t ^= t >>> 12; X[i] = v ^ t; } me.w = w; me.X = X; me.i = i; } init(me, seed); } function copy(f, t) { t.i = f.i; t.w = f.w; t.X = f.X.slice(); return t; } function impl(seed, opts) { if (seed == null) seed = +(new Date); var xg = new XorGen(seed), state = opts && opts.state, prng = function() { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function() { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (state.X) copy(state, xg); prng.state = function() { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else { this.xor4096 = impl; } })( commonjsGlobal, module); } (xor4096$1)); var xor4096Exports = xor4096$1.exports; var tychei$1 = {exports: {}}; (function (module) { (function(global, module, define) { function XorGen(seed) { var me = this, strseed = ''; me.next = function() { var b = me.b, c = me.c, d = me.d, a = me.a; b = (b << 25) ^ (b >>> 7) ^ c; c = (c - d) | 0; d = (d << 24) ^ (d >>> 8) ^ a; a = (a - b) | 0; me.b = b = (b << 20) ^ (b >>> 12) ^ c; me.c = c = (c - d) | 0; me.d = (d << 16) ^ (c >>> 16) ^ a; return me.a = (a - b) | 0; }; me.a = 0; me.b = 0; me.c = 2654435769 | 0; me.d = 1367130551; if (seed === Math.floor(seed)) { me.a = (seed / 0x100000000) | 0; me.b = seed | 0; } else { strseed += seed; } for (var k = 0; k < strseed.length + 20; k++) { me.b ^= strseed.charCodeAt(k) | 0; me.next(); } } function copy(f, t) { t.a = f.a; t.b = f.b; t.c = f.c; t.d = f.d; return t; } function impl(seed, opts) { var xg = new XorGen(seed), state = opts && opts.state, prng = function() { return (xg.next() >>> 0) / 0x100000000; }; prng.double = function() { do { var top = xg.next() >>> 11, bot = (xg.next() >>> 0) / 0x100000000, result = (top + bot) / (1 << 21); } while (result === 0); return result; }; prng.int32 = xg.next; prng.quick = prng; if (state) { if (typeof(state) == 'object') copy(state, xg); prng.state = function() { return copy(xg, {}); }; } return prng; } if (module && module.exports) { module.exports = impl; } else { this.tychei = impl; } })( commonjsGlobal, module); } (tychei$1)); var tycheiExports = tychei$1.exports; var seedrandom$1 = {exports: {}}; (function (module) { (function (global, pool, math) { var width = 256, chunks = 6, digits = 52, rngname = 'random', startdenom = math.pow(width, chunks), significance = math.pow(2, digits), overflow = significance * 2, mask = width - 1, nodecrypto; function seedrandom(seed, options, callback) { var key = []; options = (options == true) ? { entropy: true } : (options || {}); var shortseed = mixkey(flatten( options.entropy ? [seed, tostring(pool)] : (seed == null) ? autoseed() : seed, 3), key); var arc4 = new ARC4(key); var prng = function() { var n = arc4.g(chunks), d = startdenom, x = 0; while (n < significance) { n = (n + x) * width; d *= width; x = arc4.g(1); } while (n >= overflow) { n /= 2; d /= 2; x >>>= 1; } return (n + x) / d; }; prng.int32 = function() { return arc4.g(4) | 0; }; prng.quick = function() { return arc4.g(4) / 0x100000000; }; prng.double = prng; mixkey(tostring(arc4.S), pool); return (options.pass || callback || function(prng, seed, is_math_call, state) { if (state) { if (state.S) { copy(state, arc4); } prng.state = function() { return copy(arc4, {}); }; } if (is_math_call) { math[rngname] = prng; return seed; } else return prng; })( prng, shortseed, 'global' in options ? options.global : (this == math), options.state); } function ARC4(key) { var t, keylen = key.length, me = this, i = 0, j = me.i = me.j = 0, s = me.S = []; if (!keylen) { key = [keylen++]; } while (i < width) { s[i] = i++; } for (i = 0; i < width; i++) { s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))]; s[j] = t; } (me.g = function(count) { var t, r = 0, i = me.i, j = me.j, s = me.S; while (count--) { t = s[i = mask & (i + 1)]; r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))]; } me.i = i; me.j = j; return r; })(width); } function copy(f, t) { t.i = f.i; t.j = f.j; t.S = f.S.slice(); return t; } function flatten(obj, depth) { var result = [], typ = (typeof obj), prop; if (depth && typ == 'object') { for (prop in obj) { try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {} } } return (result.length ? result : typ == 'string' ? obj : obj + '\0'); } function mixkey(seed, key) { var stringseed = seed + '', smear, j = 0; while (j < stringseed.length) { key[mask & j] = mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++)); } return tostring(key); } function autoseed() { try { var out; if (nodecrypto && (out = nodecrypto.randomBytes)) { out = out(width); } else { out = new Uint8Array(width); (global.crypto || global.msCrypto).getRandomValues(out); } return tostring(out); } catch (e) { var browser = global.navigator, plugins = browser && browser.plugins; return [+new Date, global, plugins, global.screen, tostring(pool)]; } } function tostring(a) { return String.fromCharCode.apply(0, a); } mixkey(math.random(), pool); if (module.exports) { module.exports = seedrandom; try { nodecrypto = require('crypto'); } catch (ex) {} } else { math['seed' + rngname] = seedrandom; } })( (typeof self !== 'undefined') ? self : commonjsGlobal, [], Math ); } (seedrandom$1)); var seedrandomExports = seedrandom$1.exports; var alea = aleaExports; var xor128 = xor128Exports; var xorwow = xorwowExports; var xorshift7 = xorshift7Exports; var xor4096 = xor4096Exports; var tychei = tycheiExports; var sr = seedrandomExports; sr.alea = alea; sr.xor128 = xor128; sr.xorwow = xorwow; sr.xorshift7 = xorshift7; sr.xor4096 = xor4096; sr.tychei = tychei; var seedrandom = sr; class MPRandGauss { constructor(mean, stdDeviation, dtype, truncated, seed) { this.mean = mean; this.stdDev = stdDeviation; this.dtype = dtype; this.nextVal = NaN; this.truncated = truncated; if (this.truncated) { this.upper = this.mean + this.stdDev * 2; this.lower = this.mean - this.stdDev * 2; } const seedValue = seed ? seed : Math.random(); this.random = seedrandom.alea(seedValue.toString()); } nextValue() { if (!isNaN(this.nextVal)) { const value = this.nextVal; this.nextVal = NaN; return value; } let resultX, resultY; let isValid = false; while (!isValid) { let v1, v2, s; do { v1 = 2 * this.random() - 1; v2 = 2 * this.random() - 1; s = v1 * v1 + v2 * v2; } while (s >= 1 || s === 0); const mul = Math.sqrt(-2 * Math.log(s) / s); resultX = this.mean + this.stdDev * v1 * mul; resultY = this.mean + this.stdDev * v2 * mul; if (!this.truncated || this.isValidTruncated(resultX)) { isValid = true; } } if (!this.truncated || this.isValidTruncated(resultY)) { this.nextVal = this.convertValue(resultY); } return this.convertValue(resultX); } convertValue(value) { if (this.dtype == null || this.dtype === 'float32') { return value; } return Math.round(value); } isValidTruncated(value) { return value <= this.upper && value >= this.lower; } } class UniformRandom { constructor(min = 0, max = 1, dtype, seed) { this.canReturnFloat = () => (this.dtype == null || this.dtype === 'float32'); this.min = min; this.range = max - min; this.dtype = dtype; if (seed == null) { seed = Math.random(); } if (typeof seed === 'number') { seed = seed.toString(); } if (!this.canReturnFloat() && this.range <= 1) { throw new Error(`The difference between ${min} - ${max} <= 1 and dtype is not float`); } this.random = seedrandom.alea(seed); } convertValue(value) { if (this.canReturnFloat()) { return value; } return Math.round(value); } nextValue() { return this.convertValue(this.min + this.range * this.random()); } } function randomNormal_(shape, mean = 0, stdDev = 1, dtype, seed) { assertNonNegativeIntegerDimensions(shape); if (dtype != null && dtype === 'bool') { throw new Error(`Unsupported data type ${dtype}`); } const randGauss = new MPRandGauss(mean, stdDev, dtype, false , seed); const res = buffer(shape, dtype); for (let i = 0; i < res.values.length; i++) { res.values[i] = randGauss.nextValue(); } return res.toTensor(); } const randomNormal$1 = op({ randomNormal_ }); function randomUniform_(shape, minval = 0, maxval = 1, dtype = 'float32', seed) { assertNonNegativeIntegerDimensions(shape); const res = buffer(shape, dtype); const random = new UniformRandom(minval, maxval, null, seed); for (let i = 0; i < res.values.length; i++) { res.values[i] = random.nextValue(); } return res.toTensor(); } const randomUniform = op({ randomUniform_ }); function range$2(start, stop, step = 1, dtype = 'float32') { if (step === 0) { throw new Error('Cannot have a step of zero'); } const attrs = { start, stop, step, dtype }; return ENGINE.runKernel(Range, {} , attrs); } function real_(input) { const $input = convertToTensor(input, 'input', 'real'); const inputs = { input: $input }; return ENGINE.runKernel(Real, inputs); } const real$1 = op({ real_ }); function relu_(x) { const $x = convertToTensor(x, 'x', 'relu'); const inputs = { x: $x }; return ENGINE.runKernel(Relu$1, inputs); } const relu$1 = op({ relu_ }); function relu6_(x) { const $x = convertToTensor(x, 'x', 'relu6'); const inputs = { x: $x }; return ENGINE.runKernel(Relu6$1, inputs); } const relu6$1 = op({ relu6_ }); function reverse_(x, axis) { const $x = convertToTensor(x, 'x', 'reverse'); const inputs = { x: $x }; const attrs = { dims: axis }; return ENGINE.runKernel(Reverse, inputs, attrs); } const reverse$1 = op({ reverse_ }); function rsqrt_(x) { const $x = convertToTensor(x, 'x', 'rsqrt', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Rsqrt, inputs); } const rsqrt$1 = op({ rsqrt_ }); function selu_(x) { const $x = convertToTensor(x, 'x', 'selu'); const inputs = { x: $x }; return ENGINE.runKernel(Selu$1, inputs); } const selu$1 = op({ selu_ }); function sin_(x) { const $x = convertToTensor(x, 'x', 'sin', 'float32'); const inputs = { x: $x }; return ENGINE.runKernel(Sin, inputs); } const sin$1 = op({ sin_ }); function sinh_(x) { const $x = convertToTensor(x, 'x', 'sinh'); const inputs = { x: $x }; return ENGINE.runKernel(Sinh, inputs); } const sinh$1 = op({ sinh_ }); function slice1d_(x, begin, size) { const $x = convertToTensor(x, 'x', 'slice1d'); assert$1($x.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`); return slice$1($x, [begin], [size]); } const slice1d = op({ slice1d_ }); function slice2d_(x, begin, size) { const $x = convertToTensor(x, 'x', 'slice2d'); assert$1($x.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`); return slice$1($x, begin, size); } const slice2d = op({ slice2d_ }); function slice3d_(x, begin, size) { const $x = convertToTensor(x, 'x', 'slice3d'); assert$1($x.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`); return slice$1($x, begin, size); } const slice3d = op({ slice3d_ }); function slice4d_(x, begin, size) { const $x = convertToTensor(x, 'x', 'slice4d'); assert$1($x.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`); return slice$1($x, begin, size); } const slice4d = op({ slice4d_ }); function softmax_(logits, dim = -1) { const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32'); if (dim === -1) { dim = $logits.rank - 1; } if (dim !== $logits.rank - 1) { throw Error('Softmax along a non-last dimension is not yet supported. ' + `Logits was rank ${$logits.rank} and dim was ${dim}`); } const inputs = { logits: $logits }; const attrs = { dim }; return ENGINE.runKernel(Softmax$1, inputs, attrs); } const softmax$1 = op({ softmax_ }); function split_(x, numOrSizeSplits, axis = 0) { const $x = convertToTensor(x, 'x', 'split'); const inputs = { x: $x }; const attr = { numOrSizeSplits, axis }; return ENGINE.runKernel(SplitV, inputs, attr); } const split$1 = op({ split_ }); function squeeze_(x, axis) { const $x = convertToTensor(x, 'x', 'squeeze', 'string_or_numeric'); return reshape$1($x, squeezeShape($x.shape, axis).newShape); } const squeeze = op({ squeeze_ }); function stack_(tensors, axis = 0) { const $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric'); assert$1($tensors.length >= 1, () => 'Pass at least one tensor to tf.stack'); if ($tensors.length > 0) { assert$1(axis <= $tensors[0].rank, () => 'Axis must be <= rank of the tensor'); } const inputs = $tensors; const attrs = { axis }; return ENGINE.runKernel(Pack, inputs, attrs); } const stack = op({ stack_ }); function step_(x, alpha = 0.0) { const $x = convertToTensor(x, 'x', 'step'); const inputs = { x: $x }; const attrs = { alpha }; return ENGINE.runKernel(Step, inputs, attrs); } const step$1 = op({ step_ }); function tensor(values, shape, dtype) { const inferredShape = inferShape(values, dtype); return makeTensor(values, shape, inferredShape, dtype); } function tensor1d(values, dtype) { assertNonNull(values); const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 1) { throw new Error('tensor1d() requires values to be a flat/TypedArray'); } const shape = null; return makeTensor(values, shape, inferredShape, dtype); } function tensor2d(values, shape, dtype) { assertNonNull(values); if (shape != null && shape.length !== 2) { throw new Error('tensor2d() requires shape to have two numbers'); } const inferredShape = inferShape(values, dtype); if (inferredShape.length !== 2 && inferredShape.length !== 1) { throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray'); } if (inferredShape.length === 1 && shape == null) { throw new Error('tensor2d() requires shape to be provided when `values` ' + 'are a flat/TypedArray'); } return makeTensor(values, shape, inferredShape, dtype); } function validateUpdateShape(shape, indices, updates) { const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1; const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + `shape[sliceDim:], got updates.shape: ${updates.shape}` + `, indices.shape: ${indices.shape}, shape: ${shape}` + `, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`; if (updates.rank < batchDim) { throw new Error(shapeError + ` update.rank < ${batchDim}. `); } if (shape.length < sliceDim + (updates.rank - batchDim)) { throw new Error(shapeError + ` Output shape length < ${sliceDim + (updates.rank - batchDim)}`); } if (updates.rank !== batchDim + shape.length - sliceDim) { throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`); } for (let d = 0; d < batchDim; ++d) { if (updates.shape[d] !== indices.shape[d]) { throw new Error(shapeError + ` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`); } } for (let d = 0; d < updates.rank - batchDim; ++d) { if (updates.shape[d + batchDim] !== shape[d + sliceDim]) { throw new Error(shapeError + ` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`); } } } function validateInput(updates, indices, shape) { if (indices.rank < 1) { throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + ` but the rank was ${indices.rank}.`); } if (updates.rank < 1) { throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + ` but the rank was ${updates.rank}.`); } if (indices.dtype !== 'int32') { throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`); } if (shape.length < 1) { throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`); } if (shape.length === 0) { if (indices.size === 0) { throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`); } if (updates.size === 0) { throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`); } } validateUpdateShape(shape, indices, updates); } function calculateShapes(updates, indices, shape) { const indicesRank = indices.shape.length; const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1; const totalNd = shape.length; let sliceSize = 1; for (let i = sliceRank; i < totalNd; ++i) { sliceSize *= shape[i]; } const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank; const numUpdates = sizeFromShape(indices.shape) / safeSliceDim; const strides = [...computeStrides(shape.slice(0, sliceRank)), 1]; const outputSize = sizeFromShape(shape); return { sliceRank, numUpdates, sliceSize, strides, outputSize }; } function truncatedNormal_(shape, mean = 0, stdDev = 1, dtype, seed) { assertNonNegativeIntegerDimensions(shape); if (dtype != null && dtype === 'bool') { throw new Error(`Unsupported data type $ { dtype }`); } const randGauss = new MPRandGauss(mean, stdDev, dtype, true , seed); const res = buffer(shape, dtype); for (let i = 0; i < res.values.length; i++) { res.values[i] = randGauss.nextValue(); } return res.toTensor(); } const truncatedNormal = op({ truncatedNormal_ }); function unsortedSegmentSum_(x, segmentIds, numSegments) { const $x = convertToTensor(x, 'x', 'unsortedSegmentSum'); const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32'); assert$1(isInt(numSegments), () => 'numSegments must be of dtype int'); const inputs = { x: $x, segmentIds: $segmentIds }; const attrs = { numSegments }; return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs); } const unsortedSegmentSum$1 = op({ unsortedSegmentSum_ }); function unstack_(x, axis = 0) { const $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric'); assert$1(axis >= -$x.shape.length && axis < $x.shape.length, () => `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`); const inputs = { value: $x }; const attrs = { axis }; return ENGINE.runKernel(Unpack, inputs, attrs); } const unstack = op({ unstack_ }); function variable(initialValue, trainable = true, name, dtype) { return ENGINE.makeVariable(initialValue, trainable, name, dtype); } function whereImpl$1(condShape, condVals) { const indices = []; for (let i = 0; i < condVals.length; i++) { if (condVals[i]) { indices.push(i); } } const inBuffer = buffer(condShape, 'int32'); const out = buffer([indices.length, condShape.length], 'int32'); for (let i = 0; i < indices.length; i++) { const loc = inBuffer.indexToLoc(indices[i]); const offset = i * condShape.length; out.values.set(loc, offset); } return out.toTensor(); } function transpose_(x, perm, conjugate) { const $x = convertToTensor(x, 'x', 'transpose'); if (perm == null) { perm = $x.shape.map((s, i) => i).reverse(); } assert$1($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` + `must match length of perm ${perm}.`); perm.forEach(axis => { assert$1(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` + ` but got ${perm}`); }); if ($x.rank <= 1) { return $x.clone(); } const inputs = { x: $x }; const attrs = { perm }; if ($x.dtype === 'complex64') { return tidy(() => { let $real = real$1($x); let $imag = imag$1($x); $real = ENGINE.runKernel(Transpose, { x: $real }, attrs); $imag = ENGINE.runKernel(Transpose, { x: $imag }, attrs); if (conjugate) { $imag = neg$1($imag); } return complex$1($real, $imag); }); } return ENGINE.runKernel(Transpose, inputs, attrs); } const transpose$1 = op({ transpose_ }); function getNoiseShape(x, noiseShape) { if (noiseShape == null) { return x.shape.slice(); } if (arraysEqual(x.shape, noiseShape)) { return noiseShape; } if (x.shape.length === noiseShape.length) { const newDimension = []; for (let i = 0; i < x.shape.length; i++) { if (noiseShape[i] == null && x.shape[i] != null) { newDimension.push(x.shape[i]); } else { newDimension.push(noiseShape[i]); } } return newDimension; } return noiseShape; } function dropout_(x, rate, noiseShape, seed) { const $x = convertToTensor(x, 'x', 'dropout'); assert$1($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` + `scaled, but got a ${$x.dtype} tensor instead.`); assert$1(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`); if (rate === 0) { return x instanceof Tensor ? $x.clone() : $x; } const $noiseShape = getNoiseShape($x, noiseShape); const keepProb = 1 - rate; const multiplier = div(floor$1(add(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb); return mul($x, multiplier); } const dropout$2 = op({ dropout_ }); function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) { let x4D = x; if (x.rank === 3) { x4D = reshape$1(x, [1, x.shape[0], x.shape[1], x.shape[2]]); } let dy4D = dy; if (dy4D.rank === 3) { dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); } assert$1(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` + `${x4D.shape}.`); assert$1(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` + `${dy4D.shape}.`); assert$1(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` + `${filterShape}.`); const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; assert$1(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` + `match input depth in filter (${filterShape[2]}.`); assert$1(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` + `match output depth for filter (${filterShape[3]}).`); checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode); const inputs = { x: x4D, dy: dy4D }; const attrs = { strides, pad, dataFormat, dimRoundingMode, filterShape }; return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs); } const conv2DBackpropFilter$1 = op({ conv2DBackpropFilter_ }); function getFusedDyActivation(dy, y, activation) { if (activation == null || activation === 'linear') { return dy; } if (activation === 'relu') { return mul(dy, step$1(y)); } throw new Error(`Cannot compute gradient for fused activation ${activation}.`); } function getFusedBiasGradient(bias, dyActivation) { let res = dyActivation; const reduceAxes = getReductionAxes(bias.shape, dyActivation.shape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, bias.shape); } function applyActivation(x, activation, preluActivationWeights, leakyreluAlpha) { if (activation === 'linear') { return x; } else if (activation === 'relu') { return relu$1(x); } else if (activation === 'elu') { return elu$2(x); } else if (activation === 'relu6') { return relu6$1(x); } else if (activation === 'prelu') { return prelu$1(x, preluActivationWeights); } else if (activation === 'leakyrelu') { return leakyRelu$1(x, leakyreluAlpha); } else if (activation === 'sigmoid') { return sigmoid$1(x); } throw new Error(`Unknown fused activation ${activation}.`); } const shouldFuse = (gradientDepth, activation) => { const gradientMode = gradientDepth > 0; return !gradientMode || activation === 'linear'; }; function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations = [1, 1], dimRoundingMode) { let x4D = x; if (x.rank === 3) { x4D = reshape$1(x, [1, x.shape[0], x.shape[1], x.shape[2]]); } let dy4D = dy; if (dy4D.rank === 3) { dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); } const inputs = { x: x4D, dy: dy4D }; const attrs = { strides, pad, dimRoundingMode, dilations, filterShape }; return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs); } const depthwiseConv2dNativeBackpropFilter$1 = op({ depthwiseConv2dNativeBackpropFilter_ }); function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations = [1, 1], dimRoundingMode) { let dy4D = dy; let reshapedTo4D = false; if (dy.rank === 3) { reshapedTo4D = true; dy4D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]); } const inputs = { dy: dy4D, filter }; const attrs = { strides, pad, dimRoundingMode, dilations, inputShape: xShape }; const res = ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } const depthwiseConv2dNativeBackpropInput$1 = op({ depthwiseConv2dNativeBackpropInput_ }); function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha = 0.2, }) { if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { let result = matMul$1(a, b, transposeA, transposeB); if (bias != null) { result = add(result, bias); } return applyActivation(result, activation, preluActivationWeights, leakyreluAlpha); } let $a = convertToTensor(a, 'a', 'fused matMul'); let $b = convertToTensor(b, 'b', 'fused matMul'); [$a, $b] = makeTypesMatch($a, $b); const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; const outerDimsA = $a.shape.slice(0, -2); const outerDimsB = $b.shape.slice(0, -2); const batchDimA = sizeFromShape(outerDimsA); const batchDimB = sizeFromShape(outerDimsB); assert$1(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` + `${innerShapeB}) of Tensors with shapes ${$a.shape} and ` + `${$b.shape} and transposeA=${transposeA}` + ` and transposeB=${transposeB} must match.`); const outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2)); const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); const a3D = transposeA ? reshape$1($a, [batchDimA, innerShapeA, outerShapeA]) : reshape$1($a, [batchDimA, outerShapeA, innerShapeA]); const b3D = transposeB ? reshape$1($b, [batchDimB, outerShapeB, innerShapeB]) : reshape$1($b, [batchDimB, innerShapeB, outerShapeB]); let $bias; if (bias != null) { $bias = convertToTensor(bias, 'bias', 'fused matMul'); [$bias] = makeTypesMatch($bias, $a); assertAndGetBroadcastShape(outShape, $bias.shape); } let $preluActivationWeights; if (preluActivationWeights != null) { $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul'); } const grad = (dy, saved) => { const [a3D, b3D, y, $bias] = saved; const dyActivation = getFusedDyActivation(reshape$1(dy, y.shape), y, activation); let aDer; let bDer; if (!transposeA && !transposeB) { aDer = matMul$1(dyActivation, b3D, false, true); bDer = matMul$1(a3D, dyActivation, true, false); } else if (!transposeA && transposeB) { aDer = matMul$1(dyActivation, b3D, false, false); bDer = matMul$1(dyActivation, a3D, true, false); } else if (transposeA && !transposeB) { aDer = matMul$1(b3D, dyActivation, false, true); bDer = matMul$1(a3D, dyActivation, false, false); } else { aDer = matMul$1(b3D, dyActivation, true, true); bDer = matMul$1(dyActivation, a3D, true, true); } if (bias != null) { const biasDer = getFusedBiasGradient($bias, dyActivation); return [aDer, bDer, biasDer]; } else { return [aDer, bDer]; } }; const inputs = { a: a3D, b: b3D, bias: $bias, preluActivationWeights: $preluActivationWeights }; const attrs = { transposeA, transposeB, activation, leakyreluAlpha }; if (bias == null) { const customOp = customGrad((a3D, b3D, save) => { const res = ENGINE.runKernel(_FusedMatMul, inputs, attrs); save([a3D, b3D, res]); return { value: reshape$1(res, outShape), gradFunc: grad }; }); return customOp(a3D, b3D); } else { const customOpWithBias = customGrad((a3D, b3D, $bias, save) => { const res = ENGINE.runKernel(_FusedMatMul, inputs, attrs); save([a3D, b3D, res, $bias]); return { value: reshape$1(res, outShape), gradFunc: grad }; }); return customOpWithBias(a3D, b3D, $bias); } } const matMul = op({ fusedMatMul_ }); function binaryInsert(arr, element, comparator) { const index = binarySearch(arr, element, comparator); const insertionPoint = index < 0 ? -(index + 1) : index; arr.splice(insertionPoint, 0, element); } function binarySearch(arr, target, comparator) { return binarySearch_(arr, target, comparator || defaultComparator); } function defaultComparator(a, b) { return a > b ? 1 : a < b ? -1 : 0; } function binarySearch_(arr, target, comparator) { let left = 0; let right = arr.length; let middle = 0; let found = false; while (left < right) { middle = left + ((right - left) >>> 1); const compareResult = comparator(target, arr[middle]); if (compareResult > 0) { left = middle + 1; } else { right = middle; found = !compareResult; } } return found ? left : -left - 1; } function nonMaxSuppressionV3Impl$1(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 ); } function nonMaxSuppressionV4Impl$1(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) { return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 , false , padToMaxOutputSize , true ); } function nonMaxSuppressionV5Impl$1(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) { return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true ); } function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor = false, padToMaxOutputSize = false, returnValidOutputs = false) { const candidates = []; for (let i = 0; i < scores.length; i++) { if (scores[i] > scoreThreshold) { candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 }); } } candidates.sort(ascendingComparator); const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0; const selectedIndices = []; const selectedScores = []; while (selectedIndices.length < maxOutputSize && candidates.length > 0) { const candidate = candidates.pop(); const { score: originalScore, boxIndex, suppressBeginIndex } = candidate; if (originalScore < scoreThreshold) { break; } let ignoreCandidate = false; for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) { const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]); if (iou >= iouThreshold) { ignoreCandidate = true; break; } candidate.score = candidate.score * suppressWeight(iouThreshold, scale, iou); if (candidate.score <= scoreThreshold) { break; } } candidate.suppressBeginIndex = selectedIndices.length; if (!ignoreCandidate) { if (candidate.score === originalScore) { selectedIndices.push(boxIndex); selectedScores.push(candidate.score); } else if (candidate.score > scoreThreshold) { binaryInsert(candidates, candidate, ascendingComparator); } } } const validOutputs = selectedIndices.length; const elemsToPad = maxOutputSize - validOutputs; if (padToMaxOutputSize && elemsToPad > 0) { selectedIndices.push(...new Array(elemsToPad).fill(0)); selectedScores.push(...new Array(elemsToPad).fill(0.0)); } const result = { selectedIndices }; if (returnScoresTensor) { result['selectedScores'] = selectedScores; } if (returnValidOutputs) { result['validOutputs'] = validOutputs; } return result; } function intersectionOverUnion(boxes, i, j) { const iCoord = boxes.subarray(i * 4, i * 4 + 4); const jCoord = boxes.subarray(j * 4, j * 4 + 4); const yminI = Math.min(iCoord[0], iCoord[2]); const xminI = Math.min(iCoord[1], iCoord[3]); const ymaxI = Math.max(iCoord[0], iCoord[2]); const xmaxI = Math.max(iCoord[1], iCoord[3]); const yminJ = Math.min(jCoord[0], jCoord[2]); const xminJ = Math.min(jCoord[1], jCoord[3]); const ymaxJ = Math.max(jCoord[0], jCoord[2]); const xmaxJ = Math.max(jCoord[1], jCoord[3]); const areaI = (ymaxI - yminI) * (xmaxI - xminI); const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); if (areaI <= 0 || areaJ <= 0) { return 0.0; } const intersectionYmin = Math.max(yminI, yminJ); const intersectionXmin = Math.max(xminI, xminJ); const intersectionYmax = Math.min(ymaxI, ymaxJ); const intersectionXmax = Math.min(xmaxI, xmaxJ); const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * Math.max(intersectionXmax - intersectionXmin, 0.0); return intersectionArea / (areaI + areaJ - intersectionArea); } function suppressWeight(iouThreshold, scale, iou) { const weight = Math.exp(scale * iou * iou); return iou <= iouThreshold ? weight : 0.0; } function ascendingComparator(c1, c2) { return (c1.score - c2.score) || ((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex)); } function bandPart_(a, numLower, numUpper) { const $a = convertToTensor(a, 'a', 'bandPart'); assert$1($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`); const shape = $a.shape; const [M, N] = $a.shape.slice(-2); let $numLower; let $numUpper; if (typeof numLower === 'number') { assert$1(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`); assert$1(numLower <= M, () => `bandPart(): numLower (${numLower})` + ` must not be greater than the number of rows (${M}).`); $numLower = convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart'); } else { assert$1(numLower.dtype === 'int32', () => `bandPart(): numLower's dtype must be an int32.`); $numLower = where(less$1(numLower, 0), M, minimum$1(numLower, M)); } if (typeof numUpper === 'number') { assert$1(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`); assert$1(numUpper <= N, () => `bandPart(): numUpper (${numUpper})` + ` must not be greater than the number of columns (${N}).`); $numUpper = convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart'); } else { assert$1(numUpper.dtype === 'int32', () => `bandPart(): numUpper's dtype must be an int32.`); $numUpper = where(less$1(numUpper, 0), N, minimum$1(numUpper, N)); } const i = reshape$1(range$2(0, M, 1, 'int32'), [-1, 1]); const j = range$2(0, N, 1, 'int32'); const ij = sub$1(i, j); const inBand = logicalAnd$1(lessEqual$1(ij, $numLower), greaterEqual$1(ij, neg$1($numUpper))); const zero = zeros([M, N], $a.dtype); return reshape$1(stack(unstack(reshape$1($a, [-1, M, N])) .map(mat => where(inBand, mat, zero))), shape); } const bandPart = op({ bandPart_ }); function gramSchmidt_(xs) { let inputIsTensor2D; if (Array.isArray(xs)) { inputIsTensor2D = false; assert$1(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' + 'empty'); const dim = xs[0].shape[0]; for (let i = 1; i < xs.length; ++i) { assert$1(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' + `(${xs[i].shape[0]} vs. ${dim})`); } } else { inputIsTensor2D = true; xs = split$1(xs, xs.shape[0], 0).map(x => squeeze(x, [0])); } assert$1(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` + `number of dimensions (${xs[0].shape[0]}).`); const ys = []; const xs1d = xs; for (let i = 0; i < xs.length; ++i) { ys.push(ENGINE.tidy(() => { let x = xs1d[i]; if (i > 0) { for (let j = 0; j < i; ++j) { const proj = mul(sum$1(mul(ys[j], x)), ys[j]); x = sub$1(x, proj); } } return div(x, norm(x, 'euclidean')); })); } if (inputIsTensor2D) { return stack(ys, 0); } else { return ys; } } const gramSchmidt = op({ gramSchmidt_ }); function qr_(x, fullMatrices = false) { assert$1(x.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${x.rank}`); if (x.rank === 2) { return qr2d(x, fullMatrices); } else { const outerDimsProd = x.shape.slice(0, x.shape.length - 2) .reduce((value, prev) => value * prev); const x2ds = unstack(reshape$1(x, [ outerDimsProd, x.shape[x.shape.length - 2], x.shape[x.shape.length - 1] ]), 0); const q2ds = []; const r2ds = []; x2ds.forEach(x2d => { const [q2d, r2d] = qr2d(x2d, fullMatrices); q2ds.push(q2d); r2ds.push(r2d); }); const q = reshape$1(stack(q2ds, 0), x.shape); const r = reshape$1(stack(r2ds, 0), x.shape); return [q, r]; } } function qr2d(x, fullMatrices = false) { return ENGINE.tidy(() => { assert$1(x.shape.length === 2, () => `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`); const m = x.shape[0]; const n = x.shape[1]; let q = eye(m); let r = clone(x); const one2D = tensor2d([[1]], [1, 1]); let w = clone(one2D); const iters = m >= n ? n : m; for (let j = 0; j < iters; ++j) { const rTemp = r; const wTemp = w; const qTemp = q; [w, r, q] = ENGINE.tidy(() => { const rjEnd1 = slice$1(r, [j, j], [m - j, 1]); const normX = norm(rjEnd1); const rjj = slice$1(r, [j, j], [1, 1]); const s = where(greater$1(rjj, 0), tensor2d([[-1]]), tensor2d([[1]])); const u1 = sub$1(rjj, mul(s, normX)); const wPre = div(rjEnd1, u1); if (wPre.shape[0] === 1) { w = clone(one2D); } else { w = concat$1([ one2D, slice$1(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]]) ], 0); } const tau = neg$1(div(matMul$1(s, u1), normX)); const rjEndAll = slice$1(r, [j, 0], [m - j, n]); const tauTimesW = mul(tau, w); const wT = transpose$1(w); if (j === 0) { r = sub$1(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll))); } else { const rTimesTau = sub$1(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll))); r = concat$1([slice$1(r, [0, 0], [j, n]), rTimesTau], 0); } const tawTimesWT = transpose$1(tauTimesW); const qAllJEnd = slice$1(q, [0, j], [m, q.shape[1] - j]); if (j === 0) { q = sub$1(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT)); } else { const qTimesTau = sub$1(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT)); q = concat$1([slice$1(q, [0, 0], [m, j]), qTimesTau], 1); } return [w, r, q]; }); dispose([rTemp, wTemp, qTemp]); } if (!fullMatrices && m > n) { q = slice$1(q, [0, 0], [m, n]); r = slice$1(r, [0, 0], [n, n]); } return [q, r]; }); } const qr = op({ qr_ }); function stringToHashBucketFast_(input, numBuckets) { const $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string'); const attrs = { numBuckets }; if (numBuckets <= 0) { throw new Error(`Number of buckets must be at least 1`); } const inputs = { input: $input }; return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs); } const stringToHashBucketFast$1 = op({ stringToHashBucketFast_ }); const linalg = { bandPart, gramSchmidt, qr }; const GLOBAL_CUSTOM_OBJECT = new Map(); const GLOBAL_CUSTOM_NAMES = new Map(); class Serializable { getClassName() { return this.constructor .className; } static fromConfig(cls, config) { return new cls(config); } } class SerializationMap { constructor() { this.classNameMap = {}; } static getMap() { if (SerializationMap.instance == null) { SerializationMap.instance = new SerializationMap(); } return SerializationMap.instance; } static register(cls) { SerializationMap.getMap().classNameMap[cls.className] = [cls, cls.fromConfig]; } } function registerClass(cls, pkg, name) { assert$1(cls.className != null, () => `Class being registered does not have the static className ` + `property defined.`); assert$1(typeof cls.className === 'string', () => `className is required to be a string, but got type ` + typeof cls.className); assert$1(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` + `which is disallowed.`); if (typeof pkg === 'undefined') { pkg = 'Custom'; } if (typeof name === 'undefined') { name = cls.className; } const className = name; const registerName = pkg + '>' + className; SerializationMap.register(cls); GLOBAL_CUSTOM_OBJECT.set(registerName, cls); GLOBAL_CUSTOM_NAMES.set(cls, registerName); return cls; } class Optimizer extends Serializable { minimize(f, returnCost = false, varList) { const { value, grads } = this.computeGradients(f, varList); if (varList != null) { const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] })); this.applyGradients(gradArray); } else { this.applyGradients(grads); } dispose(grads); if (returnCost) { return value; } else { value.dispose(); return null; } } get iterations() { if (this.iterations_ == null) { this.iterations_ = 0; } return this.iterations_; } incrementIterations() { this.iterations_ = this.iterations + 1; } computeGradients(f, varList) { return variableGrads(f, varList); } dispose() { if (this.iterations_ != null) { dispose(this.iterations_); } } async saveIterations() { if (this.iterations_ == null) { this.iterations_ = 0; } return { name: 'iter', tensor: scalar(this.iterations_, 'int32') }; } async getWeights() { throw new Error('getWeights() is not implemented for this optimizer yet.'); } async setWeights(weightValues) { throw new Error(`setWeights() is not implemented for this optimizer class ` + `${this.getClassName()}`); } async extractIterations(weightValues) { this.iterations_ = (await weightValues[0].tensor.data())[0]; return weightValues.slice(1); } } Object.defineProperty(Optimizer, Symbol.hasInstance, { value: (instance) => { return instance.minimize != null && instance.computeGradients != null && instance.applyGradients != null; } }); class AdadeltaOptimizer extends Optimizer { static get className() { return 'Adadelta'; } constructor(learningRate, rho, epsilon = null) { super(); this.learningRate = learningRate; this.rho = rho; this.epsilon = epsilon; this.accumulatedGrads = []; this.accumulatedUpdates = []; if (epsilon == null) { this.epsilon = ENGINE.backend.epsilon(); } } applyGradients(variableGradients) { const variableNames = Array.isArray(variableGradients) ? variableGradients.map(item => item.name) : Object.keys(variableGradients); variableNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; const trainable = false; if (this.accumulatedGrads[i] == null) { this.accumulatedGrads[i] = { originalName: `${name}/accum_grad`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } if (this.accumulatedUpdates[i] == null) { this.accumulatedUpdates[i] = { originalName: `${name}/accum_var`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const accumulatedGrad = this.accumulatedGrads[i].variable; const accumulatedUpdate = this.accumulatedUpdates[i].variable; tidy(() => { const newAccumulatedGrad = add(mul(accumulatedGrad, this.rho), mul(square$2(gradient), 1 - this.rho)); const updates = mul(div(sqrt$1(add(accumulatedUpdate, this.epsilon)), sqrt$1(add(accumulatedGrad, this.epsilon))), gradient); const newAccumulatedUpdate = add(mul(accumulatedUpdate, this.rho), mul(square$2(updates), 1 - this.rho)); accumulatedGrad.assign(newAccumulatedGrad); accumulatedUpdate.assign(newAccumulatedUpdate); const newValue = add(mul(updates, -this.learningRate), value); value.assign(newValue); }); }); this.incrementIterations(); } dispose() { if (this.accumulatedUpdates != null) { dispose(this.accumulatedGrads.map(v => v.variable)); dispose(this.accumulatedUpdates.map(v => v.variable)); } } async getWeights() { const variables = [...this.accumulatedGrads, ...this.accumulatedUpdates]; return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable }))); } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); const variableCount = weightValues.length / 2; const trainable = false; this.accumulatedGrads = weightValues.slice(0, variableCount).map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); this.accumulatedUpdates = weightValues.slice(variableCount, variableCount * 2) .map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); } getConfig() { return { 'learningRate': this.learningRate, 'rho': this.rho, 'epsilon': this.epsilon }; } static fromConfig(cls, config) { return new cls(config['learningRate'], config['rho'], config['epsilon']); } } class AdagradOptimizer extends Optimizer { static get className() { return 'Adagrad'; } constructor(learningRate, initialAccumulatorValue = 0.1) { super(); this.learningRate = learningRate; this.initialAccumulatorValue = initialAccumulatorValue; this.accumulatedGrads = []; } applyGradients(variableGradients) { const variableNames = Array.isArray(variableGradients) ? variableGradients.map(item => item.name) : Object.keys(variableGradients); variableNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; if (this.accumulatedGrads[i] == null) { const trainable = false; this.accumulatedGrads[i] = { originalName: `${name}/accumulator`, variable: tidy(() => fill$1(value.shape, this.initialAccumulatorValue) .variable(trainable)) }; } const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const accumulatedGrad = this.accumulatedGrads[i].variable; tidy(() => { const newAccumulatedGrad = add(accumulatedGrad, square$2(gradient)); accumulatedGrad.assign(newAccumulatedGrad); const newValue = add(mul(div(gradient, sqrt$1(add(newAccumulatedGrad, ENGINE.backend.epsilon()))), -this.learningRate), value); value.assign(newValue); }); }); this.incrementIterations(); } dispose() { if (this.accumulatedGrads != null) { dispose(this.accumulatedGrads.map(v => v.variable)); } } async getWeights() { return [await this.saveIterations()].concat(this.accumulatedGrads.map(v => ({ name: v.originalName, tensor: v.variable }))); } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); const trainable = false; this.accumulatedGrads = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); } getConfig() { return { 'learningRate': this.learningRate, 'initialAccumulatorValue': this.initialAccumulatorValue, }; } static fromConfig(cls, config) { return new cls(config['learningRate'], config['initialAccumulatorValue']); } } class AdamOptimizer extends Optimizer { static get className() { return 'Adam'; } constructor(learningRate, beta1, beta2, epsilon = null) { super(); this.learningRate = learningRate; this.beta1 = beta1; this.beta2 = beta2; this.epsilon = epsilon; this.accumulatedFirstMoment = []; this.accumulatedSecondMoment = []; tidy(() => { this.accBeta1 = scalar(beta1).variable(); this.accBeta2 = scalar(beta2).variable(); }); if (epsilon == null) { this.epsilon = ENGINE.backend.epsilon(); } } applyGradients(variableGradients) { const varNames = Array.isArray(variableGradients) ? variableGradients.map(v => v.name) : Object.keys(variableGradients); tidy(() => { const oneMinusAccBeta1 = sub$1(1, this.accBeta1); const oneMinusAccBeta2 = sub$1(1, this.accBeta2); varNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; const trainable = false; if (this.accumulatedFirstMoment[i] == null) { this.accumulatedFirstMoment[i] = { originalName: `${name}/m`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } if (this.accumulatedSecondMoment[i] == null) { this.accumulatedSecondMoment[i] = { originalName: `${name}/v`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const firstMoment = this.accumulatedFirstMoment[i].variable; const secondMoment = this.accumulatedSecondMoment[i].variable; const newFirstMoment = add(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1)); const newSecondMoment = add(mul(secondMoment, this.beta2), mul(square$2(gradient), 1 - this.beta2)); const biasCorrectedFirstMoment = div(newFirstMoment, oneMinusAccBeta1); const biasCorrectedSecondMoment = div(newSecondMoment, oneMinusAccBeta2); firstMoment.assign(newFirstMoment); secondMoment.assign(newSecondMoment); const newValue = add(mul(div(biasCorrectedFirstMoment, add(sqrt$1(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value); value.assign(newValue); }); this.accBeta1.assign(mul(this.accBeta1, this.beta1)); this.accBeta2.assign(mul(this.accBeta2, this.beta2)); }); this.incrementIterations(); } dispose() { this.accBeta1.dispose(); this.accBeta2.dispose(); if (this.accumulatedFirstMoment != null) { dispose(this.accumulatedFirstMoment.map(v => v.variable)); } if (this.accumulatedSecondMoment != null) { dispose(this.accumulatedSecondMoment.map(v => v.variable)); } } async getWeights() { const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment]; return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable }))); } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); tidy(() => { this.accBeta1.assign(pow$1(this.beta1, this.iterations_ + 1)); this.accBeta2.assign(pow$1(this.beta2, this.iterations_ + 1)); }); const variableCount = weightValues.length / 2; const trainable = false; this.accumulatedFirstMoment = weightValues.slice(0, variableCount).map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); this.accumulatedSecondMoment = weightValues.slice(variableCount, variableCount * 2) .map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); } getConfig() { return { 'learningRate': this.learningRate, 'beta1': this.beta1, 'beta2': this.beta2, 'epsilon': this.epsilon, }; } static fromConfig(cls, config) { return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']); } } class AdamaxOptimizer extends Optimizer { static get className() { return 'Adamax'; } constructor(learningRate, beta1, beta2, epsilon = null, decay = 0.0) { super(); this.learningRate = learningRate; this.beta1 = beta1; this.beta2 = beta2; this.epsilon = epsilon; this.decay = decay; this.accumulatedFirstMoment = []; this.accumulatedWeightedInfNorm = []; tidy(() => { this.iteration = scalar(0).variable(); this.accBeta1 = scalar(beta1).variable(); }); if (epsilon == null) { this.epsilon = ENGINE.backend.epsilon(); } } applyGradients(variableGradients) { const variableNames = Array.isArray(variableGradients) ? variableGradients.map(item => item.name) : Object.keys(variableGradients); tidy(() => { const oneMinusAccBeta1 = sub$1(1, this.accBeta1); const lr = div(-this.learningRate, add(mul(this.iteration, this.decay), 1)); variableNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; const trainable = false; if (this.accumulatedFirstMoment[i] == null) { this.accumulatedFirstMoment[i] = { originalName: `${name}/m`, variable: zerosLike$1(value).variable(trainable) }; } if (this.accumulatedWeightedInfNorm[i] == null) { this.accumulatedWeightedInfNorm[i] = { originalName: `${name}/v`, variable: zerosLike$1(value).variable(trainable) }; } const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const firstMoment = this.accumulatedFirstMoment[i].variable; const weightedInfNorm = this.accumulatedWeightedInfNorm[i].variable; const newFirstMoment = add(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1)); const ut0 = mul(weightedInfNorm, this.beta2); const ut1 = abs$1(gradient); const newWeightedInfNorm = maximum$1(ut0, ut1); firstMoment.assign(newFirstMoment); weightedInfNorm.assign(newWeightedInfNorm); const newValue = add(mul(div(lr, oneMinusAccBeta1), div(newFirstMoment, add(newWeightedInfNorm, this.epsilon))), value); value.assign(newValue); }); this.iteration.assign(add(this.iteration, 1)); this.accBeta1.assign(mul(this.accBeta1, this.beta1)); }); this.incrementIterations(); } dispose() { this.accBeta1.dispose(); this.iteration.dispose(); if (this.accumulatedFirstMoment != null) { dispose(this.accumulatedFirstMoment.map(v => v.variable)); } if (this.accumulatedWeightedInfNorm != null) { dispose(this.accumulatedWeightedInfNorm.map(v => v.variable)); } } async getWeights() { throw new Error('getWeights() is not implemented for Adamax yet.'); } async setWeights(weightValues) { throw new Error('setWeights() is not implemented for Adamax yet.'); } getConfig() { return { 'learningRate': this.learningRate, 'beta1': this.beta1, 'beta2': this.beta2, 'epsilon': this.epsilon, 'decay': this.decay }; } static fromConfig(cls, config) { return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']); } } class SGDOptimizer extends Optimizer { static get className() { return 'SGD'; } constructor(learningRate) { super(); this.learningRate = learningRate; this.setLearningRate(learningRate); } applyGradients(variableGradients) { const varNames = Array.isArray(variableGradients) ? variableGradients.map(v => v.name) : Object.keys(variableGradients); varNames.forEach((name, i) => { const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const value = ENGINE.registeredVariables[name]; tidy(() => { const newValue = add(mul(this.c, gradient), value); value.assign(newValue); }); }); this.incrementIterations(); } setLearningRate(learningRate) { this.learningRate = learningRate; if (this.c != null) { this.c.dispose(); } this.c = keep(scalar(-learningRate)); } dispose() { this.c.dispose(); } async getWeights() { return [await this.saveIterations()]; } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); if (weightValues.length !== 0) { throw new Error('SGD optimizer does not have settable weights.'); } } getConfig() { return { 'learningRate': this.learningRate }; } static fromConfig(cls, config) { return new cls(config['learningRate']); } } class MomentumOptimizer extends SGDOptimizer { static get className() { return 'Momentum'; } constructor(learningRate, momentum, useNesterov = false) { super(learningRate); this.learningRate = learningRate; this.momentum = momentum; this.useNesterov = useNesterov; this.accumulations = []; this.m = scalar(this.momentum); } applyGradients(variableGradients) { const variableNames = Array.isArray(variableGradients) ? variableGradients.map(item => item.name) : Object.keys(variableGradients); variableNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; if (this.accumulations[i] == null) { const trainable = false; this.accumulations[i] = { originalName: `${name}/momentum`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } const accumulation = this.accumulations[i].variable; const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } tidy(() => { let newValue; const newAccumulation = add(mul(this.m, accumulation), gradient); if (this.useNesterov) { newValue = add(mul(this.c, add(gradient, mul(newAccumulation, this.m))), value); } else { newValue = add(mul(this.c, newAccumulation), value); } accumulation.assign(newAccumulation); value.assign(newValue); }); }); this.incrementIterations(); } dispose() { this.m.dispose(); if (this.accumulations != null) { dispose(this.accumulations.map(v => v.variable)); } } setMomentum(momentum) { this.momentum = momentum; } async getWeights() { return [await this.saveIterations()].concat(this.accumulations.map(v => ({ name: v.originalName, tensor: v.variable }))); } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); const trainable = false; this.accumulations = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); } getConfig() { return { 'learningRate': this.learningRate, 'momentum': this.momentum, 'useNesterov': this.useNesterov }; } static fromConfig(cls, config) { return new cls(config['learningRate'], config['momentum'], config['useNesterov']); } } class RMSPropOptimizer extends Optimizer { static get className() { return 'RMSProp'; } constructor(learningRate, decay = 0.9, momentum = 0.0, epsilon = null, centered = false) { super(); this.learningRate = learningRate; this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; this.accumulatedMeanSquares = []; this.accumulatedMoments = []; this.accumulatedMeanGrads = []; this.centered = centered; if (epsilon == null) { this.epsilon = ENGINE.backend.epsilon(); } if (learningRate == null) { throw new Error(`learningRate for RMSPropOptimizer must be defined.`); } } applyGradients(variableGradients) { const variableNames = Array.isArray(variableGradients) ? variableGradients.map(item => item.name) : Object.keys(variableGradients); variableNames.forEach((name, i) => { const value = ENGINE.registeredVariables[name]; const trainable = false; if (this.accumulatedMeanSquares[i] == null) { this.accumulatedMeanSquares[i] = { originalName: `${name}/rms`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } if (this.accumulatedMoments[i] == null) { this.accumulatedMoments[i] = { originalName: `${name}/momentum`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } if (this.accumulatedMeanGrads[i] == null && this.centered) { this.accumulatedMeanGrads[i] = { originalName: `${name}/mg`, variable: tidy(() => zerosLike$1(value).variable(trainable)) }; } const gradient = Array.isArray(variableGradients) ? variableGradients[i].tensor : variableGradients[name]; if (gradient == null) { return; } const accumulatedMeanSquare = this.accumulatedMeanSquares[i].variable; const accumulatedMoments = this.accumulatedMoments[i].variable; tidy(() => { const newAccumulatedMeanSquare = add(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay)); if (this.centered) { const accumulatedMeanGrad = this.accumulatedMeanGrads[i].variable; const newAccumulatedMeanGrad = add(mul(accumulatedMeanGrad, this.decay), mul(gradient, 1 - this.decay)); const gradContribution = div(mul(gradient, this.learningRate), sqrt$1(sub$1(newAccumulatedMeanSquare, add(square$2(newAccumulatedMeanGrad), this.epsilon)))); const newAccumulatedMoments = add(mul(accumulatedMoments, this.momentum), gradContribution); accumulatedMeanSquare.assign(newAccumulatedMeanSquare); accumulatedMeanGrad.assign(newAccumulatedMeanGrad); accumulatedMoments.assign(newAccumulatedMoments); const newValue = sub$1(value, newAccumulatedMoments); value.assign(newValue); } else { const newAccumulatedMeanSquare = add(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay)); const newAccumulatedMoments = add(mul(accumulatedMoments, this.momentum), div(mul(gradient, this.learningRate), sqrt$1(add(newAccumulatedMeanSquare, this.epsilon)))); accumulatedMeanSquare.assign(newAccumulatedMeanSquare); accumulatedMoments.assign(newAccumulatedMoments); const newValue = sub$1(value, newAccumulatedMoments); value.assign(newValue); } }); }); this.incrementIterations(); } dispose() { if (this.accumulatedMeanSquares != null) { dispose(this.accumulatedMeanSquares.map(v => v.variable)); } if (this.accumulatedMeanGrads != null && this.centered) { dispose(this.accumulatedMeanGrads.map(v => v.variable)); } if (this.accumulatedMoments != null) { dispose(this.accumulatedMoments.map(v => v.variable)); } } async getWeights() { const variables = [...this.accumulatedMeanSquares, ...this.accumulatedMoments]; if (this.centered) { variables.push(...this.accumulatedMeanGrads); } return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable }))); } async setWeights(weightValues) { weightValues = await this.extractIterations(weightValues); const variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2; const trainable = false; this.accumulatedMeanSquares = weightValues.slice(0, variableCount).map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); this.accumulatedMoments = weightValues.slice(variableCount, variableCount * 2) .map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); if (this.centered) { this.accumulatedMeanGrads = weightValues.slice(variableCount * 2, variableCount * 3) .map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) })); } } getConfig() { return { 'learningRate': this.learningRate, 'decay': this.decay, 'momentum': this.momentum, 'epsilon': this.epsilon, 'centered': this.centered }; } static fromConfig(cls, config) { return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']); } } const OPTIMIZERS = [ AdadeltaOptimizer, AdagradOptimizer, AdamOptimizer, AdamaxOptimizer, MomentumOptimizer, RMSPropOptimizer, SGDOptimizer, ]; function registerOptimizers() { for (const optimizer of OPTIMIZERS) { registerClass(optimizer); } } const DTYPE_VALUE_SIZE_MAP = { 'float32': 4, 'float16': 2, 'int32': 4, 'uint16': 2, 'uint8': 1, 'bool': 1, 'complex64': 8 }; class CompositeArrayBuffer { static join(buffers) { return new CompositeArrayBuffer(buffers).slice(); } constructor(buffers) { this.shards = []; this.previousShardIndex = 0; if (buffers == null) { return; } if (!(buffers instanceof Array)) { buffers = [buffers]; } buffers = buffers.map((bufferOrTypedArray) => { if (isTypedArray(bufferOrTypedArray)) { return bufferOrTypedArray.buffer; } return bufferOrTypedArray; }); if (buffers.length === 0) { return; } this.bufferUniformSize = buffers[0].byteLength; let start = 0; for (let i = 0; i < buffers.length; i++) { const buffer = buffers[i]; if (i !== buffers.length - 1 && buffer.byteLength !== this.bufferUniformSize) { this.bufferUniformSize = undefined; } const end = start + buffer.byteLength; this.shards.push({ buffer, start, end }); start = end; } if (this.shards.length === 0) { this.byteLength = 0; } this.byteLength = this.shards[this.shards.length - 1].end; } slice(start = 0, end = this.byteLength) { if (this.shards.length === 0) { return new ArrayBuffer(0); } start = isNaN(Number(start)) ? 0 : start; end = isNaN(Number(end)) ? 0 : end; start = Math.max(0, start); end = Math.min(this.byteLength, end); if (end <= start) { return new ArrayBuffer(0); } const startShardIndex = this.findShardForByte(start); if (startShardIndex === -1) { throw new Error(`Could not find start shard for byte ${start}`); } const size = end - start; const outputBuffer = new ArrayBuffer(size); const outputArray = new Uint8Array(outputBuffer); let sliced = 0; for (let i = startShardIndex; i < this.shards.length; i++) { const shard = this.shards[i]; const globalStart = start + sliced; const localStart = globalStart - shard.start; const outputStart = sliced; const globalEnd = Math.min(end, shard.end); const localEnd = globalEnd - shard.start; const outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart); outputArray.set(outputSlice, outputStart); sliced += outputSlice.length; if (end < shard.end) { break; } } return outputBuffer; } findShardForByte(byteIndex) { if (this.shards.length === 0 || byteIndex < 0 || byteIndex >= this.byteLength) { return -1; } if (this.bufferUniformSize != null) { this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize); return this.previousShardIndex; } function check(shard) { if (byteIndex < shard.start) { return -1; } if (byteIndex >= shard.end) { return 1; } return 0; } if (check(this.shards[this.previousShardIndex]) === 0) { return this.previousShardIndex; } const index = search(this.shards, check); if (index === -1) { return -1; } this.previousShardIndex = index; return this.previousShardIndex; } } function search(sortedArray, compare) { let min = 0; let max = sortedArray.length; while (min <= max) { const middle = Math.floor((max - min) / 2) + min; const side = compare(sortedArray[middle]); if (side === 0) { return middle; } else if (side < 0) { max = middle; } else { min = middle + 1; } } return -1; } const NUM_BYTES_STRING_LENGTH = 4; async function encodeWeights(tensors, group) { const specs = []; const dataPromises = []; const names = Array.isArray(tensors) ? tensors.map(tensor => tensor.name) : Object.keys(tensors); for (let i = 0; i < names.length; ++i) { const name = names[i]; const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name]; if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && t.dtype !== 'string' && t.dtype !== 'complex64') { throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); } const spec = { name, shape: t.shape, dtype: t.dtype }; if (t.dtype === 'string') { const utf8bytes = new Promise(async (resolve) => { const vals = await t.bytes(); const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) + NUM_BYTES_STRING_LENGTH * vals.length; const bytes = new Uint8Array(totalNumBytes); let offset = 0; for (let i = 0; i < vals.length; i++) { const val = vals[i]; const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); bytes.set(bytesOfLength, offset); offset += NUM_BYTES_STRING_LENGTH; bytes.set(val, offset); offset += val.length; } resolve(bytes); }); dataPromises.push(utf8bytes); } else { dataPromises.push(t.data()); } if (group != null) { spec.group = group; } specs.push(spec); } const tensorValues = await Promise.all(dataPromises); return { data: concatenateTypedArrays(tensorValues), specs }; } function decodeWeights(weightData, specs) { const compositeBuffer = new CompositeArrayBuffer(weightData); const out = {}; let offset = 0; for (const spec of specs) { const byteLength = getWeightBytelength(spec, (start, end) => { return compositeBuffer.slice(offset + start, offset + end); }); out[spec.name] = decodeWeight(spec, compositeBuffer .slice(offset, offset + byteLength)); offset += byteLength; } return out; } function getWeightBytelength(spec, slice) { const size = sizeFromShape(spec.shape); let bytesPerValue; if ('quantization' in spec) { const quantization = spec.quantization; bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; } else if (spec.dtype === 'string') { let byteLength = 0; for (let i = 0; i < size; i++) { byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; } return byteLength; } else { bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype]; } return size * bytesPerValue; } function decodeWeight(spec, byteBuffer) { const name = spec.name; const dtype = spec.dtype; const shape = spec.shape; const size = sizeFromShape(shape); let values; let offset = 0; if ('quantization' in spec) { const quantization = spec.quantization; if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { if (!('min' in quantization && 'scale' in quantization)) { throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` + `doesn't have corresponding metadata min and scale.`); } } else if (quantization.dtype === 'float16') { if (dtype !== 'float32') { throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` + `which only supports weights of type float32 not ${dtype}.`); } } else { throw new Error(`Weight ${spec.name} has unknown ` + `quantization dtype ${quantization.dtype}. ` + `Supported quantization dtypes are: ` + `'uint8', 'uint16', and 'float16'.`); } const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; const quantizedArray = (quantization.dtype === 'uint8') ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); if (dtype === 'float32') { if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { values = new Float32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; values[i] = v * quantization.scale + quantization.min; } } else if (quantization.dtype === 'float16') { const float16Decode = getFloat16Decoder(); values = float16Decode(quantizedArray); } else { throw new Error(`Unsupported quantization type ${quantization.dtype} ` + `for weight type float32.`); } } else if (dtype === 'int32') { if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { throw new Error(`Unsupported quantization type ${quantization.dtype} ` + `for weight type int32.`); } values = new Int32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; values[i] = Math.round(v * quantization.scale + quantization.min); } } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * quantizationSizeFactor; } else if (dtype === 'string') { const size = sizeFromShape(spec.shape); values = []; for (let i = 0; i < size; i++) { const byteLength = new Uint32Array(byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; offset += NUM_BYTES_STRING_LENGTH; const bytes = new Uint8Array(byteBuffer.slice(offset, offset + byteLength)); values.push(bytes); offset += byteLength; } } else { const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; if (dtype === 'float32') { values = new Float32Array(byteBuffer); } else if (dtype === 'int32') { values = new Int32Array(byteBuffer); } else if (dtype === 'bool') { values = new Uint8Array(byteBuffer); } else if (dtype === 'complex64') { values = new Float32Array(byteBuffer); const real = new Float32Array(values.length / 2); const image = new Float32Array(values.length / 2); for (let i = 0; i < real.length; i++) { real[i] = values[i * 2]; image[i] = values[i * 2 + 1]; } const realTensor = tensor(real, shape, 'float32'); const imageTensor = tensor(image, shape, 'float32'); const complexTensor = complex$1(realTensor, imageTensor); realTensor.dispose(); imageTensor.dispose(); return complexTensor; } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } offset += size * dtypeFactor; } return tensor(values, shape, dtype); } function concatenateTypedArrays(xs) { if (xs === null) { throw new Error(`Invalid input value: ${JSON.stringify(xs)}`); } let totalByteLength = 0; const normalizedXs = []; xs.forEach((x) => { totalByteLength += x.byteLength; normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x)); if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) { throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`); } }); const y = new Uint8Array(totalByteLength); let offset = 0; normalizedXs.forEach((x) => { y.set(new Uint8Array(x.buffer), offset); offset += x.byteLength; }); return y.buffer; } const useNodeBuffer = typeof Buffer !== 'undefined' && (typeof Blob === 'undefined' || typeof atob === 'undefined' || typeof btoa === 'undefined'); function stringByteLength(str) { if (useNodeBuffer) { return Buffer.byteLength(str, 'utf8'); } return new Blob([str]).size; } function arrayBufferToBase64String(buffer) { if (useNodeBuffer) { return Buffer.from(buffer).toString('base64'); } const buf = new Uint8Array(buffer); let s = ''; for (let i = 0, l = buf.length; i < l; i++) { s += String.fromCharCode(buf[i]); } return btoa(s); } function base64StringToArrayBuffer(str) { if (useNodeBuffer) { const buf = Buffer.from(str, 'base64'); return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); } const s = atob(str); const buffer = new Uint8Array(s.length); for (let i = 0; i < s.length; ++i) { buffer.set([s.charCodeAt(i)], i); } return buffer.buffer; } function concatenateArrayBuffers(buffers) { return CompositeArrayBuffer.join(buffers); } function getModelJSONForModelArtifacts(artifacts, manifest) { const result = { modelTopology: artifacts.modelTopology, format: artifacts.format, generatedBy: artifacts.generatedBy, convertedBy: artifacts.convertedBy, weightsManifest: manifest }; if (artifacts.signature != null) { result.signature = artifacts.signature; } if (artifacts.userDefinedMetadata != null) { result.userDefinedMetadata = artifacts.userDefinedMetadata; } if (artifacts.modelInitializer != null) { result.modelInitializer = artifacts.modelInitializer; } if (artifacts.initializerSignature != null) { result.initializerSignature = artifacts.initializerSignature; } if (artifacts.trainingConfig != null) { result.trainingConfig = artifacts.trainingConfig; } return result; } function getModelArtifactsInfoForJSON(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('Expected JSON model topology, received ArrayBuffer.'); } return { dateSaved: new Date(), modelTopologyType: 'JSON', modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), weightDataBytes: modelArtifacts.weightData == null ? 0 : new CompositeArrayBuffer(modelArtifacts.weightData).byteLength, }; } function computeFloat16MantisaTable() { const convertMantissa = (i) => { let m = i << 13; let e = 0; while ((m & 0x00800000) === 0) { e -= 0x00800000; m <<= 1; } m &= -8388609; e += 0x38800000; return m | e; }; const mantisaTable = new Uint32Array(2048); mantisaTable[0] = 0; for (let i = 1; i < 1024; i++) { mantisaTable[i] = convertMantissa(i); } for (let i = 1024; i < 2048; i++) { mantisaTable[i] = 0x38000000 + ((i - 1024) << 13); } return mantisaTable; } function computeFloat16ExponentTable() { const exponentTable = new Uint32Array(64); exponentTable[0] = 0; exponentTable[31] = 0x47800000; exponentTable[32] = 0x80000000; exponentTable[63] = 0xc7800000; for (let i = 1; i < 31; i++) { exponentTable[i] = i << 23; } for (let i = 33; i < 63; i++) { exponentTable[i] = 0x80000000 + ((i - 32) << 23); } return exponentTable; } function computeFloat16OffsetTable() { const offsetTable = new Uint32Array(64); for (let i = 0; i < 64; i++) { offsetTable[i] = 1024; } offsetTable[0] = offsetTable[32] = 0; return offsetTable; } function getFloat16Decoder() { const mantisaTable = computeFloat16MantisaTable(); const exponentTable = computeFloat16ExponentTable(); const offsetTable = computeFloat16OffsetTable(); return (quantizedArray) => { const buffer = new ArrayBuffer(4 * quantizedArray.length); const bufferUint32View = new Uint32Array(buffer); for (let index = 0; index < quantizedArray.length; index++) { const float16Bits = quantizedArray[index]; const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] + exponentTable[float16Bits >> 10]; bufferUint32View[index] = float32Bits; } return new Float32Array(buffer); }; } class IORouterRegistry { constructor() { this.saveRouters = []; this.loadRouters = []; } static getInstance() { if (IORouterRegistry.instance == null) { IORouterRegistry.instance = new IORouterRegistry(); } return IORouterRegistry.instance; } static registerSaveRouter(saveRouter) { IORouterRegistry.getInstance().saveRouters.push(saveRouter); } static registerLoadRouter(loadRouter) { IORouterRegistry.getInstance().loadRouters.push(loadRouter); } static getSaveHandlers(url) { return IORouterRegistry.getHandlers(url, 'save'); } static getLoadHandlers(url, loadOptions) { return IORouterRegistry.getHandlers(url, 'load', loadOptions); } static getHandlers(url, handlerType, loadOptions) { const validHandlers = []; const routers = handlerType === 'load' ? IORouterRegistry.getInstance().loadRouters : IORouterRegistry.getInstance().saveRouters; routers.forEach(router => { const handler = router(url, loadOptions); if (handler !== null) { validHandlers.push(handler); } }); return validHandlers; } } const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url); const DATABASE_NAME = 'tensorflowjs'; const DATABASE_VERSION = 1; const MODEL_STORE_NAME = 'models_store'; const INFO_STORE_NAME = 'model_info_store'; function getIndexedDBFactory() { if (!env().getBool('IS_BROWSER')) { throw new Error('Failed to obtain IndexedDB factory because the current environment' + 'is not a web browser.'); } const theWindow = typeof window === 'undefined' ? self : window; const factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB; if (factory == null) { throw new Error('The current browser does not appear to support IndexedDB.'); } return factory; } function setUpDatabase(openRequest) { const db = openRequest.result; db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' }); db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' }); } class BrowserIndexedDB { constructor(modelPath) { this.indexedDB = getIndexedDBFactory(); if (modelPath == null || !modelPath) { throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; } async save(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } return this.databaseAction(this.modelPath, modelArtifacts); } async load() { return this.databaseAction(this.modelPath); } databaseAction(modelPath, modelArtifacts) { return new Promise((resolve, reject) => { const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = () => setUpDatabase(openRequest); openRequest.onsuccess = () => { const db = openRequest.result; if (modelArtifacts == null) { const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly'); const modelStore = modelTx.objectStore(MODEL_STORE_NAME); const getRequest = modelStore.get(this.modelPath); getRequest.onsuccess = () => { if (getRequest.result == null) { db.close(); return reject(new Error(`Cannot find model with path '${this.modelPath}' ` + `in IndexedDB.`)); } else { resolve(getRequest.result.modelArtifacts); } }; getRequest.onerror = error => { db.close(); return reject(getRequest.error); }; modelTx.oncomplete = () => db.close(); } else { modelArtifacts.weightData = CompositeArrayBuffer.join(modelArtifacts.weightData); const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite'); let infoStore = infoTx.objectStore(INFO_STORE_NAME); let putInfoRequest; try { putInfoRequest = infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo }); } catch (error) { return reject(error); } let modelTx; putInfoRequest.onsuccess = () => { modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite'); const modelStore = modelTx.objectStore(MODEL_STORE_NAME); let putModelRequest; try { putModelRequest = modelStore.put({ modelPath: this.modelPath, modelArtifacts, modelArtifactsInfo }); } catch (error) { return reject(error); } putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo }); putModelRequest.onerror = error => { infoStore = infoTx.objectStore(INFO_STORE_NAME); const deleteInfoRequest = infoStore.delete(this.modelPath); deleteInfoRequest.onsuccess = () => { db.close(); return reject(putModelRequest.error); }; deleteInfoRequest.onerror = error => { db.close(); return reject(putModelRequest.error); }; }; }; putInfoRequest.onerror = error => { db.close(); return reject(putInfoRequest.error); }; infoTx.oncomplete = () => { if (modelTx == null) { db.close(); } else { modelTx.oncomplete = () => db.close(); } }; } }; openRequest.onerror = error => reject(openRequest.error); }); } } BrowserIndexedDB.URL_SCHEME = 'indexeddb://'; const indexedDBRouter = (url) => { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(indexedDBRouter); IORouterRegistry.registerLoadRouter(indexedDBRouter); function browserIndexedDB(modelPath) { return new BrowserIndexedDB(modelPath); } const PATH_SEPARATOR = '/'; const PATH_PREFIX = 'tensorflowjs_models'; const INFO_SUFFIX = 'info'; const MODEL_TOPOLOGY_SUFFIX = 'model_topology'; const WEIGHT_SPECS_SUFFIX = 'weight_specs'; const WEIGHT_DATA_SUFFIX = 'weight_data'; const MODEL_METADATA_SUFFIX = 'model_metadata'; function getModelKeys(path) { return { info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) }; } function removeItems(keys) { for (const key of Object.values(keys)) { window.localStorage.removeItem(key); } } class BrowserLocalStorage { constructor(modelPath) { if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') { throw new Error('The current environment does not support local storage.'); } this.LS = window.localStorage; if (modelPath == null || !modelPath) { throw new Error('For local storage, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; this.keys = getModelKeys(this.modelPath); } async save(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } else { const topology = JSON.stringify(modelArtifacts.modelTopology); const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(weightBuffer)); const metadata = { format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, signature: modelArtifacts.signature != null ? modelArtifacts.signature : undefined, userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ? modelArtifacts.userDefinedMetadata : undefined, modelInitializer: modelArtifacts.modelInitializer != null ? modelArtifacts.modelInitializer : undefined, initializerSignature: modelArtifacts.initializerSignature != null ? modelArtifacts.initializerSignature : undefined, trainingConfig: modelArtifacts.trainingConfig != null ? modelArtifacts.trainingConfig : undefined }; this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata)); return { modelArtifactsInfo }; } catch (err) { removeItems(this.keys); throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` + `size quota being exceeded is a possible cause of this failure: ` + `modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` + `weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` + `weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`); } } } async load() { const info = JSON.parse(this.LS.getItem(this.keys.info)); if (info == null) { throw new Error(`In local storage, there is no model with name '${this.modelPath}'`); } if (info.modelTopologyType !== 'JSON') { throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.'); } const out = {}; const topology = JSON.parse(this.LS.getItem(this.keys.topology)); if (topology == null) { throw new Error(`In local storage, the topology of model '${this.modelPath}' ` + `is missing.`); } out.modelTopology = topology; const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); if (weightSpecs == null) { throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` + `are missing.`); } out.weightSpecs = weightSpecs; const metadataString = this.LS.getItem(this.keys.modelMetadata); if (metadataString != null) { const metadata = JSON.parse(metadataString); out.format = metadata.format; out.generatedBy = metadata.generatedBy; out.convertedBy = metadata.convertedBy; if (metadata.signature != null) { out.signature = metadata.signature; } if (metadata.userDefinedMetadata != null) { out.userDefinedMetadata = metadata.userDefinedMetadata; } if (metadata.modelInitializer != null) { out.modelInitializer = metadata.modelInitializer; } if (metadata.initializerSignature != null) { out.initializerSignature = metadata.initializerSignature; } if (metadata.trainingConfig != null) { out.trainingConfig = metadata.trainingConfig; } } const weightDataBase64 = this.LS.getItem(this.keys.weightData); if (weightDataBase64 == null) { throw new Error(`In local storage, the binary weight values of model ` + `'${this.modelPath}' are missing.`); } out.weightData = base64StringToArrayBuffer(weightDataBase64); return out; } } BrowserLocalStorage.URL_SCHEME = 'localstorage://'; const localStorageRouter = (url) => { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(localStorageRouter); IORouterRegistry.registerLoadRouter(localStorageRouter); function browserLocalStorage(modelPath) { return new BrowserLocalStorage(modelPath); } const DEFAULT_FILE_NAME_PREFIX = 'model'; const DEFAULT_JSON_EXTENSION_NAME = '.json'; const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; function defer(f) { return new Promise(resolve => setTimeout(resolve)).then(f); } class BrowserDownloads { constructor(fileNamePrefix) { if (!env().getBool('IS_BROWSER')) { throw new Error('browserDownloads() cannot proceed because the current environment ' + 'is not a browser.'); } if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); } if (fileNamePrefix == null || fileNamePrefix.length === 0) { fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; } this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; } async save(modelArtifacts) { if (typeof (document) === 'undefined') { throw new Error('Browser downloads are not supported in ' + 'this environment since `document` is not present'); } const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData); const weightsURL = window.URL.createObjectURL(new Blob([weightBuffer], { type: 'application/octet-stream' })); if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserDownloads.save() does not support saving model topology ' + 'in binary formats yet.'); } else { const weightsManifest = [{ paths: ['./' + this.weightDataFileName], weights: modelArtifacts.weightSpecs }]; const modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest); const modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' })); const jsonAnchor = this.modelJsonAnchor == null ? document.createElement('a') : this.modelJsonAnchor; jsonAnchor.download = this.modelJsonFileName; jsonAnchor.href = modelJsonURL; await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click'))); if (modelArtifacts.weightData != null) { const weightDataAnchor = this.weightDataAnchor == null ? document.createElement('a') : this.weightDataAnchor; weightDataAnchor.download = this.weightDataFileName; weightDataAnchor.href = weightsURL; await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click'))); } return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }; } } } BrowserDownloads.URL_SCHEME = 'downloads://'; const browserDownloadsRouter = (url) => { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(browserDownloadsRouter); function browserDownloads(fileNamePrefix = 'model') { return new BrowserDownloads(fileNamePrefix); } class PassthroughLoader { constructor(modelArtifacts) { this.modelArtifacts = modelArtifacts; } load() { return this.modelArtifacts; } } class PassthroughSaver { constructor(saveHandler) { this.saveHandler = saveHandler; } save(modelArtifacts) { return this.saveHandler(modelArtifacts); } } class PassthroughAsync { constructor(handler) { if (handler.load) { this.load = () => Promise.resolve(handler.load()); } if (handler.save) { this.save = (modelArtifacts) => Promise.resolve(handler.save(modelArtifacts)); } } } function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { const args = arguments; return new PassthroughAsync(fromMemorySync(...args)); } function fromMemorySync(modelArtifacts, weightSpecs, weightData, trainingConfig) { if (arguments.length === 1) { const isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null; if (isModelArtifacts) { return new PassthroughLoader(modelArtifacts); } else { console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.'); return new PassthroughLoader({ modelTopology: modelArtifacts }); } } else { console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.'); return new PassthroughLoader({ modelTopology: modelArtifacts, weightSpecs, weightData, trainingConfig }); } } function withSaveHandler(saveHandler) { return new PassthroughSaver(saveHandler); } function prepareAndValidate(tensor, indices) { const tensorRank = tensor.shape.length; const indicesRank = indices.shape.length; if (tensorRank < 1) { throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' + ` but the rank was ${tensorRank}.`); } if (indicesRank < 1) { throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' + ` but the rank was ${indicesRank}.`); } if (indices.dtype !== 'int32') { throw new Error('tf.gatherND() expects the indices to be int32 type,' + ` but the dtype was ${indices.dtype}.`); } if (indices.shape[indicesRank - 1] > tensorRank) { throw new Error('index innermost dimension length must be <= tensor rank; saw: ' + `${indices.shape[indicesRank - 1]} vs. ${tensorRank}`); } if (sizeFromShape(tensor.shape) === 0) { throw new Error('Requested more than 0 entries, but input is empty.' + ` Input shape: ${tensor.shape}.`); } const indicesShape = indices.shape; const sliceRank = indicesShape[indicesShape.length - 1]; let nResult = 1; for (let i = 0; i < indicesShape.length - 1; ++i) { nResult *= indicesShape[i]; } const inputShape = tensor.shape; const resultShape = indicesShape.slice(); resultShape.pop(); let sliceSize = 1; for (let i = sliceRank; i < tensorRank; ++i) { sliceSize *= inputShape[i]; resultShape.push(inputShape[i]); } const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize), 1].slice(0, sliceRank); return [resultShape, nResult, sliceSize, strides]; } const NEW_AXIS = -2; const SHRINK_AXIS = -1; function assertParamsValid(input, begin, size) { const inputRank = input.shape.length; assert$1(inputRank === begin.length, () => `Error in slice${inputRank}D: Length of begin ${begin} must ` + `match the rank of the array (${inputRank}).`); assert$1(inputRank === size.length, () => `Error in slice${inputRank}D: Length of size ${size} must ` + `match the rank of the array (${inputRank}).`); for (let i = 0; i < inputRank; ++i) { assert$1(begin[i] + size[i] <= input.shape[i], () => `Error in slice${inputRank}D: begin[${i}] + size[${i}] ` + `(${begin[i] + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`); } } function maskToAxes(mask) { const axes = []; let axis = 0; while (mask > 0) { if (mask & 1) { axes.push(axis); } mask /= 2; axis++; } return axes; } function computeOutShape$2(begin, end, strides) { const size = []; for (let axis = 0; axis < begin.length; axis++) { size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]); } return size; } function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) { const newStrides = [...strides]; for (let i = newStrides.length; i < inputShape.length; i++) { newStrides.push(1); } for (let i = 0; i < numElidedAxes; i++) { if (i === 0) { newStrides[ellipsisInsertionIndex] = 1; } else { newStrides.splice(ellipsisInsertionIndex, 0 , 1 ); newStrides.pop(); } } return newStrides; } function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) { if (normalizedAxis <= ellipsisInsertionIndex) { return normalizedAxis; } return normalizedAxis - (numElidedAxes - 1); } function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) { const elidedAxes = []; for (let i = 0; i < numElidedAxes; i++) { elidedAxes.push(ellipsisInsertionIndex + i); } return elidedAxes; } function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) { const inputRank = inputShape.length; let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank); if (ellipsisAxes.length && numInterpolatedAxes > 0) { const fullIndex = ellipsisAxes[0]; const numElidedAxes = numInterpolatedAxes + 1; normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape); normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape); normalizedStrides = stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape); } else { for (let axis = 0; axis < inputRank; axis++) { normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask); normalizedEnd[axis] = stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask); normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask); } } return { begin: normalizedBegin, end: normalizedEnd, strides: normalizedStrides }; } function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) { const newIndices = [...inputShape]; const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); for (let axis = 0; axis < newIndices.length; axis++) { if (elidedAxes.indexOf(axis) > -1) { newIndices[axis] = 0; } else { const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); let originalValue = originalBegin[originalAxis]; if (beginMask & 1 << originalAxis) { originalValue = 0; } newIndices[axis] = originalValue; } } return newIndices; } function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) { const newIndices = [...inputShape]; const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); for (let axis = 0; axis < newIndices.length; axis++) { if (elidedAxes.indexOf(axis) > -1) { newIndices[axis] = Number.MAX_SAFE_INTEGER; } else { const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); let originalValue = originalEnd[originalAxis]; if (endMask & 1 << originalAxis) { originalValue = Number.MAX_SAFE_INTEGER; } newIndices[axis] = originalValue; } } for (let i = 0; i < newIndices.length; i++) { const axisSize = inputShape[i]; if (newIndices[i] < 0) { newIndices[i] += axisSize; } newIndices[i] = clamp(0, newIndices[i], inputShape[i]); } return newIndices; } function stridesForAxis(strides, axis, ellipsisMask) { let stride = strides[axis]; if (ellipsisMask & (1 << axis) || stride == null) { stride = 1; } return stride; } function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) { let start = startIndices[axis]; const stride = strides[axis] || 1; if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) { if (stride > 0) { start = Number.MIN_SAFE_INTEGER; } else { start = Number.MAX_SAFE_INTEGER; } } const axisSize = inputShape[axis]; if (start < 0) { start += axisSize; } start = clamp(0, start, axisSize - 1); return start; } function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) { let stop = stopIndices[axis]; const stride = strides[axis] || 1; if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) { if (stride > 0) { stop = Number.MAX_SAFE_INTEGER; } else { stop = Number.MIN_SAFE_INTEGER; } } const axisSize = inputShape[axis]; if (stop < 0) { stop += axisSize; } if (stride > 0) { stop = clamp(0, stop, axisSize); } else { stop = clamp(-1, stop, axisSize - 1); } return stop; } function isSliceContinous(shape, begin, size) { let firstNonOneAxis = size.length; for (let i = 0; i < size.length; i++) { if (size[i] > 1) { firstNonOneAxis = i; break; } } for (let i = firstNonOneAxis + 1; i < size.length; i++) { if (begin[i] > 0 || size[i] !== shape[i]) { return false; } } return true; } function computeFlatOffset(begin, strides) { let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1; for (let i = 0; i < begin.length - 1; i++) { flatOffset += begin[i] * strides[i]; } return flatOffset; } function parseSliceParams(x, begin, size) { let begin_; const xRank = x.shape.length; if (typeof begin === 'number') { begin_ = [begin, ...new Array(xRank - 1).fill(0)]; } else if (begin.length < xRank) { begin_ = begin.concat(new Array(xRank - begin.length).fill(0)); } else { begin_ = begin.slice(); } begin_.forEach(d => { assert$1(d !== -1, () => 'slice() does not support negative begin indexing.'); }); let size_; if (size == null) { size_ = new Array(xRank).fill(-1); } else if (typeof size === 'number') { size_ = [size, ...new Array(xRank - 1).fill(-1)]; } else if (size.length < xRank) { size_ = size.concat(new Array(xRank - size.length).fill(-1)); } else { size_ = size; } size_ = size_.map((d, i) => { if (d >= 0) { return d; } else { assert$1(d === -1, () => `Negative size values should be exactly -1 but got ` + `${d} for the slice() size at index ${i}.`); return x.shape[i] - begin_[i]; } }); return [begin_, size_]; } function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) { let stridesNonNull; if (strides == null) { stridesNonNull = new Array(begin.length); stridesNonNull.fill(1); } else { stridesNonNull = strides; } if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) { throw new Error('Multiple ellipses in slice is not allowed.'); } let ellipsisSeen = false; const sparseSpec = { dims: stridesNonNull.length, numAddAxisAfterEllipsis: 0, begin: begin.slice(), end: end.slice(), strides: stridesNonNull.slice(), beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask }; for (let i = 0; i < sparseSpec.dims; i++) { if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) { sparseSpec.numAddAxisAfterEllipsis++; } if ((1 << i) & ellipsisMask) { ellipsisSeen = true; } } if (!ellipsisSeen) { sparseSpec.ellipsisMask |= (1 << sparseSpec.dims); sparseSpec.dims++; } const denseSpec = { dims: xShape.length, beginMask: 0, endMask: 0, beginValid: false, endValid: false }; buildDenseSpec(sparseSpec, denseSpec); let isIdentity = true; let sliceDim0 = true; let isSimpleSlice = true; const processingShape = []; const finalShape = []; for (let i = 0; i < xShape.length; ++i) { if (denseSpec.strides[i] === 0) { throw Error(`strides[${i}] must be non-zero`); } const shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i)); const dimI = xShape[i]; if (dimI === -1) { processingShape.push(shrinkI ? 1 : -1); continue; } const masks = [denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)]; const validRange = [ denseSpec.strides[i] > 0 ? 0 : -1, denseSpec.strides[i] > 0 ? dimI : dimI - 1 ]; if (shrinkI && denseSpec.strides[i] <= 0) { throw Error('only stride 1 allowed on non-range indexing.'); } isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1); const beginAndEndMasked = !!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i))); if (denseSpec.beginValid && denseSpec.endValid) { if (shrinkI) { const xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] : denseSpec.begin[i]; denseSpec.begin[i] = xFwd; denseSpec.end[i] = denseSpec.begin[i] + 1; if (xFwd < 0 || xFwd >= dimI) { throw Error(`slice index ${denseSpec.begin[i]} of dimension ${i} out of bounds.`); } } else { denseSpec.begin[i] = canonical(denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks, validRange); denseSpec.end[i] = canonical(denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange); } const takeAllInDimension = denseSpec.strides[i] === 1 && denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI; isIdentity = isIdentity && takeAllInDimension; sliceDim0 = sliceDim0 && ((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension); } else { isIdentity = isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked); sliceDim0 = sliceDim0 && ((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked); } let intervalLength; let knownInterval = false; if (denseSpec.beginValid && denseSpec.endValid) { intervalLength = denseSpec.end[i] - denseSpec.begin[i]; knownInterval = true; } else if (shrinkI) { intervalLength = 1; knownInterval = true; } else if (beginAndEndMasked) { if (dimI >= 0) { if (denseSpec.strides[i] < 0) { intervalLength = -dimI; } else { intervalLength = dimI; } knownInterval = true; } } if (knownInterval) { let sizeI; if (intervalLength === 0 || ((intervalLength < 0) !== (denseSpec.strides[i] < 0))) { sizeI = 0; } else { sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) + (intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0); } processingShape.push(sizeI); } else { processingShape.push(-1); } } for (let denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) { const gatherIndex = denseSpec.finalShapeGatherIndices[denseDim]; if (gatherIndex >= 0) { finalShape.push(processingShape[gatherIndex]); } else if (gatherIndex === NEW_AXIS) { finalShape.push(1); } } const finalShapeSparse = finalShape.filter((dim, i) => denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS); return { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: denseSpec.begin, end: denseSpec.end, strides: denseSpec.strides }; } function buildDenseSpec(sparse, dense) { dense.beginMask = 0; dense.endMask = 0; dense.shrinkAxisMask = 0; let fullIndex = 0; dense.beginValid = sparse.begin != null; dense.endValid = sparse.end != null; dense.begin = new Array(dense.dims); dense.end = new Array(dense.dims); dense.strides = new Array(dense.dims); dense.finalShapeGatherIndices = []; dense.finalShapeGatherIndicesSparse = []; dense.inputShapeGatherIndicesSparse = new Array(dense.dims); for (let i = 0; i < sparse.dims; i++) { if ((1 << i) & sparse.ellipsisMask) { const nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims); for (; fullIndex < nextIndex; fullIndex++) { dense.begin[fullIndex] = 0; dense.end[fullIndex] = 0; dense.strides[fullIndex] = 1; dense.beginMask |= (1 << fullIndex); dense.endMask |= (1 << fullIndex); dense.finalShapeGatherIndices.push(fullIndex); dense.finalShapeGatherIndicesSparse.push(-1); dense.inputShapeGatherIndicesSparse[fullIndex] = i; } } else if ((1 << i) & sparse.newAxisMask) { dense.finalShapeGatherIndices.push(NEW_AXIS); dense.finalShapeGatherIndicesSparse.push(-1); } else { if (fullIndex === dense.begin.length) { throw Error(`Index out of range using input dim ${fullIndex}; input ` + `has only ${dense.dims} dims, ${dense.begin.length}.`); } if (sparse.begin != null) { dense.begin[fullIndex] = sparse.begin[i]; } if (sparse.end != null) { dense.end[fullIndex] = sparse.end[i]; } dense.strides[fullIndex] = sparse.strides[i]; if (sparse.beginMask & (1 << i)) { dense.beginMask |= (1 << fullIndex); } if (sparse.endMask & (1 << i)) { dense.endMask |= (1 << fullIndex); } if (sparse.shrinkAxisMask & (1 << i)) { dense.finalShapeGatherIndices.push(SHRINK_AXIS); dense.finalShapeGatherIndicesSparse.push(-1); dense.shrinkAxisMask |= (1 << fullIndex); } else { dense.finalShapeGatherIndices.push(fullIndex); dense.finalShapeGatherIndicesSparse.push(i); } dense.inputShapeGatherIndicesSparse[fullIndex] = i; fullIndex++; } } } function canonical(x, c, strideI, dimI, masks, validRange) { if (masks[c]) { return strideI > 0 ? validRange[c] : validRange[(c + 1) & 1]; } else { const xFwd = x < 0 ? dimI + x : x; return xFwd < validRange[0] ? validRange[0] : xFwd > validRange[1] ? validRange[1] : xFwd; } } var slice_util = Object.freeze({ __proto__: null, assertParamsValid: assertParamsValid, computeFlatOffset: computeFlatOffset, computeOutShape: computeOutShape$2, getNormalizedAxes: getNormalizedAxes, isSliceContinous: isSliceContinous, maskToAxes: maskToAxes, parseSliceParams: parseSliceParams, sliceInfo: sliceInfo, startForAxis: startForAxis, startIndicesWithElidedDims: startIndicesWithElidedDims, stopForAxis: stopForAxis, stopIndicesWithElidedDims: stopIndicesWithElidedDims, stridesForAxis: stridesForAxis, stridesWithElidedDims: stridesWithElidedDims }); class OptimizerConstructors { static sgd(learningRate) { return new SGDOptimizer(learningRate); } static momentum(learningRate, momentum, useNesterov = false) { return new MomentumOptimizer(learningRate, momentum, useNesterov); } static rmsprop(learningRate, decay = .9, momentum = 0.0, epsilon = null, centered = false) { return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered); } static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = null) { return new AdamOptimizer(learningRate, beta1, beta2, epsilon); } static adadelta(learningRate = .001, rho = .95, epsilon = null) { return new AdadeltaOptimizer(learningRate, rho, epsilon); } static adamax(learningRate = 0.002, beta1 = 0.9, beta2 = 0.999, epsilon = null, decay = 0.0) { return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay); } static adagrad(learningRate, initialAccumulatorValue = 0.1) { return new AdagradOptimizer(learningRate, initialAccumulatorValue); } } const train = OptimizerConstructors; const delayCallback = (() => { if (typeof requestAnimationFrame !== 'undefined') { return requestAnimationFrame; } else if (typeof setImmediate !== 'undefined') { return setImmediate; } return (f) => f(); })(); function nextFrame() { return new Promise(resolve => delayCallback(() => resolve())); } function assertParamsConsistent(shapes, axis) { const rank = shapes[0].length; shapes.forEach((shape, i) => { assert$1(shape.length === rank, () => `Error in concat${rank}D: rank of tensors[${i}] must be the same ` + `as the rank of the rest (${rank})`); }); assert$1(axis >= 0 && axis < rank, () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`); const firstShape = shapes[0]; shapes.forEach((shape, i) => { for (let r = 0; r < rank; r++) { assert$1((r === axis) || (shape[r] === firstShape[r]), () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` + `does not match the shape of the rest (${firstShape}) ` + `along the non-concatenated axis ${i}.`); } }); } function computeOutShape$1(shapes, axis) { const outputShape = shapes[0].slice(); for (let i = 1; i < shapes.length; i++) { outputShape[axis] += shapes[i][axis]; } return outputShape; } var RowPartitionType$1; (function (RowPartitionType) { RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE"; RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS"; RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS"; RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS"; RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS"; RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS"; })(RowPartitionType$1 || (RowPartitionType$1 = {})); function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) { let outputShape = new Array(); if (valueShape == null && shape == null) { return outputShape; } if (shape == null) { while (outputShape.length < raggedRank + valueShape.length) { outputShape.push(-1); } } else { outputShape = shape.slice(); } if (valueShape == null) { return outputShape; } if (raggedRank + valueShape.length !== outputShape.length) { throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.rank = ${raggedRank + valueShape.length}, but shape.rank = ${outputShape.length}`); } for (let i = 1; i < valueShape.length; ++i) { const valueDim = valueShape[i]; const outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i]; const outputShapeDim = outputShape[outputShapeDimIndex]; if (valueDim >= 0) { if (outputShapeDim >= 0) { if (outputShapeDim !== valueDim) { throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.shape[${i + raggedRank}] = ${valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`); } } else { outputShape[outputShapeDimIndex] = valueDim; } } } return outputShape; } function getRowPartitionTypesHelper(rowPartitionTypeStrings) { const stringToType = { 'FIRST_DIM_SIZE': RowPartitionType$1.FIRST_DIM_SIZE, 'VALUE_ROWIDS': RowPartitionType$1.VALUE_ROWIDS, 'ROW_LENGTHS': RowPartitionType$1.ROW_LENGTHS, 'ROW_SPLITS': RowPartitionType$1.ROW_SPLITS, 'ROW_LIMITS': RowPartitionType$1.ROW_LIMITS, 'ROW_STARTS': RowPartitionType$1.ROW_STARTS }; const result = []; for (const typeStr of rowPartitionTypeStrings) { if (typeStr in stringToType) { result.push(stringToType[typeStr]); } else { break; } } return result; } function getRaggedRank(rowPartitionTypes) { if (rowPartitionTypes.length === 0) { return 0; } if (rowPartitionTypes[0] === RowPartitionType$1.FIRST_DIM_SIZE) { return rowPartitionTypes.length - 1; } return rowPartitionTypes.length; } function validateDefaultValueShape(defaultValueShape, valueShape) { if (defaultValueShape == null || valueShape == null) { return; } const defaultNDims = defaultValueShape.length; const valuesNDims = valueShape.length; if (defaultNDims >= valuesNDims) { throw new Error(`defaultValue.shape=${defaultValueShape} and ragged tensor flatValues.shape=${valueShape}, are incompatible: defaultValue.rank = ${defaultNDims} must be less than ragged tensor input flatValues.rank = ${valuesNDims})`); } for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) { const defaultDim = defaultValueShape[i]; const valueDim = valueShape[i + 1]; if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 && defaultDim !== valueDim) { throw new Error(`defaultValue.shape=${defaultValueShape}, and ragged tensor input flatValues.shape=${valueShape} are incompatible: defaultValue.shape[${i - defaultValueShape.length}] = ${defaultDim} but ragged tensor input.flatValues.shape[${i - defaultValueShape.length}] = ${valueDim}`); } } } const PARALLELIZE_THRESHOLD = 30; function computeOptimalWindowSize(inSize) { if (inSize <= PARALLELIZE_THRESHOLD) { return inSize; } return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); } function getImageCenter(center, imageHeight, imageWidth) { const centerX = imageWidth * (typeof center === 'number' ? center : center[0]); const centerY = imageHeight * (typeof center === 'number' ? center : center[1]); return [centerX, centerY]; } function getReshaped(inputShape, blockShape, prod, batchToSpace = true) { let reshaped = []; if (batchToSpace) { reshaped = reshaped.concat(blockShape.slice(0)); reshaped.push(inputShape[0] / prod); reshaped = reshaped.concat(inputShape.slice(1)); } else { reshaped = reshaped.concat(inputShape[0]); const spatialLength = blockShape.length; for (let i = 0; i < spatialLength; ++i) { reshaped = reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]); } reshaped = reshaped.concat(inputShape.slice(spatialLength + 1)); } return reshaped; } function getPermuted(reshapedRank, blockShapeRank, batchToSpace = true) { const permuted = []; if (batchToSpace) { permuted.push(blockShapeRank); for (let i = blockShapeRank + 1; i < reshapedRank; ++i) { if (i <= 2 * blockShapeRank) { permuted.push(i); permuted.push(i - (blockShapeRank + 1)); } else { permuted.push(i); } } } else { const permutedBeforeBatch = []; const permutedAfterBatch = []; for (let i = 1; i < reshapedRank; ++i) { if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) { permutedAfterBatch.push(i); } else { permutedBeforeBatch.push(i); } } permuted.push(...permutedBeforeBatch); permuted.push(0); permuted.push(...permutedAfterBatch); } return permuted; } function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace = true) { const reshapedPermuted = []; if (batchToSpace) { reshapedPermuted.push(inputShape[0] / prod); } else { reshapedPermuted.push(inputShape[0] * prod); } for (let i = 1; i < inputShape.length; ++i) { if (i <= blockShape.length) { if (batchToSpace) { reshapedPermuted.push(blockShape[i - 1] * inputShape[i]); } else { reshapedPermuted.push(inputShape[i] / blockShape[i - 1]); } } else { reshapedPermuted.push(inputShape[i]); } } return reshapedPermuted; } function getSliceBeginCoords(crops, blockShape) { const sliceBeginCoords = [0]; for (let i = 0; i < blockShape; ++i) { sliceBeginCoords.push(crops[i][0]); } return sliceBeginCoords; } function getSliceSize(uncroppedShape, crops, blockShape) { const sliceSize = uncroppedShape.slice(0, 1); for (let i = 0; i < blockShape; ++i) { sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]); } return sliceSize; } const SELU_SCALEALPHA = 1.7580993408473768599402175208123; const SELU_SCALE = 1.0507009873554804934193349852946; const ERF_P = 0.3275911; const ERF_A1 = 0.254829592; const ERF_A2 = -0.284496736; const ERF_A3 = 1.421413741; const ERF_A4 = -1.453152027; const ERF_A5 = 1.061405429; function mergeRealAndImagArrays(real, imag) { if (real.length !== imag.length) { throw new Error(`Cannot merge real and imag arrays of different lengths. real:` + `${real.length}, imag: ${imag.length}.`); } const result = new Float32Array(real.length * 2); for (let i = 0; i < result.length; i += 2) { result[i] = real[i / 2]; result[i + 1] = imag[i / 2]; } return result; } function splitRealAndImagArrays(complex) { const real = new Float32Array(complex.length / 2); const imag = new Float32Array(complex.length / 2); for (let i = 0; i < complex.length; i += 2) { real[i / 2] = complex[i]; imag[i / 2] = complex[i + 1]; } return { real, imag }; } function complexWithEvenIndex(complex) { const len = Math.ceil(complex.length / 4); const real = new Float32Array(len); const imag = new Float32Array(len); for (let i = 0; i < complex.length; i += 4) { real[Math.floor(i / 4)] = complex[i]; imag[Math.floor(i / 4)] = complex[i + 1]; } return { real, imag }; } function complexWithOddIndex(complex) { const len = Math.floor(complex.length / 4); const real = new Float32Array(len); const imag = new Float32Array(len); for (let i = 2; i < complex.length; i += 4) { real[Math.floor(i / 4)] = complex[i]; imag[Math.floor(i / 4)] = complex[i + 1]; } return { real, imag }; } function getComplexWithIndex(complex, index) { const real = complex[index * 2]; const imag = complex[index * 2 + 1]; return { real, imag }; } function assignToTypedArray(data, real, imag, index) { data[index * 2] = real; data[index * 2 + 1] = imag; } function exponents(n, inverse) { const real = new Float32Array(n / 2); const imag = new Float32Array(n / 2); for (let i = 0; i < Math.ceil(n / 2); i++) { const x = (inverse ? 2 : -2) * Math.PI * (i / n); real[i] = Math.cos(x); imag[i] = Math.sin(x); } return { real, imag }; } function exponent(k, n, inverse) { const x = (inverse ? 2 : -2) * Math.PI * (k / n); const real = Math.cos(x); const imag = Math.sin(x); return { real, imag }; } const ARROW = '->'; const ARROW_REGEX = /->/g; const COMMA = ','; const ELLIPSIS = '...'; function decodeEinsumEquation(equation, numTensors) { equation = equation.replace(/\s/g, ''); const numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) / ARROW.length; if (numArrows < 1) { throw new Error('Equations without an arrow are not supported.'); } else if (numArrows > 1) { throw new Error(`Equation must contain exactly one arrow ("${ARROW}").`); } const [inputString, outputString] = equation.split(ARROW); assert$1(inputString.indexOf(ELLIPSIS) === -1, () => `The ellipsis notation ("${ELLIPSIS}") is not supported yet.`); const inputTerms = inputString.split(COMMA); const numInputs = inputTerms.length; if (numTensors !== numInputs) { throw new Error(`Expected ${numInputs} input tensors, received ${numTensors}`); } if (numInputs > 2) { throw new Error('Support for more than 2 input tensors is not implemented yet.'); } const allDims = []; for (let i = 0; i < outputString.length; ++i) { const dimName = outputString[i]; if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) { throw new Error(`Output subscripts contain the label ${dimName} ` + `not present in the input subscripts.`); } if (allDims.indexOf(dimName) === -1) { allDims.push(dimName); } } for (let i = 0; i < inputString.length; ++i) { const dimName = inputString[i]; if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) { allDims.push(dimName); } } const idDims = new Array(inputTerms.length); for (let i = 0; i < numInputs; ++i) { if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) { throw new Error(`Found duplicate axes in input component ${inputTerms[i]}. ` + `Support for duplicate axes in input is not implemented yet.`); } idDims[i] = []; for (let j = 0; j < inputTerms[i].length; ++j) { idDims[i].push(allDims.indexOf(inputTerms[i][j])); } } const numDims = allDims.length; const numOutDims = outputString.length; const summedDims = []; for (let i = numOutDims; i < numDims; ++i) { summedDims.push(i); } return { allDims, summedDims, idDims }; } function getEinsumPermutation(nDims, idDims) { let permutationIndices = new Array(nDims); permutationIndices.fill(-1); for (let i = 0; i < idDims.length; ++i) { permutationIndices[idDims[i]] = i; } const expandDims = []; for (let i = 0; i < nDims; ++i) { if (permutationIndices[i] === -1) { expandDims.push(i); } } permutationIndices = permutationIndices.filter(d => d !== -1); return { permutationIndices, expandDims }; } function checkEinsumDimSizes(nDims, idDims, tensors) { const dimSizes = new Array(nDims); for (let i = 0; i < tensors.length; ++i) { const shape = tensors[i].shape; for (let j = 0; j < idDims[i].length; ++j) { if (dimSizes[idDims[i][j]] === undefined) { dimSizes[idDims[i][j]] = shape[j]; } else { assert$1(dimSizes[idDims[i][j]] === shape[j], () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` + `of input shaped ${JSON.stringify(shape)}, ` + `but got dimension ${shape[j]}`); } } } } function getEinsumComputePath(summedDims, idDims) { const path = summedDims; const steps = []; let nSteps = 0; if (summedDims.length === 0) { path.push(-1); } nSteps = summedDims.length + 1; for (let i = 0; i < nSteps; ++i) { steps.push([]); } const computedTermIndices = []; for (let i = 0; i < path.length; ++i) { const summedDim = path[i]; const termIndices = findTermsWithDim(idDims, summedDim); for (const termIndex of termIndices) { if (computedTermIndices.indexOf(termIndex) === -1) { steps[i].push(termIndex); computedTermIndices.push(termIndex); } } } return { path, steps }; } function isIdentityPermutation(perm) { return perm.every((dim, index) => dim === index); } function findTermsWithDim(idDims, dim) { const termIndices = []; for (let i = 0; i < idDims.length; ++i) { if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) { termIndices.push(i); } } return termIndices; } function prepareSplitSize(x, numOrSizeSplits, axis = 0) { let splitSizes = []; if (typeof (numOrSizeSplits) === 'number') { assert$1(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.'); splitSizes = new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits); } else { const numOfNegs = numOrSizeSplits.reduce((count, value) => { if (value === -1) { count += 1; } return count; }, 0); assert$1(numOfNegs <= 1, () => 'There should be only one negative value in split array.'); const negIndex = numOrSizeSplits.indexOf(-1); if (negIndex !== -1) { const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a); numOrSizeSplits[negIndex] = x.shape[axis] - total; } assert$1(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.'); splitSizes = numOrSizeSplits; } return splitSizes; } function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) { return `Received SparseTensor with denseShape[0] = 0 but indices.shape[0] = ${indicesLength}`; } function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) { return `indices(${index}, 0) is invalid: ${value} < 0`; } function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) { return `indices(${index}, 0) is invalid: ${value} >= ${limit}`; } function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) { return `only one output dimension may be -1, not both ${dim1} and ${dim2}`; } function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) { return `size ${dim} must be non-negative, not ${value}`; } function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() { return 'reshape cannot infer the missing input size for an empty tensor ' + 'unless all specified input sizes are non-zero'; } function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) { const inputSize = sizeFromShape(inputShape); const outputSize = sizeFromShape(outputShape); return `Input to reshape is a SparseTensor with ${inputSize} dense values, but the requested shape requires a multiple of ${outputSize}. inputShape=${inputShape} outputShape= ${outputShape}`; } function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) { const inputSize = sizeFromShape(inputShape); const outputSize = sizeFromShape(outputShape); return `Input to reshape is a tensor with ${inputSize} dense values, but the requested shape has ${outputSize}. inputShape=${inputShape} outputShape=${outputShape}`; } function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() { return `segment ids must be >= 0`; } function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() { return `segment ids are not increasing`; } function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) { return `Segment id ${segmentId} out of range [0, ${outputRows}), possibly because segmentIds input is not sorted.`; } function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) { return `Bad: indices[${index}] == ${indexValue} out of range [0, ${inputRows})`; } function segOpComputeOptimalWindowSize(inSize, numSegments) { let done = false; let res; if (inSize <= PARALLELIZE_THRESHOLD) { res = inSize; done = true; } else { res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); } while (!done) { if (res > numSegments || res === inSize) { done = true; } else { res = nearestDivisor(inSize, res + 1); } } return res; } function computeOutShape(aShape, axis, numSegments) { const outShape = []; const rank = aShape.length; for (let dim = 0; dim < rank; dim++) { if (dim !== axis) { outShape.push(aShape[dim]); } else { outShape.push(numSegments); } } return outShape; } function collectGatherOpShapeInfo(x, indices, axis, batchDims) { const indicesRank = indices.shape.length; const xRank = x.shape.length; if (batchDims !== 0) { if (batchDims < -indicesRank || batchDims > indicesRank) { throw new Error(`Expect batchDims in the range of [-${indicesRank}, ${indicesRank}], but got ${batchDims}`); } } if (batchDims < 0) { batchDims += indicesRank; } if (batchDims > xRank) { throw new Error(`batchDims (${batchDims}) must be less than rank(x) ( ${xRank}).`); } if (axis < batchDims) { throw new Error(`batchDims (${batchDims}) must be less than or equal to axis (${axis}).`); } for (let i = 0; i < batchDims; ++i) { if (x.shape[i] !== indices.shape[i]) { throw new Error(`x.shape[${i}]: ${x.shape[i]} should be equal to indices.shape[${i}]: ${indices.shape[i]}.`); } } const dimSize = x.shape[axis]; const outputShape = []; let batchSize = 1; let outerSize = 1; let sliceSize = 1; for (let i = 0; i < batchDims; ++i) { outputShape.push(x.shape[i]); batchSize *= x.shape[i]; } for (let i = batchDims; i < axis; i++) { outputShape.push(x.shape[i]); outerSize *= x.shape[i]; } for (let i = batchDims; i < indicesRank; i++) { outputShape.push(indices.shape[i]); } for (let i = axis + 1; i < xRank; i++) { outputShape.push(x.shape[i]); sliceSize *= x.shape[i]; } return { batchSize, sliceSize, outerSize, dimSize, outputShape }; } var segment_util = Object.freeze({ __proto__: null, collectGatherOpShapeInfo: collectGatherOpShapeInfo, computeOutShape: computeOutShape, segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize }); function fromUint8ToStringArray(vals) { try { return vals.map(val => decodeString(val)); } catch (err) { throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${err}`); } } function fromStringArrayToUint8(strings) { return strings.map(s => encodeString(s)); } var backend_util = Object.freeze({ __proto__: null, ERF_A1: ERF_A1, ERF_A2: ERF_A2, ERF_A3: ERF_A3, ERF_A4: ERF_A4, ERF_A5: ERF_A5, ERF_P: ERF_P, PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD, get RowPartitionType () { return RowPartitionType$1; }, SELU_SCALE: SELU_SCALE, SELU_SCALEALPHA: SELU_SCALEALPHA, applyActivation: applyActivation, assertAndGetBroadcastShape: assertAndGetBroadcastShape, assertAxesAreInnerMostDims: assertAxesAreInnerMostDims, assertParamsConsistent: assertParamsConsistent, assignToTypedArray: assignToTypedArray, axesAreInnerMostDims: axesAreInnerMostDims, calculateShapes: calculateShapes, checkEinsumDimSizes: checkEinsumDimSizes, checkPadOnDimRoundingMode: checkPadOnDimRoundingMode, combineLocations: combineLocations, combineRaggedTensorToTensorShapes: combineRaggedTensorToTensorShapes, complexWithEvenIndex: complexWithEvenIndex, complexWithOddIndex: complexWithOddIndex, computeConv2DInfo: computeConv2DInfo, computeConv3DInfo: computeConv3DInfo, computeDefaultPad: computeDefaultPad, computeDilation2DInfo: computeDilation2DInfo, computeOptimalWindowSize: computeOptimalWindowSize, computeOutAndReduceShapes: computeOutAndReduceShapes, computeOutShape: computeOutShape$1, computePool2DInfo: computePool2DInfo, computePool3DInfo: computePool3DInfo, convertConv2DDataFormat: convertConv2DDataFormat, decodeEinsumEquation: decodeEinsumEquation, eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne, expandShapeToKeepDim: expandShapeToKeepDim, exponent: exponent, exponents: exponents, fromStringArrayToUint8: fromStringArrayToUint8, fromUint8ToStringArray: fromUint8ToStringArray, getAxesPermutation: getAxesPermutation, getBroadcastDims: getBroadcastDims$1, getComplexWithIndex: getComplexWithIndex, getEinsumComputePath: getEinsumComputePath, getEinsumPermutation: getEinsumPermutation, getFusedBiasGradient: getFusedBiasGradient, getFusedDyActivation: getFusedDyActivation, getImageCenter: getImageCenter, getInnerMostAxes: getInnerMostAxes, getPermuted: getPermuted, getRaggedRank: getRaggedRank, getReductionAxes: getReductionAxes, getReshaped: getReshaped, getReshapedPermuted: getReshapedPermuted, getRowPartitionTypesHelper: getRowPartitionTypesHelper, getSliceBeginCoords: getSliceBeginCoords, getSliceSize: getSliceSize, getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch, getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage, getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage, getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage, getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage, getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage, getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage, getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage, getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage, getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage, getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage, getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage, getUndoAxesPermutation: getUndoAxesPermutation, isIdentityPermutation: isIdentityPermutation, log: log$2, mergeRealAndImagArrays: mergeRealAndImagArrays, prepareAndValidate: prepareAndValidate, prepareSplitSize: prepareSplitSize, segment_util: segment_util, shouldFuse: shouldFuse, slice_util: slice_util, splitRealAndImagArrays: splitRealAndImagArrays, stridesOrDilationsArePositive: stridesOrDilationsArePositive, tupleValuesAreOne: tupleValuesAreOne, upcastType: upcastType, validateDefaultValueShape: validateDefaultValueShape, validateInput: validateInput, validateUpdateShape: validateUpdateShape, warn: warn }); registerOptimizers(); const contexts = {}; const WEBGL_ATTRIBUTES = { alpha: false, antialias: false, premultipliedAlpha: false, preserveDrawingBuffer: false, depth: false, stencil: false, failIfMajorPerformanceCaveat: true }; function setWebGLContext(webGLVersion, gl) { contexts[webGLVersion] = gl; } function getWebGLContext(webGLVersion, customCanvas) { if (!(webGLVersion in contexts) || customCanvas != null) { const newCtx = getWebGLRenderingContext(webGLVersion, customCanvas); if (newCtx !== null) { contexts[webGLVersion] = newCtx; } else { console.log('Could not get context for WebGL version', webGLVersion); return null; } } const gl = contexts[webGLVersion]; if (gl == null || gl.isContextLost()) { delete contexts[webGLVersion]; return getWebGLContext(webGLVersion); } gl.disable(gl.DEPTH_TEST); gl.disable(gl.STENCIL_TEST); gl.disable(gl.BLEND); gl.disable(gl.DITHER); gl.disable(gl.POLYGON_OFFSET_FILL); gl.disable(gl.SAMPLE_COVERAGE); gl.enable(gl.SCISSOR_TEST); gl.enable(gl.CULL_FACE); gl.cullFace(gl.BACK); return contexts[webGLVersion]; } function createCanvas(webGLVersion) { if (!env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) { return new OffscreenCanvas(300, 150); } else if (typeof document !== 'undefined') { return document.createElement('canvas'); } else { throw new Error('Cannot create a canvas in this context'); } } function getWebGLRenderingContext(webGLVersion, customCanvas) { if (webGLVersion !== 1 && webGLVersion !== 2) { throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); } const canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas; canvas.addEventListener('webglcontextlost', (ev) => { ev.preventDefault(); delete contexts[webGLVersion]; }, false); if (env().getBool('SOFTWARE_WEBGL_ENABLED')) { WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false; } if (webGLVersion === 1) { return ( canvas.getContext('webgl', WEBGL_ATTRIBUTES) || canvas .getContext('experimental-webgl', WEBGL_ATTRIBUTES)); } return canvas.getContext('webgl2', WEBGL_ATTRIBUTES); } var PackingScheme; (function (PackingScheme) { PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE"; PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH"; })(PackingScheme || (PackingScheme = {})); var TextureUsage; (function (TextureUsage) { TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER"; TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD"; TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS"; TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD"; })(TextureUsage || (TextureUsage = {})); var PhysicalTextureType; (function (PhysicalTextureType) { PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16"; PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32"; PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE"; PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32"; PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16"; })(PhysicalTextureType || (PhysicalTextureType = {})); function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) { return [columns, rows]; } function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) { return matrixSize * channelsPerTexture; } function getDenseTexShape(shape) { const size = sizeFromShape(shape); const texelsNeeded = Math.ceil(size / 4); return sizeToSquarishShape(texelsNeeded); } function getPackedMatrixTextureShapeWidthHeight(rows, columns) { return [ Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2)) ]; } function getPackedRGBAArraySizeFromMatrixShape(rows, columns) { const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return w * h * 4; } function getTextureConfig( gl, textureHalfFloatExtension) { const glany = gl; let internalFormatFloat; let internalFormatHalfFloat; let internalFormatPackedHalfFloat; let internalFormatPackedFloat; let textureFormatFloat; let downloadTextureFormat; let downloadUnpackNumChannels; let defaultNumChannels; let textureTypeHalfFloat; let textureTypeFloat; if (env().getNumber('WEBGL_VERSION') === 2) { internalFormatFloat = glany.R32F; internalFormatHalfFloat = glany.R16F; internalFormatPackedHalfFloat = glany.RGBA16F; internalFormatPackedFloat = glany.RGBA32F; textureFormatFloat = glany.RED; downloadUnpackNumChannels = 4; defaultNumChannels = 1; textureTypeHalfFloat = glany.HALF_FLOAT; textureTypeFloat = glany.FLOAT; downloadTextureFormat = glany.RGBA8; } else { internalFormatFloat = gl.RGBA; internalFormatHalfFloat = gl.RGBA; internalFormatPackedHalfFloat = gl.RGBA; internalFormatPackedFloat = glany.RGBA; textureFormatFloat = gl.RGBA; downloadUnpackNumChannels = 4; defaultNumChannels = 4; textureTypeHalfFloat = textureHalfFloatExtension != null ? textureHalfFloatExtension.HALF_FLOAT_OES : null; textureTypeFloat = gl.FLOAT; downloadTextureFormat = gl.RGBA; } return { internalFormatFloat, internalFormatHalfFloat, internalFormatPackedHalfFloat, internalFormatPackedFloat, textureFormatFloat, downloadTextureFormat, downloadUnpackNumChannels, defaultNumChannels, textureTypeHalfFloat, textureTypeFloat }; } function callAndCheck(gl, func) { const returnValue = func(); if (env().getBool('DEBUG')) { checkWebGLError(gl); } return returnValue; } function checkWebGLError(gl) { const error = gl.getError(); if (error !== gl.NO_ERROR) { throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error)); } } const MIN_FLOAT16 = 5.96e-8; const MAX_FLOAT16 = 65504; function canBeRepresented(num) { if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { return true; } return false; } function getWebGLErrorMessage(gl, status) { switch (status) { case gl.NO_ERROR: return 'NO_ERROR'; case gl.INVALID_ENUM: return 'INVALID_ENUM'; case gl.INVALID_VALUE: return 'INVALID_VALUE'; case gl.INVALID_OPERATION: return 'INVALID_OPERATION'; case gl.INVALID_FRAMEBUFFER_OPERATION: return 'INVALID_FRAMEBUFFER_OPERATION'; case gl.OUT_OF_MEMORY: return 'OUT_OF_MEMORY'; case gl.CONTEXT_LOST_WEBGL: return 'CONTEXT_LOST_WEBGL'; default: return `Unknown error code ${status}`; } } function getExtensionOrThrow(gl, extensionName) { return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.'); } function createVertexShader$1(gl, vertexShaderSource) { const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.'); callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource)); callAndCheck(gl, () => gl.compileShader(vertexShader)); if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) { console.log(gl.getShaderInfoLog(vertexShader)); throw new Error('Failed to compile vertex shader.'); } return vertexShader; } function createFragmentShader(gl, fragmentShaderSource) { const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.'); callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource)); callAndCheck(gl, () => gl.compileShader(fragmentShader)); if (env().get('ENGINE_COMPILE_ONLY')) { return fragmentShader; } if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader)); throw new Error('Failed to compile fragment shader.'); } return fragmentShader; } const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g; function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) { const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog); if (lineNumberRegexResult == null) { console.log(`Couldn't parse line number in error: ${shaderInfoLog}`); console.log(shaderSource); return; } const lineNumber = +lineNumberRegexResult[1]; const shaderLines = shaderSource.split('\n'); const pad = shaderLines.length.toString().length + 2; const linesWithLineNumbers = shaderLines.map((line, lineNumber) => rightPad((lineNumber + 1).toString(), pad) + line); let maxLineLength = 0; for (let i = 0; i < linesWithLineNumbers.length; i++) { maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength); } const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1); const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber); const afterErrorLines = linesWithLineNumbers.slice(lineNumber); console.log(beforeErrorLines.join('\n')); console.log(shaderInfoLog.split('\n')[0]); console.log(`%c ${rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717'); console.log(afterErrorLines.join('\n')); } function createProgram(gl) { return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.'); } function linkProgram(gl, program) { callAndCheck(gl, () => gl.linkProgram(program)); if (env().get('ENGINE_COMPILE_ONLY')) { return; } if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Failed to link vertex and fragment shaders.'); } } function validateProgram(gl, program) { callAndCheck(gl, () => gl.validateProgram(program)); if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Shader program validation failed.'); } } function createStaticVertexBuffer(gl, data) { const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW)); return buffer; } function createStaticIndexBuffer(gl, data) { const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer)); callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW)); return buffer; } function createTexture(gl) { return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.'); } function validateTextureSize(width, height) { const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); if ((width <= 0) || (height <= 0)) { const requested = `[${width}x${height}]`; throw new Error('Requested texture size ' + requested + ' is invalid.'); } if ((width > maxTextureSize) || (height > maxTextureSize)) { const requested = `[${width}x${height}]`; const max = `[${maxTextureSize}x${maxTextureSize}]`; throw new Error('Requested texture size ' + requested + ' greater than WebGL maximum on this browser / GPU ' + max + '.'); } } function createFramebuffer(gl) { return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.'); } function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) { const loc = gl.getAttribLocation(program, attribute); if (loc === -1) { return false; } callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes)); callAndCheck(gl, () => gl.enableVertexAttribArray(loc)); return true; } function bindTextureUnit(gl, texture, textureUnit) { validateTextureUnit(gl, textureUnit); callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); } function getProgramUniformLocationOrThrow(gl, program, uniformName) { return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.'); } function getProgramUniformLocation(gl, program, uniformName) { return gl.getUniformLocation(program, uniformName); } function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) { callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit)); callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit)); } function bindColorTextureToFramebuffer(gl, texture, framebuffer) { callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0)); } function unbindColorTextureFromFramebuffer(gl, framebuffer) { callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0)); } function validateFramebuffer(gl) { const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER); if (status !== gl.FRAMEBUFFER_COMPLETE) { throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status)); } } function getFramebufferErrorMessage(gl, status) { switch (status) { case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT: return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT'; case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT: return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT'; case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS: return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS'; case gl.FRAMEBUFFER_UNSUPPORTED: return 'FRAMEBUFFER_UNSUPPORTED'; default: return `unknown error ${status}`; } } function throwIfNull(gl, returnTOrNull, failureMessage) { const tOrNull = callAndCheck(gl, () => returnTOrNull()); if (tOrNull == null) { throw new Error(failureMessage); } return tOrNull; } function validateTextureUnit(gl, textureUnit) { const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1; const glTextureUnit = textureUnit + gl.TEXTURE0; if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) { const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`; throw new Error(`textureUnit must be in ${textureUnitRange}.`); } } function getBatchDim(shape, dimsToSkip = 2) { return sizeFromShape(shape.slice(0, shape.length - dimsToSkip)); } function getRowsCols(shape) { if (shape.length === 0) { throw Error('Cannot get rows and columns of an empty shape array.'); } return [ shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1] ]; } function getShapeAs3D(shape) { let shapeAs3D = [1, 1, 1]; const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1); if (!isScalar) { shapeAs3D = [getBatchDim(shape), ...getRowsCols(shape)]; } return shapeAs3D; } function getTextureShapeFromLogicalShape(logShape, isPacked = false) { let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); let maxSizeForNarrowTex = env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE'); if (maxSizeForNarrowTex === Infinity && env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) { maxSizeForNarrowTex = maxTexSize / 2; } if (isPacked) { maxTexSize = maxTexSize * 2; maxSizeForNarrowTex = maxSizeForNarrowTex * 2; logShape = logShape.map((d, i) => i >= logShape.length - 2 ? nearestLargerEven(logShape[i]) : logShape[i]); if (logShape.length === 1) { logShape = [2, logShape[0]]; } } if (logShape.length !== 2) { const squeezeResult = squeezeShape(logShape); logShape = squeezeResult.newShape; } let size = sizeFromShape(logShape); let textureShape = null; if (logShape.length <= 1 && size <= maxTexSize) { textureShape = [1, size]; } else if (logShape.length === 2 && logShape[0] <= maxTexSize && logShape[1] <= maxTexSize) { textureShape = logShape; } else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && logShape[2] <= maxTexSize) { textureShape = [logShape[0] * logShape[1], logShape[2]]; } else if (logShape.length === 3 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] <= maxTexSize) { textureShape = [logShape[0], logShape[1] * logShape[2]]; } else if (logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTexSize && logShape[3] <= maxTexSize) { textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]]; } else if (logShape.length === 4 && logShape[0] <= maxTexSize && logShape[1] * logShape[2] * logShape[3] <= maxTexSize) { textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]]; } const isLongNarrowTex = textureShape != null && Math.max(...textureShape) > maxSizeForNarrowTex && Math.min(...textureShape) <= (isPacked ? 2 : 1) && Math.min(...textureShape) > 0; if (textureShape == null || isLongNarrowTex) { if (isPacked) { const batchDim = getBatchDim(logShape); let rows = 2, cols = 2; if (logShape.length) { [rows, cols] = getRowsCols(logShape); } size = batchDim * (rows / 2) * (cols / 2); textureShape = sizeToSquarishShape(size).map(d => d * 2); } else { textureShape = sizeToSquarishShape(size); } } return textureShape; } function isEven(n) { return n % 2 === 0; } function isReshapeFree(shape1, shape2) { shape1 = shape1.slice(-2); shape2 = shape2.slice(-2); if (arraysEqual(shape1, shape2)) { return true; } if (!shape1.length || !shape2.length) { return true; } if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || shape2[1] === 0) { return true; } if (shape1.length !== shape2.length) { const shape1Cols = shape1[shape1.length - 1]; const shape2Cols = shape2[shape2.length - 1]; if (shape1Cols === shape2Cols) { return true; } if (isEven(shape1Cols) && isEven(shape2Cols) && (shape1[0] === 1 || shape2[0] === 1)) { return true; } } return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]); } let MAX_TEXTURE_SIZE; let MAX_TEXTURES_IN_SHADER; function getWebGLMaxTextureSize(webGLVersion) { if (MAX_TEXTURE_SIZE == null) { const gl = getWebGLContext(webGLVersion); MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE); } return MAX_TEXTURE_SIZE; } function getMaxTexturesInShader(webGLVersion) { if (MAX_TEXTURES_IN_SHADER == null) { const gl = getWebGLContext(webGLVersion); MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); } return Math.min(16, MAX_TEXTURES_IN_SHADER); } function getWebGLDisjointQueryTimerVersion(webGLVersion) { if (webGLVersion === 0) { return 0; } let queryTimerVersion; const gl = getWebGLContext(webGLVersion); if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && webGLVersion === 2) { queryTimerVersion = 2; } else if (hasExtension(gl, 'EXT_disjoint_timer_query')) { queryTimerVersion = 1; } else { queryTimerVersion = 0; } return queryTimerVersion; } function hasExtension(gl, extensionName) { const ext = gl.getExtension(extensionName); return ext != null; } function isWebGLVersionEnabled(webGLVersion) { try { const gl = getWebGLContext(webGLVersion); if (gl != null) { return true; } } catch (e) { console.log('Error when getting WebGL context: ', e); return false; } return false; } function isCapableOfRenderingToFloatTexture(webGLVersion) { if (webGLVersion === 0) { return false; } const gl = getWebGLContext(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { return false; } } else { if (!hasExtension(gl, 'EXT_color_buffer_float')) { return false; } } const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); return isFrameBufferComplete; } function isDownloadFloatTextureEnabled(webGLVersion) { if (webGLVersion === 0) { return false; } const gl = getWebGLContext(webGLVersion); if (webGLVersion === 1) { if (!hasExtension(gl, 'OES_texture_float')) { return false; } if (!hasExtension(gl, 'WEBGL_color_buffer_float')) { return false; } } else { if (hasExtension(gl, 'EXT_color_buffer_float')) { return createFloatTextureAndBindToFramebuffer(gl); } const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) { const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT); return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension); } return false; } const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); return isFrameBufferComplete; } function createFloatTextureAndBindToFramebuffer(gl) { const texConfig = getTextureConfig(gl); const texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); const width = 1; const height = 1; gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null); const frameBuffer = gl.createFramebuffer(); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; gl.bindTexture(gl.TEXTURE_2D, null); gl.bindFramebuffer(gl.FRAMEBUFFER, null); gl.deleteTexture(texture); gl.deleteFramebuffer(frameBuffer); return isFrameBufferComplete; } function createHalfFloatTextureAndBindToFramebuffer( gl, textureHalfFloatExtension) { const texConfig = getTextureConfig(gl, textureHalfFloatExtension); const texture = gl.createTexture(); gl.bindTexture(gl.TEXTURE_2D, texture); const width = 1; const height = 1; gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null); const frameBuffer = gl.createFramebuffer(); gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; gl.bindTexture(gl.TEXTURE_2D, null); gl.bindFramebuffer(gl.FRAMEBUFFER, null); gl.deleteTexture(texture); gl.deleteFramebuffer(frameBuffer); return isFrameBufferComplete; } function isWebGLFenceEnabled(webGLVersion) { if (webGLVersion !== 2) { return false; } const gl = getWebGLContext(webGLVersion); const isEnabled = gl.fenceSync != null; return isEnabled; } function assertNotComplex(tensor, opName) { if (!Array.isArray(tensor)) { tensor = [tensor]; } tensor.forEach(t => { if (t != null) { assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` + 'in the WebGL backend.'); } }); } const ENV = env(); ENV.registerFlag('HAS_WEBGL', () => ENV.getNumber('WEBGL_VERSION') > 0); ENV.registerFlag('WEBGL_VERSION', () => { if (isWebGLVersionEnabled(2)) { return 2; } else if (isWebGLVersionEnabled(1)) { return 1; } return 0; }); ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false); ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV.get('WEBGL_VERSION') === 2); ENV.registerFlag('WEBGL_CPU_FORWARD', () => true); ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false); ENV.registerFlag('WEBGL_PACK', () => ENV.getBool('HAS_WEBGL')); ENV.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_CLIP', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_REDUCE', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_CONV_IM2COL', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', () => ENV.getBool('WEBGL_PACK')); ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION'))); ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION'))); ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => { const webGLVersion = ENV.getNumber('WEBGL_VERSION'); if (webGLVersion === 0) { return 0; } return getWebGLDisjointQueryTimerVersion(webGLVersion); }); ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && !isMobile()); ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION'))); ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => { return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ? false : ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE'); }); ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION'))); ENV.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION'))); ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => { const useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); return useUniforms ? 4 : 0; }); ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => { return -1; }, threshold => { if (!(typeof threshold === 'number')) { throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' + `got ${threshold}.`); } if (threshold < 0 && threshold !== -1) { throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` + `delete) or at least 0, but got ${threshold}.`); } }); ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', () => { return isMobile() ? 1 : -1; }, threshold => { if (!(typeof threshold === 'number')) { throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' + `${threshold}.`); } if (threshold < 0 && threshold !== -1) { throw new Error(`WEBGL_FLUSH_THRESHOLD must be -1 (indicating never ` + `manual flush) or at least 0, but got ${threshold}.`); } }); ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128); ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => false); ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', () => 100000); ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', () => 128); ENV.registerFlag('WEBGL_EXP_CONV', () => false); ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST')); ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity); ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false); ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', () => false); ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false); function getGlslDifferences() { let version; let attribute; let varyingVs; let varyingFs; let texture2D; let output; let defineOutput; let defineSpecialNaN; let defineSpecialInf; let defineRound; if (env().getNumber('WEBGL_VERSION') === 2) { version = '#version 300 es'; attribute = 'in'; varyingVs = 'out'; varyingFs = 'in'; texture2D = 'texture'; output = 'outputColor'; defineOutput = 'out vec4 outputColor;'; defineSpecialNaN = env().getBool('WEBGL2_ISNAN_CUSTOM') ? ` bool isnan_custom(float val) { uint floatToUint = floatBitsToUint(val); return (floatToUint & 0x7fffffffu) > 0x7f800000u; } bvec4 isnan_custom(vec4 val) { return bvec4(isnan_custom(val.x), isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w)); } #define isnan(value) isnan_custom(value) ` : ''; defineSpecialInf = ``; defineRound = ` #define round(value) newRound(value) int newRound(float value) { return int(floor(value + 0.5)); } ivec4 newRound(vec4 value) { return ivec4(floor(value + vec4(0.5))); } `; } else { version = ''; attribute = 'attribute'; varyingVs = 'varying'; varyingFs = 'varying'; texture2D = 'texture2D'; output = 'gl_FragColor'; defineOutput = ''; defineSpecialNaN = ` #define isnan(value) isnan_custom(value) bool isnan_custom(float val) { return (val > 0. || val < 1. || val == 0.) ? false : true; } bvec4 isnan_custom(vec4 val) { return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w)); } `; defineSpecialInf = ` uniform float INFINITY; bool isinf(float val) { return abs(val) == INFINITY; } bvec4 isinf(vec4 val) { return equal(abs(val), vec4(INFINITY)); } `; defineRound = ` int round(float value) { return int(floor(value + 0.5)); } ivec4 round(vec4 value) { return ivec4(floor(value + vec4(0.5))); } `; } return { version, attribute, varyingVs, varyingFs, texture2D, output, defineOutput, defineSpecialNaN, defineSpecialInf, defineRound }; } function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') { const strides = computeStrides(shape); return strides .map((stride, i) => { const line1 = `int ${coords[i]} = ${index} / ${stride}`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` : `index -= ${coords[i]} * ${stride}`; return `${line1}; ${line2};`; }) .join(''); } function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index = 'index') { const strides = computeStrides(shape); return strides .map((_, i) => { const line1 = `int ${coords[i]} = ${index} / outShapeStrides[${i}]`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * outShapeStrides[${i}]` : `index -= ${coords[i]} * outShapeStrides[${i}]`; return `${line1}; ${line2};`; }) .join(''); } function symbolicallyComputeStrides(indicesArr, variableName) { const numCoords = indicesArr.length; const shape = indicesArr.map(d => `${variableName}[${d}]`); const strides = new Array(numCoords - 1); strides[numCoords - 2] = shape[numCoords - 1]; for (let i = numCoords - 3; i >= 0; --i) { strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`; } return strides; } function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index = 'index') { const indicesArray = coords.map((_, i) => i); const strides = symbolicallyComputeStrides(indicesArray, variableName); return strides .map((_, i) => { const line1 = `int ${coords[i]} = ${index} / ${strides[i]}`; const line2 = i === strides.length - 1 ? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${strides[i]}` : `index -= ${coords[i]} * ${strides[i]}`; return `${line1}; ${line2};`; }) .join(''); } function getFlatIndexFrom3D(shape) { const strides = computeStrides(shape).map(d => d.toString()); return ` int getFlatIndex(ivec3 coords) { return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z; } `; } function getFlatIndexFrom3DOutput() { return ` int getFlatIndex(ivec3 coords) { return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z; } `; } const ENCODE_FLOAT_SNIPPET = ` const float FLOAT_MAX = 1.70141184e38; const float FLOAT_MIN = 1.17549435e-38; lowp vec4 encode_float(highp float v) { if (isnan(v)) { return vec4(255, 255, 255, 255); } highp float av = abs(v); if(av < FLOAT_MIN) { return vec4(0.0, 0.0, 0.0, 0.0); } else if(v > FLOAT_MAX) { return vec4(0.0, 0.0, 128.0, 127.0) / 255.0; } else if(v < -FLOAT_MAX) { return vec4(0.0, 0.0, 128.0, 255.0) / 255.0; } highp vec4 c = vec4(0,0,0,0); highp float e = floor(log2(av)); highp float m = exp2(fract(log2(av))) - 1.0; c[2] = floor(128.0 * m); m -= c[2] / 128.0; c[1] = floor(32768.0 * m); m -= c[1] / 32768.0; c[0] = floor(8388608.0 * m); highp float ebias = e + 127.0; c[3] = floor(ebias / 2.0); ebias -= c[3] * 2.0; c[2] += floor(ebias) * 128.0; c[3] += 128.0 * step(0.0, -v); return c / 255.0; } `; const { getBroadcastDims } = backend_util; function makeShader(inputsInfo, outputShape, program) { const prefixSnippets = []; inputsInfo.forEach(x => { const size = sizeFromShape(x.shapeInfo.logicalShape); if (x.shapeInfo.isUniform) { prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`); } else { prefixSnippets.push(`uniform sampler2D ${x.name};`); prefixSnippets.push(`uniform int offset${x.name};`); } if (program.enableShapeUniforms) { const { uniformShape } = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape); switch (uniformShape.length) { case 1: prefixSnippets.push(`uniform int ${x.name}Shape;`); break; case 2: prefixSnippets.push(`uniform ivec2 ${x.name}Shape;`); break; case 3: prefixSnippets.push(`uniform ivec3 ${x.name}Shape;`); break; case 4: prefixSnippets.push(`uniform ivec4 ${x.name}Shape;`); break; } prefixSnippets.push(`uniform ivec2 ${x.name}TexShape;`); } }); if (program.enableShapeUniforms) { switch (outputShape.logicalShape.length) { case 1: prefixSnippets.push(`uniform int outShape;`); break; case 2: prefixSnippets.push(`uniform ivec2 outShape;`); prefixSnippets.push(`uniform int outShapeStrides;`); break; case 3: prefixSnippets.push(`uniform ivec3 outShape;`); prefixSnippets.push(`uniform ivec2 outShapeStrides;`); break; case 4: prefixSnippets.push(`uniform ivec4 outShape;`); prefixSnippets.push(`uniform ivec3 outShapeStrides;`); break; } prefixSnippets.push(`uniform ivec2 outTexShape;`); } if (program.customUniforms) { program.customUniforms.forEach((d) => { prefixSnippets.push(`uniform ${d.type} ${d.name}${d.arrayIndex ? `[${d.arrayIndex}]` : ''};`); }); } const inputPrefixSnippet = prefixSnippets.join('\n'); const inputSamplingSnippet = inputsInfo .map(x => getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms)) .join('\n'); const outTexShape = outputShape.texShape; const glsl = getGlslDifferences(); const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl); let outputSamplingSnippet; let floatTextureSetOutputSnippet; let shaderPrefix = getShaderPrefix(glsl); if (outputShape.isPacked) { outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms); floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl); } else { outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms); floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl); } if (program.packedInputs) { shaderPrefix += SHADER_PACKED_PREFIX; } const source = [ shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, program.userCode ].join('\n'); return source; } function getSamplerFromInInfo(inInfo, enableShapeUniforms = false) { const shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getSamplerScalar(inInfo, enableShapeUniforms); case 1: return getSampler1D(inInfo, enableShapeUniforms); case 2: return getSampler2D(inInfo, enableShapeUniforms); case 3: return getSampler3D(inInfo, enableShapeUniforms); case 4: return getSampler4D(inInfo, enableShapeUniforms); case 5: return getSampler5D(inInfo); case 6: return getSampler6D(inInfo); default: throw new Error(`${shape.length}-D input sampling` + ` is not yet supported`); } } function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) { const shape = inInfo.shapeInfo.logicalShape; switch (shape.length) { case 0: return getPackedSamplerScalar(inInfo); case 1: return getPackedSampler1D(inInfo, enableShapeUniforms); case 2: return getPackedSampler2D(inInfo, enableShapeUniforms); case 3: return getPackedSampler3D(inInfo, enableShapeUniforms); default: return getPackedSamplerND(inInfo, enableShapeUniforms); } } function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false, enableShapeUniforms) { let res = ''; if (usesPackedTextures) { res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms); } else { res += getSamplerFromInInfo(inInfo, enableShapeUniforms); } const inShape = inInfo.shapeInfo.logicalShape; const outShape = outShapeInfo.logicalShape; if (inShape.length <= outShape.length) { if (usesPackedTextures) { res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo); } else { res += getSamplerAtOutputCoords(inInfo, outShapeInfo); } } return res; } function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms); case 2: return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms); case 3: return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms); default: return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms); } } function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) { switch (outShape.length) { case 0: return getOutputScalarCoords(); case 1: return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms); case 2: return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms); case 3: return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms); case 4: return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms); case 5: return getOutput5DCoords(outShape, outTexShape); case 6: return getOutput6DCoords(outShape, outTexShape); default: throw new Error(`${outShape.length}-D output sampling is not yet supported`); } } function getFloatTextureSampleSnippet(glsl) { return ` float sampleTexture(sampler2D textureSampler, vec2 uv) { return ${glsl.texture2D}(textureSampler, uv).r; } `; } function getFloatTextureSetRSnippet(glsl) { return ` void setOutput(float val) { ${glsl.output} = vec4(val, 0, 0, 0); } `; } function getFloatTextureSetRGBASnippet(glsl) { return ` void setOutput(vec4 val) { ${glsl.output} = val; } `; } function getShaderPrefix(glsl) { const SHADER_PREFIX = `${glsl.version} precision highp float; precision highp int; precision highp sampler2D; ${glsl.varyingFs} vec2 resultUV; ${glsl.defineOutput} const vec2 halfCR = vec2(0.5, 0.5); struct ivec5 { int x; int y; int z; int w; int u; }; struct ivec6 { int x; int y; int z; int w; int u; int v; }; uniform float NAN; ${glsl.defineSpecialNaN} ${glsl.defineSpecialInf} ${glsl.defineRound} int imod(int x, int y) { return x - y * (x / y); } int idiv(int a, int b, float sign) { int res = a / b; int mod = imod(a, b); if (sign < 0. && mod != 0) { res -= 1; } return res; } #define HASHSCALE1 443.8975 float random(float seed){ vec2 p = resultUV * seed; vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1); p3 += dot(p3, p3.yzx + 19.19); return fract((p3.x + p3.y) * p3.z); } ${SAMPLE_1D_SNIPPET} ${SAMPLE_2D_SNIPPET} ${SAMPLE_3D_SNIPPET} `; return SHADER_PREFIX; } const SAMPLE_1D_SNIPPET = ` vec2 uvFromFlat(int texNumR, int texNumC, int index) { int texR = index / texNumC; int texC = index - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } vec2 packedUVfrom1D(int texNumR, int texNumC, int index) { int texelIndex = index / 2; int texR = texelIndex / texNumC; int texC = texelIndex - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } `; const SAMPLE_2D_SNIPPET = ` vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR, int texNumC, int row, int col) { int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2); int texR = texelIndex / texNumC; int texC = texelIndex - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } `; const SAMPLE_3D_SNIPPET = ` vec2 packedUVfrom3D(int texNumR, int texNumC, int texelsInBatch, int texelsInLogicalRow, int b, int row, int col) { int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2); int texR = index / texNumC; int texC = index - texR * texNumC; return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR); } `; const SHADER_PACKED_PREFIX = ` float getChannel(vec4 frag, vec2 innerDims) { vec2 modCoord = mod(innerDims, 2.); return modCoord.x == 0. ? (modCoord.y == 0. ? frag.r : frag.g) : (modCoord.y == 0. ? frag.b : frag.a); } float getChannel(vec4 frag, int dim) { float modCoord = mod(float(dim), 2.); return modCoord == 0. ? frag.r : frag.g; } `; function getOutputScalarCoords() { return ` int getOutputCoords() { return 0; } `; } function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) { const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (packedTexShape[0] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0)); } `; } return ` int getOutputCoords() { return 2 * int(resultUV.x * ${packedTexShape[1]}.0); } `; } if (packedTexShape[1] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0)); } `; } return ` int getOutputCoords() { return 2 * int(resultUV.y * ${packedTexShape[0]}.0); } `; } if (enableShapeUniforms) { return ` int getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y); } `; } return ` int getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y); } `; } function getOutput1DCoords(shape, texShape, enableShapeUniforms) { if (texShape[0] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return int(resultUV.x * float(outTexShape[1])); } `; } return ` int getOutputCoords() { return int(resultUV.x * ${texShape[1]}.0); } `; } if (texShape[1] === 1) { if (enableShapeUniforms) { return ` int getOutputCoords() { return int(resultUV.y * float(outTexShape[0])); } `; } return ` int getOutputCoords() { return int(resultUV.y * ${texShape[0]}.0); } `; } if (enableShapeUniforms) { return ` int getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); return resTexRC.x * outTexShape[1] + resTexRC.y; } `; } return ` int getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); return resTexRC.x * ${texShape[1]} + resTexRC.y; } `; } function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { return ` ivec3 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0)); int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); int index = resTexRC.x * packedTexShape[1] + resTexRC.y; int b = index / texelsInBatch; index -= b * texelsInBatch; int r = 2 * (index / texelsInLogicalRow); int c = imod(index, texelsInLogicalRow) * 2; return ivec3(b, r, c); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const texelsInLogicalRow = Math.ceil(shape[2] / 2); const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2); return ` ivec3 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y; int b = index / ${texelsInBatch}; index -= b * ${texelsInBatch}; int r = 2 * (index / ${texelsInLogicalRow}); int c = imod(index, ${texelsInLogicalRow}) * 2; return ivec3(b, r, c); } `; } function getOutput3DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape); return ` ivec3 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; ${coordsFromIndexSnippet} return ivec3(r, c, d); } `; } const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); return ` ivec3 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} return ivec3(r, c, d); } `; } function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { return ` ivec4 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); int index = resTexRC.x * packedTexShape[1] + resTexRC.y; int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0)); int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0)); int texelsInBatchN = texelsInBatch * outShape[1]; int b2 = index / texelsInBatchN; index -= b2 * texelsInBatchN; int b = index / texelsInBatch; index -= b * texelsInBatch; int r = 2 * (index / texelsInLogicalRow); int c = imod(index, texelsInLogicalRow) * 2; return ivec4(b2, b, r, c); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2); const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2); let texelsInBatchN = texelsInBatch; let batches = ``; let coords = 'b, r, c'; for (let b = 2; b < shape.length - 1; b++) { texelsInBatchN *= shape[shape.length - b - 1]; batches = ` int b${b} = index / ${texelsInBatchN}; index -= b${b} * ${texelsInBatchN}; ` + batches; coords = `b${b}, ` + coords; } return ` ivec${shape.length} getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y; ${batches} int b = index / ${texelsInBatch}; index -= b * ${texelsInBatch}; int r = 2 * (index / ${texelsInLogicalRow}); int c = imod(index, ${texelsInLogicalRow}) * 2; return ivec${shape.length}(${coords}); } `; } function getOutput4DCoords(shape, texShape, enableShapeUniforms) { if (enableShapeUniforms) { const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape); return ` ivec4 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; ${coordsFromIndexSnippet} return ivec4(r, c, d, d2); } `; } const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape); return ` ivec4 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} return ivec4(r, c, d, d2); } `; } function getOutput5DCoords(shape, texShape) { const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape); return ` ivec5 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} ivec5 outShape = ivec5(r, c, d, d2, d3); return outShape; } `; } function getOutput6DCoords(shape, texShape) { const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape); return ` ivec6 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; ${coordsFromIndexSnippet} ivec6 result = ivec6(r, c, d, d2, d3, d4); return result; } `; } function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) { const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); } `; } return ` ivec2 getOutputCoords() { return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); } `; } const texelsInLogicalRow = Math.ceil(shape[1] / 2); if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0)); int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0)); ivec2 resTexRC = ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1])); int index = resTexRC.x * packedTexShape[1] + resTexRC.y; int r = 2 * (index / texelsInLogicalRow); int c = imod(index, texelsInLogicalRow) * 2; return ivec2(r, c); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]})); int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y; int r = 2 * (index / ${texelsInLogicalRow}); int c = imod(index, ${texelsInLogicalRow}) * 2; return ivec2(r, c); } `; } function getOutput2DCoords(shape, texShape, enableShapeUniforms) { if (arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); } `; } return ` ivec2 getOutputCoords() { return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); } `; } if (shape[1] === 1) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; return ivec2(index, 0); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; return ivec2(index, 0); } `; } if (shape[0] === 1) { if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; return ivec2(0, index); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; return ivec2(0, index); } `; } if (enableShapeUniforms) { return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1])); int index = resTexRC.x * outTexShape[1] + resTexRC.y; int r = index / outShape[1]; int c = index - r * outShape[1]; return ivec2(r, c); } `; } return ` ivec2 getOutputCoords() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]})); int index = resTexRC.x * ${texShape[1]} + resTexRC.y; int r = index / ${shape[1]}; int c = index - r * ${shape[1]}; return ivec2(r, c); } `; } function getFlatOffsetUniformName(texName) { return `offset${texName}`; } function getPackedSamplerScalar(inputInfo) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const glsl = getGlslDifferences(); return ` vec4 ${funcName}() { return ${glsl.texture2D}(${texName}, halfCR); } `; } function getSamplerScalar(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); if (inputInfo.shapeInfo.isUniform) { return `float ${funcName}() {return ${texName};}`; } const [texNumR, texNumC] = inputInfo.shapeInfo.texShape; if (texNumR === 1 && texNumC === 1) { return ` float ${funcName}() { return sampleTexture(${texName}, halfCR); } `; } const offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return ` float ${funcName}() { vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], ${offset}); return sampleTexture(${texName}, uv); } `; } const [tNumR, tNumC] = inputInfo.shapeInfo.texShape; return ` float ${funcName}() { vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset}); return sampleTexture(${texName}, uv); } `; } function getPackedSampler1D(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; const glsl = getGlslDifferences(); if (enableShapeUniforms) { return ` vec4 ${funcName}(int index) { ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); vec2 uv = packedUVfrom1D( packedTexShape[0], packedTexShape[1], index); return ${glsl.texture2D}(${texName}, uv); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; return ` vec4 ${funcName}(int index) { vec2 uv = packedUVfrom1D( ${packedTexShape[0]}, ${packedTexShape[1]}, index); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler1D(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); if (inputInfo.shapeInfo.isUniform) { return ` float ${funcName}(int index) { ${getUniformSampler(inputInfo)} } `; } const texShape = inputInfo.shapeInfo.texShape; const tNumR = texShape[0]; const tNumC = texShape[1]; if (tNumC === 1 && tNumR === 1) { return ` float ${funcName}(int index) { return sampleTexture(${texName}, halfCR); } `; } const offset = getFlatOffsetUniformName(texName); if (tNumC === 1) { if (enableShapeUniforms) { return ` float ${funcName}(int index) { vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / float(${texName}TexShape[0])); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int index) { vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0); return sampleTexture(${texName}, uv); } `; } if (tNumR === 1) { if (enableShapeUniforms) { return ` float ${funcName}(int index) { vec2 uv = vec2((float(index + ${offset}) + 0.5) / float(${texName}TexShape[1]), 0.5); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int index) { vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5); return sampleTexture(${texName}, uv); } `; } if (enableShapeUniforms) { return ` float ${funcName}(int index) { vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset}); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int index) { vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset}); return sampleTexture(${texName}, uv); } `; } function getPackedSampler2D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; const glsl = getGlslDifferences(); if (texShape != null && arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` vec4 ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return ${glsl.texture2D}(${texName}, uv); } `; } return ` vec4 ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return ${glsl.texture2D}(${texName}, uv); } `; } if (enableShapeUniforms) { return ` vec4 ${funcName}(int row, int col) { ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); int valuesPerRow = int(ceil(float(${texName}Shape[1]) / 2.0)); vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col); return ${glsl.texture2D}(${texName}, uv); } `; } const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const valuesPerRow = Math.ceil(shape[1] / 2); return ` vec4 ${funcName}(int row, int col) { vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler2D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; if (texShape != null && arraysEqual(shape, texShape)) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } const texNumR = texShape[0]; const texNumC = texShape[1]; return ` float ${funcName}(int row, int col) { vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const { newShape, keptDims } = squeezeShape(shape); const squeezedShape = newShape; if (squeezedShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); const params = ['row', 'col']; return ` ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)} float ${funcName}(int row, int col) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { return ` float ${funcName}(int row, int col) { int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1))); ${getUniformSampler(inputInfo)} } `; } const texNumR = texShape[0]; const texNumC = texShape[1]; const offset = getFlatOffsetUniformName(texName); if (texNumC === 1) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1)); vec2 uv = vec2(0.5, (index + 0.5) / float(${texName}TexShape[0])); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1)); vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumR === 1) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1)); vec2 uv = vec2((index + 0.5) / float(${texName}TexShape[1]), 0.5); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col) { float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1)); vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5); return sampleTexture(${texName}, uv); } `; } if (enableShapeUniforms) { return ` float ${funcName}(int row, int col) { int index = row * ${texName}Shape[1] + col + ${offset}; vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col) { int index = row * ${shape[1]} + col + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getPackedSampler3D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const texShape = inputInfo.shapeInfo.texShape; const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; if (shape[0] === 1) { const squeezedShape = shape.slice(1); const keptDims = [1, 2]; const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); const params = ['b', 'row', 'col']; return ` ${getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms)} vec4 ${funcName}(int b, int row, int col) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } const glsl = getGlslDifferences(); if (enableShapeUniforms) { return ` vec4 ${funcName}(int b, int row, int col) { ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); int valuesPerRow = int(ceil(float(${texName}Shape[2]) / 2.0)); int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[1]) / 2.0)); vec2 uv = packedUVfrom3D( packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col); return ${glsl.texture2D}(${texName}, uv); } `; } const texNumR = packedTexShape[0]; const texNumC = packedTexShape[1]; const valuesPerRow = Math.ceil(shape[2] / 2); const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2); return ` vec4 ${funcName}(int b, int row, int col) { vec2 uv = packedUVfrom3D( ${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler3D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const stride0 = shape[1] * shape[2]; const stride1 = shape[2]; const { newShape, keptDims } = squeezeShape(shape); const squeezedShape = newShape; if (squeezedShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); const params = ['row', 'col', 'depth']; return ` ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)} float ${funcName}(int row, int col, int depth) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { return ` float ${funcName}(int row, int col, int depth) { int index = round(dot(vec3(row, col, depth), vec3(${stride0}, ${stride1}, 1))); ${getUniformSampler(inputInfo)} } `; } const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; const flatOffset = inputInfo.shapeInfo.flatOffset; if (texNumC === stride0 && flatOffset == null) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth) { int stride1 = ${texName}Shape[2]; float texR = float(row); float texC = dot(vec2(col, depth), vec2(stride1, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth) { float texR = float(row); float texC = dot(vec2(col, depth), vec2(${stride1}, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride1 && flatOffset == null) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth) { float texR = dot(vec2(row, col), vec2(${texName}Shape[1], 1)); float texC = float(depth); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth) { float texR = dot(vec2(row, col), vec2(${shape[1]}, 1)); float texC = float(depth); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth) { int stride0 = ${texName}Shape[1] * ${texName}Shape[2]; int stride1 = ${texName}Shape[2]; int index = row * stride0 + col * stride1 + depth + ${offset}; vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth) { int index = row * ${stride0} + col * ${stride1} + depth + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getPackedSamplerND(inputInfo, enableShapeUniforms) { const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const glsl = getGlslDifferences(); if (enableShapeUniforms) { return ` vec4 ${funcName}(int b2, int b, int row, int col) { int valuesPerRow = int(ceil(float(${texName}Shape[3]) / 2.0)); int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[2]) / 2.0)); int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2); texelsInBatch *= ${texName}Shape[1]; index = b2 * texelsInBatch + index; ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0)); int texR = index / packedTexShape[1]; int texC = index - texR * packedTexShape[1]; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ${glsl.texture2D}(${texName}, uv); } `; } const shape = inputInfo.shapeInfo.logicalShape; const rank = shape.length; const texShape = inputInfo.shapeInfo.texShape; const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; const texNumR = packedTexShape[0]; const texNumC = packedTexShape[1]; const valuesPerRow = Math.ceil(shape[rank - 1] / 2); let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2); let params = `int b, int row, int col`; let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`; for (let b = 2; b < rank - 1; b++) { params = `int b${b}, ` + params; texelsInBatch *= shape[rank - b - 1]; index = `b${b} * ${texelsInBatch} + ` + index; } return ` vec4 ${funcName}(${params}) { int index = ${index}; int texR = index / ${texNumC}; int texC = index - texR * ${texNumC}; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR}); return ${glsl.texture2D}(${texName}, uv); } `; } function getSampler4D(inputInfo, enableShapeUniforms) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const stride2 = shape[3]; const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; const { newShape, keptDims } = squeezeShape(shape); if (newShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, newShape); const params = ['row', 'col', 'depth', 'depth2']; return ` ${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)} float ${funcName}(int row, int col, int depth, int depth2) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { return ` float ${funcName}(int row, int col, int depth, int depth2) { int index = round(dot(vec4(row, col, depth, depth2), vec4(${stride0}, ${stride1}, ${stride2}, 1))); ${getUniformSampler(inputInfo)} } `; } const flatOffset = inputInfo.shapeInfo.flatOffset; const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; const stride2Str = `int stride2 = ${texName}Shape[3];`; const stride1Str = `int stride1 = ${texName}Shape[2] * stride2;`; const stride0Str = `int stride0 = ${texName}Shape[1] * stride1;`; if (texNumC === stride0 && flatOffset == null) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth, int depth2) { ${stride2Str} ${stride1Str} float texR = float(row); float texC = dot(vec3(col, depth, depth2), vec3(stride1, stride2, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth, int depth2) { float texR = float(row); float texC = dot(vec3(col, depth, depth2), vec3(${stride1}, ${stride2}, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride2 && flatOffset == null) { if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth, int depth2) { float texR = dot(vec3(row, col, depth), vec3(${texName}Shape[1] * ${texName}Shape[2], ${texName}Shape[2], 1)); float texC = float(depth2); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth, int depth2) { float texR = dot(vec3(row, col, depth), vec3(${shape[1] * shape[2]}, ${shape[2]}, 1)); float texC = float(depth2); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); if (enableShapeUniforms) { return ` float ${funcName}(int row, int col, int depth, int depth2) { ${stride2Str} ${stride1Str} ${stride0Str} int index = row * stride0 + col * stride1 + depth * stride2 + depth2; vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset}); return sampleTexture(${texName}, uv); } `; } return ` float ${funcName}(int row, int col, int depth, int depth2) { int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} + depth2; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset}); return sampleTexture(${texName}, uv); } `; } function getSampler5D(inputInfo) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const stride3 = shape[4]; const stride2 = shape[3] * stride3; const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; const { newShape, keptDims } = squeezeShape(shape); if (newShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, newShape); const params = ['row', 'col', 'depth', 'depth2', 'depth3']; return ` ${getSamplerFromInInfo(newInputInfo)} float ${funcName}(int row, int col, int depth, int depth2, int depth3) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } if (inputInfo.shapeInfo.isUniform) { return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { float index = dot( vec4(row, col, depth, depth2), vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) + depth3; ${getUniformSampler(inputInfo)} } `; } const flatOffset = inputInfo.shapeInfo.flatOffset; const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; if (texNumC === stride0 && flatOffset == null) { return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { int texR = row; float texC = dot(vec4(col, depth, depth2, depth3), vec4(${stride1}, ${stride2}, ${stride3}, 1)); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride3 && flatOffset == null) { return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { float texR = dot( vec4(row, col, depth, depth2), vec4(${shape[1] * shape[2] * shape[3]}, ${shape[2] * shape[3]}, ${shape[3]}, 1)); int texC = depth3; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3) { int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} + depth2 * ${stride3} + depth3 + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getSampler6D(inputInfo) { const shape = inputInfo.shapeInfo.logicalShape; const texName = inputInfo.name; const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); const { newShape, keptDims } = squeezeShape(shape); if (newShape.length < shape.length) { const newInputInfo = squeezeInputInfo(inputInfo, newShape); const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4']; return ` ${getSamplerFromInInfo(newInputInfo)} float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { return ${funcName}(${getSqueezedParams(params, keptDims)}); } `; } const stride4 = shape[5]; const stride3 = shape[4] * stride4; const stride2 = shape[3] * stride3; const stride1 = shape[2] * stride2; const stride0 = shape[1] * stride1; if (inputInfo.shapeInfo.isUniform) { return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { int index = round(dot( vec4(row, col, depth, depth2), vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) + dot( vec2(depth3, depth4), vec2(${stride4}, 1))); ${getUniformSampler(inputInfo)} } `; } const flatOffset = inputInfo.shapeInfo.flatOffset; const texShape = inputInfo.shapeInfo.texShape; const texNumR = texShape[0]; const texNumC = texShape[1]; if (texNumC === stride0 && flatOffset == null) { return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { int texR = row; float texC = dot(vec4(col, depth, depth2, depth3), vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) + float(depth4); vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } if (texNumC === stride4 && flatOffset == null) { return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { float texR = dot(vec4(row, col, depth, depth2), vec4(${shape[1] * shape[2] * shape[3] * shape[4]}, ${shape[2] * shape[3] * shape[4]}, ${shape[3] * shape[4]}, ${shape[4]})) + float(depth3); int texC = depth4; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0); return sampleTexture(${texName}, uv); } `; } const offset = getFlatOffsetUniformName(texName); return ` float ${funcName}(int row, int col, int depth, int depth2, int depth3, int depth4) { int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} + depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset}; vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index); return sampleTexture(${texName}, uv); } `; } function getUniformSampler(inputInfo) { const texName = inputInfo.name; const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape); if (inSize < 2) { return `return ${texName};`; } return ` for (int i = 0; i < ${inSize}; i++) { if (i == index) { return ${texName}[i]; } } `; } function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) { const texName = inputInfo.name; const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); const funcName = 'get' + texFuncSnippet + 'AtOutCoords'; const inRank = inputInfo.shapeInfo.logicalShape.length; const outRank = outShapeInfo.logicalShape.length; const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); const type = getCoordsDataType(outRank); const rankDiff = outRank - inRank; let coordsSnippet; const fields = ['x', 'y', 'z', 'w', 'u', 'v']; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`) .join('\n'); } let unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map((s, i) => `coords.${fields[i + rankDiff]}`) .join(', '); } let output = `return outputValue;`; const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape); const isInputScalar = inSize === 1; const outSize = sizeFromShape(outShapeInfo.logicalShape); const isOutputScalar = outSize === 1; if (inRank === 1 && !isInputScalar && !isOutputScalar) { output = ` return vec4(outputValue.xy, outputValue.xy); `; } else if (isInputScalar && !isOutputScalar) { if (outRank === 1) { output = ` return vec4(outputValue.x, outputValue.x, 0., 0.); `; } else { output = ` return vec4(outputValue.x); `; } } else if (broadcastDims.length) { const rows = inRank - 2; const cols = inRank - 1; if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) { output = `return vec4(outputValue.x);`; } else if (broadcastDims.indexOf(rows) > -1) { output = `return vec4(outputValue.x, outputValue.y, ` + `outputValue.x, outputValue.y);`; } else if (broadcastDims.indexOf(cols) > -1) { output = `return vec4(outputValue.xx, outputValue.zz);`; } } return ` vec4 ${funcName}() { ${type} coords = getOutputCoords(); ${coordsSnippet} vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet}); ${output} } `; } function getSamplerAtOutputCoords(inputInfo, outShapeInfo) { const texName = inputInfo.name; const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); const funcName = 'get' + texFuncSnippet + 'AtOutCoords'; const outTexShape = outShapeInfo.texShape; const inTexShape = inputInfo.shapeInfo.texShape; const inRank = inputInfo.shapeInfo.logicalShape.length; const outRank = outShapeInfo.logicalShape.length; if (!inputInfo.shapeInfo.isUniform && inRank === outRank && inputInfo.shapeInfo.flatOffset == null && arraysEqual(inTexShape, outTexShape)) { return ` float ${funcName}() { return sampleTexture(${texName}, resultUV); } `; } const type = getCoordsDataType(outRank); const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); const rankDiff = outRank - inRank; let coordsSnippet; const fields = ['x', 'y', 'z', 'w', 'u', 'v']; if (inRank === 0) { coordsSnippet = ''; } else if (outRank < 2 && broadcastDims.length >= 1) { coordsSnippet = 'coords = 0;'; } else { coordsSnippet = broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`) .join('\n'); } let unpackedCoordsSnippet = ''; if (outRank < 2 && inRank > 0) { unpackedCoordsSnippet = 'coords'; } else { unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape .map((s, i) => `coords.${fields[i + rankDiff]}`) .join(', '); } return ` float ${funcName}() { ${type} coords = getOutputCoords(); ${coordsSnippet} return get${texFuncSnippet}(${unpackedCoordsSnippet}); } `; } function getCoordsDataType(rank) { if (rank <= 1) { return 'int'; } else if (rank === 2) { return 'ivec2'; } else if (rank === 3) { return 'ivec3'; } else if (rank === 4) { return 'ivec4'; } else if (rank === 5) { return 'ivec5'; } else if (rank === 6) { return 'ivec6'; } else { throw Error(`GPU for rank ${rank} is not yet supported`); } } function getUniformInfoFromShape(isPacked, shape, texShape) { const { newShape, keptDims } = squeezeShape(shape); const rank = shape.length; const useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1; const squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape; const useSqueezeShape = (!isPacked && rank > 1 && !arraysEqual(shape, texShape) && newShape.length < rank) || useSqueezePackedShape; const uniformShape = useSqueezeShape ? squeezeShape$1 : shape; return { useSqueezeShape, uniformShape, keptDims }; } function squeezeInputInfo(inInfo, squeezedShape) { const newInputInfo = JSON.parse(JSON.stringify(inInfo)); newInputInfo.shapeInfo.logicalShape = squeezedShape; return newInputInfo; } function getSqueezedParams(params, keptDims) { return keptDims.map(d => params[d]).join(', '); } function compileProgram(gpgpu, program, inputs, output) { const inputInfos = inputs.map((input, i) => { const shapeInfo = { logicalShape: input.shape, texShape: input.isUniform ? null : input.texData.texShape, isUniform: input.isUniform, isPacked: input.isUniform ? false : input.texData.isPacked, flatOffset: null }; if (input.texData != null && input.texData.slice != null && input.texData.slice.flatOffset > 0) { shapeInfo.flatOffset = input.texData.slice.flatOffset; } return { name: program.variableNames[i], shapeInfo }; }); const inShapeInfos = inputInfos.map(x => x.shapeInfo); const outShapeInfo = { logicalShape: output.shape, texShape: output.texData.texShape, isUniform: false, isPacked: output.texData.isPacked, flatOffset: null }; const source = makeShader(inputInfos, outShapeInfo, program); const fragmentShader = createFragmentShader(gpgpu.gl, source); const webGLProgram = gpgpu.createProgram(fragmentShader); if (!env().get('ENGINE_COMPILE_ONLY')) { gpgpu.buildVao(webGLProgram); return Object.assign({ program, fragmentShader, source, webGLProgram, inShapeInfos, outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram)); } else { return { program, fragmentShader, source, webGLProgram, inShapeInfos, outShapeInfo, variablesLocations: null, customUniformLocations: null, infLoc: null, nanLoc: null, outShapeLocation: null, outShapeStridesLocation: null, outTexShapeLocation: null }; } } function getUniformLocations(gpgpu, program, webGLProgram) { const variablesLocations = []; const customUniformLocations = []; let outShapeLocation; let outTexShapeLocation; let outShapeStridesLocation; let infLoc = null; let nanLoc = null; nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); if (env().getNumber('WEBGL_VERSION') === 1) { infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); } const shouldThrow = false; for (const varName of program.variableNames) { const varLocs = { name: varName, uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow), offset: gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow), }; if (program.enableShapeUniforms) { varLocs.shape = gpgpu.getUniformLocation(webGLProgram, `${varName}Shape`, shouldThrow); varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, `${varName}TexShape`, shouldThrow); } variablesLocations.push(varLocs); } if (program.enableShapeUniforms) { outShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow); outShapeStridesLocation = gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow); outTexShapeLocation = gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow); } if (program.customUniforms) { for (const d of program.customUniforms) { customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow)); } } return { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation }; } function validateBinaryAndProgram(shapeInfos, inputs) { if (shapeInfos.length !== inputs.length) { throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` + `was executed with ${inputs.length} inputs`); } shapeInfos.forEach((s, i) => { const shapeA = s.logicalShape; const input = inputs[i]; const shapeB = input.shape; if (!arraysEqual(shapeA, shapeB)) { throw Error(`Binary was compiled with different shapes than ` + `the current args. Shapes ${shapeA} and ${shapeB} must match`); } if (s.isUniform && input.isUniform) { return; } const texShapeA = s.texShape; const texShapeB = input.isUniform ? null : input.texData.texShape; if (!arraysEqual(texShapeA, texShapeB)) { throw Error(`Binary was compiled with different texture shapes than the` + ` current args. Shape ${texShapeA} and ${texShapeB} must match`); } }); } function runProgram(gpgpu, binary, inputs, output, customUniformValues) { if (!binary.program.enableShapeUniforms) { validateBinaryAndProgram(binary.inShapeInfos, inputs); validateBinaryAndProgram([binary.outShapeInfo], [output]); } const outTex = output.texData.texture; const outTexShape = output.texData.texShape; if (output.texData.isPacked) { gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]); } else { gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]); } gpgpu.setProgram(binary.webGLProgram); gpgpu.bindVertexArray(binary.webGLProgram.vao); if (env().getNumber('WEBGL_VERSION') === 1) { if (binary.infLoc !== null) { gpgpu.gl.uniform1f(binary.infLoc, Infinity); } } if (binary.nanLoc !== null) { gpgpu.gl.uniform1f(binary.nanLoc, NaN); } for (let i = 0; i < inputs.length; ++i) { const input = inputs[i]; const { uniform: varLoc, offset: varOffsetLoc, shape: varShapeLoc, texShape: varTexShapeLoc, } = binary.variablesLocations[i]; if (varShapeLoc) { const { uniformShape } = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape); switch (uniformShape.length) { case 1: gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape)); break; case 2: gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape)); break; case 3: gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape)); break; case 4: gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape)); break; } } if (varTexShapeLoc) { gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]); } if (varLoc == null) { continue; } if (input.isUniform) { if (sizeFromShape(input.shape) < 2) { gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]); } else { let vals = input.uniformValues; if (!(vals instanceof Float32Array)) { vals = new Float32Array(vals); } gpgpu.gl.uniform1fv(varLoc, vals); } continue; } if (input.texData.slice != null && varOffsetLoc != null) { gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); } gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i); } const outShapeLoc = binary.outShapeLocation; if (outShapeLoc) { switch (output.shape.length) { case 1: gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape)); break; case 2: gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape)); break; case 3: gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape)); break; case 4: gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape)); break; } } if (binary.outShapeStridesLocation) { const strides = computeStrides(output.shape); switch (output.shape.length) { case 2: gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; case 3: gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; case 4: gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides)); break; } } if (binary.outTexShapeLocation) { gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]); } if (binary.program.customUniforms && customUniformValues) { for (let i = 0; i < binary.program.customUniforms.length; ++i) { const d = binary.program.customUniforms[i]; const customLoc = binary.customUniformLocations[i]; const customValue = customUniformValues[i]; if (d.type === 'float') { gpgpu.gl.uniform1fv(customLoc, customValue); } else if (d.type === 'vec2') { gpgpu.gl.uniform2fv(customLoc, customValue); } else if (d.type === 'vec3') { gpgpu.gl.uniform3fv(customLoc, customValue); } else if (d.type === 'vec4') { gpgpu.gl.uniform4fv(customLoc, customValue); } else if (d.type === 'int') { gpgpu.gl.uniform1iv(customLoc, customValue); } else if (d.type === 'ivec2') { gpgpu.gl.uniform2iv(customLoc, customValue); } else if (d.type === 'ivec3') { gpgpu.gl.uniform3iv(customLoc, customValue); } else if (d.type === 'ivec4') { gpgpu.gl.uniform4iv(customLoc, customValue); } else { throw Error(`uniform type ${d.type} is not supported yet.`); } } } gpgpu.executeProgram(); } function makeShaderKey(program, inputs, output) { let keyInputs = ''; inputs.concat(output).forEach(x => { const hasOffset = x.texData != null && x.texData.slice != null && x.texData.slice.flatOffset > 0; if (program.enableShapeUniforms && !x.isUniform) { const xTexShape = x.texData.texShape; const { useSqueezeShape, uniformShape, keptDims } = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape); let rank1 = '', rank2 = '', rank34 = ''; if (uniformShape.length === 1 && program.packedInputs) { const packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)]; rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`; } else if (uniformShape.length === 2 && !program.packedInputs) { rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`; } else if (uniformShape.length > 2 && !program.packedInputs) { const strides = computeStrides(uniformShape); rank34 = `${strides[0] === xTexShape[1]}_${strides[strides.length - 1] === xTexShape[1]}`; } const xRank = x.shape.length; const isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape); const isScalar = sizeFromShape(x.shape) === 1; const broadcastDims = getBroadcastDims$1(x.shape, output.shape); const isInOutTexShapeEqual = !program.packedInputs && xRank === output.shape.length && arraysEqual(xTexShape, output.texData.texShape); const isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ? '' : `${xTexShape[0] > 1}_${xTexShape[1] > 1}`; keyInputs += `${xRank}_${isInOutTexShapeEqual}_${useSqueezeShape ? keptDims : ''}_${uniformShape.length}_${isScalar}_${broadcastDims}_${isLogicalShapTexShapeEqual}_${rank1}_${rank2}_${rank34}_${isTexShapeGreaterThanOne}_${hasOffset}`; } else { const texShape = x.isUniform ? 'uniform' : x.texData.texShape; keyInputs += `${x.shape}_${texShape}_${hasOffset}`; } }); const keyUserCode = program.userCode; let key = program.constructor.name; key += '_' + keyInputs + '_' + keyUserCode + `${env().getNumber('WEBGL_VERSION')}`; return key; } function useShapeUniforms(rank) { return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4; } class DecodeMatrixProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.outPackingScheme = PackingScheme.DENSE; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` ivec3 outCoordsFromFlatIndex(int index) { ${this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)} return ivec3(r, c, d); } void main() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1])); int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y); vec4 result = vec4(0.); for (int i=0; i<4; i++) { int flatIndex = index + i; ivec3 rc = outCoordsFromFlatIndex(flatIndex); result[i] = getA(rc.x, rc.y, rc.z); } ${glsl.output} = result; } `; } } class DecodeMatrixPackedProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outPackingScheme = PackingScheme.DENSE; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` ivec3 outCoordsFromFlatIndex(int index) { ${this.enableShapeUniforms ? getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)} return ivec3(r, c, d); } void main() { ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1])); int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y); vec4 result = vec4(0.); for (int i=0; i<4; i++) { int flatIndex = index + i; ivec3 rc = outCoordsFromFlatIndex(flatIndex); result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z)); } ${glsl.output} = result; } `; } } class EncodeFloatProgram { constructor(outputShape) { this.variableNames = ['A']; this.outTexUsage = TextureUsage.DOWNLOAD; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.userCode = ` ${ENCODE_FLOAT_SNIPPET} void main() { float x = getAAtOutCoords(); ${glsl.output} = encode_float(x); } `; } } class EncodeFloatPackedProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = false; this.outTexUsage = TextureUsage.DOWNLOAD; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.userCode = ` ${ENCODE_FLOAT_SNIPPET} void main() { ivec3 coords = getOutputCoords(); float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z)); ${glsl.output} = encode_float(x); } `; } } const CHANNEL_CHAR_TO_INDEX_MAP = { 'R': 0, 'G': 1, 'B': 2, 'A': 3 }; class EncodeMatrixProgram { constructor(outputShape, inputIsUnsignedByte = false, usedChannels = 'RGBA') { this.variableNames = ['A']; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); let output = `result`; if (inputIsUnsignedByte) { output = `floor(result * 255. + 0.5)`; } let mainLoop = ''; for (let usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) { const curChannel = usedChannels[usedChannelIndex]; mainLoop += ` if(offset == ${usedChannelIndex}) { result = values[${CHANNEL_CHAR_TO_INDEX_MAP[curChannel]}]; }`; } this.userCode = ` ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)} void main() { ivec3 coords = getOutputCoords(); int flatIndex = getFlatIndex(coords); float result = 0.; int offset = imod(flatIndex, ${usedChannels.length}); flatIndex = idiv(flatIndex, ${usedChannels.length}, 1.); int r = flatIndex / texShape[1]; if (r < texShape[0]) { int c = imod(flatIndex, texShape[1]); vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]); vec4 values = ${glsl.texture2D}(A, uv); ${mainLoop} } ${glsl.output} = vec4(${output}, 0., 0., 0.); } `; } } class EncodeMatrixPackedProgram { constructor(outputShape, inputIsUnsignedByte = false) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.customUniforms = [{ name: 'texShape', type: 'ivec2' }]; const glsl = getGlslDifferences(); this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); let mainLoop = ''; let output = 'result'; if (inputIsUnsignedByte) { output = 'floor(result * 255. + 0.5)'; } for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { const channel = row * 2 + col; mainLoop += ` localCoords = coords; if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) { localCoords[2] += ${col}; if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) { localCoords[1] += ${row}; flatIndex = getFlatIndex(localCoords); offset = imod(flatIndex, 4); flatIndex = idiv(flatIndex, 4, 1.); int r = flatIndex / texShape[1]; int c = imod(flatIndex, texShape[1]); vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]); values = ${glsl.texture2D}(A, uv); if (offset == 0) { result[${channel}] = values[0]; } else if (offset == 1) { result[${channel}] = values[1]; } else if (offset == 2) { result[${channel}] = values[2]; } else { result[${channel}] = values[3]; } } } `; } } this.userCode = ` ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)} void main() { ivec3 coords = getOutputCoords(); vec4 result = vec4(0.); int flatIndex, r, c, offset; ivec3 localCoords; vec2 uv; vec4 values; ${mainLoop} ${glsl.output} = ${output}; } `; } } function createVertexShader(gl) { const glsl = getGlslDifferences(); const vertexShaderSource = `${glsl.version} precision highp float; ${glsl.attribute} vec3 clipSpacePos; ${glsl.attribute} vec2 uv; ${glsl.varyingVs} vec2 resultUV; void main() { gl_Position = vec4(clipSpacePos, 1); resultUV = uv; }`; return createVertexShader$1(gl, vertexShaderSource); } function createVertexBuffer(gl) { const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]); return createStaticVertexBuffer(gl, vertexArray); } function createIndexBuffer(gl) { const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]); return createStaticIndexBuffer(gl, triangleVertexIndices); } function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) { validateTextureSize(width, height); const texture = createTexture(gl); const tex2d = gl.TEXTURE_2D; callAndCheck(gl, () => gl.bindTexture(tex2d, texture)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST)); callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST)); if (env().getNumber('WEBGL_VERSION') === 1) { callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null)); } else { callAndCheck(gl, () => gl .texStorage2D(tex2d, 1, internalFormat, width, height)); } callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); return { texture, texShape: [height, width] }; } function getInternalFormatForFloat32MatrixTexture(textureConfig) { return textureConfig.internalFormatFloat; } function createFloat32MatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT); } function getInternalFormatForFloat16MatrixTexture(textureConfig) { return textureConfig.internalFormatHalfFloat; } function createFloat16MatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat); } function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) { return textureConfig.downloadTextureFormat; } function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE); } function getInternalFormatForPackedMatrixTexture(textureConfig) { return textureConfig.internalFormatPackedFloat; } function createPackedMatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT); } function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) { return textureConfig.internalFormatPackedHalfFloat; } function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) { const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat); } function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) { const posOffset = 0; const uvOffset = 3 * 4; const stride = (3 * 4) + (2 * 4); callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer)); const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); return success && bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset); } function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) { callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); let dataForUpload, texelDataType, internalFormat; if (data instanceof Uint8Array) { dataForUpload = new Uint8Array(width * height * 4); texelDataType = gl.UNSIGNED_BYTE; internalFormat = gl.RGBA; } else { dataForUpload = new Float32Array(width * height * 4); texelDataType = gl.FLOAT; internalFormat = textureConfig.internalFormatPackedFloat; } dataForUpload.set(data); if (env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload)); } else { callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload)); } callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } function uploadPixelDataToTexture(gl, texture, pixels) { callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); if (pixels.data instanceof Uint8Array) { if (env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data)); } else { callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data)); } } else { if (env().getNumber('WEBGL_VERSION') === 2) { callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels)); } else { callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels)); } } callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) { const buffer = gl2.createBuffer(); callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer)); const bytesPerFloat = 4; const valuesPerTexel = 4; const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ)); callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0)); callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null)); return buffer; } function downloadFloat32MatrixFromBuffer(gl, buffer, size) { const gl2 = gl; const downloadTarget = new Float32Array(size); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); return downloadTarget; } function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) { const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns); const numChannels = 4; const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels)); callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget)); return new Float32Array(downloadTarget.buffer); } function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) { const gl2 = gl; const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols)); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); return downloadTarget; } function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) { const packedRGBA = new Float32Array(physicalRows * physicalCols * 4); callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA)); return packedRGBA; } class GPGPUContext { constructor(gl) { this.outputTexture = null; this.program = null; this.disposed = false; this.itemsToPoll = []; const glVersion = env().getNumber('WEBGL_VERSION'); if (gl != null) { this.gl = gl; setWebGLContext(glVersion, gl); } else { this.gl = getWebGLContext(glVersion); } gl = this.gl; if (env().getNumber('WEBGL_VERSION') === 2) { const gl2 = gl; this.createVertexArray = () => { return callAndCheck(gl2, () => gl2.createVertexArray()); }; this.bindVertexArray = (vao) => { return callAndCheck(gl2, () => gl2.bindVertexArray(vao)); }; this.deleteVertexArray = (vao) => { return callAndCheck(gl2, () => gl2.deleteVertexArray(vao)); }; this.getVertexArray = () => { return callAndCheck(gl2, () => gl2.getParameter(gl2.VERTEX_ARRAY_BINDING)); }; } else if (gl != null) { const ext = gl.getExtension('OES_vertex_array_object'); if (ext == null) { throw new Error('All WebGL1 implementations are expected to offer' + ' OES_vertex_array_object.'); } this.createVertexArray = () => { return callAndCheck(gl, () => ext.createVertexArrayOES()); }; this.bindVertexArray = (vao) => { return callAndCheck(gl, () => ext.bindVertexArrayOES(vao)); }; this.deleteVertexArray = (vao) => { return callAndCheck(gl, () => ext.deleteVertexArrayOES(vao)); }; this.getVertexArray = () => { return callAndCheck(gl, () => gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES)); }; } let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float'; const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; this.parallelCompilationExtension = this.gl.getExtension('KHR_parallel_shader_compile'); if (env().getNumber('WEBGL_VERSION') === 1) { const TEXTURE_FLOAT = 'OES_texture_float'; const TEXTURE_HALF_FLOAT = 'OES_texture_half_float'; this.textureFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_FLOAT); if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) { this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT); } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error('GL context does not support half float textures, yet the ' + 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); } this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT); } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error('GL context does not support color renderable half floats, yet ' + 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); } } else { COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float'; if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) { this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); } else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { this.colorBufferHalfFloatExtension = this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT); } else { throw new Error('GL context does not support color renderable floats'); } } this.vertexBuffer = createVertexBuffer(this.gl); this.indexBuffer = createIndexBuffer(this.gl); this.framebuffer = createFramebuffer(this.gl); this.textureConfig = getTextureConfig(this.gl, this.textureHalfFloatExtension); } get debug() { return env().getBool('DEBUG'); } dispose() { if (this.disposed) { return; } if (this.program != null) { console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + ' This is probably a resource leak, delete the program with ' + 'GPGPUContext.deleteProgram before disposing.'); } if (this.outputTexture != null) { console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + 'texture. This is probably a resource leak, delete the output ' + 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + 'disposing.'); } const gl = this.gl; callAndCheck(gl, () => gl.finish()); callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer)); callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null)); callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null)); callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer)); this.disposed = true; } createFloat32MatrixTexture(rows, columns) { this.throwIfDisposed(); return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig); } createFloat16MatrixTexture(rows, columns) { this.throwIfDisposed(); return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig); } createUnsignedBytesMatrixTexture(rows, columns) { this.throwIfDisposed(); return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig); } uploadPixelDataToTexture(texture, pixels) { this.throwIfDisposed(); uploadPixelDataToTexture(this.gl, texture, pixels); } uploadDenseMatrixToTexture(texture, width, height, data) { this.throwIfDisposed(); uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig); } createFloat16PackedMatrixTexture(rows, columns) { this.throwIfDisposed(); return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig); } createPackedMatrixTexture(rows, columns) { this.throwIfDisposed(); return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig); } deleteMatrixTexture(texture) { this.throwIfDisposed(); if (this.outputTexture === texture) { unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); this.outputTexture = null; } callAndCheck(this.gl, () => this.gl.deleteTexture(texture)); } downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) { return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig)); } downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) { return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols); } downloadFloat32MatrixFromBuffer(buffer, size) { return downloadFloat32MatrixFromBuffer(this.gl, buffer, size); } createBufferFromTexture(texture, rows, columns) { this.bindTextureToFrameBuffer(texture); const result = createBufferFromOutputTexture(this.gl, rows, columns); this.unbindTextureToFrameBuffer(); return result; } createAndWaitForFence() { const fenceContext = this.createFence(this.gl); return this.pollFence(fenceContext); } createFence(gl) { let query; let isFencePassed; if (env().getBool('WEBGL_FENCE_API_ENABLED')) { const gl2 = gl; const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0); gl.flush(); isFencePassed = () => { const status = gl2.clientWaitSync(sync, 0, 0); return status === gl2.ALREADY_SIGNALED || status === gl2.CONDITION_SATISFIED; }; query = sync; } else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { query = this.beginQuery(); this.endQuery(); isFencePassed = () => this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } else { isFencePassed = () => true; } return { query, isFencePassed }; } downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) { return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols)); } createProgram(fragmentShader) { this.throwIfDisposed(); const gl = this.gl; if (this.vertexShader == null) { this.vertexShader = createVertexShader(gl); } const program = createProgram(gl); callAndCheck(gl, () => gl.attachShader(program, this.vertexShader)); callAndCheck(gl, () => gl.attachShader(program, fragmentShader)); linkProgram(gl, program); const program2 = Object.assign(program, { vao: this.createVertexArray() }); if (this.debug) { validateProgram(gl, program2); } return program2; } buildVao(program) { this.setProgram(program); this.bindVertexArray(program.vao); const gl = this.gl; callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer)); bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer); } deleteProgram(program) { this.throwIfDisposed(); if (program === this.program) { this.program = null; } if (program != null) { callAndCheck(this.gl, () => this.gl.deleteProgram(program)); this.deleteVertexArray(program.vao); } } setProgram(program) { this.throwIfDisposed(); this.program = program; if (this.program != null) { if (this.debug) { validateProgram(this.gl, this.program); } } callAndCheck(this.gl, () => this.gl.useProgram(program)); } getUniformLocation(program, uniformName, shouldThrow = true) { this.throwIfDisposed(); if (shouldThrow) { return getProgramUniformLocationOrThrow(this.gl, program, uniformName); } else { return getProgramUniformLocation(this.gl, program, uniformName); } } getAttributeLocation(program, attribute) { this.throwIfDisposed(); return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute)); } getUniformLocationNoThrow(program, uniformName) { this.throwIfDisposed(); return this.gl.getUniformLocation(program, uniformName); } setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) { this.throwIfDisposed(); this.throwIfNoProgram(); bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit); } setOutputMatrixTexture(outputMatrixTexture, rows, columns) { this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows); } setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) { this.throwIfDisposed(); const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns); this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height); } setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) { this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows); } setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) { throw new Error('setOutputPackedMatrixWriteRegion not implemented.'); } debugValidate() { if (this.program != null) { validateProgram(this.gl, this.program); } validateFramebuffer(this.gl); } executeProgram() { this.throwIfDisposed(); this.throwIfNoProgram(); const gl = this.gl; if (this.debug) { const boundVao = this.getVertexArray(); console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!'); this.debugValidate(); } callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0)); } blockUntilAllProgramsCompleted() { this.throwIfDisposed(); callAndCheck(this.gl, () => this.gl.finish()); } getQueryTimerExtension() { if (this.disjointQueryTimerExtension == null) { this.disjointQueryTimerExtension = getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : 'EXT_disjoint_timer_query'); } return this.disjointQueryTimerExtension; } getQueryTimerExtensionWebGL2() { return this.getQueryTimerExtension(); } getQueryTimerExtensionWebGL1() { return this.getQueryTimerExtension(); } beginQuery() { if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl; const ext = this.getQueryTimerExtensionWebGL2(); const query = gl2.createQuery(); gl2.beginQuery(ext.TIME_ELAPSED_EXT, query); return query; } const ext = this.getQueryTimerExtensionWebGL1(); const query = ext.createQueryEXT(); ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query); return query; } endQuery() { if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { const gl2 = this.gl; const ext = this.getQueryTimerExtensionWebGL2(); gl2.endQuery(ext.TIME_ELAPSED_EXT); return; } const ext = this.getQueryTimerExtensionWebGL1(); ext.endQueryEXT(ext.TIME_ELAPSED_EXT); } async waitForQueryAndGetTime(query) { await repeatedTry(() => this.disposed || this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))); return this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); } getQueryTime(query, queryTimerVersion) { if (queryTimerVersion === 0) { return null; } if (queryTimerVersion === 2) { const gl2 = this.gl; const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); return timeElapsedNanos / 1000000; } else { const ext = this.getQueryTimerExtensionWebGL1(); const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); return timeElapsedNanos / 1000000; } } isQueryAvailable(query, queryTimerVersion) { if (queryTimerVersion === 0) { return true; } if (queryTimerVersion === 2) { const gl2 = this.gl; const ext = this.getQueryTimerExtensionWebGL2(); const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE); if (this.disjoint == null) { this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; } else { const ext = this.getQueryTimerExtensionWebGL1(); const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT); if (this.disjoint == null) { this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); } return available && !this.disjoint; } } pollFence(fenceContext) { return new Promise(resolve => { this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve()); }); } pollItems() { const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn)); for (let i = 0; i <= index; ++i) { const { resolveFn } = this.itemsToPoll[i]; resolveFn(); } this.itemsToPoll = this.itemsToPoll.slice(index + 1); } addItemToPoll(isDoneFn, resolveFn) { this.itemsToPoll.push({ isDoneFn, resolveFn }); if (this.itemsToPoll.length > 1) { return; } let scheduleFn = undefined; if ('setTimeoutCustom' in env().platform) { scheduleFn = env().platform.setTimeoutCustom.bind(env().platform); } repeatedTry(() => { this.pollItems(); return this.itemsToPoll.length === 0; }, () => 0, null, scheduleFn); } bindTextureToFrameBuffer(texture) { this.throwIfDisposed(); bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer); if (this.debug) { validateFramebuffer(this.gl); } } unbindTextureToFrameBuffer() { if (this.outputTexture != null) { bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer); if (this.debug) { validateFramebuffer(this.gl); } } else { unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); } } downloadMatrixDriver(texture, downloadAndDecode) { this.bindTextureToFrameBuffer(texture); const result = downloadAndDecode(); this.unbindTextureToFrameBuffer(); return result; } setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) { this.throwIfDisposed(); const gl = this.gl; bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer); if (this.debug) { validateFramebuffer(gl); } this.outputTexture = outputMatrixTextureMaybePacked; callAndCheck(gl, () => gl.viewport(0, 0, width, height)); callAndCheck(gl, () => gl.scissor(0, 0, width, height)); } setOutputMatrixWriteRegionDriver(x, y, width, height) { this.throwIfDisposed(); callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height)); } throwIfDisposed() { if (this.disposed) { throw new Error('Attempted to use disposed GPGPUContext.'); } } throwIfNoProgram() { if (this.program == null) { throw new Error('No GPU program is currently set.'); } } } function linearSearchLastTrue(arr) { let i = 0; for (; i < arr.length; ++i) { const isDone = arr[i](); if (!isDone) { break; } } return i - 1; } function simpleAbsImpl(vals) { const resultValues = new Float32Array(vals.length); for (let i = 0; i < vals.length; ++i) { resultValues[i] = Math.abs(vals[i]); } return resultValues; } function createSimpleBinaryKernelImpl(op) { return (aShape, bShape, aVals, bVals, dtype) => { const newShape = assertAndGetBroadcastShape(aShape, bShape); const resultRank = newShape.length; const resultStrides = computeStrides(newShape); const resultSize = sizeFromShape(newShape); const result = getTypedArrayFromDType(dtype, resultSize); const aRank = aShape.length; const bRank = bShape.length; const aStrides = computeStrides(aShape); const bStrides = computeStrides(bShape); const aBroadcastDims = getBroadcastDims$1(aShape, newShape); const bBroadcastDims = getBroadcastDims$1(bShape, newShape); if (aBroadcastDims.length + bBroadcastDims.length === 0) { for (let i = 0; i < result.length; ++i) { result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); } } else { for (let i = 0; i < result.length; ++i) { const loc = indexToLoc(i, resultRank, resultStrides); const aLoc = loc.slice(-aRank); aBroadcastDims.forEach(d => aLoc[d] = 0); const aIndex = locToIndex(aLoc, aRank, aStrides); const bLoc = loc.slice(-bRank); bBroadcastDims.forEach(d => bLoc[d] = 0); const bIndex = locToIndex(bLoc, bRank, bStrides); result[i] = op(aVals[aIndex], bVals[bIndex]); } } return [result, newShape]; }; } function castImpl(values, shape, inputType, dtype) { if (dtype === 'int32') { const resultValues = Int32Array.from(values); return [shape, 'int32', resultValues]; } if (dtype === 'bool') { const zero = toTypedArray([0], inputType); const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool'); return [resultShape, 'bool', resultData]; } throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`); } const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b)); function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) { const weightsSize = sizeFromShape(weightsShape); const outVals = makeZerosTypedArray(size, weightsDtype); for (let i = 0; i < xVals.length; i++) { const value = xVals[i]; if (value < 0) { throw new Error('Input x must be non-negative!'); } if (value >= size) { continue; } if (weightsSize > 0) { outVals[value] += weightsVals[i]; } else { outVals[value] += 1; } } return outVals; } function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) { const numRows = xBuf.shape[0]; const numCols = xBuf.shape[1]; const outBuf = buffer([numRows, size], weightsBuf.dtype); for (let i = 0; i < numRows; i++) { for (let j = 0; j < numCols; j++) { const value = xBuf.get(i, j); if (value < 0) { throw new Error('Input x must be non-negative!'); } if (value >= size) { continue; } if (binaryOutput) { outBuf.set(1, i, value); } else { if (weightsBuf.size > 0) { outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value); } else { outBuf.set(outBuf.get(i, value) + 1, i, value); } } } } return outBuf; } const bitwiseAndImpl = createSimpleBinaryKernelImpl(((a, b) => a & b)); function createSimpleUnaryImpl(op) { return (values, dtype, attrs) => { const newValues = getArrayFromDType(dtype, values.length); for (let i = 0; i < values.length; ++i) { newValues[i] = op(values[i], attrs); } return newValues; }; } const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi)); function concatImpl$1(inputs, outShape, dtype, simplyConcat) { const outVals = getArrayFromDType(dtype, sizeFromShape(outShape)); if (simplyConcat && dtype !== 'string') { let offset = 0; inputs.forEach(input => { const size = sizeFromShape(input.shape); outVals.set(input.vals, offset); offset += size; }); } else { let colOffset = 0; inputs.forEach(input => { const decodedData = dtype === 'string' ? fromUint8ToStringArray(input.vals) : input.vals; let tIdx = 0; for (let row = 0; row < input.shape[0]; ++row) { const resIdx = row * outShape[1] + colOffset; for (let col = 0; col < input.shape[1]; ++col) { outVals[resIdx + col] = decodedData[tIdx++]; } } colOffset += input.shape[1]; }); } return outVals; } const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0); const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi)); const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi)); const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi)); const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b)); function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) { const outBuf = buffer([numSlices, sliceSize], dtype); for (let i = 0; i < numSlices; i++) { const index = []; let flattenIndex = 0; for (let j = 0; j < sliceRank; j++) { const dim = indicesData[i * sliceRank + j]; flattenIndex += dim * strides[j]; index.push(dim); } if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) { throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`); } for (let k = 0; k < sliceSize; k++) { outBuf.values[i * sliceSize + k] = paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k)); } } return outBuf; } function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) { const outBuf = buffer(flattenOutputShape, xBuf.dtype); for (let i = 0; i < outBuf.size; ++i) { const newLoc = outBuf.indexToLoc(i); const originalLoc = newLoc.slice(); const batchIdx = originalLoc[0]; const indicesIdx = originalLoc[2]; const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]); originalLoc[2] = indicesBuf.values[indicesIndex]; const originalIndex = xBuf.locToIndex(originalLoc); if (0 <= originalIndex && originalIndex < xBuf.values.length) { outBuf.values[i] = xBuf.values[originalIndex]; } } return outBuf; } const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0); const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0); const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0); const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0); function linSpaceImpl(start, stop, num) { const step = (stop - start) / (num - 1); const values = makeZerosTypedArray(num, 'float32'); values[0] = start; for (let i = 1; i < values.length; i++) { values[i] = values[i - 1] + step; } return values; } const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi)); function maxImpl$1(aVals, reduceSize, outShape, dtype) { const vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape)); for (let i = 0; i < vals.length; ++i) { const offset = i * reduceSize; let max = aVals[offset]; for (let j = 0; j < reduceSize; ++j) { const value = aVals[offset + j]; if (Number.isNaN(value) || value > max) { max = value; } } vals[i] = max; } return vals; } const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue))); const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue))); const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue)); function negImpl(xVals, xShape, xDtype) { const minusOne = createScalarValue(-1, xDtype); return multiplyImpl([], xShape, minusOne, xVals, xDtype); } const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0)); function transposeImpl$1(xVals, xShape, dtype, perm, newShape) { const xRank = xShape.length; const xSize = sizeFromShape(xShape); const xStrides = computeStrides(xShape); const newStrides = computeStrides(newShape); const result = getTypedArrayFromDType(dtype, sizeFromShape(newShape)); for (let i = 0; i < xSize; ++i) { const loc = indexToLoc(i, xRank, xStrides); const newLoc = new Array(loc.length); for (let i = 0; i < newLoc.length; i++) { newLoc[i] = loc[perm[i]]; } const newIndex = locToIndex(newLoc, xRank, newStrides); result[newIndex] = xVals[i]; } return result; } function prodImpl(xShape, xDtype, xVals, reductionAxes) { const [outShape, reduceShape] = computeOutAndReduceShapes(xShape, reductionAxes); const outDtype = upcastType(xDtype, 'int32'); const outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype); const reduceSize = sizeFromShape(reduceShape); for (let i = 0; i < outVals.length; ++i) { const offset = i * reduceSize; let prod = 1; for (let j = 0; j < reduceSize; ++j) { prod *= xVals[offset + j]; } outVals[i] = prod; } return { outVals, outShape, outDtype }; } function validateIndices(indices, indicesShape, numParams) { indices.forEach((index, i) => { if (index < 0 || index >= numParams) { const locString = indexToLoc(i, indicesShape.length, computeStrides(indicesShape)) .join(','); throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`); } }); } function validateSplits(paramsNestedSplits, numParamsDenseValues) { for (let dim = 0; dim < paramsNestedSplits.length; ++dim) { const splits = paramsNestedSplits[dim]; const lastSplit = (dim === paramsNestedSplits.length - 1) ? numParamsDenseValues : paramsNestedSplits[dim + 1].length; if (splits.length === 0) { throw new Error('Ragged splits may not be empty'); } if (splits[0] < 0) { throw new Error('Ragged splits must be non-negative'); } if (splits[splits.length - 1] > lastSplit) { throw new Error('Ragged splits must not point past values'); } for (let i = 1; i < splits.length; ++i) { if (splits[i - 1] > splits[i]) { throw new Error('Ragged splits must be sorted in ascending order'); } } } } function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) { const valueSlices = []; let numValues = 0; const numSplits = indicesShape.length - 1 + paramsNestedSplits.length; const outSplits = new Array(numSplits).fill(null).map(() => [0]); validateSplits(paramsNestedSplits, numParamsDenseValues); let nrows = 1; for (let dim = 0; dim < indicesShape.length - 1; ++dim) { nrows *= indicesShape[dim]; const rowLength = indicesShape[dim + 1]; for (let i = 1; i < nrows + 1; ++i) { outSplits[dim].push(i * rowLength); } } for (let i = 0; i < indices.length; ++i) { let start = indices[i]; let limit = indices[i] + 1; for (let dim = 0; dim < paramsNestedSplits.length; ++dim) { const splits = paramsNestedSplits[dim]; const outDim = dim + indicesShape.length - 1; if (outDim >= 0) { const outSplitsOutDim = outSplits[outDim]; const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start]; for (let j = start; j < limit; ++j) { outSplits[outDim].push(splits[j + 1] + delta); } } start = splits[start]; limit = splits[limit]; } if (limit !== start) { valueSlices.push([start, limit]); numValues += limit - start; } } return { outSplits, valueSlices, numValues }; } function getSplits(outSplits) { const splitsOut = []; for (let i = 0; i < outSplits.length; ++i) { const numSplits = outSplits[i].length; const splits = getArrayFromDType('int32', numSplits); splitsOut.push(splits); outSplits[i].forEach((value, j) => splits[j] = value); } return splitsOut; } function computeFlatOuterDims(orig, numOutDims) { const outDims = orig.slice(0, numOutDims); while (outDims.length < numOutDims) { outDims.push(1); } for (let inDim = numOutDims; inDim < orig.length; inDim++) { outDims[numOutDims - 1] *= orig[inDim]; } return outDims; } function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) { const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1]; const valuesM = computeFlatOuterDims(valuesShape, 2)[1]; let outPos = 0; for (const slice of valueSlices) { for (let i = slice[0]; i < slice[1]; ++i) { for (let j = 0; j < valueSize; ++j) { values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j]; } ++outPos; } } } function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) { const valuesShape = paramsDenseValuesShape.slice(); valuesShape[0] = numValues; const valuesOut = getArrayFromDType(paramsDenseValuesDType, sizeFromShape(valuesShape)); const numElements = paramsDenseValues.length; const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]); writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape); return [valuesOut, valuesShape]; } function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) { if (paramsNestedSplits.length === 0) { throw new Error('paramsNestedSplits must be non empty'); } if (paramsNestedSplitsShapes[0].length === 0) { throw new Error('Split tensors must not be scalars'); } const numParams = paramsNestedSplitsShapes[0][0] - 1; validateIndices(indices, indicesShape, numParams); if (paramsDenseValuesShape.length === 0) { throw new Error('params.rank must be nonzero'); } const numParamsDenseValues = paramsDenseValuesShape[0]; const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues); const outputNestedSplits = getSplits(outSplits); const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues); return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]]; } const INT32_MAX = 2147483647; function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) { if (startsShape.length > 1) { throw new Error('starts must be a scalar or vector'); } if (limitsShape.length > 1) { throw new Error('limits must be a scalar or vector'); } if (deltasShape.length > 1) { throw new Error('deltas must be a scalar or vector'); } const broadcastStarts = startsShape.length === 0; const broadcastLimits = limitsShape.length === 0; const broadcastDeltas = deltasShape.length === 0; const inSizes = []; if (!broadcastStarts) { inSizes.push(startsShape[0]); } if (!broadcastLimits) { inSizes.push(limitsShape[0]); } if (!broadcastDeltas) { inSizes.push(deltasShape[0]); } for (let i = 1; i < inSizes.length; ++i) { if (inSizes[i] !== inSizes[i - 1]) { throw new Error('starts, limits, and deltas must have the same shape'); } } const nRows = inSizes.length === 0 ? 1 : inSizes[0]; const rtNestedSplits = getArrayFromDType('int32', nRows + 1); rtNestedSplits[0] = 0; for (let row = 0; row < nRows; ++row) { const start = broadcastStarts ? starts[0] : starts[row]; const limit = broadcastLimits ? limits[0] : limits[row]; const delta = broadcastDeltas ? deltas[0] : deltas[row]; if (delta === 0) { throw new Error('Requires delta != 0'); } let size; if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) { size = 0; } else { size = Math.ceil(Math.abs((limit - start) / delta)); if (size > INT32_MAX) { throw new Error(`Requires ((limit - start) / delta) <= ${INT32_MAX}`); } } rtNestedSplits[row + 1] = rtNestedSplits[row] + size; } const nVals = rtNestedSplits[nRows]; const rtDenseValues = getArrayFromDType(startsDType, nVals); let valueIndex = 0; for (let row = 0; row < nRows; ++row) { const rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row]; let value = broadcastStarts ? starts[0] : starts[row]; const delta = broadcastDeltas ? deltas[0] : deltas[row]; for (let i = 0; i < rowSize; ++i) { rtDenseValues[valueIndex++] = value; value += delta; } } return [rtNestedSplits, rtDenseValues]; } var RowPartitionType = RowPartitionType$1; class RaggedTensorToTensorOp { constructor(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) { this.shape = shape; this.shapeShape = shapeShape; this.values = values; this.valuesShape = valuesShape; this.valuesDType = valuesDType; this.defaultValue = defaultValue; this.defaultValueShape = defaultValueShape; this.rowPartitionValues = rowPartitionValues; this.rowPartitionValuesShapes = rowPartitionValuesShapes; this.rowPartitionTypes = getRowPartitionTypesHelper(rowPartitionTypeStrings); this.raggedRank = getRaggedRank(this.rowPartitionTypes); } getRowPartitionTypeByDimension(dimension) { if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) { return this.rowPartitionTypes[dimension + 1]; } else { return this.rowPartitionTypes[dimension]; } } getRowPartitionTensor(dimension) { if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) { return this.rowPartitionValues[dimension + 1]; } else { return this.rowPartitionValues[dimension]; } } getMaxWidth(dimension) { const rowPartitionTensor = this.getRowPartitionTensor(dimension - 1); switch (this.getRowPartitionTypeByDimension(dimension - 1)) { case RowPartitionType.VALUE_ROWIDS: return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor); case RowPartitionType.ROW_SPLITS: return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor); default: throw new Error(`Cannot handle partition type ${RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]}`); } } static getMaxWidthRowSplit(rowSplit) { const tensorLength = rowSplit.length; if (tensorLength === 0 || tensorLength === 1) { return 0; } let maxWidth = 0; for (let i = 0; i < tensorLength - 1; ++i) { const currentWidth = rowSplit[i + 1] - rowSplit[i]; if (currentWidth > maxWidth) { maxWidth = currentWidth; } } return maxWidth; } static getMaxWidthValueRowID(valueRowIds) { const indexLength = valueRowIds.length; if (indexLength === 0) { return 0; } let firstEqualIndex = 0; let firstEqualIndexValue = valueRowIds[0]; let maxWidth = 0; for (let i = 1; i < indexLength; ++i) { const value = valueRowIds[i]; if (value !== firstEqualIndexValue) { firstEqualIndexValue = value; maxWidth = Math.max(i - firstEqualIndex, maxWidth); firstEqualIndex = i; } } return Math.max(indexLength - firstEqualIndex, maxWidth); } tensorShapeFromTensor(t, tShape, isPartial = true) { if (tShape.length === 0) { if (t[0] === -1) { return []; } throw new Error(`The only valid scalar shape tensor is the fully unknown shape specified as -1.`); } return makeShape(t, isPartial); } calculateOutputSize(firstDim) { const valueShape = this.valuesShape; const defaultValueShape = this.defaultValueShape; validateDefaultValueShape(defaultValueShape, valueShape); const shape = this.tensorShapeFromTensor(this.shape, this.shapeShape); const outputShape = combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape); const result = outputShape; if (result[0] < 0) { result[0] = firstDim; } for (let i = 1; i <= this.raggedRank; ++i) { if (result[i] < 0) { result[i] = this.getMaxWidth(i); } } return result; } calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) { const minDimension = Math.min(firstDimension, firstDimensionOutput); const result = []; let currentOutputIndex = 0; for (let i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) { result.push(currentOutputIndex); } for (let i = minDimension; i < firstDimension; ++i) { result.push(-1); } assert$1(result.length === firstDimension, () => 'Final length of result must be equal to firstDimension.'); return result; } calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) { const rowSplitSize = rowSplit.length; const result = []; for (let i = 0; i < rowSplitSize - 1; ++i) { const rowLength = rowSplit[i + 1] - rowSplit[i]; let realLength = Math.min(outputSize, rowLength); let parentOutputIndexCurrent = parentOutputIndex[i]; if (parentOutputIndexCurrent === -1) { realLength = 0; } for (let j = 0; j < realLength; ++j) { result.push(parentOutputIndexCurrent); parentOutputIndexCurrent += outputIndexMultiplier; } for (let j = 0; j < rowLength - realLength; ++j) { result.push(-1); } } if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) { throw new Error('Invalid row split size.'); } return result; } calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) { const indexSize = valueRowIds.length; const result = []; if (indexSize === 0) { return []; } let currentOutputColumn = 0; let currentValueRowId = valueRowIds[0]; if (currentValueRowId >= parentOutputIndex.length) { throw new Error(`Got currentValueRowId=${currentValueRowId}, which is not less than ${parentOutputIndex.length}`); } let currentOutputIndex = parentOutputIndex[currentValueRowId]; result.push(currentOutputIndex); for (let i = 1; i < indexSize; ++i) { const nextValueRowId = valueRowIds[i]; if (nextValueRowId === currentValueRowId) { if (currentOutputIndex >= 0) { ++currentOutputColumn; if (currentOutputColumn < outputSize) { currentOutputIndex += outputIndexMultiplier; } else { currentOutputIndex = -1; } } } else { currentOutputColumn = 0; currentValueRowId = nextValueRowId; if (nextValueRowId >= parentOutputIndex.length) { throw new Error(`Got nextValueRowId=${nextValueRowId} which is not less than ${parentOutputIndex.length}`); } currentOutputIndex = parentOutputIndex[nextValueRowId]; } result.push(currentOutputIndex); } if (result.length !== valueRowIds.length) { throw new Error('Invalid row ids.'); } return result; } calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) { const rowPartitionTensor = this.getRowPartitionTensor(dimension); const partitionType = this.getRowPartitionTypeByDimension(dimension); switch (partitionType) { case RowPartitionType.VALUE_ROWIDS: return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize); case RowPartitionType.ROW_SPLITS: if (rowPartitionTensor.length - 1 > parentOutputIndex.length) { throw new Error(`Row partition size is greater than output size: ${rowPartitionTensor.length - 1} > ${parentOutputIndex.length}`); } return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize); default: throw new Error(`Unsupported partition type: ${RowPartitionType[partitionType]}`); } } getFirstDimensionSize() { const firstPartitionTensor = this.rowPartitionValues[0]; if (this.rowPartitionTypes.length === 0) { throw new Error('No row_partition_types given.'); } const firstPartitionType = this.rowPartitionTypes[0]; switch (firstPartitionType) { case RowPartitionType.FIRST_DIM_SIZE: return firstPartitionTensor[0]; case RowPartitionType.VALUE_ROWIDS: throw new Error('Cannot handle VALUE_ROWIDS in first dimension.'); case RowPartitionType.ROW_SPLITS: return this.rowPartitionValuesShapes[0][0] - 1; default: throw new Error(`Cannot handle type ${RowPartitionType[firstPartitionType]}`); } } compute() { const firstPartitionTensor = this.rowPartitionValues[0]; if (firstPartitionTensor.length <= 0) { throw new Error('Invalid first partition input. ' + 'Tensor requires at least one element.'); } const firstDimension = this.getFirstDimensionSize(); const outputSize = this.calculateOutputSize(firstDimension); const multiplier = new Array(this.raggedRank + 1); multiplier[multiplier.length - 1] = 1; for (let i = multiplier.length - 2; i >= 0; --i) { multiplier[i] = multiplier[i + 1] * outputSize[i + 1]; } const outputShape = makeShape(outputSize, false); const outputTensor = getArrayFromDType(this.valuesDType, sizeFromShape(outputShape)); const fullSize = multiplier[0] * outputSize[0]; if (fullSize > 0) { let outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]); for (let i = 1; i <= this.raggedRank; ++i) { const newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]); outputIndex = newOutputIndex; } this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape); } return [outputShape, outputTensor]; } setOutput(raggedRank, outputIndex, outputTensor, outputShape) { if (outputTensor.length === 0) { return; } const valuesBase = this.values; const outputBase = outputTensor; let elementShape = outputShape.slice(); elementShape = elementShape.slice(raggedRank + 1); const valueElementSize = sizeFromShape(elementShape); const outputIndexSize = outputIndex.length; let defaultValue = this.defaultValue; if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) { const srcShape = this.defaultValueShape; tidy(() => { const defaultValueTensor = reshape$1(defaultValue, srcShape); const bCastDefault = broadcastTo(defaultValueTensor, elementShape); defaultValue = bCastDefault.dataSync(); }); } let srcStart = 0; let dstStart = 0; let dstEnd = 0; for (let srcI = 0; srcI <= outputIndexSize; ++srcI) { let dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1; if (dstI === dstEnd) { ++dstEnd; continue; } if (dstStart < dstEnd) { const src = valuesBase.subarray(srcStart * valueElementSize); const dst = outputBase.subarray(dstStart * valueElementSize); const nVals = (dstEnd - dstStart) * valueElementSize; copyArray(dst, src, nVals); } if (srcI >= outputIndexSize) { const outputSize = outputTensor.length; dstI = Math.floor(outputSize / valueElementSize); } if (dstI > dstEnd) { if (this.defaultValue.length === 1) { outputBase .subarray(dstEnd * valueElementSize, dstI * valueElementSize) .fill(this.defaultValue[0]); dstEnd = dstI; } else { while (dstI > dstEnd) { const dst = outputBase.slice(dstEnd * valueElementSize); copyArray(dst, defaultValue, valueElementSize); ++dstEnd; } } } if (dstI < 0) { srcStart = srcI + 1; dstStart = dstEnd; } else { srcStart = srcI; dstStart = dstEnd; dstEnd = dstStart + 1; } } } } function copyArray(dst, src, size) { for (let i = 0; i < size; i++) { dst[i] = src[i]; } } function makeShape(shape, isPartial) { const out = []; for (let dim of shape) { if (dim < 0) { if (!isPartial) { throw new Error(`Dimension ${dim} must be >= 0`); } if (dim < -1) { throw new Error(`Dimension ${dim} must be >= -1`); } dim = -1; } out.push(dim); } return out; } function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) { return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) .compute(); } function rangeImpl(start, stop, step, dtype) { const sameStartStop = start === stop; const increasingRangeNegativeStep = start < stop && step < 0; const decreasingRangePositiveStep = stop < start && step > 1; if (sameStartStop || increasingRangeNegativeStep || decreasingRangePositiveStep) { return makeZerosTypedArray(0, dtype); } const numElements = Math.abs(Math.ceil((stop - start) / step)); const values = makeZerosTypedArray(numElements, dtype); if (stop < start && step === 1) { step = -1; } values[0] = start; for (let i = 1; i < values.length; i++) { values[i] = values[i - 1] + step; } return values; } const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi)); function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) { const flattenShape = [outputSize / sliceSize, sliceSize]; const indicesData = indices.values; const updatesData = updates.values; if (outputSize === 0) { return buffer(shape, updates.dtype); } const outBuf = (defaultValue instanceof TensorBuffer) ? defaultValue : buffer(flattenShape, updates.dtype); if (typeof defaultValue === 'string') { outBuf.values.fill(defaultValue); } else if (typeof defaultValue === 'number') { outBuf.values.fill(defaultValue); } else if (typeof defaultValue === 'boolean') { outBuf.values.fill(+defaultValue); } for (let i = 0; i < numUpdates; i++) { const index = []; let flattenIndex = 0; for (let j = 0; j < sliceRank; j++) { const dim = indicesData[i * sliceRank + j]; index.push(dim); flattenIndex += dim * strides[j]; } if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) { throw new Error(`Invalid indices: ${index} does not index into ${shape}`); } for (let k = 0; k < sliceSize; k++) { if (sumDupeIndices) { outBuf.values[flattenIndex * sliceSize + k] += updatesData[i * sliceSize + k]; } else { outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? updatesData[0] : updatesData[i * sliceSize + k]; } } } return outBuf; } const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi))); function sliceImpl(vals, begin, size, shape, dtype) { const isContinous = isSliceContinous(shape, begin, size); const length = sizeFromShape(size); const xStrides = computeStrides(shape); if (isContinous) { const flatOffset = computeFlatOffset(begin, xStrides); if (dtype === 'string') { return vals.slice(flatOffset, flatOffset + length); } return vals.subarray(flatOffset, flatOffset + length); } const decodedData = dtype === 'string' ? fromUint8ToStringArray(vals) : vals; const inBuf = buffer(shape, dtype, decodedData); const outBuf = buffer(size, dtype); for (let i = 0; i < outBuf.size; ++i) { const outLoc = outBuf.indexToLoc(i); const inLoc = outLoc.map((idx, j) => idx + begin[j]); outBuf.set(inBuf.get(...inLoc), ...outLoc); } if (dtype === 'string') { return fromStringArrayToUint8(outBuf.values); } return outBuf.values; } function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) { const indicesCount = indicesShape[0]; const denseRows = denseShape[0]; const emptyRowIndicator = new Array(denseRows); const reverseIndexMap = new Array(indicesCount); const rank = indicesShape[1]; if (denseRows === 0) { if (indicesCount !== 0) { throw new Error(getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount)); } const outputIndices = getArrayFromDType(indicesDType, 0); const outputValues = getArrayFromDType(valuesDType, 0); return [ outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } let rowsAreOrdered = true; let lastIndicesRow = 0; const csrOffset = new Array(denseRows).fill(0); for (let i = 0; i < indicesCount; ++i) { const row = indices[i * rank]; if (row < 0) { throw new Error(getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row)); } if (row >= denseRows) { throw new Error(getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows)); } ++csrOffset[row]; rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow); lastIndicesRow = row; } let allRowsFull = true; for (let row = 0; row < denseRows; ++row) { const rowEmpty = (csrOffset[row] === 0); emptyRowIndicator[row] = rowEmpty; allRowsFull = allRowsFull && !rowEmpty; csrOffset[row] = Math.max(csrOffset[row], 1); if (row > 0) { csrOffset[row] += csrOffset[row - 1]; } } if (allRowsFull && rowsAreOrdered) { const outputIndices = indices; const outputValues = values; for (let i = 0; i < indicesCount; ++i) { reverseIndexMap[i] = i; } return [ outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } else { const fullIndicesCount = csrOffset[denseRows - 1]; const outputIndices = getArrayFromDType(indicesDType, fullIndicesCount * rank); const outputValues = getArrayFromDType(valuesDType, fullIndicesCount); const filledCount = new Array(denseRows).fill(0); for (let i = 0; i < indicesCount; ++i) { const row = indices[i * rank]; const offset = filledCount[row]; const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset; filledCount[row]++; for (let j = 0; j < rank; ++j) { outputIndices[outputI * rank + j] = indices[i * rank + j]; } outputValues[outputI] = values[i]; reverseIndexMap[i] = outputI; } for (let row = 0; row < denseRows; ++row) { const rowCount = filledCount[row]; if (rowCount === 0) { const startingIndex = (row === 0) ? 0 : csrOffset[row - 1]; outputIndices[startingIndex * rank + 0] = row; for (let col = 1; col < rank; ++col) { outputIndices[startingIndex * rank + col] = 0; } outputValues[startingIndex] = defaultValue; } } return [ outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator, reverseIndexMap ]; } } function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) { const denseSize = sizeFromShape(inputShape); const nnz = inputIndicesShape[0]; const outputRank = targetShape.length; const outputShape = []; let product = 1; let unknownIndex = -1; for (let d = 0; d < outputRank; ++d) { const size = targetShape[d]; if (size === -1) { if (unknownIndex !== -1) { throw new Error(getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d)); } unknownIndex = d; outputShape.push(1); } else { if (size < 0) { throw new Error(getSparseReshapeNegativeOutputDimErrorMessage(d, size)); } product *= size; outputShape.push(size); } } if (unknownIndex !== -1) { if (product <= 0) { throw new Error(getSparseReshapeEmptyTensorZeroOutputDimErrorMessage()); } const missing = Math.trunc(denseSize / product); if (product * missing !== denseSize) { throw new Error(getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape)); } outputShape[unknownIndex] = missing; } const outputSize = sizeFromShape(outputShape); if (outputSize !== denseSize) { throw new Error(getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape)); } const inputRank = inputShape.length; const inputStrides = []; if (inputRank > 0) { inputStrides[inputRank - 1] = 1; for (let d = inputRank - 2; d >= 0; --d) { inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1]; } } const outputStrides = []; if (outputRank > 0) { outputStrides[outputRank - 1] = 1; for (let d = outputRank - 2; d >= 0; --d) { outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1]; } } const newIndices = getArrayFromDType(inputDType, nnz * outputRank); for (let i = 0; i < nnz; ++i) { let id = 0; for (let j = 0; j < inputRank; ++j) { id += inputIndices[i * inputRank + j] * inputStrides[j]; } for (let j = 0; j < outputRank; ++j) { newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]); id %= outputStrides[j]; } } return [newIndices, [nnz, outputRank], outputShape]; } function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) { const numIndices = indices.length; const inputFlat = [inputShape[0], input.length / inputShape[0]]; const numCol = inputFlat[1]; const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0; const outputRows = lastSegmentIdPlusOne; if (outputRows < 0) { throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage()); } const outputShape = inputShape.slice(); outputShape[0] = outputRows; const outputLength = outputShape.reduce((product, value) => product * value, 1); const output = getArrayFromDType(inputDType, outputLength); if (numIndices === 0) { if (outputRows > 0) { output.fill(defaultValue); } return [output, outputShape]; } if (outputRows <= 0) { throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage()); } let start = 0, end = 1; let uninitializedIndex = 0; let outIndex = segmentIds[start]; while (true) { let nextIndex = 0; if (end < numIndices) { nextIndex = segmentIds[end]; if (outIndex === nextIndex) { ++end; continue; } if (outIndex >= nextIndex) { throw new Error(getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage()); } } if (outIndex < 0 || outIndex >= outputRows) { throw new Error(getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows)); } if (outIndex > uninitializedIndex) { output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol); } for (let i = start; i < end; ++i) { const index = indices[i]; if (index < 0 || index >= inputFlat[0]) { throw new Error(getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0])); } for (let j = 0; j < numCol; j++) { output[outIndex * numCol + j] += input[index * numCol + j]; } } if (isMean) { for (let j = 0; j < numCol; j++) { output[outIndex * numCol + j] /= end - start; } } start = end; ++end; uninitializedIndex = outIndex + 1; outIndex = nextIndex; if (end > numIndices) { break; } } if (uninitializedIndex < outputRows) { output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol); } return [output, outputShape]; } const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi)); const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => { const diff = a - b; return diff * diff; })); const staticRegexReplaceImpl = createSimpleUnaryImpl((x, attrs) => { const { pattern, replaceGlobal, rewrite } = attrs; return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite); }); function stridedSliceImpl(outShape, xBuf, strides, begin) { const outBuf = buffer(outShape, xBuf.dtype); for (let i = 0; i < outBuf.size; i++) { const loc = outBuf.indexToLoc(i); const newLoc = new Array(loc.length); for (let j = 0; j < newLoc.length; j++) { newLoc[j] = loc[j] * strides[j] + begin[j]; } outBuf.set(xBuf.get(...newLoc), ...loc); } return outBuf; } class StringNGramsOp { constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) { this.separator = encodeString(separator); this.nGramWidths = nGramWidths; this.leftPad = encodeString(leftPad); this.rightPad = encodeString(rightPad); this.padWidth = padWidth; this.preserveShort = preserveShortSequences; } getPadWidth(nGramWidth) { return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1); } getNumNGrams(length, nGramWidth) { const padWidth = this.getPadWidth(nGramWidth); return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1); } createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) { for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) { const padWidth = this.getPadWidth(nGramWidth); const leftPadding = Math.max(0, padWidth - nGramIndex); const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1))); const numTokens = nGramWidth - (leftPadding + rightPadding); const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth); let nGramSize = 0; nGramSize += leftPadding * this.leftPad.length; for (let n = 0; n < numTokens; ++n) { nGramSize += data[dataStartIndex + n].length; } nGramSize += rightPadding * this.rightPad.length; const numSeparators = leftPadding + rightPadding + numTokens - 1; nGramSize += numSeparators * this.separator.length; output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize); const nGram = output[outputStartIndex + nGramIndex]; let nextNGramIndex = 0; const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value); for (let n = 0; n < leftPadding; ++n) { appendToNGram(this.leftPad); appendToNGram(this.separator); } for (let n = 0; n < numTokens - 1; ++n) { appendToNGram(data[dataStartIndex + n]); appendToNGram(this.separator); } if (numTokens > 0) { appendToNGram(data[dataStartIndex + numTokens - 1]); for (let n = 0; n < rightPadding; ++n) { appendToNGram(this.separator); appendToNGram(this.rightPad); } } else { for (let n = 0; n < rightPadding - 1; ++n) { appendToNGram(this.rightPad); appendToNGram(this.separator); } appendToNGram(this.rightPad); } } } compute(data, splits) { const inputDataSize = data.length; const splitsSize = splits.length; if (splitsSize > 0) { let prevSplit = splits[0]; if (prevSplit !== 0) { throw new Error(`First split value must be 0, got ${prevSplit}`); } for (let i = 1; i < splitsSize; ++i) { let validSplits = splits[i] >= prevSplit; validSplits = validSplits && (splits[i] <= inputDataSize); if (!validSplits) { throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`); } prevSplit = splits[i]; } if (prevSplit !== inputDataSize) { throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`); } } const numBatchItems = splitsSize - 1; const nGramsSplits = getArrayFromDType('int32', splitsSize); if (inputDataSize === 0 || splitsSize === 0) { const empty = new Array(inputDataSize); for (let i = 0; i <= numBatchItems; ++i) { nGramsSplits[i] = 0; } return [empty, nGramsSplits]; } nGramsSplits[0] = 0; for (let i = 1; i <= numBatchItems; ++i) { const length = splits[i] - splits[i - 1]; let numNGrams = 0; this.nGramWidths.forEach((nGramWidth) => { numNGrams += this.getNumNGrams(length, nGramWidth); }); if (this.preserveShort && length > 0 && numNGrams === 0) { numNGrams = 1; } nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams; } const nGrams = new Array(nGramsSplits[numBatchItems]); for (let i = 0; i < numBatchItems; ++i) { const splitIndex = splits[i]; let outputStartIdx = nGramsSplits[i]; this.nGramWidths.forEach((nGramWidth) => { const length = splits[i + 1] - splits[i]; const numNGrams = this.getNumNGrams(length, nGramWidth); this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth); outputStartIdx += numNGrams; }); if (this.preserveShort && outputStartIdx === nGramsSplits[i]) { const dataLength = splits[i + 1] - splits[i]; if (dataLength === 0) { continue; } const nGramWidth = dataLength + 2 * this.padWidth; const numNGrams = 1; this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth); } } return [nGrams, nGramsSplits]; } } function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) { return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) .compute(data, dataSplits); } function split(str, delimiters, skipEmpty, result) { if (!str.length) { return; } if (delimiters.length === 0) { for (let i = 0; i < str.length; ++i) { result.push(str.subarray(i, i + 1)); } return; } if (delimiters.length === 1) { const delimiter = delimiters[0]; let f = str.indexOf(delimiter); while (f !== -1) { const token = str.subarray(0, f); if (!skipEmpty || token.length !== 0) { result.push(token); } str = str.subarray(f + 1); f = str.indexOf(delimiter); } if (!skipEmpty || str.length !== 0) { result.push(str); } return; } let tokenStart = 0; for (let i = 0; i < str.length + 1; i++) { if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) { const token = str.subarray(tokenStart, i); if (!skipEmpty || token.length !== 0) { result.push(token); } tokenStart = i + 1; } } } function stringSplitImpl(input, delimiter, skipEmpty) { const batchSize = input.length; const tokens = []; let outputSize = 0; let maxNumEntries = 0; const numIndices = new Array(batchSize); for (let i = 0; i < batchSize; ++i) { const prevTokensLength = tokens.length; split(input[i], delimiter, skipEmpty, tokens); const nEntries = tokens.length - prevTokensLength; numIndices[i] = nEntries; outputSize += nEntries; maxNumEntries = Math.max(maxNumEntries, nEntries); } const indices = getArrayFromDType('int32', outputSize * 2); const values = new Array(outputSize); const shape = [batchSize, maxNumEntries]; let c = 0; for (let i = 0; i < batchSize; ++i) { for (let j = 0; j < numIndices[i]; ++j) { indices[c * 2] = i; indices[c * 2 + 1] = j; values[c] = tokens[c]; ++c; } } return [indices, values, shape]; } function stringToHashBucketFastImpl(input, numBuckets) { const output = getArrayFromDType('int32', input.length); for (let i = 0; i < input.length; ++i) { output[i] = fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned(); } return output; } const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue)); function tileImpl(xBuf, reps) { const newShape = new Array(xBuf.rank); for (let i = 0; i < newShape.length; i++) { newShape[i] = xBuf.shape[i] * reps[i]; } const result = buffer(newShape, xBuf.dtype); for (let i = 0; i < result.values.length; ++i) { const newLoc = result.indexToLoc(i); const originalLoc = new Array(xBuf.rank); for (let j = 0; j < originalLoc.length; j++) { originalLoc[j] = newLoc[j] % xBuf.shape[j]; } const originalIndex = xBuf.locToIndex(originalLoc); result.values[i] = xBuf.values[originalIndex]; } return result; } const comparePair = (a, b) => { const valueDiff = b.value - a.value; return valueDiff === 0 ? a.index - b.index : valueDiff; }; function select$1(array, k, left = 0, right = array.length - 1) { while (right > left) { if (right - left > 600) { const n = right - left + 1; const i = k - left + 1; const z = Math.log(n); const s = 0.5 * Math.exp(2 * z / 3); const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2); const newLeft = Math.max(left, Math.floor(k - i * s / n + sd)); const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd)); select$1(array, k, newLeft, newRight); } const t = array[k]; let i = left; let j = right; swap(array, left, k); if (comparePair(array[right], t) > 0) { swap(array, left, right); } while (i < j) { swap(array, i, j); i++; j--; while (comparePair(array[i], t) < 0) { i = i + 1; } while (comparePair(array[j], t) > 0) { j = j - 1; } } if (comparePair(array[left], t) === 0) { swap(array, left, j); } else { j = j + 1; swap(array, j, right); } if (j <= k) { left = j + 1; } if (k <= j) { right = j - 1; } } } function topKImpl(x, xShape, xDtype, k, sorted) { const lastDim = xShape[xShape.length - 1]; const [batch, size] = [x.length / lastDim, lastDim]; const allTopKVals = getTypedArrayFromDType(xDtype, batch * k); const allTopKIndices = getTypedArrayFromDType('int32', batch * k); for (let b = 0; b < batch; b++) { const offset = b * size; const vals = x.subarray(offset, offset + size); let valAndInd = new Array(vals.length); vals.forEach((value, index) => valAndInd[index] = { value, index }); if (k < valAndInd.length) { select$1(valAndInd, k); valAndInd = valAndInd.slice(0, k); } if (sorted) { valAndInd.sort(comparePair); } const outOffset = b * k; const topKVals = allTopKVals.subarray(outOffset, outOffset + k); const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k); for (let i = 0; i < k; i++) { topKVals[i] = valAndInd[i].value; topKIndices[i] = valAndInd[i].index; } } const outputShape = xShape.slice(); outputShape[outputShape.length - 1] = k; return [ buffer(outputShape, xDtype, allTopKVals), buffer(outputShape, 'int32', allTopKIndices) ]; } function uniqueImpl(values, axis, shape, dtype) { const $axis = parseAxisParam(axis, shape)[0]; const newShape = [1, shape[0], 1]; for (let i = 0; i < $axis; i++) { newShape[0] *= shape[i]; } newShape[1] = shape[$axis]; for (let i = $axis + 1; i < shape.length; i++) { newShape[2] *= shape[i]; } const uniqueElements = new Map(); const indices = new Int32Array(shape[$axis]); const inputBuffer = new TensorBuffer(newShape, dtype, values); const uniqueIndices = []; const is1DTensor = newShape[0] === 1 && newShape[2] === 1; for (let i = 0; i < shape[$axis]; i++) { let element; if (is1DTensor) { element = values[i].toString(); } else { const axisValues = []; for (let m = 0; m < newShape[0]; m++) { for (let n = 0; n < newShape[2]; n++) { axisValues.push(inputBuffer.get(m, i, n)); } } element = axisValues.join(','); } const existingIndex = uniqueElements.get(element); if (existingIndex != null) { indices[i] = existingIndex; } else { const uniqueIndex = uniqueElements.size; uniqueElements.set(element, uniqueIndex); indices[i] = uniqueIndex; uniqueIndices.push(i); } } const outputTmpShape = newShape.slice(); outputTmpShape[1] = uniqueElements.size; const outputBuffer = new TensorBuffer(outputTmpShape, dtype); uniqueIndices.forEach((uniqueElementIndex, i) => { for (let m = 0; m < newShape[0]; m++) { for (let n = 0; n < newShape[2]; n++) { outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n); } } }); const outputShape = shape.slice(); outputShape[$axis] = outputTmpShape[1]; return { outputValues: outputBuffer.values, outputShape, indices, }; } var shared = Object.freeze({ __proto__: null, addImpl: addImpl, bincountImpl: bincountImpl, bincountReduceImpl: bincountReduceImpl, bitwiseAndImpl: bitwiseAndImpl, castImpl: castImpl, ceilImpl: ceilImpl, concatImpl: concatImpl$1, equalImpl: equalImpl, expImpl: expImpl, expm1Impl: expm1Impl, floorDivImpl: floorDivImpl, floorImpl: floorImpl, gatherNdImpl: gatherNdImpl, gatherV2Impl: gatherV2Impl, greaterEqualImpl: greaterEqualImpl, greaterImpl: greaterImpl, lessEqualImpl: lessEqualImpl, lessImpl: lessImpl, linSpaceImpl: linSpaceImpl, logImpl: logImpl, maxImpl: maxImpl$1, maximumImpl: maximumImpl, minimumImpl: minimumImpl, multiplyImpl: multiplyImpl, negImpl: negImpl, notEqualImpl: notEqualImpl, prodImpl: prodImpl, raggedGatherImpl: raggedGatherImpl, raggedRangeImpl: raggedRangeImpl, raggedTensorToTensorImpl: raggedTensorToTensorImpl, rangeImpl: rangeImpl, rsqrtImpl: rsqrtImpl, scatterImpl: scatterImpl, sigmoidImpl: sigmoidImpl, simpleAbsImpl: simpleAbsImpl, sliceImpl: sliceImpl, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl, sparseReshapeImpl: sparseReshapeImpl, sparseSegmentReductionImpl: sparseSegmentReductionImpl, sqrtImpl: sqrtImpl, squaredDifferenceImpl: squaredDifferenceImpl, staticRegexReplaceImpl: staticRegexReplaceImpl, stridedSliceImpl: stridedSliceImpl, stringNGramsImpl: stringNGramsImpl, stringSplitImpl: stringSplitImpl, stringToHashBucketFastImpl: stringToHashBucketFastImpl, subImpl: subImpl, tileImpl: tileImpl, topKImpl: topKImpl, transposeImpl: transposeImpl$1, uniqueImpl: uniqueImpl }); const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, bitwiseAndImpl: bitwiseAndImplCPU, castImpl: castImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, equalImpl: equalImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterImpl: greaterImplCPU, greaterEqualImpl: greaterEqualImplCPU, lessImpl: lessImplCPU, lessEqualImpl: lessEqualImplCPU, linSpaceImpl: linSpaceImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, maximumImpl: maximumImplCPU, minimumImpl: minimumImplCPU, multiplyImpl: multiplyImplCPU, negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, raggedGatherImpl: raggedGatherImplCPU, raggedRangeImpl: raggedRangeImplCPU, raggedTensorToTensorImpl: raggedTensorToTensorImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, scatterImpl: scatterImplCPU, sigmoidImpl: sigmoidImplCPU, simpleAbsImpl: simpleAbsImplCPU, sliceImpl: sliceImplCPU, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImplCPU, sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, staticRegexReplaceImpl: staticRegexReplaceImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, stringToHashBucketFastImpl: stringToHashBucketFastImplCPU, subImpl: subImplCPU, tileImpl: tileImplCPU, topKImpl: topKImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared; function getVecChannels(name, rank) { return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`); } function getChannels(name, rank) { if (rank === 1) { return [name]; } return getVecChannels(name, rank); } function getSourceCoords$2(rank, dims) { if (rank === 1) { return 'rc'; } let coords = ''; for (let i = 0; i < rank; i++) { coords += dims[i]; if (i < rank - 1) { coords += ','; } } return coords; } class PackProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; this.outputShape = outputShape; this.rank = outputShape.length; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); if (this.rank === 0) { this.userCode = ` void main() { setOutput(vec4(getA(), 0., 0., 0.)); } `; } else { const channels = getChannels('rc', this.rank); const dtype = getCoordsDataType(this.rank); const outOfBoundsCondition = this.getOutOfBoundsCondition(channels); const setup = this.getSetup(channels); const output = this.getOutput(channels); this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); if(${outOfBoundsCondition}) { setOutput(vec4(0)); } else { ${setup} setOutput(vec4(${output})); } } `; } } getSourceCoordsArr(dims) { const coords = []; for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`; for (let d = 2; d < this.rank; d++) { coord = `${dims[dims.length - 1 - d]},` + coord; } coords.push(coord); } } return coords; } getOutOfBoundsCondition(dims) { if (this.rank === 1) { return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`; } let cond = ''; for (let i = this.rank - 2; i < this.rank; i++) { cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`; if (i < this.rank - 1) { cond += '||'; } } return cond; } getSetup(dims) { if (this.rank === 1) { return ''; } const innerDims = dims.slice(-2); const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` : this.outputShape[this.rank - 1]; const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` : this.outputShape[this.rank - 2]; return ` int r = ${innerDims[0]}; int c = ${innerDims[1]}; int rp1 = r + 1; int cp1 = c + 1; bool cEdge = cp1 >= ${col}; bool rEdge = rp1 >= ${row}; `; } getOutput(dims) { const sourceCoords = this.getSourceCoordsArr(dims); if (this.rank === 1) { const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0]; return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`; } return `getA(${sourceCoords[0]}), cEdge ? 0. : getA(${sourceCoords[1]}), rEdge ? 0. : getA(${sourceCoords[2]}), rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`; } } class ReshapePackedProgram { constructor(outputShape, inputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }]; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); let mainLoop = ``; for (let i = 0; i < 4; i++) { let thisRC = `thisRC = rc;`; if (i % 2 === 1) { thisRC += `thisRC.z += 1;`; } if (i > 1) { thisRC += `thisRC.y += 1;`; } mainLoop += ` ${thisRC} ${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''} int flatIndex = getFlatIndex(thisRC); ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex); vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z)); result[${i}] = getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims); ${i > 0 ? '}' : ''} `; } this.userCode = ` ${getReshapedInputCoords(inputShape, this.enableShapeUniforms)} ${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() : getFlatIndexFrom3D(outputShape)} void main() { ivec3 rc = getOutputCoords(); vec4 result = vec4(0.); ivec3 thisRC; int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]}; int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]}; ${mainLoop} setOutput(result); } `; } } function getReshapedInputCoords(shape, enableShapeUniforms) { const coordsFromIndexSnippet = enableShapeUniforms ? getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') : getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); return ` ivec3 inputCoordsFromReshapedOutCoords(int index) { ${coordsFromIndexSnippet} return ivec3(r, c, d); } `; } class TextureManager { constructor(gpgpu) { this.gpgpu = gpgpu; this.numUsedTextures = 0; this.numFreeTextures = 0; this._numBytesAllocated = 0; this._numBytesFree = 0; this.freeTextures = {}; this.usedTextures = {}; this.logEnabled = false; } acquireTexture(shapeRC, usage, isPacked) { const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked); const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked); if (!(shapeKey in this.freeTextures)) { this.freeTextures[shapeKey] = []; } if (!(shapeKey in this.usedTextures)) { this.usedTextures[shapeKey] = []; } const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked); if (this.freeTextures[shapeKey].length > 0) { this.numFreeTextures--; this.numUsedTextures++; this._numBytesFree -= texBytes; this.log(); const newTexture = this.freeTextures[shapeKey].pop(); this.usedTextures[shapeKey].push(newTexture); return newTexture; } let newTexture; if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) { newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) { newTexture = this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) { newTexture = this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) { newTexture = this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]); } else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) { newTexture = this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]); } this.usedTextures[shapeKey].push(newTexture); this.numUsedTextures++; this._numBytesAllocated += texBytes; this.log(); return newTexture; } releaseTexture(texture, shape, logicalTexType, isPacked) { if (this.freeTextures == null) { return; } const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked); const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked); if (!(shapeKey in this.freeTextures)) { this.freeTextures[shapeKey] = []; } const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked); const deleteTexThreshold = env() .getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD'); if (deleteTexThreshold !== -1 && this._numBytesAllocated > deleteTexThreshold) { this.gpgpu.deleteMatrixTexture(texture.texture); this._numBytesAllocated -= texBytes; } else { this.freeTextures[shapeKey].push(texture); this.numFreeTextures++; this._numBytesFree += texBytes; } this.numUsedTextures--; const texList = this.usedTextures[shapeKey]; const texIndex = texList && texList.indexOf(texture); if (texIndex == null || texIndex < 0) { throw new Error('Cannot release a texture that was never provided by this ' + 'texture manager'); } texList[texIndex] = texList[texList.length - 1]; texList.pop(); this.log(); } log() { if (!this.logEnabled) { return; } const total = this.numFreeTextures + this.numUsedTextures; console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`); const freeRatio = this._numBytesFree / this._numBytesAllocated; console.log(`Bytes allocated: ${this._numBytesAllocated}`); console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`); } get numBytesAllocated() { return this._numBytesAllocated; } get numBytesFree() { return this._numBytesFree; } getNumUsedTextures() { return this.numUsedTextures; } getNumFreeTextures() { return this.numFreeTextures; } dispose() { if (this.freeTextures == null) { return; } for (const texShape in this.freeTextures) { this.freeTextures[texShape].forEach(tex => { this.gpgpu.deleteMatrixTexture(tex.texture); }); } for (const texShape in this.usedTextures) { this.usedTextures[texShape].forEach(tex => { this.gpgpu.deleteMatrixTexture(tex.texture); }); } this.freeTextures = null; this.usedTextures = null; this.numUsedTextures = 0; this.numFreeTextures = 0; this._numBytesAllocated = 0; this._numBytesFree = 0; } } function numBytesForInternalFormat(gl, internalFormat) { const glany = gl; if (internalFormat === glany.R32F) { return 4; } else if (internalFormat === glany.R16F) { return 2; } else if (internalFormat === glany.RGBA32F) { return 16; } else if (internalFormat === gl.RGBA) { return 16; } else if (internalFormat === glany.RGBA16F) { return 8; } else if (internalFormat === glany.RGBA8) { return 4; } throw new Error(`Unknown internal format ${internalFormat}`); } function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) { const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig); let numElements; if (isPacked) { const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]); numElements = packedWidth * packedHeight; } else { const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]); numElements = width * height; } const bytesPerElement = numBytesForInternalFormat(gl, internalFormat); return numElements * bytesPerElement; } function internalFormatForPhysicalTexType(physicalTexType, textureConfig) { switch (physicalTexType) { case PhysicalTextureType.PACKED_2X2_FLOAT32: return getInternalFormatForPackedMatrixTexture(textureConfig); case PhysicalTextureType.PACKED_2X2_FLOAT16: return getInternalFormatForFloat16PackedMatrixTexture(textureConfig); case PhysicalTextureType.UNPACKED_FLOAT32: return getInternalFormatForFloat32MatrixTexture(textureConfig); case PhysicalTextureType.UNPACKED_FLOAT16: return getInternalFormatForFloat16MatrixTexture(textureConfig); case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE: return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig); default: throw new Error(`Unknown physical texture type ${physicalTexType}`); } } function getPhysicalTextureForRendering(isPacked) { if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } return PhysicalTextureType.UNPACKED_FLOAT32; } if (isPacked) { return PhysicalTextureType.PACKED_2X2_FLOAT16; } return PhysicalTextureType.UNPACKED_FLOAT16; } function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) { if (logicalTexType === TextureUsage.UPLOAD) { return PhysicalTextureType.PACKED_2X2_FLOAT32; } else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) { return getPhysicalTextureForRendering(isPacked); } else if (logicalTexType === TextureUsage.DOWNLOAD || logicalTexType === TextureUsage.PIXELS) { return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE; } throw new Error(`Unknown logical texture type ${logicalTexType}`); } function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) { return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`; } class UnaryOpProgram { constructor(aShape, opSnippet) { this.variableNames = ['A']; this.outputShape = aShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` float unaryOperation(float x) { ${opSnippet} } void main() { float x = getAAtOutCoords(); float y = unaryOperation(x); setOutput(y); } `; } } const CHECK_NAN_SNIPPET$1 = `if (isnan(x)) return x;`; const LINEAR$1 = `return x;`; const ABS$1 = `return abs(x);`; const ELU$2 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; const RELU$2 = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : x; `; const RELU6$2 = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : min(6.0, x); `; const CLONE = 'return x;'; const SIGMOID$2 = `return 1.0 / (1.0 + exp(-1.0 * x));`; const LINEAR = `return x;`; const ELU$1 = ` vec4 result; result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0); result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0); result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0); result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0); return result; `; const RELU$1 = ` vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const RELU6$1 = ` vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const SIGMOID$1 = `return 1.0 / (1.0 + exp(-1.0 * x));`; class UnaryOpPackedProgram { constructor(aShape, opSnippet) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = aShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` vec4 unaryOperation(vec4 x) { ${opSnippet} } void main() { vec4 x = getAAtOutCoords(); vec4 y = unaryOperation(x); setOutput(y); } `; } } class UnpackProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = false; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const rank = outputShape.length; const channels = getChannels('rc', rank); const dtype = getCoordsDataType(rank); const sourceCoords = getSourceCoords$2(rank, channels); const innerDims = channels.slice(-2); const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`; this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); vec4 packedInput = getA(${sourceCoords}); setOutput(getChannel(packedInput, ${coords})); } `; } } const whereImpl = whereImpl$1; const EPSILON_FLOAT32 = 1e-7; const EPSILON_FLOAT16 = 1e-4; const binaryCaches = {}; function getBinaryCache(webGLVersion) { if (webGLVersion in binaryCaches) { return binaryCaches[webGLVersion]; } binaryCaches[webGLVersion] = {}; return binaryCaches[webGLVersion]; } const CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD'); const BEFORE_PAGING_CONSTANT = 600; function numMBBeforeWarning() { if (env().global.screen == null) { return 1024; } return (env().global.screen.height * env().global.screen.width * window.devicePixelRatio) * BEFORE_PAGING_CONSTANT / 1024 / 1024; } class MathBackendWebGL extends KernelBackend { nextDataId() { return MathBackendWebGL.nextDataId++; } constructor(gpuResource) { super(); this.pendingRead = new WeakMap(); this.pendingDisposal = new WeakSet(); this.dataRefCount = new WeakMap(); this.numBytesInGPU = 0; this.uploadWaitMs = 0; this.downloadWaitMs = 0; this.lastGlFlushTime = 0; this.warnedAboutMemory = false; this.pendingDeletes = 0; this.disposed = false; if (!env().getBool('HAS_WEBGL')) { throw new Error('WebGL is not supported on this device'); } let newGPGPU; if (gpuResource != null) { if (gpuResource instanceof GPGPUContext) { newGPGPU = gpuResource; } else { const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'), gpuResource); newGPGPU = new GPGPUContext(gl); } this.binaryCache = {}; this.gpgpuCreatedLocally = false; } else { const gl = getWebGLContext(env().getNumber('WEBGL_VERSION')); newGPGPU = new GPGPUContext(gl); this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION')); this.gpgpuCreatedLocally = true; } this.gpgpu = newGPGPU; this.canvas = this.gpgpu.gl.canvas; this.textureManager = new TextureManager(this.gpgpu); this.numMBBeforeWarning = numMBBeforeWarning(); this.texData = new DataStorage(this, engine()); } numDataIds() { return this.texData.numDataIds() - this.pendingDeletes; } writeTexture(texture, shape, dtype, texHeight, texWidth, channels) { const input = this.makeTensorInfo(shape, dtype); const inData = this.texData.get(input.dataId); inData.isPacked = false; inData.texture = { texture, texShape: [texHeight, texWidth] }; inData.texShape = [texHeight, texWidth]; const shapeAs3D = getShapeAs3D(shape); const program = new EncodeMatrixProgram(shapeAs3D, false , channels); const output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]); output.shape = shape; inData.texture = null; this.disposeIntermediateTensorInfo(input); return output.dataId; } write(values, shape, dtype) { if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') || env().getBool('DEBUG')) { this.checkNumericalProblems(values); } if (dtype === 'complex64' && values != null) { throw new Error(`Cannot write to a complex64 dtype. ` + `Please use tf.complex(real, imag).`); } const dataId = { id: this.nextDataId() }; this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1 }); return dataId; } refCount(dataId) { if (this.texData.has(dataId)) { const tensorData = this.texData.get(dataId); return tensorData.refCount; } return 0; } incRef(dataId) { const texData = this.texData.get(dataId); texData.refCount++; } decRef(dataId) { if (this.texData.has(dataId)) { const texData = this.texData.get(dataId); texData.refCount--; } } move(dataId, values, shape, dtype, refCount) { if (env().getBool('DEBUG')) { this.checkNumericalProblems(values); } if (dtype === 'complex64') { throw new Error(`Cannot write to a complex64 dtype. ` + `Please use tf.complex(real, imag).`); } this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount }); } disposeIntermediateTensorInfo(tensorInfo) { this.disposeData(tensorInfo.dataId); } readSync(dataId) { const texData = this.texData.get(dataId); const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData; if (slice != null) { let program; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype); const data = this.readSync(res.dataId); this.disposeIntermediateTensorInfo(res); return data; } if (values != null) { return this.convertAndCacheOnCPU(dataId); } if (dtype === 'string') { return values; } const shouldTimeProgram = this.activeTimers != null; let start; if (shouldTimeProgram) { start = now(); } let result; if (dtype === 'complex64') { const realValues = this.readSync(complexTensorInfos.real.dataId); const imagValues = this.readSync(complexTensorInfos.imag.dataId); result = mergeRealAndImagArrays(realValues, imagValues); } else { result = this.getValuesFromTexture(dataId); } if (shouldTimeProgram) { this.downloadWaitMs += now() - start; } return this.convertAndCacheOnCPU(dataId, result); } async read(dataId) { if (this.pendingRead.has(dataId)) { const subscribers = this.pendingRead.get(dataId); return new Promise(resolve => subscribers.push(resolve)); } const texData = this.texData.get(dataId); const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData; if (slice != null) { let program; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype); const data = this.read(res.dataId); this.disposeIntermediateTensorInfo(res); return data; } if (values != null) { return this.convertAndCacheOnCPU(dataId); } if (env().getBool('DEBUG')) { if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && env().getNumber('WEBGL_VERSION') === 2) { throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` + `WEBGL_VERSION=2 not yet supported.`); } } let buffer = null; let tmpDownloadTarget; if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) { tmpDownloadTarget = this.decode(dataId); const tmpData = this.texData.get(tmpDownloadTarget.dataId); buffer = this.gpgpu.createBufferFromTexture(tmpData.texture.texture, ...getDenseTexShape(shape)); } this.pendingRead.set(dataId, []); if (dtype !== 'complex64') { await this.gpgpu.createAndWaitForFence(); } let vals; if (dtype === 'complex64') { const ps = await Promise.all([ this.read(complexTensorInfos.real.dataId), this.read(complexTensorInfos.imag.dataId) ]); const realValues = ps[0]; const imagValues = ps[1]; vals = mergeRealAndImagArrays(realValues, imagValues); } else if (buffer == null) { vals = this.getValuesFromTexture(dataId); } else { const size = sizeFromShape(shape); vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size); } if (tmpDownloadTarget != null) { this.disposeIntermediateTensorInfo(tmpDownloadTarget); } if (buffer != null) { const gl = this.gpgpu.gl; callAndCheck(gl, () => gl.deleteBuffer(buffer)); } const dTypeVals = this.convertAndCacheOnCPU(dataId, vals); const subscribers = this.pendingRead.get(dataId); this.pendingRead.delete(dataId); subscribers.forEach(resolve => resolve(dTypeVals)); if (this.pendingDisposal.has(dataId)) { this.pendingDisposal.delete(dataId); if (this.disposeData(dataId)) { engine().removeDataId(dataId, this); } this.pendingDeletes--; } return dTypeVals; } readToGPU(dataId, options = {}) { const texData = this.texData.get(dataId); const { values, shape, slice, dtype, isPacked, texture } = texData; if (dtype === 'complex64') { throw new Error('Does not support reading texture for complex64 dtype.'); } if (slice != null) { let program; if (isPacked) { program = new UnaryOpPackedProgram(shape, CLONE); } else { program = new UnaryOpProgram(shape, CLONE); } const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype); const gpuResouorce = this.readToGPU(res, options); this.disposeIntermediateTensorInfo(res); return gpuResouorce; } if (texture == null) { if (values != null) { throw new Error('Data is not on GPU but on CPU.'); } else { throw new Error('There is no data on GPU or CPU.'); } } const tmpTarget = this.decode(dataId, options.customTexShape); const tensorRef = engine().makeTensorFromTensorInfo(tmpTarget); const tmpData = this.texData.get(tmpTarget.dataId); return Object.assign({ tensorRef }, tmpData.texture); } bufferSync(t) { const data = this.readSync(t.dataId); if (t.dtype === 'string') { try { const strings = data.map(d => decodeString(d)); return buffer(t.shape, t.dtype, strings); } catch (_a) { throw new Error('Failed to decode encoded string bytes into utf-8'); } } return buffer(t.shape, t.dtype, data); } checkNumericalProblems(values) { if (values == null) { return; } for (let i = 0; i < values.length; i++) { const num = values[i]; if (!canBeRepresented(num)) { if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { throw Error(`The value ${num} cannot be represented with your ` + `current settings. Consider enabling float32 rendering: ` + `'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`); } throw Error(`The value ${num} cannot be represented on this device.`); } } } getValuesFromTexture(dataId) { const { shape, dtype, isPacked } = this.texData.get(dataId); const size = sizeFromShape(shape); if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { const tmpTarget = this.decode(dataId); const tmpData = this.texData.get(tmpTarget.dataId); const vals = this.gpgpu .downloadMatrixFromPackedTexture(tmpData.texture.texture, ...getDenseTexShape(shape)) .subarray(0, size); this.disposeIntermediateTensorInfo(tmpTarget); return vals; } const shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true; const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape; const program = shouldUsePackedProgram ? new EncodeFloatPackedProgram(outputShape) : new EncodeFloatProgram(outputShape); const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32'); const tmpData = this.texData.get(output.dataId); const vals = this.gpgpu .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1]) .subarray(0, size); this.disposeIntermediateTensorInfo(output); return vals; } timerAvailable() { return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0; } time(f) { const oldActiveTimers = this.activeTimers; const newActiveTimers = []; let outerMostTime = false; if (this.programTimersStack == null) { this.programTimersStack = newActiveTimers; outerMostTime = true; } else { this.activeTimers.push(newActiveTimers); } this.activeTimers = newActiveTimers; f(); const flattenedActiveTimerQueries = flatten$1(this.activeTimers.map((d) => d.query)) .filter(d => d != null); const flattenedActiveTimerNames = flatten$1(this.activeTimers.map((d) => d.name)) .filter(d => d != null); this.activeTimers = oldActiveTimers; if (outerMostTime) { this.programTimersStack = null; } const res = { uploadWaitMs: this.uploadWaitMs, downloadWaitMs: this.downloadWaitMs, kernelMs: null, wallMs: null }; return (async () => { if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { const kernelMs = await Promise.all(flattenedActiveTimerQueries); res['kernelMs'] = sum$2(kernelMs); res['getExtraProfileInfo'] = () => kernelMs .map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d })) .map(d => `${d.name}: ${d.ms}`) .join(', '); } else { res['kernelMs'] = { error: 'WebGL query timers are not supported in this environment.' }; } this.uploadWaitMs = 0; this.downloadWaitMs = 0; return res; })(); } memory() { return { unreliable: false, numBytesInGPU: this.numBytesInGPU, numBytesInGPUAllocated: this.textureManager.numBytesAllocated, numBytesInGPUFree: this.textureManager.numBytesFree }; } startTimer() { if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.beginQuery(); } return { startMs: now(), endMs: null }; } endTimer(query) { if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { this.gpgpu.endQuery(); return query; } query.endMs = now(); return query; } async getQueryTime(query) { if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.waitForQueryAndGetTime(query); } const timerQuery = query; return timerQuery.endMs - timerQuery.startMs; } disposeData(dataId, force = false) { if (this.pendingDisposal.has(dataId)) { return false; } if (!this.texData.has(dataId)) { return true; } if (force) { this.texData.get(dataId).refCount = 0; } else { this.texData.get(dataId).refCount--; } if (!force && this.texData.get(dataId).refCount > 0) { return false; } if (this.pendingRead.has(dataId)) { this.pendingDisposal.add(dataId); this.pendingDeletes++; return false; } this.releaseGPUData(dataId); const { complexTensorInfos } = this.texData.get(dataId); if (complexTensorInfos != null) { this.disposeData(complexTensorInfos.real.dataId, force); this.disposeData(complexTensorInfos.imag.dataId, force); } this.texData.delete(dataId); return true; } releaseGPUData(dataId) { const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId); const key = slice && slice.origDataId || dataId; const refCount = this.dataRefCount.get(key); if (refCount > 1) { this.dataRefCount.set(key, refCount - 1); } else { this.dataRefCount.delete(key); if (texture != null) { this.numBytesInGPU -= this.computeBytes(texShape, dtype); this.textureManager.releaseTexture(texture, texShape, usage, isPacked); } } const texData = this.texData.get(dataId); texData.texture = null; texData.texShape = null; texData.isPacked = false; texData.slice = null; } getTexture(dataId) { this.uploadToGPU(dataId); return this.texData.get(dataId).texture.texture; } getDataInfo(dataId) { return this.texData.get(dataId); } shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) { return env().getBool('WEBGL_CPU_FORWARD') && inputs.every(input => this.texData.get(input.dataId).texture == null && sizeFromShape(input.shape) < sizeThreshold); } getGPGPUContext() { return this.gpgpu; } where(condition) { warn('tf.where() in webgl locks the UI thread. ' + 'Call tf.whereAsync() instead'); const condVals = condition.dataSync(); return whereImpl(condition.shape, condVals); } packedUnaryOp(x, op, dtype) { const program = new UnaryOpPackedProgram(x.shape, op); const outInfo = this.compileAndRun(program, [x], dtype); return engine().makeTensorFromTensorInfo(outInfo); } abs(x) { if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') { const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values); return this.makeOutput(x.shape, x.dtype, outValues); } if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { return this.packedUnaryOp(x, ABS$1, x.dtype); } const program = new UnaryOpProgram(x.shape, ABS$1); const outInfo = this.compileAndRun(program, [x]); return engine().makeTensorFromTensorInfo(outInfo); } makeTensorInfo(shape, dtype, values) { let dataId; if (dtype === 'string' && values != null && values.length > 0 && isString(values[0])) { const encodedValues = values.map(d => encodeString(d)); dataId = this.write(encodedValues, shape, dtype); } else { dataId = this.write(values, shape, dtype); } this.texData.get(dataId).usage = null; return { dataId, shape, dtype }; } makeOutput(shape, dtype, values) { return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this); } unpackTensor(input) { const program = new UnpackProgram(input.shape); return this.runWebGLProgram(program, [input], input.dtype); } packTensor(input) { const program = new PackProgram(input.shape); const preventEagerUnpackingOutput = true; return this.runWebGLProgram(program, [input], input.dtype, null , preventEagerUnpackingOutput); } packedReshape(input, afterShape) { const input3DShape = [ getBatchDim(input.shape), ...getRowsCols(input.shape) ]; const input3D = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; const afterShapeAs3D = [ getBatchDim(afterShape), ...getRowsCols(afterShape) ]; const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); const preventEagerUnpackingOfOutput = true; const customValues = [input3DShape]; const output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput); return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; } decode(dataId, customTexShape) { const texData = this.texData.get(dataId); const { isPacked, shape, dtype } = texData; if (customTexShape != null) { const size = sizeFromShape(shape); const texSize = customTexShape[0] * customTexShape[1] * 4; assert$1(size <= texSize, () => 'customTexShape is too small. ' + 'Row * Column * 4 should be equal or larger than the ' + 'size of the tensor data.'); } const shapeAs3D = getShapeAs3D(shape); let program; if (isPacked) { program = new DecodeMatrixPackedProgram(shapeAs3D); } else { program = new DecodeMatrixProgram(shapeAs3D); } const preventEagerUnpackingOfOutput = true; const customValues = [customTexShape != null ? customTexShape : getDenseTexShape(shapeAs3D)]; const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape); return { dtype, shape, dataId: out.dataId }; } runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false, customTexShape) { const output = this.makeTensorInfo(program.outputShape, outputDtype); const outData = this.texData.get(output.dataId); if (program.packedOutput) { outData.isPacked = true; } if (program.outPackingScheme === PackingScheme.DENSE) { const texelShape = customTexShape != null ? customTexShape : getDenseTexShape(program.outputShape); outData.texShape = texelShape.map(d => d * 2); } if (program.outTexUsage != null) { outData.usage = program.outTexUsage; } if (sizeFromShape(output.shape) === 0) { outData.values = getTypedArrayFromDType(output.dtype, 0); return output; } const dataToDispose = []; const inputsData = inputs.map(input => { if (input.dtype === 'complex64') { throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` + `dtypes, please separate the program into real and imaginary ` + `parts.`); } let texData = this.texData.get(input.dataId); if (texData.texture == null) { if (!program.packedInputs && sizeFromShape(input.shape) <= env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { return { shape: input.shape, texData: null, isUniform: true, uniformValues: texData.values }; } if (program.packedInputs) { texData.isPacked = true; texData.shape = input.shape; } } this.uploadToGPU(input.dataId); if (!!texData.isPacked !== !!program.packedInputs) { input = texData.isPacked ? this.unpackTensor(input) : this.packTensor(input); dataToDispose.push(input); texData = this.texData.get(input.dataId); } else if (texData.isPacked && !isReshapeFree(texData.shape, input.shape)) { const savedInput = input; const targetShape = input.shape; input.shape = texData.shape; input = this.packedReshape(input, targetShape); dataToDispose.push(input); texData = this.texData.get(input.dataId); savedInput.shape = targetShape; } return { shape: input.shape, texData, isUniform: false }; }); this.uploadToGPU(output.dataId); const outputData = { shape: output.shape, texData: outData, isUniform: false }; const key = makeShaderKey(program, inputsData, outputData); const binary = this.getAndSaveBinary(key, () => { return compileProgram(this.gpgpu, program, inputsData, outputData); }); const shouldTimeProgram = this.activeTimers != null; let query; if (shouldTimeProgram) { query = this.startTimer(); } if (!env().get('ENGINE_COMPILE_ONLY')) { runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues); } dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info)); if (shouldTimeProgram) { query = this.endTimer(query); this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) }); } const glFlushThreshold = env().getNumber('WEBGL_FLUSH_THRESHOLD'); if (glFlushThreshold > 0) { const time = now(); if ((time - this.lastGlFlushTime) > glFlushThreshold) { this.gpgpu.gl.flush(); this.lastGlFlushTime = time; } } if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked && preventEagerUnpackingOfOutput === false) { const unpacked = this.unpackTensor(output); this.disposeIntermediateTensorInfo(output); return unpacked; } return output; } compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false) { outputDtype = outputDtype || inputs[0].dtype; const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput); return outInfo; } getAndSaveBinary(key, getBinary) { if (!(key in this.binaryCache)) { this.binaryCache[key] = getBinary(); } return this.binaryCache[key]; } getTextureManager() { return this.textureManager; } dispose() { if (this.disposed) { return; } if (!env().getBool('IS_TEST')) { const allKeys = Object.keys(this.binaryCache); allKeys.forEach(key => { this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram); delete this.binaryCache[key]; }); } this.textureManager.dispose(); if (this.canvas != null && (typeof (HTMLCanvasElement) !== 'undefined' && this.canvas instanceof HTMLCanvasElement)) { this.canvas.remove(); } else { this.canvas = null; } if (this.gpgpuCreatedLocally) { this.gpgpu.program = null; this.gpgpu.dispose(); } this.disposed = true; } floatPrecision() { if (this.floatPrecisionValue == null) { this.floatPrecisionValue = tidy(() => { if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) { const debugFlag = env().getBool('DEBUG'); env().set('DEBUG', false); const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0]; env().set('DEBUG', debugFlag); if (underflowCheckValue > 0) { return 32; } } return 16; }); } return this.floatPrecisionValue; } epsilon() { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; } uploadToGPU(dataId) { const texData = this.texData.get(dataId); const { shape, dtype, values, texture, usage, isPacked } = texData; if (texture != null) { return; } const shouldTimeProgram = this.activeTimers != null; let start; if (shouldTimeProgram) { start = now(); } let texShape = texData.texShape; if (texShape == null) { texShape = getTextureShapeFromLogicalShape(shape, isPacked); texData.texShape = texShape; } if (values != null) { const shapeAs3D = getShapeAs3D(shape); let program; let width = texShape[1], height = texShape[0]; const isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray; if (isPacked || !isByteArray) { [width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]); } if (isPacked) { program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray); } else { program = new EncodeMatrixProgram(shapeAs3D, isByteArray); } const tempDenseInputTexShape = isByteArray ? [height, width] : texShape; const tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype); const tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId); if (isByteArray) { tempDenseInputTexData.usage = TextureUsage.PIXELS; } else { tempDenseInputTexData.usage = TextureUsage.UPLOAD; } tempDenseInputTexData.texShape = tempDenseInputTexShape; this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values); const customValues = [[height, width]]; const preventEagerUnpacking = true; const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking); const outputTexData = this.texData.get(encodedOutputTarget.dataId); texData.texShape = outputTexData.texShape; texData.isPacked = outputTexData.isPacked; texData.usage = outputTexData.usage; if (!env().get('ENGINE_COMPILE_ONLY')) { texData.texture = outputTexData.texture; texData.values = null; this.texData.delete(encodedOutputTarget.dataId); } else { this.disposeData(encodedOutputTarget.dataId); } this.disposeIntermediateTensorInfo(tempDenseInputHandle); if (shouldTimeProgram) { this.uploadWaitMs += now() - start; } } else { const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked); texData.texture = newTexture; } } convertAndCacheOnCPU(dataId, float32Values) { const texData = this.texData.get(dataId); const { dtype } = texData; if (float32Values != null) { texData.values = float32ToTypedArray(float32Values, dtype); } return texData.values; } acquireTexture(texShape, texType, dtype, isPacked) { this.numBytesInGPU += this.computeBytes(texShape, dtype); if (!this.warnedAboutMemory && this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) { const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2); this.warnedAboutMemory = true; console.warn(`High memory usage in GPU: ${mb} MB, ` + `most likely due to a memory leak`); } return this.textureManager.acquireTexture(texShape, texType, isPacked); } computeBytes(shape, dtype) { return shape[0] * shape[1] * bytesPerElement(dtype); } checkCompileCompletion() { for (const [, binary] of Object.entries(this.binaryCache)) { this.checkCompletion_(binary); } } async checkCompileCompletionAsync() { const ps = []; if (this.gpgpu.parallelCompilationExtension) { for (const [, binary] of Object.entries(this.binaryCache)) { ps.push(this.checkCompletionAsync_(binary)); } return Promise.all(ps); } else { for (const [, binary] of Object.entries(this.binaryCache)) { const p = new Promise((resolve) => { try { this.checkCompletion_(binary); resolve(true); } catch (error) { throw error; } }); ps.push(p); } return Promise.all(ps); } } async checkCompletionAsync_(binary) { if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) { return this.checkCompletion_(binary); } else { await nextFrame(); return this.checkCompletionAsync_(binary); } } checkCompletion_(binary) { if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) { console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram)); if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader)); throw new Error('Failed to compile fragment shader.'); } throw new Error('Failed to link vertex and fragment shaders.'); } return true; } getUniformLocations() { for (const binary of Object.values(this.binaryCache)) { this.gpgpu.buildVao(binary.webGLProgram); const { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation } = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram); binary.variablesLocations = variablesLocations; binary.customUniformLocations = customUniformLocations; binary.infLoc = infLoc; binary.nanLoc = nanLoc; binary.outShapeLocation = outShapeLocation; binary.outShapeStridesLocation = outShapeStridesLocation; binary.outTexShapeLocation = outTexShapeLocation; } } createTensorFromGPUData(values, shape, dtype) { values.channels = values.channels || 'RGBA'; const { texture, height, width, channels } = values; const backend = engine().backend; if (!backend.gpgpu.gl.isTexture(texture)) { throw new Error(`The texture is invalid. Also, please make sure the texture and ` + `the TFJS WebGL backend are using the same canvas. If you want to ` + `use your own custom canvas, you have to create and use the custom ` + `TFJS WebGL backend created from the canvas through ` + `'new tf.MathBackendWebGL(customCanvas)'.`); } const dataId = backend.writeTexture(texture, shape, dtype, height, width, channels); return engine().makeTensorFromDataId(dataId, shape, dtype, backend); } } MathBackendWebGL.nextDataId = 0; function float32ToTypedArray(a, dtype) { if (dtype === 'float32' || dtype === 'complex64') { return a; } else if (dtype === 'int32' || dtype === 'bool') { const result = (dtype === 'int32') ? new Int32Array(a.length) : new Uint8Array(a.length); for (let i = 0; i < result.length; ++i) { result[i] = Math.round(a[i]); } return result; } else { throw new Error(`Unknown dtype ${dtype}`); } } if (isBrowser()) { registerBackend('webgl', () => new MathBackendWebGL(), 2 ); } const CHECK_NAN_SNIPPET = ` if (isnan(a)) return a; if (isnan(b)) return b; `; class BinaryOpProgram { constructor(op, aShape, bShape) { this.variableNames = ['A', 'B']; this.outputShape = assertAndGetBroadcastShape(aShape, bShape); this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); this.userCode = ` float binaryOperation(float a, float b) { ${op} } void main() { float a = getAAtOutCoords(); float b = getBAtOutCoords(); setOutput(binaryOperation(a, b)); } `; } } const CHECK_NAN_SNIPPET_PACKED = ` result.r = isNaN.r ? NAN : result.r; result.g = isNaN.g ? NAN : result.g; result.b = isNaN.b ? NAN : result.b; result.a = isNaN.a ? NAN : result.a; `; class BinaryOpPackedProgram { constructor(op, aShape, bShape, checkOutOfBounds = false) { this.variableNames = ['A', 'B']; this.supportsBroadcasting = true; this.packedInputs = true; this.packedOutput = true; this.outputShape = assertAndGetBroadcastShape(aShape, bShape); const rank = this.outputShape.length; this.enableShapeUniforms = useShapeUniforms(rank); let checkOutOfBoundsString = ''; if (checkOutOfBounds) { if (rank === 0 || sizeFromShape(this.outputShape) === 1) { checkOutOfBoundsString = ` result.y = 0.; result.z = 0.; result.w = 0.; `; } else { const dtype = getCoordsDataType(rank); checkOutOfBoundsString = ` ${dtype} coords = getOutputCoords(); `; if (rank === 1) { if (this.enableShapeUniforms) { checkOutOfBoundsString += ` result.y = (coords + 1) >= outShape ? 0. : result.y; result.z = 0.; result.w = 0.; `; } else { checkOutOfBoundsString += ` result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y; result.z = 0.; result.w = 0.; `; } } else { const channels = getChannels('coords', rank); if (this.enableShapeUniforms) { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= outShape[${rank} - 2]; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= outShape[${rank} - 1]; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } else { checkOutOfBoundsString += ` bool nextRowOutOfBounds = (${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]}; bool nextColOutOfBounds = (${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]}; result.y = nextColOutOfBounds ? 0. : result.y; result.z = nextRowOutOfBounds ? 0. : result.z; result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w; `; } } } } this.userCode = ` vec4 binaryOperation(vec4 a, vec4 b) { ${op} } void main() { vec4 a = getAAtOutCoords(); vec4 b = getBAtOutCoords(); vec4 result = binaryOperation(a, b); ${checkOutOfBoundsString} setOutput(result); } `; } } function identity(args) { const { inputs, backend } = args; const { x } = inputs; backend.incRef(x.dataId); return { dataId: x.dataId, shape: x.shape, dtype: x.dtype }; } const identityConfig = { kernelName: Identity$1, backendName: 'webgl', kernelFunc: identity }; function complex(args) { const { inputs, backend } = args; const { real, imag } = inputs; const complexInfo = backend.makeTensorInfo(real.shape, 'complex64'); const complex = backend.texData.get(complexInfo.dataId); const realTensorInfo = identity({ inputs: { x: real }, backend }); const imagTensorInfo = identity({ inputs: { x: imag }, backend }); complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo }; return complexInfo; } const complexConfig = { kernelName: Complex, backendName: 'webgl', kernelFunc: complex }; const LEAKYRELU = `return (a < 0.) ? b * a : a;`; const LEAKYRELU_PACKED = ` vec4 aLessThanZero = vec4(lessThan(a, vec4(0.))); return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a); `; function leakyRelu(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { alpha } = attrs; const $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32')); const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) : new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape); const result = backend.runWebGLProgram(program, [x, $alpha], 'float32'); backend.disposeIntermediateTensorInfo($alpha); return result; } const leakyReluConfig = { kernelName: LeakyRelu, backendName: 'webgl', kernelFunc: leakyRelu }; const PRELU = `return (a < 0.) ? b * a : a;`; const PRELU_PACKED = ` vec4 aLessThanZero = vec4(lessThan(a, vec4(0.))); return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a); `; function prelu(args) { const { inputs, backend } = args; const { x, alpha } = inputs; const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) : new BinaryOpProgram(PRELU, x.shape, alpha.shape); return backend.runWebGLProgram(program, [x, alpha], 'float32'); } const preluConfig = { kernelName: Prelu, backendName: 'webgl', kernelFunc: prelu }; const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`; function unaryKernelFunc({ opSnippet, packedOpSnippet, cpuKernelImpl, dtype }) { return ({ inputs, backend }) => { const { x } = inputs; const webglBackend = backend; const $dtype = dtype || x.dtype; if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) { const xData = webglBackend.texData.get(x.dataId); const outValues = cpuKernelImpl(xData.values, $dtype); return webglBackend.makeTensorInfo(x.shape, $dtype, outValues); } const shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null; let program; if (shouldUsePackedProgram) { program = new UnaryOpPackedProgram(x.shape, packedOpSnippet); } else { program = new UnaryOpProgram(x.shape, opSnippet); } return webglBackend.runWebGLProgram(program, [x], $dtype); }; } function binaryKernelFunc({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) { return ({ inputs, backend }) => { const { a, b } = inputs; const webglBackend = backend; if (supportsComplex && a.dtype === 'complex64') { const aData = webglBackend.texData.get(a.dataId); const bData = webglBackend.texData.get(b.dataId); const [real, imag] = [ [aData.complexTensorInfos.real, bData.complexTensorInfos.real], [aData.complexTensorInfos.imag, bData.complexTensorInfos.imag] ].map(complexParts => { const [aPart, bPart] = complexParts; const aHandle = { dataId: aPart.dataId, dtype: aPart.dtype, shape: a.shape }; const bHandle = { dataId: bPart.dataId, dtype: bPart.dtype, shape: b.shape }; const program = new BinaryOpProgram(opSnippet, a.shape, b.shape); return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype)); }); const complexOutput = complex({ inputs: { real, imag }, backend: webglBackend }); webglBackend.disposeIntermediateTensorInfo(real); webglBackend.disposeIntermediateTensorInfo(imag); return complexOutput; } const $dtype = dtype || upcastType(a.dtype, b.dtype); if ((a.dtype === 'string' || b.dtype === 'string' || webglBackend.shouldExecuteOnCPU([a, b])) && cpuKernelImpl != null) { const aVals = webglBackend.texData.get(a.dataId).values; const bVals = webglBackend.texData.get(b.dataId).values; const decodedAVals = a.dtype === 'string' ? fromUint8ToStringArray(aVals) : aVals; const decodedBVals = a.dtype === 'string' ? fromUint8ToStringArray(bVals) : bVals; const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype); const out = webglBackend.makeTensorInfo(outShape, $dtype); const outData = webglBackend.texData.get(out.dataId); outData.values = outValues; return out; } const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') && packedOpSnippet != null; let program; if (shouldUsePackedProgram) { program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds); } else { program = new BinaryOpProgram(opSnippet, a.shape, b.shape); } return webglBackend.runWebGLProgram(program, [a, b], $dtype); }; } function mapActivationToShaderProgram(activation, packed = false) { if (activation === 'linear') { if (packed) { return LINEAR; } return LINEAR$1; } else if (activation === 'relu') { if (packed) { return RELU$1; } return RELU$2; } else if (activation === 'elu') { if (packed) { return ELU$1; } return ELU$2; } else if (activation === 'relu6') { if (packed) { return RELU6$1; } return RELU6$2; } else if (activation === 'prelu') { if (packed) { return PRELU_PACKED; } return PRELU; } else if (activation === 'leakyrelu') { if (packed) { return LEAKYRELU_PACKED; } return LEAKYRELU; } else if (activation === 'sigmoid') { if (packed) { return SIGMOID$1; } return SIGMOID$2; } throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`); } class MatMulPackedProgram { constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) { this.variableNames = ['matrixA', 'matrixB']; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const sharedDim = transposeA ? aShape[1] : aShape[2]; const sharedDimensionPacked = Math.ceil(sharedDim / 2); const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2'; const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z'; const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww']; const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw']; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluActivation) { this.variableNames.push('leakyreluAlpha'); } let batchASnippet = 'rc.x'; let batchBSnippet = 'rc.x'; if (aShape[0] < bShape[0]) { batchASnippet = `imod(rc.x, ${aShape[0]})`; } else if (bShape[0] < aShape[0]) { batchBSnippet = `imod(rc.x, ${bShape[0]})`; } this.userCode = ` ${activationSnippet} const float sharedDimension = ${sharedDimensionPacked}.0; vec4 dot2x2ARowBCol(ivec3 rc) { vec4 result = vec4(0); int batchA = ${batchASnippet}; int batchB = ${batchBSnippet}; for (int i = 0; i < ${sharedDimensionPacked}; i++) { vec4 a = getMatrixA(batchA, ${aSample}); vec4 b = getMatrixB(batchB, ${bSample}); result += (${aSwizzle[0]} * ${bSwizzle[0]}); result += (${aSwizzle[1]} * ${bSwizzle[1]}); } return result; } void main() { ivec3 rc = getOutputCoords(); vec4 result = dot2x2ARowBCol(rc); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } const COMPLEX_MULTIPLY = { REAL: 'return areal * breal - aimag * bimag;', IMAG: 'return areal * bimag + aimag * breal;' }; class BinaryOpComplexProgram { constructor(op, aShape, bShape) { this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag']; this.outputShape = assertAndGetBroadcastShape(aShape, bShape); this.userCode = ` float binaryOpComplex( float areal, float aimag, float breal, float bimag) { ${op} } void main() { float areal = getARealAtOutCoords(); float aimag = getAImagAtOutCoords(); float breal = getBRealAtOutCoords(); float bimag = getBImagAtOutCoords(); setOutput(binaryOpComplex(areal, aimag, breal, bimag)); } `; } } const MUL = 'return a * b;'; function multiply(args) { const { inputs, backend } = args; const { a, b } = inputs; const dtype = upcastType(a.dtype, b.dtype); if (a.dtype === 'complex64') { const aData = backend.texData.get(a.dataId); const bData = backend.texData.get(b.dataId); const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape); const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape); const inputs = [ { dataId: aData.complexTensorInfos.real.dataId, dtype: aData.complexTensorInfos.real.dtype, shape: a.shape }, { dataId: aData.complexTensorInfos.imag.dataId, dtype: aData.complexTensorInfos.imag.dtype, shape: a.shape }, { dataId: bData.complexTensorInfos.real.dataId, dtype: bData.complexTensorInfos.real.dtype, shape: b.shape }, { dataId: bData.complexTensorInfos.imag.dataId, dtype: bData.complexTensorInfos.imag.dtype, shape: b.shape } ]; const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32'); const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32'); const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(imagPart); return complexOutput; } if (backend.shouldExecuteOnCPU([a, b])) { const aData = backend.texData.get(a.dataId); const bData = backend.texData.get(b.dataId); const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype); const out = backend.makeTensorInfo(outShape, dtype); const outData = backend.texData.get(out.dataId); outData.values = outValues; return out; } let program; if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { program = new BinaryOpPackedProgram(MUL, a.shape, b.shape); } else { program = new BinaryOpProgram(MUL, a.shape, b.shape); } return backend.runWebGLProgram(program, [a, b], dtype); } const multiplyConfig = { kernelName: Multiply, backendName: 'webgl', kernelFunc: multiply }; function packedReshape(input, afterShape, backend) { const input3DShape = [getBatchDim(input.shape), ...getRowsCols(input.shape)]; const input3D = { dtype: input.dtype, shape: input3DShape, dataId: input.dataId }; const afterShapeAs3D = [getBatchDim(afterShape), ...getRowsCols(afterShape)]; const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); const preventEagerUnpackingOfOutput = true; const customValues = [input3DShape]; const output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput); return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; } function reshape(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { shape } = attrs; const webglBackend = backend; const xSize = sizeFromShape(x.shape); const $shape = inferFromImplicitShape(shape, xSize); const $xSize = sizeFromShape($shape); assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` + `shape (${x.shape}) has ${xSize} elements. The new shape and old ` + `shape must have the same number of elements.`); const xTexData = webglBackend.texData.get(x.dataId); if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) && !(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) { return packedReshape(x, $shape, webglBackend); } webglBackend.incRef(x.dataId); return { dataId: x.dataId, shape: $shape, dtype: x.dtype }; } const reshapeConfig = { kernelName: Reshape$1, backendName: 'webgl', kernelFunc: reshape }; class MeanProgram { constructor(reduceInfo, divisor) { this.variableNames = ['x']; const { windowSize, batchSize, inSize, outSize } = reduceInfo; this.outputShape = [batchSize, outSize]; const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; let updateSnippet = `sumValue += dot(values, ones);`; if (divisor != null) { const denominator = 1 / divisor; updateSnippet = `sumValue += dot(values * ${isInt(denominator) ? denominator.toPrecision(2) : denominator}, ones);`; } let checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return 0.0; } `; } this.userCode = ` const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float getValue(int batch, int inIdx) { ${checkOutOfBounds} return getX(batch, inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = outIdx * ${windowSize}; float sumValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), 0.0, 0.0); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), 0.0); ${updateSnippet} } setOutput(sumValue); } `; } } class ReduceProgram { constructor(reduceInfo, reduceType) { this.variableNames = ['x']; const { windowSize, batchSize, inSize, outSize } = reduceInfo; this.outputShape = [batchSize, outSize]; let initializationValue = '0.0'; let compareOp = ``; if (reduceType === 'prod') { initializationValue = '1.0'; } else if (reduceType === 'min') { initializationValue = '1.0 / 1e-20'; compareOp = `min`; } else if (reduceType === 'max') { initializationValue = '-1.0 / 1e-20'; compareOp = `max`; } let returnValue = `${reduceType}(${reduceType}(${reduceType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (reduceType === 'sum') { returnValue = `sumValue`; } else if (reduceType === 'prod') { returnValue = `prodValue`; } else if (reduceType === 'all') { returnValue = `allValue`; } else if (reduceType === 'any') { returnValue = `anyValue`; } const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; let updateSnippet = ` if (${reduceType === 'sum'}) { sumValue += dot(values, ones); } else if (${reduceType === 'prod'}) { vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]); prodValue *= tmp[0] * tmp[1]; } else { minMaxValue = ${compareOp}(values, minMaxValue); if (${reduceType === 'min'} || ${reduceType === 'max'}) { minMaxValue = ${compareOp}(values, minMaxValue); bvec4 isNaN = isnan(values); if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) { minMaxValue = vec4(NAN); } } } `; let vecType = `vec4`; if (reduceType === 'all') { initializationValue = '1.0'; updateSnippet = ` bool reducedAllValue = all(values); float floatedReducedAllValue = float(reducedAllValue); allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0); `; vecType = `bvec4`; } else if (reduceType === 'any') { initializationValue = '0.0'; updateSnippet = ` bool reducedAnyValue = any(values); float floatedReducedAnyValue = float(reducedAnyValue); anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0); `; vecType = `bvec4`; } let checkOutOfBounds = ''; if (inSize % windowSize > 0) { checkOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return initializationValue; } `; } this.userCode = ` const float initializationValue = ${initializationValue}; const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float getValue(int batch, int inIdx) { ${checkOutOfBounds} return getX(batch, inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = outIdx * ${windowSize}; vec4 minMaxValue = vec4(${initializationValue}); float prodValue = 1.0; float sumValue = 0.0; float allValue = 1.0; float anyValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; ${vecType} values = ${vecType}( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { ${vecType} values = ${vecType}( getValue(batch, inIdx), initializationValue, initializationValue, initializationValue ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { ${vecType} values = ${vecType}( getValue(batch, inIdx), getValue(batch, inIdx + 1), initializationValue, initializationValue ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { ${vecType} values = ${vecType}( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), initializationValue ); ${updateSnippet} } setOutput(${returnValue}); } `; } } function getReductionStages(inShape) { const stages = []; while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) { const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1]; const windowSize = computeOptimalWindowSize(outSize); stages.push({ inSize: outSize, windowSize, outSize: Math.ceil(outSize / windowSize) }); } return stages; } function reduce(x, dtype, reductionType, backend) { const reductionStages = getReductionStages(x.shape); let result = x; for (let i = 0; i < reductionStages.length; i++) { const { inSize, windowSize, outSize } = reductionStages[i]; let program; let previousResult; if (reductionType === 'mean') { program = i === 0 ? new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) : new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }); } else { program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType); } previousResult = result; result = backend.runWebGLProgram(program, [result], dtype); if (previousResult.dataId !== x.dataId) { backend.disposeIntermediateTensorInfo(previousResult); } } return result; } class TransposeProgram { constructor(aShape, newDim) { this.variableNames = ['A']; const outputShape = new Array(aShape.length); for (let i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[newDim[i]]; } this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const switched = getSwitchedCoords(newDim); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); setOutput(getA(${switched})); } `; } } function getSwitchedCoords(newDim) { const rank = newDim.length; if (rank > 6) { throw Error(`Transpose for rank ${rank} is not yet supported`); } const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v']; const switchedCoords = new Array(rank); for (let i = 0; i < newDim.length; i++) { switchedCoords[newDim[i]] = originalOrder[i]; } return switchedCoords.join(); } class TransposePackedProgram { constructor(aShape, newDim) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; const outputShape = new Array(aShape.length); for (let i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[newDim[i]]; } this.outputShape = outputShape; this.rank = outputShape.length; if (this.rank > 6) { throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`); } const dtype = getCoordsDataType(this.rank); const outputOrder = getVecChannels('rc', this.rank); const switchedOrder = new Array(this.rank); for (let i = 0; i < newDim.length; i++) { switchedOrder[newDim[i]] = outputOrder[i]; } const innerDims = `vec2(${switchedOrder.slice(-2).join()})`; const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`; const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`; this.userCode = ` void main() { ${dtype} rc = getOutputCoords(); vec4 result = vec4(0.); result[0] = ${getc}; if(${nextColumn}) { result[1] = ${getc}; } --${outputOrder[this.rank - 1]}; if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) { result[2] = ${getc}; if(${nextColumn}) { result[3] = ${getc}; } } setOutput(result); } `; } } function transposeImpl(x, perm, backend) { const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new TransposePackedProgram(x.shape, perm) : new TransposeProgram(x.shape, perm); return backend.runWebGLProgram(program, [x], x.dtype); } function sumImpl(x, axis, keepDims, backend) { const reductionIndices = axis; const xRank = x.shape.length; const origAxes = parseAxisParam(reductionIndices, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); const sumInputIsTransposed = permutedAxes != null; let sumInput = x; if (sumInputIsTransposed) { sumInput = transposeImpl(x, permutedAxes, backend); axes = getInnerMostAxes(axes.length, xRank); } assertAxesAreInnerMostDims('sum', axes, xRank); const [sumOutShape, reduceShape] = computeOutAndReduceShapes(sumInput.shape, axes); let outShape = sumOutShape; if (keepDims) { outShape = expandShapeToKeepDim(sumOutShape, origAxes); } const inSize = sizeFromShape(reduceShape); const xSize = sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend }); const outType = sumOutType(x.dtype); const reduced = reduce(reshapedInput, outType, 'sum', backend); const out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); if (sumInputIsTransposed) { backend.disposeIntermediateTensorInfo(sumInput); } return out; } function sum(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; return sumImpl(x, axis, keepDims, backend); } const sumConfig = { kernelName: Sum, backendName: 'webgl', kernelFunc: sum }; function transpose(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { perm } = attrs; const webglBackend = backend; const xRank = x.shape.length; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[perm[i]]; } let out; if (webglBackend.shouldExecuteOnCPU([x])) { const xTexData = webglBackend.texData.get(x.dataId); const values = xTexData.values; const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape); out = webglBackend.makeTensorInfo(newShape, x.dtype); const outData = webglBackend.texData.get(out.dataId); outData.values = outValues; } else { out = transposeImpl(x, perm, webglBackend); } return out; } const transposeConfig = { kernelName: Transpose, backendName: 'webgl', kernelFunc: transpose }; const MATMUL_SHARED_DIM_THRESHOLD = 1000; function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { const aRank = a.shape.length; const bRank = b.shape.length; const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1]; const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2]; const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2]; const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1]; const outerDimsA = a.shape.slice(0, -2); const outerDimsB = b.shape.slice(0, -2); const batchDimA = sizeFromShape(outerDimsA); const batchDimB = sizeFromShape(outerDimsB); const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2)); const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]); assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` + `${innerShapeB}) of Tensors with shapes ${a.shape} and ` + `${b.shape} and transposeA=${transposeA}` + ` and transposeB=${transposeB} must match.`); const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] : [batchDimA, outerShapeA, innerShapeA]; const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] : [batchDimB, innerShapeB, outerShapeB]; const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } }); const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } }); const intermediates = [a3d, b3d]; const batchDim = Math.max(batchDimA, batchDimB); const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const fusedActivation = activation != null ? mapActivationToShaderProgram(activation, true) : null; const containsFusedOps = hasBias || hasPreluActivationWeights || hasLeakyreluAlpha || fusedActivation != null; let out; if ((outerShapeA === 1 || outerShapeB === 1) && sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) { let aVec = a3d; let bVec = b3d; if (transposeA) { aVec = transpose({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(aVec); } if (transposeB) { bVec = transpose({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } }); intermediates.push(bVec); } const shouldReshapeA = outerShapeB !== 1; const shouldReshapeB = outerShapeB === 1; let aVec3d = aVec; if (shouldReshapeA) { aVec3d = reshape({ inputs: { x: aVec }, backend, attrs: { shape: [batchDim, sharedDim, 1] } }); intermediates.push(aVec3d); } const axis = outerShapeB === 1 ? 2 : 1; let bVec3d = bVec; if (shouldReshapeB) { bVec3d = reshape({ inputs: { x: bVec }, backend, attrs: { shape: [batchDim, 1, sharedDim] } }); intermediates.push(bVec3d); } const product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend }); out = sum({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } }); intermediates.push(product); } else { const dtype = upcastType(a.dtype, b.dtype); const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = [a3d, b3d]; if (bias != null) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } out = backend.runWebGLProgram(program, inputs, dtype); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: outShape } }); intermediates.push(out); for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return outReshaped; } function _fusedMatMul(args) { const { inputs, backend, attrs } = args; const { a, b, bias, preluActivationWeights } = inputs; const { transposeA, transposeB, activation, leakyreluAlpha } = attrs; return batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias, preluActivationWeights, leakyreluAlpha, activation }); } const _fusedMatMulConfig = { kernelName: _FusedMatMul, backendName: 'webgl', kernelFunc: _fusedMatMul, }; const ABS = `return abs(x);`; function abs(args) { const { inputs, backend } = args; const { x } = inputs; if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') { const xData = backend.texData.get(x.dataId); const outValues = simpleAbsImplCPU(xData.values); return backend.makeTensorInfo(x.shape, x.dtype, outValues); } let program; if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { program = new UnaryOpPackedProgram(x.shape, ABS); } else { program = new UnaryOpProgram(x.shape, ABS); } return backend.runWebGLProgram(program, [x], x.dtype); } const absConfig = { kernelName: Abs, backendName: 'webgl', kernelFunc: abs }; const ACOS = CHECK_NAN_SNIPPET$1 + ` if (abs(x) > 1.) { return NAN; } return acos(x); `; const acos = unaryKernelFunc({ opSnippet: ACOS }); const acosConfig = { kernelName: Acos, backendName: 'webgl', kernelFunc: acos, }; const ACOSH = CHECK_NAN_SNIPPET$1 + ` if (x < 1.0) return NAN; return log(x + sqrt(x * x - 1.0));`; const acosh = unaryKernelFunc({ opSnippet: ACOSH }); const acoshConfig = { kernelName: Acosh, backendName: 'webgl', kernelFunc: acosh, }; const ADD = 'return a + b;'; const addKernelFunc = binaryKernelFunc({ opSnippet: ADD, packedOpSnippet: ADD, supportsComplex: true, cpuKernelImpl: addImplCPU }); const addConfig = { kernelName: Add, backendName: 'webgl', kernelFunc: addKernelFunc }; class AddNProgram { constructor(outputShape, shapes) { this.outputShape = []; this.outputShape = outputShape; this.variableNames = shapes.map((_, i) => `T${i}`); const snippets = []; this.variableNames.forEach(variable => { snippets.push(`float v${variable} = get${variable}AtOutCoords();`); }); const operation = this.variableNames .map(variable => { return `v${variable}`; }) .join(' + '); this.userCode = ` void main() { ${snippets.join('\n ')} float result = ${operation}; setOutput(result); } `; } } class AddNPackedProgram { constructor(outputShape, shapes) { this.outputShape = []; this.packedInputs = true; this.packedOutput = true; this.outputShape = outputShape; this.variableNames = shapes.map((_, i) => `T${i}`); const snippets = []; this.variableNames.forEach(variable => { snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`); }); const operation = this.variableNames .map(variable => { return `v${variable}`; }) .join(' + '); this.userCode = ` void main() { ${snippets.join('\n ')} vec4 result = ${operation}; setOutput(result); } `; } } function addN(args) { const { inputs, backend } = args; const tensors = inputs; if (tensors.length === 1) { return identity({ inputs: { x: tensors[0] }, backend }); } if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { const midIndex = Math.floor(tensors.length / 2); const leftSide = addN({ inputs: tensors.slice(0, midIndex), backend }); const rightSide = addN({ inputs: tensors.slice(midIndex), backend }); return addN({ inputs: [leftSide, rightSide], backend }); } const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2)); const shapes = tensors.map(t => t.shape); const usePackedOp = env().getBool('WEBGL_PACK'); const program = usePackedOp ? new AddNPackedProgram(tensors[0].shape, shapes) : new AddNProgram(tensors[0].shape, shapes); return backend.runWebGLProgram(program, tensors, dtype); } const addNConfig = { kernelName: AddN, backendName: 'webgl', kernelFunc: addN }; function all(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const origAxes = parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = getInnerMostAxes(axes.length, xRank); } assertAxesAreInnerMostDims('all', axes, xRank); const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes); const inSize = sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const reduced = reduce(a2D, a2D.dtype, 'all', backend); let res; if (keepDims) { const newShape = expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } const allConfig = { kernelName: All, backendName: 'webgl', kernelFunc: all }; function any(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const origAxes = parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = getInnerMostAxes(axes.length, xRank); } assertAxesAreInnerMostDims('any', axes, xRank); const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes); const inSize = sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const reduced = reduce(a2D, a2D.dtype, 'any', backend); let res; if (keepDims) { const newShape = expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } const anyConfig = { kernelName: Any, backendName: 'webgl', kernelFunc: any }; class ArgMinMaxProgram { constructor(reduceInfo, op, firstPass) { this.variableNames = ['A']; const { windowSize, batchSize, outSize } = reduceInfo; if (!firstPass) { this.variableNames.push('bestIndicesA'); } this.outputShape = [batchSize, outSize]; const compOp = (op === 'max') ? '>' : '<'; const indexSnippet = firstPass ? 'inOffset + i;' : 'round(getBestIndicesA(batch, inOffset + i));'; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = outIdx * ${windowSize}; int bestIndex = inOffset; float bestValue = getA(batch, bestIndex); for (int i = 0; i < ${windowSize}; i++) { int inIdx = ${indexSnippet}; float candidate = getA(batch, inIdx); if (candidate ${compOp} bestValue) { bestValue = candidate; bestIndex = inIdx; } } setOutput(float(bestIndex)); } `; } } class ArgMinMaxPackedProgram { constructor(shape, windowSize, op, firstPass) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; assert$1(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() + op.slice(1)} supports only inputs with rank above 2.`); const inSize = shape[shape.length - 1]; const outSize = Math.ceil(inSize / windowSize); this.outputShape = shape.slice(0, -1); if (outSize > 1) { this.outputShape.push(outSize); } if (!firstPass) { this.variableNames.push('bestIndicesA'); } const outShape = this.outputShape; const rank = outShape.length; const dtype = getCoordsDataType(rank); const coords = getChannels('coords', rank); let sourceLocSetup; let sourceRank; if (outSize === 1) { sourceRank = rank + 1; const sourceLocDType = getCoordsDataType(sourceRank); sourceLocSetup = ` ${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0); ++${coords[rank - 1]}; ${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0); ++${coords[rank - 2]}; ${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0); --${coords[rank - 1]}; ${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0); --${coords[rank - 2]};`; } else { sourceRank = rank; sourceLocSetup = ` ${dtype} sourceLocR = coords; ++${coords[rank - 1]}; ${dtype} sourceLocG = coords; ++${coords[rank - 2]}; ${dtype} sourceLocA = coords; --${coords[rank - 1]}; ${dtype} sourceLocB = coords; --${coords[rank - 2]};`; } const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); const inChannel = '.' + channels[sourceRank - 1]; const intChannels = channels.map(x => 'int ' + x); const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); const compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; const fetchCandidateIdx = firstPass ? '' : ` inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}), getBestIndicesAChannel(${srcGCoords.join()}), getBestIndicesAChannel(${srcBCoords.join()}), getBestIndicesAChannel(${srcACoords.join()})));`; const fetchValue = `vec4( getAChannel(${srcRCoords.join()}), hasNextCol ? getAChannel(${srcGCoords.join()}) : 0., hasNextRow ? getAChannel(${srcBCoords.join()}) : 0., hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`; const getBestIndicesAChannelSnippet = firstPass ? '' : ` float getBestIndicesAChannel(${intChannels.join()}) { return getChannel(getBestIndicesA(${channels.join()}), vec2(${channels.slice(-2).join()})); }`; this.userCode = ` float getAChannel(${intChannels.join()}) { return getChannel(getA(${channels.join()}), vec2(${channels.slice(-2).join()})); } ${getBestIndicesAChannelSnippet} void main() { ${dtype} coords = getOutputCoords(); bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1}; bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1}; ${sourceLocSetup} ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel}, sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize}; ivec4 inIdx = srcIdx; vec4 bestIndex = vec4(inIdx); vec4 bestValue = ${fetchValue}; for (int i = 0; i < ${windowSize}; i++) { inIdx = srcIdx; ${fetchCandidateIdx} vec4 candidate = ${fetchValue}; bvec4 nan = isnan(candidate); bvec4 replace = bvec4( vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan))); bestValue = vec4(replace.x ? candidate.x : bestValue.x, replace.y ? candidate.y : bestValue.y, replace.z ? candidate.z : bestValue.z, replace.w ? candidate.w : bestValue.w); bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace)); srcIdx++; } setOutput(bestIndex); } `; } } function argReduce(backend, x, reduceType, bestIndicesA = null) { let batchSize = x.shape[0]; let inSize = x.shape[1]; if (bestIndicesA != null) { batchSize = bestIndicesA.shape[0]; inSize = bestIndicesA.shape[1]; } const windowSize = computeOptimalWindowSize(inSize); const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) }; const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null); const inputs = [x]; if (bestIndicesA != null) { inputs.push(bestIndicesA); } const output = backend.runWebGLProgram(program, inputs, 'int32'); if (output.shape[1] === 1) { return output; } const result = argReduce(backend, x, reduceType, output); backend.disposeIntermediateTensorInfo(output); return result; } function argReducePacked(backend, x, reduceType, bestIndicesA = null) { const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape; const inSize = inShape[inShape.length - 1]; const windowSize = computeOptimalWindowSize(inSize); const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null); const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA]; const output = backend.runWebGLProgram(program, inputs, 'int32'); if (output.shape.length === x.shape.length) { const result = argReducePacked(backend, x, reduceType, output); backend.disposeIntermediateTensorInfo(output); return result; } return output; } function argMinMaxReduce(backend, x, axis, reduceType) { const axes = [axis]; assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length); if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) { const intermediateTensorInfos = []; const xtexData = backend.texData.get(x.dataId); const xIsPacked = xtexData !== null && xtexData.isPacked; let xUnPacked = x; if (xIsPacked) { xUnPacked = backend.unpackTensor(x); intermediateTensorInfos.push(xUnPacked); } const [outShape, reduceShape] = computeOutAndReduceShapes(xUnPacked.shape, axes); const inSize = sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } }); intermediateTensorInfos.push(a2D); const reduced = argReduce(backend, a2D, reduceType); intermediateTensorInfos.push(reduced); const reshaped = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return reshaped; } return argReducePacked(backend, x, reduceType); } function argMax(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis } = attrs; let axes = parseAxisParam(axis, x.shape); const permutedAxes = getAxesPermutation(axes, x.shape.length); let $x = x; const intermediateTensorInfos = []; if (permutedAxes != null) { $x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); intermediateTensorInfos.push($x); axes = getInnerMostAxes(axes.length, $x.shape.length); } assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length); const out = argMinMaxReduce(backend, $x, axes[0], 'max'); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return out; } const argMaxConfig = { kernelName: ArgMax, backendName: 'webgl', kernelFunc: argMax }; function argMin(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis } = attrs; let axes = parseAxisParam(axis, x.shape); const permutedAxes = getAxesPermutation(axes, x.shape.length); let $x = x; const intermediateTensorInfos = []; if (permutedAxes != null) { $x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); intermediateTensorInfos.push($x); axes = getInnerMostAxes(axes.length, $x.shape.length); } assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length); const out = argMinMaxReduce(backend, $x, axes[0], 'min'); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return out; } const argMinConfig = { kernelName: ArgMin, backendName: 'webgl', kernelFunc: argMin }; const ASIN = CHECK_NAN_SNIPPET$1 + ` if (abs(x) > 1.) { return NAN; } return asin(x); `; const asin = unaryKernelFunc({ opSnippet: ASIN }); const asinConfig = { kernelName: Asin, backendName: 'webgl', kernelFunc: asin, }; const ASINH = CHECK_NAN_SNIPPET$1 + `return log(x + sqrt(x * x + 1.0));`; const asinh = unaryKernelFunc({ opSnippet: ASINH }); const asinhConfig = { kernelName: Asinh, backendName: 'webgl', kernelFunc: asinh, }; const ATAN = CHECK_NAN_SNIPPET$1 + ` return atan(x); `; const atan = unaryKernelFunc({ opSnippet: ATAN }); const atanConfig = { kernelName: Atan, backendName: 'webgl', kernelFunc: atan, }; const ATAN2 = CHECK_NAN_SNIPPET + ` return atan(a, b); `; const ATAN2_PACKED = ` vec4 result = atan(a, b); bvec4 isNaNA = isnan(a); bvec4 isNaNB = isnan(b); bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const atan2 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED }); const atan2Config = { kernelName: Atan2, backendName: 'webgl', kernelFunc: atan2, }; const ATANH = CHECK_NAN_SNIPPET$1 + ` if ((x < -1.0) || (x > 1.0)) return NAN; return (log(1.0 + x) - log(1.0 - x)) / 2.0;`; const atanh = unaryKernelFunc({ opSnippet: ATANH }); const atanhConfig = { kernelName: Atanh, backendName: 'webgl', kernelFunc: atanh, }; class Pool2DProgram { constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) { this.variableNames = ['x']; if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; this.outputShape = convInfo.outShape; const isAvgPool = poolType === 'avg'; const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`; const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`; let initializationValue = '0.0'; if (!isAvgPool) { initializationValue = '-1.0 / 1e-20'; } if (computePositions) { const compareOp = '>='; this.userCode = ` const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d = coords[3]; ivec2 xRCCorner = coords.yz * strides - pads; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; float minMaxValue = 0.0; float minMaxValueFound = 0.0; int minMaxPosition = 0; float avgValue = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { int xC = xCCorner + wC; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float value = getX(batch, xR, xC, d); float currMinMaxValue = mix( value, minMaxValue, minMaxValueFound); if (value ${compareOp} currMinMaxValue) { minMaxValue = value; minMaxValueFound = 1.0; minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr : flattenPositionStr) : `wR * ${effectiveFilterWidth} + wC`}; } } } setOutput(float(minMaxPosition)); } `; return; } const compareOp = 'max'; let returnValue = `${poolType}(${poolType}(${poolType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { returnValue = `avgValue / max(count, 1.0)`; } const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; const filterWidthVec4Remainder = filterWidth % 4; const updateSnippet = ` if (${isAvgPool}) { avgValue += dot(values, ones); } else { minMaxValue = ${compareOp}(values, minMaxValue); } `; this.userCode = ` const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); const float initializationValue = ${initializationValue}; const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float count = 0.0; float getValue(int batch, int xR, int xC, int d) { if (xC < 0 || xC >= ${convInfo.inWidth}) { return initializationValue; } count += 1.0; return getX(batch, xR, xC, d); } void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d = coords[3]; ivec2 xRCCorner = coords.yz * strides - pads; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; vec4 minMaxValue = vec4(${initializationValue}); float avgValue = 0.0; count = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) { int xC = xCCorner + wC * ${dilationWidth}; vec4 values = vec4( getValue(batch, xR, xC, d), getValue(batch, xR, xC + ${dilationWidth}, d), getValue(batch, xR, xC + 2 * ${dilationWidth}, d), getValue(batch, xR, xC + 3 * ${dilationWidth}, d) ); ${updateSnippet} } int xC = xCCorner + ${filterWidthNearestVec4}; if (${filterWidthVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, xR, xC, d), initializationValue, initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, xR, xC, d), getValue(batch, xR, xC + ${dilationWidth}, d), initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, xR, xC, d), getValue(batch, xR, xC + ${dilationWidth}, d), getValue(batch, xR, xC + 2 * ${dilationWidth}, d), initializationValue ); ${updateSnippet} } } setOutput(${returnValue}); } `; } } class Pool3DProgram { constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) { this.variableNames = ['x']; if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } const filterWidth = convInfo.filterWidth; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterDepth = convInfo.effectiveFilterDepth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padFront = convInfo.padInfo.front; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; this.outputShape = convInfo.outShape; const isAvgPool = poolType === 'avg'; let initializationValue = '0.0'; if (!isAvgPool) { initializationValue = '-1.0 / 1e-20'; } if (computePositions) { const compareOp = '>='; this.userCode = ` const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth}); const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads; int xDCorner = xCorner.x; int xRCorner = xCorner.y; int xCCorner = xCorner.z; float minMaxValue = 0.0; float minMaxValueFound = 0.0; int minMaxPosition = 0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { int xD = xDCorner + wD; if (xD < 0 || xD >= ${convInfo.inDepth}) { continue; } for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { int xC = xCCorner + wC; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float value = getX(batch, xD, xR, xC, ch); float currMinMaxValue = mix( value, minMaxValue, minMaxValueFound); if (value ${compareOp} currMinMaxValue) { minMaxValue = value; minMaxValueFound = 1.0; minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? `(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` : `((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) : `wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} + wR * ${effectiveFilterWidth} + wC`}; } } } } setOutput(float(minMaxPosition)); } `; return; } const compareOp = 'max'; let returnValue = `${poolType}(${poolType}(${poolType}(` + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; if (poolType === 'avg') { returnValue = `avgValue / max(count, 1.0)`; } const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; const filterWidthVec4Remainder = filterWidth % 4; const updateSnippet = ` if (${isAvgPool}) { avgValue += dot(values, ones); } else { minMaxValue = ${compareOp}(values, minMaxValue); } `; this.userCode = ` const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth}); const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); const float initializationValue = ${initializationValue}; const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0); float count = 0.0; float getValue(int batch, int xD, int xR, int xC, int ch) { if (xC < 0 || xC >= ${convInfo.inWidth}) { return initializationValue; } count += 1.0; return getX(batch, xD, xR, xC, ch); } void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads; int xDCorner = xCorner.x; int xRCorner = xCorner.y; int xCCorner = xCorner.z; vec4 minMaxValue = vec4(${initializationValue}); float avgValue = 0.0; count = 0.0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { int xD = xDCorner + wD; if (xD < 0 || xD >= ${convInfo.inDepth}) { continue; } for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) { int xC = xCCorner + wC * ${dilationWidth}; vec4 values = vec4( getValue(batch, xD, xR, xC, ch), getValue(batch, xD, xR, xC + ${dilationWidth}, ch), getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch), getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch) ); ${updateSnippet} } int xC = xCCorner + ${filterWidthNearestVec4}; if (${filterWidthVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, xD, xR, xC, ch), initializationValue, initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, xD, xR, xC, ch), getValue(batch, xD, xR, xC + ${dilationWidth}, ch), initializationValue, initializationValue ); ${updateSnippet} } else if (${filterWidthVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, xD, xR, xC, ch), getValue(batch, xD, xR, xC + ${dilationWidth}, ch), getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch), initializationValue ); ${updateSnippet} } } } setOutput(${returnValue}); } `; } } function avgPool(args) { const { inputs, backend, attrs } = args; const { x } = inputs; assertNotComplex(x, 'avgPool'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = 1; assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x }, backend }); } const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false); return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } const avgPoolConfig = { kernelName: AvgPool, backendName: 'webgl', kernelFunc: avgPool }; function avgPool3D(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs; const dilations = [1, 1, 1]; const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); const avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false); return backend.runWebGLProgram(avgPoolProgram, [x], 'float32'); } const avgPool3DConfig = { kernelName: AvgPool3D, backendName: 'webgl', kernelFunc: avgPool3D }; class AvgPool2DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy']; this.outputShape = convInfo.inShape; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const avgMultiplier = 1 / (filterHeight * filterWidth); this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); const float avgMultiplier = float(${avgMultiplier}); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 dyRCCorner = coords.yz - pads; int dyRCorner = dyRCCorner.x; int dyCCorner = dyRCCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC+= ${dilationWidth}) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(b, idyR, idyC, d); dotProd += dyValue * avgMultiplier; } } setOutput(dotProd); } `; } } class AvgPool3DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy']; this.outputShape = convInfo.inShape; const filterDepth = convInfo.filterDepth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterDepth = convInfo.effectiveFilterDepth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); this.userCode = ` const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); const float avgMultiplier = float(${avgMultiplier}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads; int dyDCorner = dyCorner.x; int dyRCorner = dyCorner.y; int dyCCorner = dyCorner.z; float dotProd = 0.0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { float dyD = float(dyDCorner + wD) / ${strideDepth}.0; if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) { continue; } int idyD = int(dyD); for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(batch, idyD, idyR, idyC, ch); dotProd += dyValue * avgMultiplier; } } } setOutput(dotProd); } `; } } function avgPool3DGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const x = input; const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = [1, 1, 1]; const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); const avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo); return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); } const avgPool3DGradConfig$1 = { kernelName: AvgPool3DGrad, backendName: 'webgl', kernelFunc: avgPool3DGrad }; function avgPoolGrad$1(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const x = input; assertNotComplex([dy, input], 'avgPoolGrad'); const { filterSize, strides, pad } = attrs; const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad); const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype); } const avgPoolGradConfig$1 = { kernelName: AvgPoolGrad, backendName: 'webgl', kernelFunc: avgPoolGrad$1 }; function batchMatMul(args) { const { inputs, backend, attrs } = args; const { a, b } = inputs; const { transposeA, transposeB } = attrs; return batchMatMulImpl({ a, b, transposeA, transposeB, backend }); } const batchMatMulConfig = { kernelName: BatchMatMul, backendName: 'webgl', kernelFunc: batchMatMul, }; class BatchNormProgram { constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { this.outputShape = []; this.variableNames = ['x', 'mean', 'variance']; assertAndGetBroadcastShape(xShape, meanShape); assertAndGetBroadcastShape(xShape, varianceShape); let offsetSnippet = '0.0'; if (offsetShape != null) { assertAndGetBroadcastShape(xShape, offsetShape); this.variableNames.push('offset'); offsetSnippet = 'getOffsetAtOutCoords()'; } let scaleSnippet = '1.0'; if (scaleShape != null) { assertAndGetBroadcastShape(xShape, scaleShape); this.variableNames.push('scale'); scaleSnippet = 'getScaleAtOutCoords()'; } this.outputShape = xShape; this.userCode = ` void main() { float x = getXAtOutCoords(); float mean = getMeanAtOutCoords(); float variance = getVarianceAtOutCoords(); float offset = ${offsetSnippet}; float scale = ${scaleSnippet}; float inv = scale * inversesqrt(variance + float(${varianceEpsilon})); setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1))); } `; } } class BatchNormPackedProgram { constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { this.packedInputs = true; this.packedOutput = true; this.variableNames = ['x', 'mean', 'variance']; assertAndGetBroadcastShape(xShape, meanShape); assertAndGetBroadcastShape(xShape, varianceShape); let offsetSnippet = 'vec4(0.0)'; if (offsetShape != null) { assertAndGetBroadcastShape(xShape, offsetShape); this.variableNames.push('offset'); offsetSnippet = 'getOffsetAtOutCoords()'; } let scaleSnippet = 'vec4(1.0)'; if (scaleShape != null) { assertAndGetBroadcastShape(xShape, scaleShape); this.variableNames.push('scale'); scaleSnippet = 'getScaleAtOutCoords()'; } this.outputShape = xShape; this.userCode = ` void main() { vec4 offset = ${offsetSnippet}; vec4 scale = ${scaleSnippet}; vec4 x = getXAtOutCoords(); vec4 mean = getMeanAtOutCoords(); vec4 variance = getVarianceAtOutCoords(); vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon})); setOutput((x - mean) * inv + offset); } `; } } const batchNorm = ({ inputs, backend, attrs }) => { const { x, mean, variance, offset, scale } = inputs; assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' + 'equal ranks.'); assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' + 'equal ranks.'); assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' + 'equal ranks.'); let { varianceEpsilon } = attrs; if (varianceEpsilon == null) { varianceEpsilon = 0.001; } const finalInputs = [x, mean, variance]; let offsetShape = null; if (offset != null) { offsetShape = offset.shape; finalInputs.push(offset); } let scaleShape = null; if (scale != null) { scaleShape = scale.shape; finalInputs.push(scale); } const program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) : new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype); return output; }; const batchNormConfig = { kernelName: FusedBatchNorm, backendName: 'webgl', kernelFunc: batchNorm, }; class SliceProgram { constructor(destSize) { this.variableNames = ['source']; this.outputShape = destSize; this.rank = destSize.length; const dtype = getCoordsDataType(this.rank); this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }]; const sourceCoords = getCoords$1(this.rank); let body; const coordSum = destSize.map((_, i) => { return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`; }); body = ` ${dtype} sourceLoc; ${dtype} coords = getOutputCoords(); ${coordSum.join('\n')} `; this.userCode = ` void main() { ${body} setOutput(getSource(${sourceCoords})); } `; } } const coords = ['x', 'y', 'z', 'w', 'u', 'v']; function getCoords$1(rank) { if (rank === 1) { return 'sourceLoc'; } else if (rank <= 6) { return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(','); } else { throw Error(`Slicing for rank ${rank} is not yet supported`); } } class SlicePackedProgram { constructor(destSize) { this.variableNames = ['source']; this.packedInputs = true; this.packedOutput = true; this.outputShape = destSize; this.rank = destSize.length; this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }]; const dtype = getCoordsDataType(this.rank); const coords = getChannels('coords', this.rank); const sourceLoc = getChannels('sourceLoc', this.rank); const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`; const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`; const upperRow = ` result.x = ${getChannel}; if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) { ++${sourceLoc[this.rank - 1]}; result.y = ${getChannel}; --${sourceLoc[this.rank - 1]}; } `; const lowerRow = this.rank === 1 ? '' : ` --${coords[this.rank - 1]}; if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) { ++${sourceLoc[this.rank - 2]}; result.z = ${getChannel}; if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) { ++${sourceLoc[this.rank - 1]}; result.w = ${getChannel}; } } `; const sourceLocSetup = this.rank <= 4 ? `sourceLoc = coords + ${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` : destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`) .join('\n'); this.userCode = ` void main() { ${dtype} coords = getOutputCoords(); ${dtype} sourceLoc; ${sourceLocSetup} vec4 result = vec4(0.); ${upperRow} ${lowerRow} setOutput(result); } `; } } function shallowSlice(x, begin, size, backend) { const xTexData = backend.texData.get(x.dataId); const t = backend.makeTensorInfo(size, x.dtype); const newTexData = backend.texData.get(t.dataId); Object.assign(newTexData, xTexData); newTexData.refCount = 1; newTexData.shape = size; newTexData.dtype = x.dtype; let flatOffset = computeFlatOffset(begin, computeStrides(x.shape)); if (xTexData.slice) { flatOffset += xTexData.slice.flatOffset; } newTexData.slice = { flatOffset, origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId }; const refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1; backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1); return t; } function slice(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { begin, size } = attrs; const [$begin, $size] = parseSliceParams(x, begin, size); assertParamsValid(x, $begin, $size); if (sizeFromShape($size) === 0) { return backend.makeTensorInfo($size, x.dtype, []); } if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') { const xTexData = backend.texData.get(x.dataId); const outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype); return backend.makeTensorInfo($size, x.dtype, outValues); } const { isPacked } = backend.texData.get(x.dataId); const isContinous = isSliceContinous(x.shape, $begin, $size); if (isPacked || !isContinous) { const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram($size) : new SliceProgram($size); const customValues = [$begin]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); } backend.uploadToGPU(x.dataId); return shallowSlice(x, $begin, $size, backend); } const sliceConfig = { kernelName: Slice, backendName: 'webgl', kernelFunc: slice }; const batchToSpaceND = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { blockShape, crops } = attrs; assert$1(x.shape.length <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' + 'implemented yet'); const prod = blockShape.reduce((a, b) => a * b); const reshaped = getReshaped(x.shape, blockShape, prod); const permuted = getPermuted(reshaped.length, blockShape.length); const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod); const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length); const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length); const toDispose = []; const reshapedIntermediate = reshape({ inputs: { x }, backend, attrs: { shape: reshaped } }); const transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend, attrs: { perm: permuted } }); const reshapedIntermediate2 = reshape({ inputs: { x: transposedIntermediate }, backend, attrs: { shape: reshapedPermuted } }); const sliced = slice({ inputs: { x: reshapedIntermediate2 }, backend, attrs: { begin: sliceBeginCoords, size: sliceSize } }); toDispose.push(reshapedIntermediate); toDispose.push(transposedIntermediate); toDispose.push(reshapedIntermediate2); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return sliced; }; const batchToSpaceNDConfig = { kernelName: BatchToSpaceND, backendName: 'webgl', kernelFunc: batchToSpaceND }; function bincount(args) { const { inputs, backend, attrs } = args; const { x, weights } = inputs; const { size } = attrs; const xVals = backend.readSync(x.dataId); const weightsVals = backend.readSync(weights.dataId); const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size); return backend.makeTensorInfo([size], weights.dtype, outVals); } const bincountConfig = { kernelName: Bincount, backendName: 'webgl', kernelFunc: bincount }; const BITWISEAND = ` int r = int(a.r) & int(b.r); int g = int(a.g) & int(b.g); int rb = int(a.b) & int(b.b); int ra = int(a.a) & int(b.a); return vec4(r, g, rb, ra); `; const BITWISEAND_UNPACKED = ` return float(int(a.r) & int(b.r)); `; function bitwiseAnd(args) { const { inputs, backend } = args; const { a, b } = inputs; const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); const versionNumber = env().getNumber('WEBGL_VERSION'); if ((backend.shouldExecuteOnCPU([a, b])) || versionNumber === 1) { const aVals = backend.texData.get(a.dataId).values; const bVals = backend.texData.get(b.dataId).values; const [outValues, outShape] = bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype); const out = backend.makeTensorInfo(outShape, a.dtype); const outData = backend.texData.get(out.dataId); outData.values = outValues; return out; } let program; if (shouldUsePackedProgram) { program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false); } else { program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape); } return backend.runWebGLProgram(program, [a, b], a.dtype); } const bitwiseAndConfig = { kernelName: BitwiseAnd, backendName: 'webgl', kernelFunc: bitwiseAnd }; function broadcastArgs(args) { const { inputs, backend } = args; const { s0, s1 } = inputs; const s0Vals = backend.readSync(s0.dataId); const s1Vals = backend.readSync(s1.dataId); const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals)); return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape)); } const broadcastArgsConfig = { kernelName: BroadcastArgs, backendName: 'webgl', kernelFunc: broadcastArgs }; const NOT_EQUAL = `return float(a != b);`; const notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' }); const notEqualConfig = { kernelName: NotEqual, backendName: 'webgl', kernelFunc: notEqual, }; function real(args) { const { inputs, backend } = args; const { input } = inputs; const inputData = backend.texData.get(input.dataId); return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend }); } const realConfig = { kernelName: Real, backendName: 'webgl', kernelFunc: real }; const TO_INT = `return float(int(x));`; function int(input, backend) { const program = new UnaryOpProgram(input.shape, TO_INT); const output = backend.runWebGLProgram(program, [input], 'int32'); return { dataId: output.dataId, shape: output.shape, dtype: output.dtype }; } function cast$1(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { dtype } = attrs; if (dtype === 'complex64') { if (x.dtype === 'complex64') { return identity({ inputs: { x }, backend }); } const zerosTensor = zeros(x.shape); const floatX = cast$1({ inputs: { x }, backend, attrs: { dtype: 'float32' } }); const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend }); zerosTensor.dispose(); backend.disposeIntermediateTensorInfo(floatX); return result; } if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const result = cast$1({ inputs: { x: realPart }, backend, attrs: { dtype } }); backend.disposeIntermediateTensorInfo(realPart); return result; } if (!hasEncodingLoss(x.dtype, dtype)) { const result = identity({ inputs: { x }, backend }); return { dataId: result.dataId, shape: result.shape, dtype }; } if (backend.shouldExecuteOnCPU([x])) { const values = backend.texData.get(x.dataId).values; const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype); return backend.makeTensorInfo(resultShape, resultType, resultData); } if (dtype === 'int32') { return int(x, backend); } if (dtype === 'bool') { const zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1)); const binaryInputs = { a: x, b: zerosTensorInfo }; const result = notEqual({ inputs: binaryInputs, backend }); backend.disposeIntermediateTensorInfo(zerosTensorInfo); return result; } throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`); } const castConfig = { kernelName: Cast, backendName: 'webgl', kernelFunc: cast$1 }; const CEIL = `return ceil(x);`; const ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU }); const ceilConfig = { kernelName: Ceil, backendName: 'webgl', kernelFunc: ceil }; class ClipProgram { constructor(aShape) { this.variableNames = ['A']; this.customUniforms = [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' } ]; this.outputShape = aShape; this.userCode = ` void main() { float value = getAAtOutCoords(); if (isnan(value)) { setOutput(value); return; } setOutput(clamp(value, minVal, maxVal)); } `; } } class ClipPackedProgram { constructor(aShape) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'minVal', type: 'float' }, { name: 'maxVal', type: 'float' } ]; this.outputShape = aShape; this.userCode = ` void main() { vec4 value = getAAtOutCoords(); if (any(isnan(value))) { setOutput(value); return; } setOutput(clamp(value, vec4(minVal), vec4(maxVal))); } `; } } function clipByValue(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { clipValueMin, clipValueMax } = attrs; let program; if (env().getBool('WEBGL_PACK_CLIP')) { program = new ClipPackedProgram(x.shape); } else { program = new ClipProgram(x.shape); } const customValues = [[clipValueMin], [clipValueMax]]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); } const clipByValueConfig = { kernelName: ClipByValue, backendName: 'webgl', kernelFunc: clipByValue }; class ComplexAbsProgram { constructor(shape) { this.variableNames = ['real', 'imag']; this.outputShape = shape; this.userCode = ` void main() { float re = abs(getRealAtOutCoords()); float im = abs(getImagAtOutCoords()); float mx = max(re, im); setOutput( mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx)) ); } `; } } function makeComplexComponentTensorInfo(complexTensor, complexPart) { return { dataId: complexPart.dataId, dtype: complexPart.dtype, shape: complexTensor.shape }; } function complexAbs(args) { const { inputs, backend } = args; const { x } = inputs; const xData = backend.texData.get(x.dataId); const program = new ComplexAbsProgram(x.shape); const programInputs = [ makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real), makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag), ]; return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype); } const complexAbsConfig = { kernelName: ComplexAbs, backendName: 'webgl', kernelFunc: complexAbs }; class ConcatProgram { constructor(shapes) { this.outputShape = []; this.outputShape = computeOutShape$1(shapes, 1 ); this.variableNames = shapes.map((_, i) => `T${i}`); const offsets = new Array(shapes.length - 1); offsets[0] = shapes[0][1]; for (let i = 1; i < offsets.length; i++) { offsets[i] = offsets[i - 1] + shapes[i][1]; } const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`]; for (let i = 1; i < offsets.length; i++) { const shift = offsets[i - 1]; snippets.push(`else if (yC < ${offsets[i]}) ` + `setOutput(getT${i}(yR, yC-${shift}));`); } const lastIndex = offsets.length; const lastShift = offsets[offsets.length - 1]; snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`); this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int yR = coords.x; int yC = coords.y; ${snippets.join('\n ')} } `; } } class ConcatPackedProgram { constructor(shapes, axis) { this.packedInputs = true; this.packedOutput = true; this.outputShape = []; this.outputShape = computeOutShape$1(shapes, axis); const shape = this.outputShape; const rank = shape.length; const dtype = getCoordsDataType(rank); const coords = getChannels('coords', rank); const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank); this.variableNames = shapes.map((_, i) => `T${i}`); const offsets = new Array(shapes.length - 1); offsets[0] = shapes[0][axis]; for (let i = 1; i < offsets.length; i++) { offsets[i] = offsets[i - 1] + shapes[i][axis]; } const channel = channels[axis]; const lastChannels = channels.slice(-2); const allChannels = channels.join(); let getValueSnippet = `if (${channel} < ${offsets[0]}) { return getChannel( getT0(${allChannels}), vec2(${lastChannels.join()})); }`; for (let i = 1; i < offsets.length; i++) { const shift = offsets[i - 1]; getValueSnippet += ` if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) { return getChannel( getT${i}(${shiftedChannels(channels, channel, shift)}), vec2(${shiftedChannels(lastChannels, channel, shift)})); }`; } const lastIndex = offsets.length; const shift = offsets[offsets.length - 1]; getValueSnippet += ` return getChannel( getT${lastIndex}(${shiftedChannels(channels, channel, shift)}), vec2(${shiftedChannels(lastChannels, channel, shift)}));`; this.userCode = ` float getValue(${channels.map(x => 'int ' + x)}) { ${getValueSnippet} } void main() { ${dtype} coords = getOutputCoords(); vec4 result = vec4(getValue(${coords}), 0., 0., 0.); ${coords[rank - 1]} = ${coords[rank - 1]} + 1; if (${coords[rank - 1]} < ${shape[rank - 1]}) { result.g = getValue(${coords}); } ${coords[rank - 2]} = ${coords[rank - 2]} + 1; if (${coords[rank - 2]} < ${shape[rank - 2]}) { result.a = getValue(${coords}); } ${coords[rank - 1]} = ${coords[rank - 1]} - 1; if (${coords[rank - 2]} < ${shape[rank - 2]} && ${coords[rank - 1]} < ${shape[rank - 1]}) { result.b = getValue(${coords}); } setOutput(result); } `; } } function shiftedChannels(channels, channel, shift) { const channelIdx = channels.indexOf(channel); const res = channels.map((c, idx) => { if (idx === channelIdx) { return `${c} - ${shift}`; } else { return c; } }); return res.join(); } function imag(args) { const { inputs, backend } = args; const { input } = inputs; const inputData = backend.texData.get(input.dataId); return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend }); } const imagConfig = { kernelName: Imag, backendName: 'webgl', kernelFunc: imag }; function concatImpl(inputs, axis, backend) { const dtype = inputs[0].dtype; if (dtype === 'complex64') { const reals = inputs.map((t) => real({ inputs: { input: t }, backend })); const imags = inputs.map((t) => imag({ inputs: { input: t }, backend })); const realConcated = concatImpl(reals, axis, backend); const imagConcated = concatImpl(imags, axis, backend); const result = complex({ inputs: { real: realConcated, imag: imagConcated }, backend }); reals.forEach(r => backend.disposeIntermediateTensorInfo(r)); imags.forEach(i => backend.disposeIntermediateTensorInfo(i)); backend.disposeIntermediateTensorInfo(realConcated); backend.disposeIntermediateTensorInfo(imagConcated); return result; } let runOnCpu = backend.shouldExecuteOnCPU(inputs); if (dtype === 'string') { runOnCpu = true; } if (runOnCpu) { const tensors2D = inputs.map(t => { const innerSize = sizeFromShape(t.shape.slice(axis)); const shape = [-1, innerSize]; return reshape({ inputs: { x: t }, backend, attrs: { shape } }); }); const inputsValShapes = tensors2D.map(t => { return { vals: backend.readSync(t.dataId), shape: t.shape }; }); const outShape = computeOutShape$1(tensors2D.map(t => t.shape), 1 ); const simplyConcat = tensors2D[0].shape[0] === 1; const outVals = concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat); const finalOutShape = computeOutShape$1(inputs.map(t => t.shape), axis); const outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals); tensors2D.forEach(t => backend.disposeIntermediateTensorInfo(t)); return outInfo; } const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0); const shouldPack = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && $inputs[0].shape.length > 1; if ($inputs.length === 1) { const program = shouldPack ? new UnaryOpProgram(inputs[0].shape, CLONE) : new UnaryOpPackedProgram(inputs[0].shape, CLONE); return backend.runWebGLProgram(program, inputs, dtype); } const maxTexturesInShader = env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER'); if ($inputs.length > maxTexturesInShader) { const reducedInputs = []; for (let i = 0; i < $inputs.length; i += maxTexturesInShader) { const subArray = $inputs.slice(i, i + maxTexturesInShader); reducedInputs.push(concatImpl(subArray, axis, backend)); } const result = concatImpl(reducedInputs, axis, backend); for (const i of reducedInputs) { backend.disposeIntermediateTensorInfo(i); } return result; } if (shouldPack) { const program = new ConcatPackedProgram($inputs.map(t => t.shape), axis); return backend.runWebGLProgram(program, $inputs, dtype); } const { tensors2D, outShape } = computeTensors2D($inputs, axis, backend); const program = new ConcatProgram(tensors2D.map(t => t.shape)); const result = backend.runWebGLProgram(program, tensors2D, dtype); tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r)); const reshapedResult = reshape({ inputs: { x: result }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(result); return reshapedResult; } function computeTensors2D(inputs, axis, backend) { const outShape = computeOutShape$1(inputs.map(t => t.shape), axis); const tensors2D = inputs.map(x => reshape({ inputs: { x }, attrs: { shape: [-1, sizeFromShape(x.shape.slice(axis))] }, backend })); return { tensors2D, outShape }; } function concat(args) { const { inputs, backend, attrs } = args; const { axis } = attrs; const $axis = parseAxisParam(axis, inputs[0].shape)[0]; const shapes = inputs.map(t => t.shape); assertParamsConsistent(shapes, $axis); const outShape = computeOutShape$1(inputs.map(t => t.shape), $axis); if (sizeFromShape(outShape) === 0) { return backend.makeTensorInfo(outShape, inputs[0].dtype, []); } const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0); if ($inputs.length === 1) { return identity({ inputs: { x: $inputs[0] }, backend }); } return concatImpl($inputs, $axis, backend); } const concatConfig = { kernelName: Concat, backendName: 'webgl', kernelFunc: concat }; class Conv2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false, hasLeakyreluAlpha = false) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; const inputDepthVec4Remainder = convInfo.inChannels % 4; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivationWeights) { activationSnippet = `float activation(float a) { float b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyreluAlpha) { activationSnippet = `float activation(float a) { float b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = ` float activation(float x) { ${activation} } `; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivationWeights) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyreluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d2 = coords[${channelDim}]; ivec2 xRCCorner = ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * ${dilationHeight}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * ${dilationWidth}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) { vec4 wValues = vec4( getW(wR, wC, d1, d2), getW(wR, wC, d1 + 1, d2), getW(wR, wC, d1 + 2, d2), getW(wR, wC, d1 + 3, d2) ); if (${isChannelsLast}) { vec4 xValues = vec4( getX(batch, xR, xC, d1), getX(batch, xR, xC, d1 + 1), getX(batch, xR, xC, d1 + 2), getX(batch, xR, xC, d1 + 3) ); dotProd += dot(xValues, wValues); } else { vec4 xValues = vec4( getX(batch, d1, xR, xC), getX(batch, d1 + 1, xR, xC), getX(batch, d1 + 2, xR, xC), getX(batch, d1 + 3, xR, xC) ); dotProd += dot(xValues, wValues); } } if (${inputDepthVec4Remainder === 1}) { if (${isChannelsLast}) { dotProd += getX(batch, xR, xC, ${inputDepthNearestVec4}) * getW(wR, wC, ${inputDepthNearestVec4}, d2); } else { dotProd += getX(batch, ${inputDepthNearestVec4}, xR, xC) * getW(wR, wC, ${inputDepthNearestVec4}, d2); } } else if (${inputDepthVec4Remainder === 2}) { vec2 wValues = vec2( getW(wR, wC, ${inputDepthNearestVec4}, d2), getW(wR, wC, ${inputDepthNearestVec4} + 1, d2) ); if (${isChannelsLast}) { vec2 xValues = vec2( getX(batch, xR, xC, ${inputDepthNearestVec4}), getX(batch, xR, xC, ${inputDepthNearestVec4} + 1) ); dotProd += dot(xValues, wValues); } else { vec2 xValues = vec2( getX(batch, ${inputDepthNearestVec4}, xR, xC), getX(batch, ${inputDepthNearestVec4} + 1, xR, xC) ); dotProd += dot(xValues, wValues); } } else if (${inputDepthVec4Remainder === 3}) { vec3 wValues = vec3( getW(wR, wC, ${inputDepthNearestVec4}, d2), getW(wR, wC, ${inputDepthNearestVec4} + 1, d2), getW(wR, wC, ${inputDepthNearestVec4} + 2, d2) ); if (${isChannelsLast}) { vec3 xValues = vec3( getX(batch, xR, xC, ${inputDepthNearestVec4}), getX(batch, xR, xC, ${inputDepthNearestVec4} + 1), getX(batch, xR, xC, ${inputDepthNearestVec4} + 2) ); dotProd += dot(xValues, wValues); } else { vec3 xValues = vec3( getX(batch, ${inputDepthNearestVec4}, xR, xC), getX(batch, ${inputDepthNearestVec4} + 1, xR, xC), getX(batch, ${inputDepthNearestVec4} + 2, xR, xC) ); dotProd += dot(xValues, wValues); } } } } float result = dotProd; ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } class Conv3DProgram { constructor(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; const padFront = convInfo.padInfo.front; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const filterDepth = convInfo.filterDepth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; const inputDepthVec4Remainder = convInfo.inChannels % 4; this.userCode = ` const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth}); const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int d2 = coords.u; ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads; int xFCorner = xFRCCorner.x; int xRCorner = xFRCCorner.y; int xCCorner = xFRCCorner.z; float dotProd = 0.0; for (int wF = 0; wF < ${filterDepth}; wF++) { int xF = xFCorner + wF * ${dilationDepth}; if (xF < 0 || xF >= ${convInfo.inDepth}) { continue; } for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * ${dilationHeight}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * ${dilationWidth}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) { vec4 xValues = vec4( getX(batch, xF, xR, xC, d1), getX(batch, xF, xR, xC, d1 + 1), getX(batch, xF, xR, xC, d1 + 2), getX(batch, xF, xR, xC, d1 + 3) ); vec4 wValues = vec4( getW(wF, wR, wC, d1, d2), getW(wF, wR, wC, d1 + 1, d2), getW(wF, wR, wC, d1 + 2, d2), getW(wF, wR, wC, d1 + 3, d2) ); dotProd += dot(xValues, wValues); } if (${inputDepthVec4Remainder === 1}) { dotProd += getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) * getW(wF, wR, wC, ${inputDepthNearestVec4}, d2); } else if (${inputDepthVec4Remainder === 2}) { vec2 xValues = vec2( getX(batch, xF, xR, xC, ${inputDepthNearestVec4}), getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1) ); vec2 wValues = vec2( getW(wF, wR, wC, ${inputDepthNearestVec4}, d2), getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2) ); dotProd += dot(xValues, wValues); } else if (${inputDepthVec4Remainder === 3}) { vec3 xValues = vec3( getX(batch, xF, xR, xC, ${inputDepthNearestVec4}), getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1), getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2) ); vec3 wValues = vec3( getW(wF, wR, wC, ${inputDepthNearestVec4}, d2), getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2), getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2) ); dotProd += dot(xValues, wValues); } } } } setOutput(dotProd); } `; } } class Conv2DPackedProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const padLeft = convInfo.padInfo.left; const strideWidth = convInfo.strideWidth; const dilationWidth = convInfo.dilationWidth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const texelsAcross = filterWidth; let mainLoop = ` int xR; int xC; int xCOffset; vec4 wTexel; vec4 previous; vec4 final;`; for (let c = 0; c < filterWidth; c++) { mainLoop += ` vec4 xTexelC${c * 2}; int xTexelC${c * 2}Ready; vec4 xTexelC${c * 2 + 1}; int xTexelC${c * 2 + 1}Ready; vec4 xC${c};`; } mainLoop += ` for (int r = 0; r < ${filterHeight}; r++) { for (int d1 = 0; d1 < ${convInfo.inChannels}; d1 += 2) { `; for (let c = 0; c < filterWidth; c++) { mainLoop += ` xTexelC${c * 2} = vec4(0.0); xTexelC${c * 2}Ready = 0; xTexelC${c * 2 + 1} = vec4(0.0); xTexelC${c * 2 + 1}Ready = 0; xC${c} = vec4(0.0);`; } mainLoop += ` xR = xRCorner + r * dilations[0]; if (xR >=0 && xR < inDims[0]) { `; for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) { const colIndex = texelC * 2; mainLoop += ` xC = xCCorner + ${colIndex * dilationWidth}; `; if (strideWidth === 1) { if (colIndex < filterWidth) { if (padLeft % 2 === 1) { mainLoop += ` xCOffset = xC + 1; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } `; if (dilationWidth === 1 && colIndex > 0) { mainLoop += ` xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy); `; } else { mainLoop += ` xCOffset = xC + 1 - 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { previous.zw = vec2(0.0); } xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy); } else { xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy); } `; } } else { mainLoop += ` if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xC${colIndex} = xTexelC${colIndex}; `; } if (colIndex + 1 < filterWidth) { const nextTexelOffset = padLeft % 2 === 0 ? nearestLargerEven(dilationWidth) : dilationWidth; if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { mainLoop += ` xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } `; if (dilationWidth > 1) { mainLoop += ` xCOffset -= 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy); } else { xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy); } `; } else { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy); `; } } else { if (nextTexelOffset === 1) { mainLoop += ` xC${colIndex + 1} = xTexelC${colIndex}; `; } else { mainLoop += ` xCOffset = xC + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex + 1} = xTexelC${colIndex + 1}; `; } } } } } else { if (colIndex < filterWidth) { if (padLeft % 2 === 1) { mainLoop += ` xCOffset = xC + 1 - strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1); if (xC + 2 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; if (colIndex + 1 < filterWidth) { mainLoop += ` final = vec4(0.0); xCOffset = xC + 1 + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1]) { final = getX(batch, xR, xCOffset, d1); } xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy); `; } } else { mainLoop += ` if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xCOffset = xC + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4( xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy); `; if (colIndex + 1 < filterWidth) { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; } } } } if (colIndex < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex}, d1, d2); dotProd += xC${colIndex}.xxzz * vec4(wTexel.xy, wTexel.xy); if(d1 + 1 < ${convInfo.inChannels}) { dotProd += xC${colIndex}.yyww * vec4(wTexel.zw, wTexel.zw); } `; if (colIndex + 1 < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex + 1}, d1, d2); dotProd += xC${colIndex + 1}.xxzz * vec4(wTexel.xy, wTexel.xy); if(d1 + 1 < ${convInfo.inChannels}) { dotProd += xC${colIndex + 1}.yyww * vec4(wTexel.zw, wTexel.zw); } `; } } } mainLoop += ` } `; mainLoop += ` } `; mainLoop += ` } `; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; vec4 dotProd = vec4(0.000000000000001); ${mainLoop} vec4 result = dotProd - vec4(0.000000000000001); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } class Im2ColPackedProgram { constructor(outputShape, convInfo) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'inputShape', type: 'ivec4' }, { name: 'pad', type: 'ivec2' }, { name: 'stride', type: 'ivec2' }, { name: 'dilation', type: 'ivec2' }, { name: 'inChannels', type: 'int' }, { name: 'itemsPerBlockRow', type: 'int' }, { name: 'outWidth', type: 'int' }, ]; this.outputShape = outputShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const { dataFormat } = convInfo; const glsl = getGlslDifferences(); const isChannelsLast = dataFormat === 'channelsLast'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const boundsCheckingSnippet = this.enableShapeUniforms ? 'if(blockIndex < outShape[2] && pos < outShape[1]) {' : `if(blockIndex < ${outputShape[2]} && pos < ${outputShape[1]}) {`; let unrolled = ``; for (let row = 0; row <= 1; row++) { for (let col = 0; col <= 1; col++) { unrolled += ` blockIndex = rc.z + ${col}; pos = rc.y + ${row}; ${boundsCheckingSnippet} offsetY = int(blockIndex / outWidth) * stride[0] - pad[0]; d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow); if(d0 < inputShape[${rowDim}] && d0 >= 0) { offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1]; d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) / inChannels); if(d1 < inputShape[${colDim}] && d1 >= 0) { ch = imod(pos, inChannels); if (${isChannelsLast}) { innerDims = vec2(d1, ch); result[${row * 2 + col}] = getChannel( getA(rc.x, d0, int(innerDims.x), int(innerDims.y)), innerDims); } else { innerDims = vec2(d0, d1); result[${row * 2 + col}] = getChannel( getA(rc.x, ch, int(innerDims.x), int(innerDims.y)), innerDims); } } } } `; } } this.userCode = ` void main() { ivec3 rc = getOutputCoords(); vec4 result = vec4(0); int blockIndex, pos, offsetY, d0, offsetX, d1, ch; vec2 innerDims; ${unrolled} ${glsl.output} = result; } `; } } function getShapeForBatchMatMul(shape, isChannelsLast) { const length = shape.length; if (length >= 3) { return isChannelsLast ? [ ...shape.slice(0, -3) , shape[length - 3] * shape[length - 2] , shape[length - 1] ] : [ ...shape.slice(0, -3) , shape[length - 3] , shape[length - 2] * shape[length - 1] ]; } else if (!isChannelsLast && length === 1 && shape[0] > 1) { return [shape[0], 1]; } else { return null; } } function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { const xShape = x.shape; const xTexData = backend.texData.get(x.dataId); const sharedMatMulDim = convInfo.inChannels; const outerShapeX = xShape[0] * xShape[1] * xShape[2]; const outerShapeFilter = convInfo.outChannels; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const transposeA = false; const transposeB = false; let out; const intermediates = []; if (preluActivationWeights != null) { const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked && isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 && arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3)); if (canOptimize) { const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1); const xReshaped = { dataId: x.dataId, shape: [1, targetShape, convInfo.inChannels], dtype: x.dtype }; const originalXTexDataShape = xTexData.shape; xTexData.shape = xTexData.shape.slice(); xTexData.shape[xTexData.shape.length - 2]++; assert$1(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`); const filterReshaped = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); intermediates.push(filterReshaped); const pointwiseConv = batchMatMulImpl({ a: xReshaped, b: filterReshaped, backend, transposeA, transposeB, bias, activation, preluActivationWeights, leakyreluAlpha }); const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId); assert$1(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed'); xTexData.shape = originalXTexDataShape; pointwiseConvTexData.shape = convInfo.outShape; out = identity({ inputs: { x: pointwiseConv }, backend }); out.shape = convInfo.outShape; intermediates.push(pointwiseConv); } else { const numCols = convInfo.outHeight * convInfo.outWidth; const xReshaped = reshape({ inputs: { x }, backend, attrs: { shape: isChannelsLast ? [convInfo.batchSize, numCols, convInfo.inChannels] : [convInfo.batchSize, convInfo.inChannels, numCols] } }); const filterReshaped = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] } }); const result = batchMatMulImpl({ a: isChannelsLast ? xReshaped : filterReshaped, b: isChannelsLast ? filterReshaped : xReshaped, transposeA: !isChannelsLast, transposeB, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); out = reshape({ inputs: { x: result }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(xReshaped); intermediates.push(filterReshaped); intermediates.push(result); } for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return out; } function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) { const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo; const isChannelsLast = dataFormat === 'channelsLast'; const sharedDim = filterWidth * filterHeight * inChannels; const numCols = outHeight * outWidth; const x2ColShape = [convInfo.batchSize, sharedDim, numCols]; const transposeA = true; const transposeB = false; const intermediates = []; if (preluActivationWeights != null) { const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast); if (targetShape != null) { preluActivationWeights = reshape({ inputs: { x: preluActivationWeights }, backend, attrs: { shape: targetShape } }); intermediates.push(preluActivationWeights); } } if (bias != null) { const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast); if (targetShape != null) { bias = reshape({ inputs: { x: bias }, backend, attrs: { shape: targetShape } }); intermediates.push(bias); } } const w2Row = reshape({ inputs: { x: filter }, backend, attrs: { shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim] } }); intermediates.push(w2Row); const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo); const customValues = [ x.shape, [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels], [convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth] ]; const im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues); const im2ColReshaped = reshape({ inputs: { x: im2Col }, backend, attrs: { shape: x2ColShape } }); intermediates.push(im2Col); intermediates.push(im2ColReshaped); const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; const matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape : w2Row.shape, isChannelsLast ? w2Row.shape : im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] : [convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped]; if (bias) { inputs.push(bias); } if (hasPreluActivationWeights) { inputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32'); const out = reshape({ inputs: { x: product }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(product); for (const i of intermediates) { backend.disposeIntermediateTensorInfo(i); } return out; } function conv2d(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs; const $dataFormat = convertConv2DDataFormat(dataFormat); const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat); let out; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x, filter, convInfo, backend }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && env().getBool('WEBGL_EXP_CONV')) { const program = new Conv2DPackedProgram(convInfo); const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; out = backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } else if (env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x, filter, convInfo, backend }); } else { const program = new Conv2DProgram(convInfo); out = backend.runWebGLProgram(program, [x, filter], 'float32'); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); backend.disposeIntermediateTensorInfo(out); return outReshaped; } const conv2DConfig = { kernelName: Conv2D, backendName: 'webgl', kernelFunc: conv2d, }; class Conv2DDerFilterProgram { constructor(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int wR = coords.x; int wC = coords.y; int d1 = coords.z; int d2 = coords.w; float dotProd = 0.0; for (int b = 0; b < ${convInfo.batchSize}; b++) { for (int yR = 0; yR < ${convInfo.outHeight}; yR++) { int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int yC = 0; yC < ${convInfo.outWidth}; yC++) { int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } ${isChannelsLast ? `float dyValue = getDy(b, yR, yC, d2); float xValue = getX(b, xR, xC, d1); dotProd += (xValue * dyValue);` : `float dyValue = getDy(b, d2, yR, yC); float xValue = getX(b, d1, xR, xC); dotProd += (xValue * dyValue);`} } } } setOutput(dotProd); } `; } } class Conv2DDerInputProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const isChannelsLast = convInfo.dataFormat === 'channelsLast'; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[${channelDim}]; ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); int wCPerm = ${filterWidth} - 1 - wC; for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) { if (${isChannelsLast}) { float xValue = getDy(batch, idyR, idyC, d2); float wValue = getW(wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } else { float xValue = getDy(batch, d2, idyR, idyC); float wValue = getW(wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } } } } setOutput(dotProd); } `; } } class Conv3DDerFilterProgram { constructor(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padFront = convInfo.padInfo.front; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; this.userCode = ` void main() { ivec5 coords = getOutputCoords(); int wF = coords.x; int wR = coords.y; int wC = coords.z; int d1 = coords.w; int d2 = coords.u; float dotProd = 0.0; for (int b = 0; b < ${convInfo.batchSize}; b++) { for (int yF = 0; yF < ${convInfo.outDepth}; yF++) { int xF = wF + yF * ${strideDepth} - ${padFront}; if (xF < 0 || xF >= ${convInfo.inDepth}) { continue; } for (int yR = 0; yR < ${convInfo.outHeight}; yR++) { int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int yC = 0; yC < ${convInfo.outWidth}; yC++) { int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float dyValue = getDy(b, yF, yR, yC, d2); float xValue = getX(b, xF, xR, xC, d1); dotProd += (xValue * dyValue); } } } } setOutput(dotProd); } `; } } class Conv3DDerInputProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; const filterDepth = convInfo.filterDepth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padFront = filterDepth - 1 - convInfo.padInfo.front; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = ` const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int d1 = coords.u; ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads; int dyFCorner = dyCorner.x; int dyRCorner = dyCorner.y; int dyCCorner = dyCorner.z; float dotProd = 0.0; for (int wF = 0; wF < ${filterDepth}; wF++) { float dyF = float(dyFCorner + wF) / ${strideDepth}.0; if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) { continue; } int idyF = int(dyF); int wFPerm = ${filterDepth} - 1 - wF; for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); int wCPerm = ${filterWidth} - 1 - wC; for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) { float xValue = getDy(batch, idyF, idyR, idyC, d2); float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } } } } setOutput(dotProd); } `; } } function conv2DBackpropFilter(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs; const $dataFormat = convertConv2DDataFormat(dataFormat); const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 , pad, dimRoundingMode, false , $dataFormat); const program = new Conv2DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } const conv2DBackpropFilterConfig = { kernelName: Conv2DBackpropFilter, backendName: 'webgl', kernelFunc: conv2DBackpropFilter, }; class Conv2DDerInputPackedProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'strides', type: 'vec2' }, ]; this.outputShape = convInfo.inShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[3]; ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; vec4 result = vec4(0.); for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / strides[0]; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { int wCPerm = ${filterWidth} - 1 - wC; float dyC = float(dyCCorner + wC) / strides[1]; bool idyCVal = (dyC >= 0.0) && (dyC < ${convInfo.outWidth}.0) && (fract(dyC) == 0.0); int idyC = int(dyC); float dyC2 = float(dyCCorner + wC + 1) / strides[1]; bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ${convInfo.outWidth}.0) && (fract(dyC2) == 0.0); int idyC2 = int(dyC2); if (idyCVal && idyCVal2) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC, d2); vec4 dySample2 = (idyC / 2 == idyC2 / 2) ? dySample : getDy(batch, idyR, idyC2, d2); vec2 dyValue = mod(float(idyC), 2.) == 0. ? dySample.xy : dySample.zw; result.xy += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); dyValue = mod(float(idyC2), 2.) == 0. ? dySample2.xy : dySample2.zw; result.zw += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } else if (idyCVal) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC, d2); vec2 dyValue = mod(float(idyC), 2.) == 0. ? dySample.xy : dySample.zw; result.xy += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } else if (idyCVal2) { for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) { vec4 wValue = getW(wRPerm, wCPerm, d1, d2); vec4 dySample = getDy(batch, idyR, idyC2, d2); vec2 dyValue = mod(float(idyC2), 2.) == 0. ? dySample.xy : dySample.zw; result.zw += vec2(dot(dyValue, wValue.xy), dot(dyValue, wValue.zw)); } } } } setOutput(result); } `; } } function conv2DBackpropInput(args) { const { inputs, backend, attrs } = args; const { dy, filter } = inputs; const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs; const $dataFormat = convertConv2DDataFormat(dataFormat); const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 , pad, dimRoundingMode, false, $dataFormat); if (env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') && $dataFormat === 'channelsLast') { const customValues = [ [convInfo.strideHeight, convInfo.strideWidth], ]; const program = new Conv2DDerInputPackedProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues); } else { const program = new Conv2DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } } const conv2DBackpropInputConfig = { kernelName: Conv2DBackpropInput, backendName: 'webgl', kernelFunc: conv2DBackpropInput, }; function conv3D(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations } = attrs; const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad); const program = new Conv3DProgram(convInfo); return backend.runWebGLProgram(program, [x, filter], 'float32'); } const conv3DConfig = { kernelName: Conv3D, backendName: 'webgl', kernelFunc: conv3D, }; function conv3DBackpropFilterV2(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, pad, filterShape } = attrs; const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 , pad); const program = new Conv3DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } const conv3DBackpropFilterV2Config = { kernelName: Conv3DBackpropFilterV2, backendName: 'webgl', kernelFunc: conv3DBackpropFilterV2 }; function conv3DBackpropInput(args) { const { inputs, backend, attrs } = args; const { dy, filter } = inputs; const { pad, strides, inputShape } = attrs; const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 , pad); const program = new Conv3DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } const conv3DBackpropInputConfig = { kernelName: Conv3DBackpropInputV2, backendName: 'webgl', kernelFunc: conv3DBackpropInput, }; const COS = CHECK_NAN_SNIPPET_UNARY + ` return cos(x); `; const COS_PACKED = ` vec4 result = cos(x); bvec4 isNaN = isnan(x); ${CHECK_NAN_SNIPPET_PACKED} return result; `; const cos = unaryKernelFunc({ opSnippet: COS, packedOpSnippet: COS_PACKED }); const cosConfig = { kernelName: Cos, backendName: 'webgl', kernelFunc: cos, }; const COSH = ` float e2x = exp(-x); return (e2x + 1.0 / e2x) / 2.0; `; const cosh = unaryKernelFunc({ opSnippet: COSH }); const coshConfig = { kernelName: Cosh, backendName: 'webgl', kernelFunc: cosh, }; class CropAndResizeProgram { constructor(imageShape, boxShape, cropSize, method, extrapolationValue) { this.variableNames = ['Image', 'Boxes', 'BoxInd']; this.outputShape = []; const [batch, imageHeight, imageWidth, depth] = imageShape; const [numBoxes,] = boxShape; const [cropHeight, cropWidth] = cropSize; this.outputShape = [numBoxes, cropHeight, cropWidth, depth]; const methodId = method === 'bilinear' ? 1 : 0; const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`]; const [heightRatio, heightScale, inY] = cropHeight > 1 ? [ `${(imageHeight - 1) / (cropHeight - 1)}`, '(y2-y1) * height_ratio', `y1*${inputHeightFloat} + float(y)*(height_scale)`, ] : [ '0.0', '0.0', `0.5 * (y1+y2) * ${inputHeightFloat}`, ]; const [widthRatio, widthScale, inX] = cropWidth > 1 ? [ `${(imageWidth - 1) / (cropWidth - 1)}`, '(x2-x1) * width_ratio', `x1*${inputWidthFloat} + float(x)*(width_scale)`, ] : [ '0.0', '0.0', `0.5 * (x1+x2) * ${inputWidthFloat}`, ]; this.userCode = ` const float height_ratio = float(${heightRatio}); const float width_ratio = float(${widthRatio}); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int y = coords[1]; int x = coords[2]; int d = coords[3]; float y1 = getBoxes(b,0); float x1 = getBoxes(b,1); float y2 = getBoxes(b,2); float x2 = getBoxes(b,3); int bInd = round(getBoxInd(b)); if(bInd < 0 || bInd >= ${batch}) { return; } float height_scale = ${heightScale}; float width_scale = ${widthScale}; float in_y = ${inY}; if( in_y < 0.0 || in_y > ${inputHeightFloat} ) { setOutput(float(${extrapolationValue})); return; } float in_x = ${inX}; if( in_x < 0.0 || in_x > ${inputWidthFloat} ) { setOutput(float(${extrapolationValue})); return; } vec2 sourceFracIndexCR = vec2(in_x,in_y); if(${methodId} == 1) { ivec2 sourceFloorCR = ivec2(sourceFracIndexCR); ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR)); float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d); float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d); float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d); float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d); vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR); float top = topLeft + (topRight - topLeft) * fracCR.x; float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x; float newValue = top + (bottom - top) * fracCR.y; setOutput(newValue); } else { ivec2 sourceNearestCR = ivec2(floor( sourceFracIndexCR + vec2(0.5,0.5))); float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d); setOutput(newValue); } } `; } } const cropAndResize = (args) => { const { inputs, backend, attrs } = args; const { image, boxes, boxInd } = inputs; const { cropSize, method, extrapolationValue } = attrs; const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue); return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32'); }; const cropAndResizeConfig = { kernelName: CropAndResize, backendName: 'webgl', kernelFunc: cropAndResize }; var CumOpType; (function (CumOpType) { CumOpType["Prod"] = "*"; CumOpType["Sum"] = "+"; })(CumOpType || (CumOpType = {})); class CumProgram { constructor(op, outputShape, exclusive, reverse) { this.op = op; this.outputShape = outputShape; this.variableNames = ['x']; this.customUniforms = [{ name: 'index', type: 'float' }]; const rank = this.outputShape.length; const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0'; const val = exclusive ? initVal : `getX(${getCoords(rank, 'coords', this.op)})`; const length = this.outputShape[this.outputShape.length - 1]; let condition = ''; let idxString = ''; if (exclusive) { condition = reverse ? `end != ${length - 1}` : 'end != 0'; idxString = reverse ? 'end + 1' : 'end - 1'; } else { condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2'; idxString = (reverse ? 'end + pow2' : 'end - pow2'); } this.userCode = ` void main() { ${getCoordsDataType(rank)} coords = getOutputCoords(); int end = ${getFinalCoord(rank, 'coords', this.op)}; float val = ${val}; int pow2 = int(pow(2.0, index)); if (${condition}) { int idx = ${idxString}; ${getFinalCoord(rank, 'coords', this.op)} = idx; val ${this.op}= getX(${getCoords(rank, 'coords', this.op)}); } setOutput(val); } `; } } function getCoords(rank, name, op) { if (rank === 1) { return `${name}`; } else if (rank === 2) { return `${name}.x, ${name}.y`; } else if (rank === 3) { return `${name}.x, ${name}.y, ${name}.z`; } else if (rank === 4) { return `${name}.x, ${name}.y, ${name}.z, ${name}.w`; } else { throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`); } } function getFinalCoord(rank, name, op) { if (rank === 1) { return `${name}`; } else if (rank === 2) { return `${name}.y`; } else if (rank === 3) { return `${name}.z`; } else if (rank === 4) { return `${name}.w`; } else { throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`); } } function cumImpl(op, x, backend, axis, exclusive, reverse) { const xRank = x.shape.length; const permutation = getAxesPermutation([axis], xRank); let permutedX = x; if (permutation != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } }); } const permutedAxis = getInnerMostAxes(1, xRank)[0]; if (permutedAxis !== xRank - 1) { throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` + `but got axis=${axis}`); } const size = permutedX.shape[permutedAxis]; let result = identity({ inputs: { x: permutedX }, backend }); for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) { const program = new CumProgram(op, permutedX.shape, false, reverse); const customValues = [[i]]; const prevResult = result; result = backend.runWebGLProgram(program, [result], result.dtype, customValues); backend.disposeIntermediateTensorInfo(prevResult); } if (exclusive) { const program = new CumProgram(op, permutedX.shape, exclusive, reverse); const prevResult = result; result = backend.runWebGLProgram(program, [result], result.dtype); backend.disposeIntermediateTensorInfo(prevResult); } if (permutation != null) { const reversePermutation = getUndoAxesPermutation(permutation); const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } }); backend.disposeIntermediateTensorInfo(result); backend.disposeIntermediateTensorInfo(permutedX); return reverseTransposedResult; } return result; } function cumprod(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, exclusive, reverse } = attrs; return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse); } const cumprodConfig = { kernelName: Cumprod, backendName: 'webgl', kernelFunc: cumprod }; function cumsum(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, exclusive, reverse } = attrs; return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse); } const cumsumConfig = { kernelName: Cumsum, backendName: 'webgl', kernelFunc: cumsum }; function denseBincount(args) { const { inputs, backend, attrs } = args; const { x, weights } = inputs; const { size, binaryOutput } = attrs; if (x.shape.length === 1) { const xVals = backend.readSync(x.dataId); const weightsVals = backend.readSync(weights.dataId); const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size); return backend.makeTensorInfo([size], weights.dtype, outVals); } else if (x.shape.length === 2) { const xBuf = backend.bufferSync(x); const weightsBuf = backend.bufferSync(weights); const outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput); return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values); } throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` + `${x.shape.length}.`); } const denseBincountConfig = { kernelName: DenseBincount, backendName: 'webgl', kernelFunc: denseBincount }; class DepthToSpaceProgram { constructor(outputShape, blockSize, dataFormat) { this.variableNames = ['x']; this.outputShape = []; this.outputShape = outputShape; this.blockSize = blockSize; this.dataFormat = dataFormat; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int h = ${this.getHeightCoordString()}; int w = ${this.getWidthCoordString()}; int d = ${this.getDepthCoordString()}; int in_h = h / ${blockSize}; int offset_h = imod(h, ${blockSize}); int in_w = w / ${blockSize}; int offset_w = imod(w, ${blockSize}); int offset_d = (offset_h * ${blockSize} + offset_w) * ${this.getOutputDepthSize()}; int in_d = d + offset_d; float result = ${this.getInputSamplingString()}; setOutput(result); } `; } getHeightCoordString() { if (this.dataFormat === 'NHWC') { return `coords[1]`; } else { return `coords[2]`; } } getWidthCoordString() { if (this.dataFormat === 'NHWC') { return `coords[2]`; } else { return `coords[3]`; } } getDepthCoordString() { if (this.dataFormat === 'NHWC') { return `coords[3]`; } else { return `coords[1]`; } } getOutputDepthSize() { if (this.dataFormat === 'NHWC') { return this.outputShape[3]; } else { return this.outputShape[1]; } } getInputSamplingString() { if (this.dataFormat === 'NHWC') { return `getX(b, in_h, in_w, in_d)`; } else { return `getX(b, in_d, in_h, in_w)`; } } } function depthToSpace(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { blockSize, dataFormat } = attrs; const batchSize = x.shape[0]; const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; const outputHeight = inputHeight * blockSize; const outputWidth = inputWidth * blockSize; const outputDepth = inputDepth / (blockSize * blockSize); const outputShape = (dataFormat === 'NHWC') ? [batchSize, outputHeight, outputWidth, outputDepth] : [batchSize, outputDepth, outputHeight, outputWidth]; const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat); return backend.runWebGLProgram(program, [x], x.dtype); } const depthToSpaceConfig = { kernelName: DepthToSpace, backendName: 'webgl', kernelFunc: depthToSpace }; class DepthwiseConv2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const channelMul = convInfo.outChannels / convInfo.inChannels; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `float activation(float a) { float b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `float activation(float a) { float b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = ` float activation(float x) { ${activation} } `; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int d1 = d2 / ${channelMul}; int q = d2 - d1 * ${channelMul}; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR * dilations[0]; if (xR < 0 || xR >= inDims[0]) { continue; } for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC * dilations[1]; if (xC < 0 || xC >= inDims[1]) { continue; } float xVal = getX(batch, xR, xC, d1); float wVal = getW(wR, wC, d1, q); dotProd += xVal * wVal; } } float result = dotProd; ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } class DepthwiseConvPacked2DProgram { constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) { this.variableNames = ['x', 'W']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [ { name: 'pads', type: 'ivec2' }, { name: 'strides', type: 'ivec2' }, { name: 'dilations', type: 'ivec2' }, { name: 'inDims', type: 'ivec2' }, ]; this.outputShape = convInfo.outShape; this.enableShapeUniforms = useShapeUniforms(this.outputShape.length); const channelMul = convInfo.outChannels / convInfo.inChannels; const padLeft = convInfo.padInfo.left; const strideWidth = convInfo.strideWidth; const dilationWidth = convInfo.dilationWidth; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const texelsAcross = filterWidth; let mainLoop = ` int xR; int xC; int xCOffset; vec4 wTexel; vec4 previous; vec4 final;`; for (let c = 0; c < filterWidth; c++) { mainLoop += ` vec4 xTexelC${c * 2}; int xTexelC${c * 2}Ready; vec4 xTexelC${c * 2 + 1}; int xTexelC${c * 2 + 1}Ready; vec4 xC${c};`; } mainLoop += ` for (int r = 0; r < ${filterHeight}; r++) { `; for (let c = 0; c < filterWidth; c++) { mainLoop += ` xTexelC${c * 2} = vec4(0.0); xTexelC${c * 2}Ready = 0; xTexelC${c * 2 + 1} = vec4(0.0); xTexelC${c * 2 + 1}Ready = 0; xC${c} = vec4(0.0);`; } mainLoop += ` xR = xRCorner + r * dilations[0]; if (xR >=0 && xR < inDims[0]) { `; for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) { const colIndex = texelC * 2; mainLoop += ` xC = xCCorner + ${colIndex * dilationWidth}; `; if (strideWidth === 1) { if (colIndex < filterWidth) { if (padLeft % 2 === 1) { mainLoop += ` xCOffset = xC + 1; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } `; if (dilationWidth === 1 && colIndex > 0) { mainLoop += ` xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy); `; } else { mainLoop += ` xCOffset = xC + 1 - 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { previous.zw = vec2(0.0); } xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy); } else { xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy); } `; } } else { mainLoop += ` if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xC${colIndex} = xTexelC${colIndex}; `; } if (colIndex + 1 < filterWidth) { const nextTexelOffset = padLeft % 2 === 0 ? nearestLargerEven(dilationWidth) : dilationWidth; if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { mainLoop += ` xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } `; if (dilationWidth > 1) { mainLoop += ` xCOffset -= 2; if (xCOffset >= 0 && xCOffset < inDims[1]) { previous = getX(batch, xR, xCOffset, d1); xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy); } else { xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy); } `; } else { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy); `; } } else { if (nextTexelOffset === 1) { mainLoop += ` xC${colIndex + 1} = xTexelC${colIndex}; `; } else { mainLoop += ` xCOffset = xC + ${nextTexelOffset}; if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex + 1} = xTexelC${colIndex + 1}; `; } } } } } else { if (colIndex < filterWidth) { if (padLeft % 2 === 1) { mainLoop += ` xCOffset = xC + 1 - strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1); if (xC + 2 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.0); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; if (colIndex + 1 < filterWidth) { mainLoop += ` final = vec4(0.0); xCOffset = xC + 1 + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1]) { final = getX(batch, xR, xCOffset, d1); } xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy); `; } } else { mainLoop += ` if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) { xTexelC${colIndex} = getX(batch, xR, xC, d1); if (xC + 1 >= inDims[1]) { xTexelC${colIndex}.zw = vec2(0.0); } xTexelC${colIndex}Ready = 1; } xCOffset = xC + strides[1]; if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) { xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1); if (xCOffset + 1 >= inDims[1]) { xTexelC${colIndex + 1}.zw = vec2(0.); } xTexelC${colIndex + 1}Ready = 1; } xC${colIndex} = vec4( xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy); `; if (colIndex + 1 < filterWidth) { mainLoop += ` xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw); `; } } } } if (colIndex < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex}, d1, q); dotProd += xC${colIndex} * vec4(wTexel.xz, wTexel.xz); `; if (colIndex + 1 < filterWidth) { mainLoop += ` wTexel = getW(r, ${colIndex + 1}, d1, q); dotProd += xC${colIndex + 1} * vec4(wTexel.xz, wTexel.xz); `; } } } mainLoop += ` } `; mainLoop += ` } `; let activationSnippet = '', applyActivationSnippet = ''; if (activation) { if (hasPreluActivation) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getPreluActivationWeightsAtOutCoords(); ${activation} }`; } else if (hasLeakyReluAlpha) { activationSnippet = `vec4 activation(vec4 a) { vec4 b = getLeakyreluAlphaAtOutCoords(); ${activation} }`; } else { activationSnippet = `vec4 activation(vec4 x) { ${activation} }`; } applyActivationSnippet = `result = activation(result);`; } const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; if (addBias) { this.variableNames.push('bias'); } if (hasPreluActivation) { this.variableNames.push('preluActivationWeights'); } if (hasLeakyReluAlpha) { this.variableNames.push('leakyreluAlpha'); } this.userCode = ` ${activationSnippet} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; ivec2 xRCCorner = coords.yz * strides - pads; int d2 = coords.w; int d1 = d2 / ${channelMul}; int q = d2 - d1 * ${channelMul}; int xRCorner = xRCCorner.x; int xCCorner = xRCCorner.y; vec4 dotProd = vec4(0.000000000000001); ${mainLoop} vec4 result = dotProd - vec4(0.000000000000001); ${addBiasSnippet} ${applyActivationSnippet} setOutput(result); } `; } } function depthwiseConv2dNative(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations, dimRoundingMode } = attrs; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true ); let program; if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); } else { program = new DepthwiseConv2DProgram(convInfo); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; return backend.runWebGLProgram(program, [x, filter], 'float32', customValues); } const depthwiseConv2dNativeConfig = { kernelName: DepthwiseConv2dNative, backendName: 'webgl', kernelFunc: depthwiseConv2dNative, }; class DepthwiseConv2DDerFilterProgram { constructor(convInfo) { this.variableNames = ['x', 'dy']; this.outputShape = convInfo.filterShape; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padTop = convInfo.padInfo.top; const padLeft = convInfo.padInfo.left; const channelMul = convInfo.outChannels / convInfo.inChannels; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int wR = coords.x; int wC = coords.y; int d1 = coords.z; int dm = coords.w; int d2 = d1 * ${channelMul} + dm; float dotProd = 0.0; for (int b = 0; b < ${convInfo.batchSize}; b++) { for (int yR = 0; yR < ${convInfo.outHeight}; yR++) { int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${convInfo.inHeight}) { continue; } for (int yC = 0; yC < ${convInfo.outWidth}; yC++) { int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${convInfo.inWidth}) { continue; } float dyValue = getDy(b, yR, yC, d2); float xValue = getX(b, xR, xC, d1); dotProd += (xValue * dyValue); } } } setOutput(dotProd); } `; } } class DepthwiseConv2DDerInputProgram { constructor(convInfo) { this.variableNames = ['dy', 'W']; this.outputShape = convInfo.inShape; const filterHeight = convInfo.filterHeight; const filterWidth = convInfo.filterWidth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const padTop = filterHeight - 1 - convInfo.padInfo.top; const padLeft = filterWidth - 1 - convInfo.padInfo.left; const channelMul = convInfo.outChannels / convInfo.inChannels; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int batch = coords[0]; int d1 = coords[3]; ivec2 dyCorner = coords.yz - pads; int dyRCorner = dyCorner.x; int dyCCorner = dyCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${filterHeight}; wR++) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); int wRPerm = ${filterHeight} - 1 - wR; for (int wC = 0; wC < ${filterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); int wCPerm = ${filterWidth} - 1 - wC; for (int dm = 0; dm < ${channelMul}; dm++) { int d2 = d1 * ${channelMul} + dm; float xValue = getDy(batch, idyR, idyC, d2); float wValue = getW(wRPerm, wCPerm, d1, dm); dotProd += xValue * wValue; } } } setOutput(dotProd); } `; } } function depthwiseConv2dNativeBackpropFilter(args) { const { inputs, backend, attrs } = args; const { x, dy } = inputs; const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs; const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true ); const program = new DepthwiseConv2DDerFilterProgram(convInfo); return backend.runWebGLProgram(program, [x, dy], 'float32'); } const depthwiseConv2dNativeBackpropFilterConfig = { kernelName: DepthwiseConv2dNativeBackpropFilter, backendName: 'webgl', kernelFunc: depthwiseConv2dNativeBackpropFilter }; function depthwiseConv2dNativeBackpropInput(args) { const { inputs, backend, attrs } = args; const { dy, filter } = inputs; const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs; const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true ); const program = new DepthwiseConv2DDerInputProgram(convInfo); return backend.runWebGLProgram(program, [dy, filter], 'float32'); } const depthwiseConv2dNativeBackpropInputConfig = { kernelName: DepthwiseConv2dNativeBackpropInput, backendName: 'webgl', kernelFunc: depthwiseConv2dNativeBackpropInput }; class DiagProgram { constructor(size) { this.variableNames = ['X']; this.outputShape = [size, size]; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0; setOutput(val); } `; } } function diag(args) { const { inputs, backend } = args; const { x } = inputs; const outShape = [...x.shape, ...x.shape]; const xSize = sizeFromShape(x.shape); const flat = reshape({ inputs: { x }, backend, attrs: { shape: [xSize] } }); const program = new DiagProgram(xSize); const res = backend.runWebGLProgram(program, [flat], flat.dtype); const out = reshape({ inputs: { x: res }, backend, attrs: { shape: outShape } }); backend.disposeIntermediateTensorInfo(flat); backend.disposeIntermediateTensorInfo(res); return out; } const diagConfig = { kernelName: Diag, backendName: 'webgl', kernelFunc: diag }; class Dilation2DProgram { constructor(convInfo) { this.variableNames = ['x', 'W']; this.outputShape = convInfo.outShape; const { inHeight, inWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth } = convInfo; const { top: padTop, left: padLeft } = padInfo; this.userCode = ` const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); const ivec2 pads = ivec2(${padTop}, ${padLeft}); const float neg_infinity = -3.4e38; void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; int d1 = coords.w; ivec2 outTopLeftCorner = coords.yz * strides - pads; int hBeg = outTopLeftCorner.x; int wBeg = outTopLeftCorner.y; float curVal = neg_infinity; for (int h = 0; h < ${filterHeight}; h++) { int hIn = hBeg + h * ${dilationHeight}; if (hIn >= 0 && hIn < ${inHeight}) { for (int w = 0; w < ${filterWidth}; w++) { int wIn = wBeg + w * ${dilationWidth}; if (wIn >= 0 && wIn < ${inWidth}) { float xVal = getX(batch, hIn, wIn, d1); float wVal = getW(h, w, d1); float val = xVal + wVal; if (val > curVal) { curVal = val; } } } } } float result = curVal; setOutput(result); } `; } } function dilation2D(args) { const { inputs, backend, attrs } = args; const { x, filter } = inputs; const { strides, pad, dilations } = attrs; const convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations); let out; const program = new Dilation2DProgram(convInfo); out = backend.runWebGLProgram(program, [x, filter], 'float32'); const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); backend.disposeIntermediateTensorInfo(out); return outReshaped; } const dilation2DConfig = { kernelName: Dilation2D, backendName: 'webgl', kernelFunc: dilation2D, }; function einsum(args) { const { inputs, backend, attrs } = args; const { equation } = attrs; const tensors = inputs; const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length); checkEinsumDimSizes(allDims.length, idDims, tensors); const { path, steps } = getEinsumComputePath(summedDims, idDims); const nSteps = steps.length; let out = null; let numDimsRemaining = allDims.length; const tensorsToDispose = []; for (let i = 0; i < nSteps; ++i) { for (const idTerm of steps[i]) { const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]); let x; if (isIdentityPermutation(perm)) { x = tensors[idTerm]; } else { x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } }); tensorsToDispose.push(x); } const targetShape = x.shape.slice(); for (let k = 0; k < dimsToExpand.length; ++k) { targetShape.splice(dimsToExpand[k], 0, 1); } if (!arraysEqual(x.shape, targetShape)) { x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } }); tensorsToDispose.push(x); } if (out === null) { out = x; } else { out = multiply({ inputs: { a: x, b: out }, backend }); tensorsToDispose.push(out); } } if (i < nSteps - 1) { if (path[i] >= 0) { out = sum({ inputs: { x: out }, backend, attrs: { axis: path[i] - (allDims.length - numDimsRemaining), keepDims: false } }); tensorsToDispose.push(out); } numDimsRemaining--; } } for (const tensorInfo of tensorsToDispose) { if (tensorInfo === out) { continue; } backend.disposeIntermediateTensorInfo(tensorInfo); } return out; } const einsumConfig = { kernelName: Einsum, backendName: 'webgl', kernelFunc: einsum }; const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`; const ELU_PACKED = ` vec4 result; result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0); result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0); result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0); result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0); return result; `; const elu$1 = unaryKernelFunc({ opSnippet: ELU, packedOpSnippet: ELU_PACKED }); const eluConfig = { kernelName: Elu$1, backendName: 'webgl', kernelFunc: elu$1 }; const ELU_DER = `return (b >= 0.0) ? a : a * (b + 1.0);`; const ELU_DER_PACKED = ` vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.))); return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0)))); `; const eluGrad = (args) => { const { inputs, backend } = args; const { dy, y } = inputs; const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) : new BinaryOpProgram(ELU_DER, dy.shape, y.shape); return backend.runWebGLProgram(program, [dy, y], dy.dtype); }; const eluGradConfig$1 = { kernelName: EluGrad, backendName: 'webgl', kernelFunc: eluGrad }; const PACKED_EQUAL = ` return vec4(equal(a, b)); `; const EQUAL = `return float(a == b);`; const equal = binaryKernelFunc({ opSnippet: EQUAL, packedOpSnippet: PACKED_EQUAL, dtype: 'bool', cpuKernelImpl: equalImplCPU, }); const equalConfig = { kernelName: Equal, backendName: 'webgl', kernelFunc: equal }; const ERF = ` float p = ${ERF_P}; float a1 = ${ERF_A1}; float a2 = ${ERF_A2}; float a3 = ${ERF_A3}; float a4 = ${ERF_A4}; float a5 = ${ERF_A5}; float sign = sign(x); x = abs(x); float t = 1.0 / (1.0 + p * x); return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x)); `; const erf = unaryKernelFunc({ opSnippet: ERF }); const erfConfig = { kernelName: Erf, backendName: 'webgl', kernelFunc: erf, }; const EXP = CHECK_NAN_SNIPPET_UNARY + ` return exp(x); `; const EXP_PACKED = ` vec4 result = exp(x); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const exp = unaryKernelFunc({ opSnippet: EXP, packedOpSnippet: EXP_PACKED, cpuKernelImpl: expImplCPU, dtype: 'float32', }); const expConfig = { kernelName: Exp, backendName: 'webgl', kernelFunc: exp }; function expandDims$1(args) { const { inputs, attrs, backend } = args; const { dim } = attrs; const { input } = inputs; const inputRank = input.shape.length; const newShape = input.shape.slice(); let $dim = dim; if (dim < 0) { assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`); $dim = inputRank + dim + 1; } newShape.splice($dim, 0, 1); return reshape({ inputs: { x: input }, backend, attrs: { shape: newShape } }); } const expandDimsConfig = { kernelName: ExpandDims, backendName: 'webgl', kernelFunc: expandDims$1, }; const EXPM1 = `return exp(x) - 1.0;`; const expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU }); const expm1Config = { kernelName: Expm1, backendName: 'webgl', kernelFunc: expm1 }; class FFTProgram { constructor(component, inputShape, inverse) { this.variableNames = ['real', 'imag']; const innerDim = inputShape[1]; this.outputShape = inputShape; const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`; const resultDenominator = inverse ? `${innerDim}.0` : '1.0'; let opString; if (component === 'real') { opString = 'return real * expR - imag * expI;'; } else if (component === 'imag') { opString = 'return real * expI + imag * expR;'; } else { throw new Error(`FFT component must be either "real" or "imag", got ${component}.`); } this.userCode = ` const float exponentMultiplier = ${exponentMultiplierSnippet}; float unaryOpComplex(float real, float expR, float imag, float expI) { ${opString} } float mulMatDFT(int batch, int index) { float indexRatio = float(index) / float(${innerDim}); float exponentMultiplierTimesIndexRatio = exponentMultiplier * indexRatio; float result = 0.0; for (int i = 0; i < ${innerDim}; i++) { float x = exponentMultiplierTimesIndexRatio * float(i); float expR = cos(x); float expI = sin(x); float real = getReal(batch, i); float imag = getImag(batch, i); result += unaryOpComplex(real, expR, imag, expI) / ${resultDenominator}; } return result; } void main() { ivec2 coords = getOutputCoords(); setOutput(mulMatDFT(coords[0], coords[1])); } `; } } function fftImpl(x, inverse, backend) { const xData = backend.texData.get(x.dataId); const inputSize = sizeFromShape(x.shape); const innerDimensionSize = x.shape[x.shape.length - 1]; const batch = inputSize / innerDimensionSize; const input2D = reshape({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } }); const xShape = input2D.shape; const realProgram = new FFTProgram('real', xShape, inverse); const imagProgram = new FFTProgram('imag', xShape, inverse); const inputs = [ { dataId: xData.complexTensorInfos.real.dataId, dtype: xData.complexTensorInfos.real.dtype, shape: xShape }, { dataId: xData.complexTensorInfos.imag.dataId, dtype: xData.complexTensorInfos.imag.dtype, shape: xShape } ]; const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32'); const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32'); const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(imagPart); const complexOutputReshaped = reshape({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } }); backend.disposeIntermediateTensorInfo(input2D); backend.disposeIntermediateTensorInfo(complexOutput); return complexOutputReshaped; } function fft(args) { const { inputs, backend } = args; const { input } = inputs; return fftImpl(input, false , backend); } const fftConfig = { kernelName: FFT, backendName: 'webgl', kernelFunc: fft }; class FillProgram { constructor(shape, value) { this.outputShape = []; this.customUniforms = [{ name: 'value', type: 'float' }]; this.variableNames = ['x']; this.outputShape = shape; this.userCode = ` void main() { setOutput(value); } `; } } function fill(args) { const { backend, attrs } = args; const { shape, value } = attrs; let { dtype } = attrs; dtype = dtype || inferDtype(value); if (dtype === 'string') { const values = getArrayFromDType(dtype, sizeFromShape(shape)); values.fill(value); return backend.makeTensorInfo(shape, dtype, values); } else { const program = new FillProgram(shape, value); const customValues = [[value]]; return backend.runWebGLProgram(program, [], dtype, customValues); } } const fillConfig = { kernelName: Fill, backendName: 'webgl', kernelFunc: fill }; class FlipLeftRightProgram { constructor(imageShape) { this.variableNames = ['Image']; this.outputShape = []; const imageWidth = imageShape[2]; this.outputShape = imageShape; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int x = coords[2]; int coordX = ${imageWidth} - x - 1; float outputValue; if(coordX >= 0 && coordX < ${imageWidth}) { outputValue = getImage(coords[0], coords[1], coordX, coords[3]); } else { outputValue = getImage(coords[0], coords[1], coords[2], coords[3]); } setOutput(outputValue); } `; } } const flipLeftRightConfig = { kernelName: FlipLeftRight, backendName: 'webgl', kernelFunc: ({ inputs, backend }) => { const { image } = inputs; const webglBackend = backend; const program = new FlipLeftRightProgram(image.shape); const output = webglBackend.runWebGLProgram(program, [image], image.dtype); return output; } }; const FLOOR = `return floor(x);`; const floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU }); const floorConfig = { kernelName: Floor, backendName: 'webgl', kernelFunc: floor, }; const INT_DIV = ` float s = sign(a) * sign(b); int ia = round(a); int ib = round(b); if (ib != 0) { return float(idiv(ia, ib, s)); } else { return NAN; } `; const INT_DIV_PACKED = ` ivec4 ia = round(a); ivec4 ib = round(b); bvec4 cond = notEqual(ib, ivec4(0)); ivec4 result = ivec4(0); vec4 s = sign(a) * sign(b); if (cond[0]) { result[0] = idiv(ia[0], ib[0], s[0]); } if (cond[1]) { result[1] = idiv(ia[1], ib[1], s[1]); } if (cond[2]) { result[2] = idiv(ia[2], ib[2], s[2]); } if (cond[3]) { result[3] = idiv(ia[3], ib[3], s[3]); } return vec4(result); `; const floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' }); const floorDivConfig = { kernelName: FloorDiv, backendName: 'webgl', kernelFunc: floorDiv }; class FromPixelsProgram { constructor(outputShape) { this.variableNames = ['A']; const glsl = getGlslDifferences(); const [height, width,] = outputShape; this.outputShape = outputShape; this.userCode = ` void main() { ivec3 coords = getOutputCoords(); int texR = coords[0]; int texC = coords[1]; int depth = coords[2]; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); vec4 values = ${glsl.texture2D}(A, uv); float value; if (depth == 0) { value = values.r; } else if (depth == 1) { value = values.g; } else if (depth == 2) { value = values.b; } else if (depth == 3) { value = values.a; } setOutput(floor(value * 255.0 + 0.5)); } `; } } class FromPixelsPackedProgram { constructor(outputShape) { this.variableNames = ['A']; this.packedInputs = false; this.packedOutput = true; const glsl = getGlslDifferences(); const [height, width,] = outputShape; this.outputShape = outputShape; this.userCode = ` void main() { ivec3 coords = getOutputCoords(); int texR = coords[0]; int texC = coords[1]; int depth = coords[2]; vec4 result = vec4(0.); for(int row=0; row<=1; row++) { for(int col=0; col<=1; col++) { texC = coords[1] + row; depth = coords[2] + col; vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0); vec4 values = ${glsl.texture2D}(A, uv); float value; if (depth == 0) { value = values.r; } else if (depth == 1) { value = values.g; } else if (depth == 2) { value = values.b; } else if (depth == 3) { value = values.a; } result[row * 2 + col] = floor(value * 255.0 + 0.5); } } ${glsl.output} = result; } `; } } const fromPixelsConfig = { kernelName: FromPixels, backendName: 'webgl', kernelFunc: fromPixels, }; let fromPixels2DContext; let willReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU'); function fromPixels(args) { const { inputs, backend, attrs } = args; let { pixels } = inputs; const { numChannels } = attrs; const isVideo = typeof (HTMLVideoElement) !== 'undefined' && pixels instanceof HTMLVideoElement; const isImage = typeof (HTMLImageElement) !== 'undefined' && pixels instanceof HTMLImageElement; const [width, height] = isVideo ? [ pixels.videoWidth, pixels.videoHeight ] : [pixels.width, pixels.height]; const texShape = [height, width]; const outShape = [height, width, numChannels]; if (isImage || isVideo) { const newWillReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU'); if (fromPixels2DContext == null || newWillReadFrequently !== willReadFrequently) { willReadFrequently = newWillReadFrequently; fromPixels2DContext = document.createElement('canvas').getContext('2d', { willReadFrequently }); } fromPixels2DContext.canvas.width = width; fromPixels2DContext.canvas.height = height; fromPixels2DContext.drawImage(pixels, 0, 0, width, height); pixels = fromPixels2DContext.canvas; } const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32'); backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS; backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels); const program = env().getBool('WEBGL_PACK') ? new FromPixelsPackedProgram(outShape) : new FromPixelsProgram(outShape); const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32'); backend.disposeData(tempPixelHandle.dataId); return res; } function fusedConv2d(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const $dataFormat = convertConv2DDataFormat(dataFormat); const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat); let out; const intermediates = []; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; const prepareInputs = () => { const inputs = [x, filter]; const alignInputWithDataFormat = (input, dataFormat) => { if (dataFormat === 'NCHW' && input.shape.length === 1 && input.shape[0] !== 1) { const alignedInput = reshape({ inputs: { x: input }, backend, attrs: { shape: [input.shape[0], 1, 1] } }); intermediates.push(alignedInput); return alignedInput; } return input; }; if (hasBias) { inputs.push(alignInputWithDataFormat(bias, dataFormat)); } if (hasPreluActivationWeights) { inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat)); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32')); inputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } return inputs; }; if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && (convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) { out = conv2dByMatMul({ x, filter, convInfo, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); } else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast' && env().getBool('WEBGL_EXP_CONV')) { const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; const program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const inputs = prepareInputs(); out = backend.runWebGLProgram(program, inputs, 'float32', customValues); } else if (env().getBool('WEBGL_CONV_IM2COL')) { out = conv2dWithIm2Row({ x, filter, convInfo, backend, bias, activation, preluActivationWeights, leakyreluAlpha }); } else { const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null; const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); const inputs = prepareInputs(); out = backend.runWebGLProgram(program, inputs, 'float32'); } const outReshaped = reshape({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } }); intermediates.push(out); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return outReshaped; } const fusedConv2DConfig = { kernelName: FusedConv2D, backendName: 'webgl', kernelFunc: fusedConv2d, }; function fusedDepthwiseConv2D(args) { const { inputs, backend, attrs } = args; const { x, filter, bias, preluActivationWeights } = inputs; const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs; const intermediates = []; let $dilations = dilations; if ($dilations == null) { $dilations = [1, 1]; } assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' + `1. Got strides ${strides} and dilations '${$dilations}'`); const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true ); const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 && convInfo.outChannels / convInfo.inChannels === 1; const fusedActivation = activation ? mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : null; const programInputs = [x, filter]; const hasBias = bias != null; const hasPreluActivationWeights = preluActivationWeights != null; const hasLeakyreluAlpha = activation === 'leakyrelu'; if (hasBias) { programInputs.push(bias); } if (hasPreluActivationWeights) { programInputs.push(preluActivationWeights); } if (hasLeakyreluAlpha) { const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32')); programInputs.push($leakyreluAlpha); intermediates.push($leakyreluAlpha); } let program; if (shouldPackDepthwiseConv) { program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } else { program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha); } const customValues = [ [convInfo.padInfo.top, convInfo.padInfo.left], [convInfo.strideHeight, convInfo.strideWidth], [convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inHeight, convInfo.inWidth] ]; const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues); intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } const fusedDepthwiseConv2DConfig = { kernelName: FusedDepthwiseConv2D, backendName: 'webgl', kernelFunc: fusedDepthwiseConv2D, }; class GatherNDProgram { constructor(sliceDim, strides, shape, paramsShape) { this.sliceDim = sliceDim; this.strides = strides; this.paramsShape = paramsShape; this.variableNames = ['x', 'indices']; this.outputShape = shape; const dtype = getCoordsDataType(shape.length); let mainLoop = ` int index;`; for (let j = 0; j < this.sliceDim; j++) { mainLoop += ` index = round(getIndices(coords[0], ${j})); out_of_bounds = out_of_bounds || index < 0; out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]}; flattenIndex += index * ${this.strides[j]};`; } this.userCode = ` void main() { ${dtype} coords = getOutputCoords(); int flattenIndex = 0; bool out_of_bounds = false; ${mainLoop} setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1])); } `; } } function gatherNd(args) { const { inputs, backend } = args; const { params, indices } = inputs; const indicesShape = indices.shape; const sliceRank = indicesShape[indicesShape.length - 1]; const paramsSize = sizeFromShape(params.shape); const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices); const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numSlices, sliceRank] } }); const flattenX = reshape({ inputs: { x: params }, backend, attrs: { shape: [(sizeFromShape(params.shape) / sliceSize), sliceSize] } }); if (backend.shouldExecuteOnCPU([params, indices]) || params.dtype === 'string') { const indicesData = backend.readSync(indices.dataId); const paramsBuf = backend.bufferSync(params); const outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize); return backend.makeTensorInfo(resultShape, params.dtype, outValue.values); } const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape); const res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: resultShape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(res); return reshaped; } const gatherNdConfig = { kernelName: GatherNd, backendName: 'webgl', kernelFunc: gatherNd }; class GatherProgram { constructor(aShape, outputShape) { this.variableNames = ['A', 'indices']; this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const sourceCoords = getSourceCoords$1(aShape); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); int index = int(getIndices(resRC.x, resRC.z)); float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0; setOutput(inBounds * getA(${sourceCoords})); } `; } } function getSourceCoords$1(aShape, axis) { const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; const sourceCoords = []; for (let i = 0; i < aShape.length; i++) { if (i === 2) { sourceCoords.push('index'); } else { sourceCoords.push(`${currentCoords[i]}`); } } return sourceCoords.join(); } function gatherV2(args) { const { inputs, backend, attrs } = args; const { x, indices } = inputs; const { axis, batchDims } = attrs; const parsedAxis = parseAxisParam(axis, x.shape)[0]; if (env().get('DEBUG')) { const indicesVals = backend.readSync(indices.dataId); const axisDim = x.shape[parsedAxis]; for (let i = 0; i < indicesVals.length; ++i) { const index = indicesVals[i]; assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`); } } const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims); const indicesSize = sizeFromShape(indices.shape); const toDispose = []; const flattenX = reshape({ inputs: { x }, backend, attrs: { shape: [ shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize, shapeInfo.sliceSize ] } }); const flattenIndex = reshape({ inputs: { x: indices }, backend, attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] } }); toDispose.push(flattenX); toDispose.push(flattenIndex); const flattenOutputShape = [ shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize, shapeInfo.sliceSize ]; if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') { const indicesBuf = backend.bufferSync(flattenIndex); const xBuf = backend.bufferSync(flattenX); const outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values); } const program = new GatherProgram(flattenX.shape, flattenOutputShape); const res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype); toDispose.push(res); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: shapeInfo.outputShape } }); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return reshaped; } const gatherV2Config = { kernelName: GatherV2, backendName: 'webgl', kernelFunc: gatherV2 }; const GREATER = `return float(a > b);`; const GREATER_PACKED = ` return vec4(greaterThan(a, b)); `; const greater = binaryKernelFunc({ opSnippet: GREATER, packedOpSnippet: GREATER_PACKED, cpuKernelImpl: greaterImplCPU, dtype: 'bool' }); const greaterConfig = { kernelName: Greater, backendName: 'webgl', kernelFunc: greater }; const GREATER_EQUAL = `return float(a >= b);`; const GREATER_EQUAL_PACKED = ` return vec4(greaterThanEqual(a, b)); `; const greaterEqual = binaryKernelFunc({ opSnippet: GREATER_EQUAL, packedOpSnippet: GREATER_EQUAL_PACKED, dtype: 'bool', cpuKernelImpl: greaterEqualImplCPU }); const greaterEqualConfig = { kernelName: GreaterEqual, backendName: 'webgl', kernelFunc: greaterEqual }; function ifft(args) { const { inputs, backend } = args; const { input } = inputs; return fftImpl(input, true , backend); } const ifftConfig = { kernelName: IFFT, backendName: 'webgl', kernelFunc: ifft }; const IS_FINITE = `return float(!isnan(x) && !isinf(x));`; const isFinite$1 = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' }); const isFiniteConfig = { kernelName: IsFinite, backendName: 'webgl', kernelFunc: isFinite$1, }; const IS_INF = `return float(isinf(x));`; const isInf = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' }); const isInfConfig = { kernelName: IsInf, backendName: 'webgl', kernelFunc: isInf, }; const IS_NAN = `return float(isnan(x));`; const isNaN$1 = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' }); const isNaNConfig = { kernelName: IsNan, backendName: 'webgl', kernelFunc: isNaN$1, }; const LESS = `return float(a < b);`; const LESS_PACKED = ` return vec4(lessThan(a, b)); `; const less = binaryKernelFunc({ opSnippet: LESS, packedOpSnippet: LESS_PACKED, cpuKernelImpl: lessImplCPU, dtype: 'bool' }); const lessConfig = { kernelName: Less, backendName: 'webgl', kernelFunc: less }; const LESS_EQUAL = `return float(a <= b);`; const LESS_EQUAL_PACKED = ` return vec4(lessThanEqual(a, b)); `; const lessEqual = binaryKernelFunc({ opSnippet: LESS_EQUAL, packedOpSnippet: LESS_EQUAL_PACKED, cpuKernelImpl: lessEqualImplCPU, dtype: 'bool' }); const lessEqualConfig = { kernelName: LessEqual, backendName: 'webgl', kernelFunc: lessEqual }; function linSpace(args) { const { backend, attrs } = args; const { start, stop, num } = attrs; const outVals = linSpaceImplCPU(start, stop, num); return backend.makeTensorInfo([outVals.length], 'float32', outVals); } const linSpaceConfig = { kernelName: LinSpace, backendName: 'webgl', kernelFunc: linSpace }; const LOG = CHECK_NAN_SNIPPET_UNARY + ` return x < 0.0 ? 0./0. : log(x); `; const LOG_PACKED = ` vec4 result = log(x); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r); result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g); result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b); result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a); return result; `; const log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU }); const logConfig = { kernelName: Log, backendName: 'webgl', kernelFunc: log }; const LOG1P = CHECK_NAN_SNIPPET_UNARY + ` return log(1.0 + x); `; const log1p = unaryKernelFunc({ opSnippet: LOG1P }); const log1pConfig = { kernelName: Log1p, backendName: 'webgl', kernelFunc: log1p, }; const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`; const LOGICAL_AND_PACKED = ` return vec4( vec4(greaterThanEqual(a, vec4(1.0))) * vec4(greaterThanEqual(b, vec4(1.0)))); `; const logicalAnd = binaryKernelFunc({ opSnippet: LOGICAL_AND, packedOpSnippet: LOGICAL_AND_PACKED, dtype: 'bool' }); const logicalAndConfig = { kernelName: LogicalAnd, backendName: 'webgl', kernelFunc: logicalAnd }; const LOGICAL_NOT = `return float(!(x >= 1.0));`; const logicalNot = unaryKernelFunc({ opSnippet: LOGICAL_NOT }); const logicalNotConfig = { kernelName: LogicalNot, backendName: 'webgl', kernelFunc: logicalNot, }; const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`; const LOGICAL_OR_PACKED = ` return min( vec4(greaterThanEqual(a, vec4(1.0))) + vec4(greaterThanEqual(b, vec4(1.0))), vec4(1.0)); `; const logicalOr = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' }); const logicalOrConfig = { kernelName: LogicalOr, backendName: 'webgl', kernelFunc: logicalOr }; class LRNProgram { constructor(xShape, radius, bias, alpha, beta) { this.variableNames = ['x']; this.outputShape = []; const rad = radius; const maxD = xShape[3] - 1; this.outputShape = xShape; let powOperator; const basis = `float(${bias}) + float(${alpha}) * sum`; if (beta === 0.5) { powOperator = `inversesqrt(${basis})`; } else if (beta === 1.0) { powOperator = `1.0/(${basis})`; } else { powOperator = `exp(log(${basis}) * float(-${beta}));`; } this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int r = coords[1]; int c = coords[2]; int d = coords[3]; float x = getX(b, r, c, d); float sum = 0.0; for (int j = -${rad}; j <= ${rad}; j++) { int idx = d + j; if (idx >= 0 && idx <= ${maxD}) { float z = getX(b, r, c, idx); sum += z * z; } } float val = x * ${powOperator}; setOutput(val); } `; } } class LRNPackedProgram { constructor(xShape, radius, bias, alpha, beta) { this.variableNames = ['x']; this.outputShape = []; this.packedInputs = true; this.packedOutput = true; const rad = radius; const maxD = xShape[3] - 1; this.outputShape = xShape; let powOperator; const basis = `float(${bias}) + float(${alpha}) * sum`; if (beta === 0.5) { powOperator = `inversesqrt(${basis})`; } else if (beta === 1.0) { powOperator = `1.0/(${basis})`; } else { powOperator = `exp(log(${basis}) * float(-${beta}));`; } this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords.x; int r = coords.y; int c = coords.z; int d = coords.w; bool hasNextCol = d < ${this.outputShape[3]}; bool hasNextRow = c < ${this.outputShape[2]}; vec4 sum = vec4(0.); vec4 xFragAtOutputCoords = getX(b, r, c, d); vec4 xAtOutputCoords = vec4( getChannel(xFragAtOutputCoords, vec2(c, d)), hasNextCol ? getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0, hasNextRow ? getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0, (hasNextRow && hasNextCol) ? getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0 ); int firstChannel = d - ${rad}; vec2 cache = vec2(0.); if(firstChannel >= 0){ vec4 firstChannelFrag = getX(b, r, c, firstChannel); cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel)); if(hasNextRow){ cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel)); } } ivec2 depth = ivec2(d, d + 1); for (int j = - ${rad}; j <= ${rad}; j++) { ivec2 idx = depth + j; bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0)); bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD})); bool depthInRange = aboveLowerBound.x && belowUpperBound.x; bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y; if(depthInRange || depthPlusOneInRange){ vec4 z = vec4(0.); vec4 xFragAtCurrentDepth; z.xz = cache.xy; if(depthPlusOneInRange && hasNextCol){ xFragAtCurrentDepth = idx.y != d ? getX(b, r, c, idx.y) : xFragAtOutputCoords; z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y)); if(hasNextRow){ z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y)); } } cache.xy = z.yw; sum += z * z; } } vec4 result = xAtOutputCoords * ${powOperator}; setOutput(result); } `; } } const lrn = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { depthRadius, bias, alpha, beta } = attrs; const program = env().getBool('WEBGL_PACK_NORMALIZATION') ? new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) : new LRNProgram(x.shape, depthRadius, bias, alpha, beta); return backend.runWebGLProgram(program, [x], x.dtype); }; const LRNConfig = { kernelName: LRN, backendName: 'webgl', kernelFunc: lrn }; class LRNGradProgram { constructor(inputShape, depthRadius, bias, alpha, beta) { this.variableNames = ['inputImage', 'outputImage', 'dy']; this.outputShape = []; this.outputShape = inputShape; this.depth = inputShape[3]; this.depthRadius = depthRadius; this.bias = bias; this.alpha = alpha; this.beta = beta; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int r = coords[1]; int c = coords[2]; float result = 0.0; for (int d = 0; d < ${this.depth}; ++d) { int depthBegin = int(max(0.0, float(d - ${depthRadius}))); int depthEnd = int(min(float(${this.depth}), float(d + ${depthRadius} + 1))); const int MIN_DEPTH_BEGIN = 0; const int MAX_DEPTH_END = ${this.depth}; float norm = 0.0; for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) { if (k < depthBegin){ continue; } else if (k >= depthBegin && k < depthEnd) { norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k); } else { break; } } norm = float(${alpha}) * norm + float(${bias}); for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){ if (k < depthBegin){ continue; } else if (k >= depthBegin && k < depthEnd){ float dyi = -2.0 * float(${alpha}) * float(${beta}) * getInputImage(b, r, c, k) * getOutputImage(b, r, c, d) / norm; if (k == d) { dyi += pow(norm, -1.0 * ${beta}); } if (k == coords[3]) { dyi *= getDy(b, r, c, d); result += dyi; } } else { break; } } } setOutput(result); } `; } } const lrnGrad = (args) => { const { inputs, backend, attrs } = args; const { x, y, dy } = inputs; const { depthRadius, bias, alpha, beta } = attrs; const program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta); return backend.runWebGLProgram(program, [x, y, dy], x.dtype); }; const LRNGradConfig = { kernelName: LRNGrad, backendName: 'webgl', kernelFunc: lrnGrad }; function maxImpl(x, reduceShape, outShape, backend) { const inSize = sizeFromShape(reduceShape); const xSize = sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend }); const reduced = reduce(reshapedInput, x.dtype, 'max', backend); const reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); return reshapedOutput; } function max(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { reductionIndices, keepDims } = attrs; const xRank = x.shape.length; const origAxes = parseAxisParam(reductionIndices, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); const maxInputIsTransposed = permutedAxes != null; const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]); let maxInput = x; if (maxInputIsTransposed) { if (shouldExecuteOnCPU) { const xTexData = backend.texData.get(maxInput.dataId); const values = xTexData.values; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); maxInput = backend.makeTensorInfo(newShape, x.dtype); const maxInputData = backend.texData.get(maxInput.dataId); maxInputData.values = maxInputValues; } else { maxInput = transposeImpl(x, permutedAxes, backend); } axes = getInnerMostAxes(axes.length, xRank); } assertAxesAreInnerMostDims('max', axes, xRank); const [maxOutShape, reduceShape] = computeOutAndReduceShapes(maxInput.shape, axes); let outShape = maxOutShape; if (keepDims) { outShape = expandShapeToKeepDim(maxOutShape, origAxes); } let out; if (shouldExecuteOnCPU) { const xTexData = backend.texData.get(maxInput.dataId); const values = xTexData.values; const outValues = maxImplCPU(values, sizeFromShape(reduceShape), outShape, x.dtype); out = backend.makeTensorInfo(outShape, x.dtype); const outData = backend.texData.get(out.dataId); outData.values = outValues; } else { out = maxImpl(maxInput, reduceShape, outShape, backend); } if (maxInputIsTransposed) { backend.disposeIntermediateTensorInfo(maxInput); } return out; } const maxConfig = { kernelName: Max, backendName: 'webgl', kernelFunc: max }; const MAXIMUM = CHECK_NAN_SNIPPET + ` return max(a, b); `; const MAXIMUM_PACKED = ` vec4 result = vec4(max(a, b)); bvec4 isNaNA = isnan(a); bvec4 isNaNB = isnan(b); bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const maximum = binaryKernelFunc({ opSnippet: MAXIMUM, packedOpSnippet: MAXIMUM_PACKED, cpuKernelImpl: maximumImplCPU }); const maximumConfig = { kernelName: Maximum, backendName: 'webgl', kernelFunc: maximum }; function maxPool(args) { const { inputs, backend, attrs } = args; const { x } = inputs; assertNotComplex(x, 'maxPool'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = 1; assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && arraysEqual(convInfo.inShape, convInfo.outShape)) { return identity({ inputs: { x }, backend }); } const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } const maxPoolConfig = { kernelName: MaxPool, backendName: 'webgl', kernelFunc: maxPool }; function maxPool3d(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { filterSize, strides, pad, dataFormat, dimRoundingMode } = attrs; const dilations = [1, 1, 1]; const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); const maxPoolProgram = new Pool3DProgram(convInfo, 'max', false); return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype); } const maxPool3DConfig = { kernelName: MaxPool3D, backendName: 'webgl', kernelFunc: maxPool3d }; class MaxPool2DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy', 'maxPos']; this.outputShape = convInfo.inShape; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationHeight = convInfo.dilationHeight; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1; this.userCode = ` const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 dyRCCorner = coords.yz - pads; int dyRCorner = dyRCCorner.x; int dyCCorner = dyRCCorner.y; float dotProd = 0.0; for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(b, idyR, idyC, d); int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d)); int curPosValue = wR * ${effectiveFilterWidth} + wC; float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0); dotProd += dyValue * mask; } } setOutput(dotProd); } `; } } class MaxPool3DBackpropProgram { constructor(convInfo) { this.variableNames = ['dy', 'maxPos']; this.outputShape = convInfo.inShape; const strideDepth = convInfo.strideDepth; const strideHeight = convInfo.strideHeight; const strideWidth = convInfo.strideWidth; const dilationDepth = convInfo.dilationDepth; const dilationHeight = convInfo.dilationHeight; const dilationWidth = convInfo.dilationWidth; const effectiveFilterDepth = convInfo.effectiveFilterDepth; const effectiveFilterHeight = convInfo.effectiveFilterHeight; const effectiveFilterWidth = convInfo.effectiveFilterWidth; const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1; this.userCode = ` const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft}); void main() { ivec5 coords = getOutputCoords(); int batch = coords.x; int ch = coords.u; ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads; int dyDCorner = dyCorner.x; int dyRCorner = dyCorner.y; int dyCCorner = dyCorner.z; float dotProd = 0.0; for (int wD = 0; wD < ${effectiveFilterDepth}; wD += ${dilationDepth}) { float dyD = float(dyDCorner + wD) / ${strideDepth}.0; if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) { continue; } int idyD = int(dyD); for (int wR = 0; wR < ${effectiveFilterHeight}; wR += ${dilationHeight}) { float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); for (int wC = 0; wC < ${effectiveFilterWidth}; wC += ${dilationWidth}) { float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 || fract(dyC) > 0.0) { continue; } int idyC = int(dyC); float dyValue = getDy(batch, idyD, idyR, idyC, ch); int maxPosValue = ${lastIndex} - int(getMaxPos(batch, idyD, idyR, idyC, ch)); int curPosValue = wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} + wR * ${effectiveFilterWidth} + wC; float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0); dotProd += dyValue * mask; } } } setOutput(dotProd); } `; } } function maxPool3DGrad(args) { const { inputs, backend, attrs } = args; const { dy, input } = inputs; const x = input; const { filterSize, strides, pad, dimRoundingMode } = attrs; const dilations = [1, 1, 1]; const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode); const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true ); const maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype); const maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo); const result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPool3dPositions); return result; } const maxPool3DGradConfig$1 = { kernelName: MaxPool3DGrad, backendName: 'webgl', kernelFunc: maxPool3DGrad }; function maxPoolGrad$1(args) { const { inputs, backend, attrs } = args; const { dy, input, output } = inputs; const x = input; assertNotComplex([input, output], 'maxPoolGrad'); const { filterSize, strides, pad, dimRoundingMode } = attrs; const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode); const getPositions = true; const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions); const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype); const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); backend.disposeIntermediateTensorInfo(maxPoolPositions); return result; } const maxPoolGradConfig$1 = { kernelName: MaxPoolGrad, backendName: 'webgl', kernelFunc: maxPoolGrad$1 }; function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) { let program = new Pool2DProgram(convInfo, 'max', false); const poolOutput = backend.runWebGLProgram(program, [x], 'float32'); program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex); const indexOutput = backend.runWebGLProgram(program, [x], 'float32'); return [poolOutput, indexOutput]; } const maxPoolWithArgmaxConfig = { kernelName: MaxPoolWithArgmax, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { x } = inputs; const { filterSize, strides, pad, includeBatchInIndex } = attrs; const webglBackend = backend; assert$1(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`); const dilations = [1, 1]; assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' + `Got strides ${strides} and dilations '${dilations}'`); const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad); const [result, indexes] = maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend); return [result, indexes]; } }; function meanImpl(x, reduceShape, outShape, backend) { const inSize = sizeFromShape(reduceShape); const xSize = sizeFromShape(x.shape); const batchSize = xSize / inSize; const reshapedInput = reshape({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend }); const reduced = reduce(reshapedInput, 'float32', 'mean', backend); const reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend }); backend.disposeIntermediateTensorInfo(reshapedInput); backend.disposeIntermediateTensorInfo(reduced); return reshapedOutput; } const meanConfig = { kernelName: Mean, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { x } = inputs; const { keepDims, axis } = attrs; const webglBackend = backend; const xRank = x.shape.length; const origAxes = parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); const meanInputIsTransposed = permutedAxes != null; const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]); const intermediates = []; let meanInput = x; if (meanInputIsTransposed) { if (shouldExecuteOnCPU) { const xTexData = webglBackend.texData.get(meanInput.dataId); const values = xTexData.values; const newShape = new Array(xRank); for (let i = 0; i < newShape.length; i++) { newShape[i] = x.shape[permutedAxes[i]]; } const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape); meanInput = webglBackend.makeTensorInfo(newShape, x.dtype); const meanInputData = webglBackend.texData.get(meanInput.dataId); meanInputData.values = meanInputValues; } else { meanInput = transposeImpl(x, permutedAxes, webglBackend); } intermediates.push(meanInput); axes = getInnerMostAxes(axes.length, xRank); } assertAxesAreInnerMostDims('sum', axes, xRank); const [meanOutShape, reduceShape] = computeOutAndReduceShapes(meanInput.shape, axes); let outShape = meanOutShape; if (keepDims) { outShape = expandShapeToKeepDim(meanOutShape, origAxes); } const out = meanImpl(meanInput, reduceShape, outShape, webglBackend); for (const i of intermediates) { webglBackend.disposeIntermediateTensorInfo(i); } return out; } }; function min(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const origAxes = parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = getInnerMostAxes(axes.length, x.shape.length); } assertAxesAreInnerMostDims('min', axes, xRank); const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes); const inSize = sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const reduced = reduce(a2D, a2D.dtype, 'min', backend); let res; if (keepDims) { const newShape = expandShapeToKeepDim(outShape, origAxes); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: newShape } }); } else { res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); } backend.disposeIntermediateTensorInfo(a2D); backend.disposeIntermediateTensorInfo(reduced); if (permutedAxes != null) { backend.disposeIntermediateTensorInfo(permutedX); } return res; } const minConfig = { kernelName: Min, backendName: 'webgl', kernelFunc: min }; const MINIMUM = CHECK_NAN_SNIPPET + ` return min(a, b); `; const MINIMUM_PACKED = ` vec4 result = vec4(min(a, b)); bvec4 isNaNA = isnan(a); bvec4 isNaNB = isnan(b); bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const minimum = binaryKernelFunc({ opSnippet: MINIMUM, packedOpSnippet: MINIMUM_PACKED, cpuKernelImpl: minimumImplCPU }); const minimumConfig = { kernelName: Minimum, backendName: 'webgl', kernelFunc: minimum }; class MirrorPadProgram { constructor(xShape, paddings, mode) { this.variableNames = ['x']; this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] ); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); const offset = mode === 'reflect' ? 0 : 1; if (rank === 1) { this.userCode = ` int start = ${start}; int end = ${end}; void main() { int outC = getOutputCoords(); if (outC < start) { outC = start * 2 - outC - ${offset}; } else if(outC >= end) { outC = (end - 1) * 2 - outC + ${offset}; } setOutput(getX(outC - start)); } `; return; } this.userCode = ` ${dtype} start = ${dtype}(${start}); ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outC = getOutputCoords(); for (int i = 0; i < ${rank}; i++) { if (outC[i] < start[i]) { outC[i] = start[i] * 2 - outC[i] - ${offset}; } else if(outC[i] >= end[i]) { outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset}; } } ${dtype} coords = outC - start; setOutput(getX(${unpackedCoords})); } `; } } class MirrorPadPackedProgram { constructor(xShape, paddings, mode) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] ); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const offset = mode === 'reflect' ? 0 : 1; let mainLoop = ''; if (rank === 1) { const padSetup = ` ${dtype} source = rc; if (source < start) { source = start * 2 - source - ${offset}; } else if (source >= end) { source = (end - 1) * 2 - source + ${offset}; } source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } `; } else { const padSetup = ` ${dtype} source = rc; ${dtype} lt = ${dtype}(lessThan(source, start)); ${dtype} gte = ${dtype}(greaterThanEqual(source, end)); ${dtype} orig = 1 - (lt + gte); source = orig * source + lt * (start * 2 - source - ${offset}) + gte * ((end - 1) * 2 - source + ${offset}); source -= start; `; mainLoop = ` ${dtype} rc = outputLoc; ${padSetup} result[0] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[1] = getChannel(getX(${source.join()}), ${innerDims}); } rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) { ${padSetup} result[2] = getChannel(getX(${source.join()}), ${innerDims}); ${coords[rank - 1]} += 1; if(${cLimit}) { ${padSetup} result[3] = getChannel(getX(${source.join()}), ${innerDims}); } } `; } this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } } const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => { const { x } = inputs; const { paddings, mode } = attrs; const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new MirrorPadPackedProgram(x.shape, paddings, mode) : new MirrorPadProgram(x.shape, paddings, mode); const output = backend.runWebGLProgram(program, [x], x.dtype); return output; }; const mirrorPadConfig = { kernelName: MirrorPad, backendName: 'webgl', kernelFunc: mirrorPadKernelFunc, }; const MOD = `if (b == 0.0) return NAN; return mod(a, b);`; const MOD_PACKED = ` vec4 result = mod(a, b); bvec4 isNaN = equal(b, vec4(0.0)); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const mod = binaryKernelFunc({ opSnippet: MOD, packedOpSnippet: MOD_PACKED, }); const modConfig = { kernelName: Mod, backendName: 'webgl', kernelFunc: mod }; class MultinomialProgram { constructor(batchSize, numOutcomes, numSamples) { this.variableNames = ['probs']; this.customUniforms = [{ name: 'seed', type: 'float' }]; this.outputShape = [batchSize, numSamples]; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; float r = random(seed); float cdf = 0.0; for (int i = 0; i < ${numOutcomes - 1}; i++) { cdf += getProbs(batch, i); if (r < cdf) { setOutput(float(i)); return; } } setOutput(float(${numOutcomes - 1})); } `; } } const DIV = ` if (a == b) { return 1.0; }; return a / b;`; const DIV_PACKED = ` vec4 result = a / b; if(a.x == b.x) { result.x = 1.; } if(a.y == b.y) { result.y = 1.; } if(a.z == b.z) { result.z = 1.; } if(a.w == b.w) { result.w = 1.; } return result; `; const realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true }); const realDivConfig = { kernelName: RealDiv, backendName: 'webgl', kernelFunc: realDiv, }; const SUB = 'return a - b;'; const sub = binaryKernelFunc({ opSnippet: SUB, packedOpSnippet: SUB, supportsComplex: true, cpuKernelImpl: subImplCPU }); const subConfig = { kernelName: Sub, backendName: 'webgl', kernelFunc: sub }; function softmax(args) { const { inputs, backend, attrs } = args; const { logits } = inputs; const { dim } = attrs; const axes = parseAxisParam([dim], logits.shape); const maxLogit = max({ inputs: { x: logits }, backend, attrs: { reductionIndices: axes, keepDims: false } }); const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes); const maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } }); const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend }); const b = exp({ inputs: { x: a }, backend }); const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } }); const sumExpReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } }); const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend }); backend.disposeIntermediateTensorInfo(maxLogit); backend.disposeIntermediateTensorInfo(maxLogitsReshaped); backend.disposeIntermediateTensorInfo(a); backend.disposeIntermediateTensorInfo(b); backend.disposeIntermediateTensorInfo(sumExp); backend.disposeIntermediateTensorInfo(sumExpReshaped); return res; } const softmaxConfig = { kernelName: Softmax$1, backendName: 'webgl', kernelFunc: softmax }; function multinomial(args) { const { inputs, backend, attrs } = args; const { logits } = inputs; const { numSamples, seed, normalized } = attrs; const probs = normalized ? logits : softmax({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 } }); const batchSize = probs.shape[0]; const numOutcomes = probs.shape[1]; const program = new MultinomialProgram(batchSize, numOutcomes, numSamples); const customValues = [[seed]]; const res = backend.runWebGLProgram(program, [probs], 'int32', customValues); if (!normalized) { backend.disposeIntermediateTensorInfo(probs); } return res; } const multinomialConfig = { kernelName: Multinomial, backendName: 'webgl', kernelFunc: multinomial }; const NEG = CHECK_NAN_SNIPPET$1 + ` return -x; `; const NEG_PACKED = ` vec4 result = -x; bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; function neg(args) { const { inputs, backend } = args; const { x } = inputs; if (backend.shouldExecuteOnCPU([x])) { const xData = backend.texData.get(x.dataId); const [outValues, newShape] = negImplCPU(xData.values, x.shape, x.dtype); return backend.makeTensorInfo(newShape, x.dtype, outValues); } let program; if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { program = new UnaryOpPackedProgram(x.shape, NEG_PACKED); } else { program = new UnaryOpProgram(x.shape, NEG); } return backend.runWebGLProgram(program, [x], x.dtype); } const negConfig = { kernelName: Neg, backendName: 'webgl', kernelFunc: neg }; const nonMaxSuppressionV3Impl = nonMaxSuppressionV3Impl$1; function nonMaxSuppressionV3(args) { warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const { inputs, backend, attrs } = args; const { boxes, scores } = inputs; const { maxOutputSize, iouThreshold, scoreThreshold } = attrs; const boxesVals = backend.readSync(boxes.dataId); const scoresVals = backend.readSync(scores.dataId); const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)); } const nonMaxSuppressionV3Config = { kernelName: NonMaxSuppressionV3, backendName: 'webgl', kernelFunc: nonMaxSuppressionV3 }; const nonMaxSuppressionV4Impl = nonMaxSuppressionV4Impl$1; function nonMaxSuppressionV4(args) { warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const { inputs, backend, attrs } = args; const { boxes, scores } = inputs; const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs; const boxesVals = backend.readSync(boxes.dataId); const scoresVals = backend.readSync(scores.dataId); const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize); return [ backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs])) ]; } const nonMaxSuppressionV4Config = { kernelName: NonMaxSuppressionV4, backendName: 'webgl', kernelFunc: nonMaxSuppressionV4 }; const nonMaxSuppressionV5Impl = nonMaxSuppressionV5Impl$1; function nonMaxSuppressionV5(args) { warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + 'Call tf.nonMaxSuppressionAsync() instead'); const { inputs, backend, attrs } = args; const { boxes, scores } = inputs; const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs; const boxesVals = backend.readSync(boxes.dataId); const scoresVals = backend.readSync(scores.dataId); const maxOutputSizeVal = maxOutputSize; const iouThresholdVal = iouThreshold; const scoreThresholdVal = scoreThreshold; const softNmsSigmaVal = softNmsSigma; const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal); return [ backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)), backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores)) ]; } const nonMaxSuppressionV5Config = { kernelName: NonMaxSuppressionV5, backendName: 'webgl', kernelFunc: nonMaxSuppressionV5 }; class OneHotProgram { constructor(numIndices, depth, onValue, offValue) { this.variableNames = ['indices']; this.outputShape = [numIndices, depth]; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int index = round(getIndices(coords.x)); setOutput(mix(float(${offValue}), float(${onValue}), float(index == coords.y))); } `; } } const oneHot = (args) => { const { inputs, backend, attrs } = args; const { indices } = inputs; const { dtype, depth, onValue, offValue } = attrs; const indicesSize = sizeFromShape(indices.shape); const program = new OneHotProgram(indicesSize, depth, onValue, offValue); const reshaped = reshape({ inputs: { x: indices }, backend, attrs: { shape: [indicesSize] } }); const result = backend.runWebGLProgram(program, [reshaped], dtype); backend.disposeIntermediateTensorInfo(reshaped); const outShape = [...indices.shape, depth]; const out = reshape({ inputs: { x: result }, backend, attrs: { shape: outShape } }); backend.disposeIntermediateTensorInfo(result); return out; }; const oneHotConfig = { kernelName: OneHot, backendName: 'webgl', kernelFunc: oneHot }; function zerosLike(args) { const { inputs, backend } = args; const { x } = inputs; if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const r = zerosLike({ inputs: { x: realPart }, backend }); const imagPart = imag({ inputs: { input: x }, backend }); const i = zerosLike({ inputs: { x: imagPart }, backend }); const result = complex({ inputs: { real: r, imag: i }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(r); backend.disposeIntermediateTensorInfo(imagPart); backend.disposeIntermediateTensorInfo(i); return result; } else { return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: x.dtype === 'string' ? '' : 0 }, backend }); } } const zerosLikeConfig = { kernelName: ZerosLike, backendName: 'webgl', kernelFunc: zerosLike }; function onesLike(args) { const { inputs, backend } = args; const { x } = inputs; if (x.dtype === 'string') { throw new Error('onesLike is not supported under string dtype'); } else if (x.dtype === 'complex64') { const realPart = real({ inputs: { input: x }, backend }); const r = onesLike({ inputs: { x: realPart }, backend }); const imagPart = imag({ inputs: { input: x }, backend }); const i = zerosLike({ inputs: { x: imagPart }, backend }); const result = complex({ inputs: { real: r, imag: i }, backend }); backend.disposeIntermediateTensorInfo(realPart); backend.disposeIntermediateTensorInfo(r); backend.disposeIntermediateTensorInfo(imagPart); backend.disposeIntermediateTensorInfo(i); return result; } else { return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend }); } } const onesLikeConfig = { kernelName: OnesLike, backendName: 'webgl', kernelFunc: onesLike }; function pack(args) { const { inputs, backend, attrs } = args; const { axis } = attrs; if (inputs.length === 1) { return expandDims$1({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } }); } const shape = inputs[0].shape; const dtype = inputs[0].dtype; inputs.forEach(t => { assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes'); assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes'); }); const intermediateTensorInfos = []; const expandedTensors = inputs.map(t => { const expandedT = expandDims$1({ inputs: { input: t }, backend, attrs: { dim: axis } }); intermediateTensorInfos.push(expandedT); return expandedT; }); const result = concat({ inputs: expandedTensors, backend, attrs: { axis } }); intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } const packConfig = { kernelName: Pack, backendName: 'webgl', kernelFunc: pack }; class PadProgram { constructor(xShape, paddings, constantValue) { this.variableNames = ['x']; this.customUniforms = [{ name: 'value', type: 'float' }]; this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] ); const rank = xShape.length; const type = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); if (rank === 1) { this.userCode = ` int start = ${start}; int end = ${end}; void main() { int outC = getOutputCoords(); if (outC < start || outC >= end) { setOutput(value); } else { setOutput(getX(outC - start)); } } `; return; } this.userCode = ` ${type} start = ${type}(${start}); ${type} end = ${type}(${end}); void main() { ${type} outC = getOutputCoords(); if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) { setOutput(value); } else { ${type} coords = outC - start; setOutput(getX(${unpackedCoords})); } } `; } } class PadPackedProgram { constructor(xShape, paddings, constantValue) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; this.customUniforms = [{ name: 'value', type: 'float' }]; this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] ); const rank = xShape.length; const dtype = getCoordsDataType(rank); const start = paddings.map(p => p[0]).join(','); const end = paddings.map((p, i) => p[0] + xShape[i]).join(','); const coords = getChannels('rc', rank); const source = getChannels('source', rank); const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`; const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`; const componentSetup = [ `${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1; if(${cLimit}) { `, rank === 1 ? '' : `} rc = outputLoc; ${coords[rank - 2]} += 1; if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`, rank === 1 ? '' : ` ${coords[rank - 1]} += 1; if(${cLimit}) {` ]; const paddingArea = rank === 1 ? 'rc < start || rc >= end' : 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))'; let mainLoop = ''; for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) { mainLoop += ` ${componentSetup[i]} if (${paddingArea}) { result[${i}] = float(value); } else { ${dtype} source = rc - start; result[${i}] = getChannel(getX(${source.join()}), ${innerDims}); } `; } mainLoop += (rank === 1 ? `} ` : `}}`); this.userCode = ` const ${dtype} start = ${dtype}(${start}); const ${dtype} end = ${dtype}(${end}); void main() { ${dtype} outputLoc = getOutputCoords(); vec4 result = vec4(0.); ${mainLoop} setOutput(result); } `; } } const padV2 = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { paddings, constantValue } = attrs; if (sizeFromShape(x.shape) === 0) { const outputShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] ); return fill({ backend, attrs: { shape: outputShape, value: constantValue, dtype: x.dtype } }); } const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new PadPackedProgram(x.shape, paddings, constantValue) : new PadProgram(x.shape, paddings, constantValue); const customValues = [[constantValue]]; return backend.runWebGLProgram(program, [x], x.dtype, customValues); }; const padV2Config = { kernelName: PadV2, backendName: 'webgl', kernelFunc: padV2 }; const POW = ` if(a < 0.0 && floor(b) < b){ return NAN; } if (b == 0.0) { return 1.0; } return (round(mod(b, 2.0)) != 1) ? pow(abs(a), b) : sign(a) * pow(abs(a), b); `; const POW_PACKED = ` vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1))); vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1); vec4 result = multiplier * pow(abs(a), b); bvec4 isExpZero = equal(b, vec4(0.0)); result.r = isExpZero.r ? 1.0 : result.r; result.g = isExpZero.g ? 1.0 : result.g; result.b = isExpZero.b ? 1.0 : result.b; result.a = isExpZero.a ? 1.0 : result.a; bvec4 isNaN1 = lessThan(a, vec4(0.0)); bvec4 isNaN2 = lessThan(floor(b), b); bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w); ` + CHECK_NAN_SNIPPET_PACKED + ` return result; `; const pow = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED }); const powConfig = { kernelName: Pow, backendName: 'webgl', kernelFunc: pow }; function prod(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { axis, keepDims } = attrs; const xRank = x.shape.length; const toDispose = []; const origAxes = parseAxisParam(axis, x.shape); let axes = origAxes; const permutedAxes = getAxesPermutation(axes, xRank); let permutedX = x; if (permutedAxes != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } }); axes = getInnerMostAxes(axes.length, xRank); toDispose.push(permutedX); } assertAxesAreInnerMostDims('prod', axes, xRank); let res; if (backend.shouldExecuteOnCPU([permutedX])) { const xVals = backend.texData.get(permutedX.dataId).values; const { outVals, outShape, outDtype } = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes); res = backend.makeTensorInfo(outShape, outDtype, outVals); } else { const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes); const inSize = sizeFromShape(reduceShape); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); const outputDType = sumOutType(x.dtype); const reduced = reduce(a2D, outputDType, 'prod', backend); res = reshape({ inputs: { x: reduced }, backend, attrs: { shape: outShape } }); toDispose.push(a2D); toDispose.push(reduced); } if (keepDims) { toDispose.push(res); const newShape = expandShapeToKeepDim(res.shape, origAxes); res = reshape({ inputs: { x: res }, backend, attrs: { shape: newShape } }); } toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return res; } const prodConfig = { kernelName: Prod, backendName: 'webgl', kernelFunc: prod }; function raggedGather(args) { const { inputs, backend, attrs } = args; const { paramsNestedSplits, paramsDenseValues, indices } = inputs; const { outputRaggedRank } = attrs; const $paramsNestedSplits = paramsNestedSplits.map(t => backend.readSync(t.dataId)); const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape); const $paramsDenseValues = backend.readSync(paramsDenseValues.dataId); const $indices = backend.readSync(indices.dataId); const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank); const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits)); const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues); return outputNestedSplitsTensors.concat([outputDenseValuesTensor]); } const raggedGatherConfig = { kernelName: RaggedGather, backendName: 'webgl', kernelFunc: raggedGather, }; function raggedRange(args) { const { inputs, backend } = args; const { starts, limits, deltas } = inputs; const $starts = backend.readSync(starts.dataId); const $limits = backend.readSync(limits.dataId); const $deltas = backend.readSync(deltas.dataId); const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape); const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData); const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData); return [rtNestedSplits, rtDenseValues]; } const raggedRangeConfig = { kernelName: RaggedRange, backendName: 'webgl', kernelFunc: raggedRange, }; function raggedTensorToTensor(args) { const { inputs, backend, attrs } = args; const { shape, values, defaultValue, rowPartitionTensors } = inputs; const { rowPartitionTypes } = attrs; const $shape = backend.readSync(shape.dataId); const $values = backend.readSync(values.dataId); const $defaultValue = backend.readSync(defaultValue.dataId); const $rowPartitionValues = rowPartitionTensors.map(t => backend.readSync(t.dataId)); const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape); const [outputShape, output] = raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes); return backend.makeTensorInfo(outputShape, values.dtype, output); } const raggedTensorToTensorConfig = { kernelName: RaggedTensorToTensor, backendName: 'webgl', kernelFunc: raggedTensorToTensor, }; const range$1 = (args) => { const { backend, attrs } = args; const { start, stop, step, dtype } = attrs; const values = rangeImplCPU(start, stop, step, dtype); return backend.makeTensorInfo([values.length], dtype, values); }; const rangeConfig = { kernelName: Range, backendName: 'webgl', kernelFunc: range$1 }; const RECIPROCAL = `return 1.0 / x;`; const reciprocal = unaryKernelFunc({ opSnippet: RECIPROCAL }); const reciprocalConfig = { kernelName: Reciprocal, backendName: 'webgl', kernelFunc: reciprocal, }; const RELU = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : x; `; const RELU_PACKED = ` vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const relu = unaryKernelFunc({ opSnippet: RELU, packedOpSnippet: RELU_PACKED }); const reluConfig = { kernelName: Relu$1, backendName: 'webgl', kernelFunc: relu }; const RELU6 = CHECK_NAN_SNIPPET$1 + ` return (x < 0.0) ? 0.0 : min(6.0, x); `; const RELU6_PACKED = ` vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0))); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const relu6 = unaryKernelFunc({ opSnippet: RELU6, packedOpSnippet: RELU6_PACKED }); const relu6Config = { kernelName: Relu6$1, backendName: 'webgl', kernelFunc: relu6 }; class ResizeBilinearProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` + ` - vec2(0.5)`; } else { sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec2 effectiveInputOverOutputRatioRC = vec2( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 yRC = coords.yz; vec2 sourceFracIndexRC = ${sourceFracIndexRC}; ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0))); ivec2 sourceCeilRC = ivec2( min(inputShapeRC - 1.0, ceil(sourceFracIndexRC))); float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d); float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d); float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d); float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d); vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC); float top = topLeft + (topRight - topLeft) * fracRC.y; float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y; float newValue = top + (bottom - top) * fracRC.x; setOutput(newValue); } `; } } class ResizeBilinearPackedProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `(vec3(yRC) + vec3(0.5)) * ` + `effectiveInputOverOutputRatioRC - vec3(0.5)`; } else { sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec3 effectiveInputOverOutputRatioRC = vec3( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0, ${oldWidth}.0); float getAValue(int b, int r, int c, int d) { return getChannel(getA(b, r, c, d), vec2(c, d)); } void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec3 yRC = coords.yzz + ivec3(0, 0, 1); vec3 sourceFracIndexRC = ${sourceFracIndexRC}; ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0))); ivec3 sourceCeilRC = ivec3( min(inputShapeRC - 1.0, ceil(sourceFracIndexRC))); bool hasNextCol = d < ${depth - 1}; bool hasNextRow = coords.z < ${newWidth - 1}; vec4 topLeft = vec4( getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d), hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0); vec4 bottomLeft = vec4( getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d), hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0); vec4 topRight = vec4( getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d), hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0); vec4 bottomRight = vec4( getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d), hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0); vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC); vec4 top = mix(topLeft, topRight, fracRC.yyzz); vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz); vec4 newValue = mix(top, bottom, fracRC.x); setOutput(newValue); } `; } } function resizeBilinear(args) { const { inputs, backend, attrs } = args; const { images } = inputs; const { alignCorners, halfPixelCenters, size } = attrs; const [newHeight, newWidth] = size; const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters); return backend.runWebGLProgram(program, [images], 'float32'); } const resizeBilinearConfig = { kernelName: ResizeBilinear, backendName: 'webgl', kernelFunc: resizeBilinear }; class ResizeBilinearBackpropProgram { constructor(dyShape, inputShape, alignCorners) { this.variableNames = ['dy']; this.outputShape = []; this.outputShape = inputShape; const [, xHeight, xWidth,] = inputShape; const [, yHeight, yWidth] = dyShape; const effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; const effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; const heightScale = effectiveXSize[0] / effectiveYSize[0]; const widthScale = effectiveXSize[1] / effectiveYSize[1]; const invHeightScale = 1 / heightScale; const invWidthScale = 1 / widthScale; const winHeight = (Math.ceil(invHeightScale) * 2) + 2; const winWidth = (Math.ceil(invWidthScale) * 2) + 2; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; int r = coords[1]; int c = coords[2]; float accumulator = 0.0; const float heightScale = float(${heightScale}); const float widthScale = float(${widthScale}); const float invHeightScale = float(${invHeightScale}); const float invWidthScale = float(${invWidthScale}); const int winHeight = int(${winHeight}); const int winWidth = int(${winWidth}); float startRLerp = floor(float(r) * invHeightScale); int startDyR = int(startRLerp - float(winHeight / 2)); float startCLerp = floor(float(c) * invWidthScale); int startDyC = int(startCLerp - float(winWidth / 2)); for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) { int dyR = dyROffset + startDyR; if (dyR < 0 || dyR >= ${yHeight}) { continue; } for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) { int dyC = dyCOffset + startDyC; if (dyC < 0 || dyC >= ${yWidth}) { continue; } float dxR = float(dyR) * heightScale; int topDxRIndex = int(floor(dxR)); int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0)); float dxRLerp = dxR - float(topDxRIndex); float inverseDxRLerp = 1.0 - dxRLerp; float dxC = float(dyC) * widthScale; int leftDxCIndex = int(floor(dxC)); int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0)); float dxCLerp = dxC - float(leftDxCIndex); float inverseDxCLerp = 1.0 - dxCLerp; if (r == topDxRIndex && c == leftDxCIndex) { accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp; } if (r == topDxRIndex && c == rightDxCIndex) { accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp; } if (r == bottomDxRIndex && c == leftDxCIndex) { accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp; } if (r == bottomDxRIndex && c == rightDxCIndex) { accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp; } } } setOutput(accumulator); } `; } } function resizeBilinearGrad(args) { const { inputs, backend, attrs } = args; const { images, dy } = inputs; const { alignCorners } = attrs; const program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners); return backend.runWebGLProgram(program, [dy], dy.dtype); } const resizeBilinearGradConfig$1 = { kernelName: ResizeBilinearGrad, backendName: 'webgl', kernelFunc: resizeBilinearGrad }; class ResizeNearestNeighborProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; const roundBase = alignCorners ? '0.5' : '0.0'; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` + `, vec2(0.0))`; } else { sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec2 effectiveInputOverOutputRatioRC = vec2( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0); void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec2 yRC = coords.yz; vec2 sourceFracIndexRC = ${sourceFracIndexRC}; ivec2 sourceNearestRC = ivec2( min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase}))); float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d); setOutput(newValue); } `; } } class ResizeNearestNeighborPackedProgram { constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) { this.variableNames = ['A']; this.packedInputs = true; this.packedOutput = true; this.outputShape = []; const [batch, oldHeight, oldWidth, depth] = inputShape; this.outputShape = [batch, newHeight, newWidth, depth]; const effectiveInSize = [ (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth ]; const effectiveOutSize = [ (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth ]; const roundBase = alignCorners ? '0.5' : '0.0'; let sourceFracIndexRC; if (halfPixelCenters) { sourceFracIndexRC = `max((vec3(yRC) + vec3(0.5)) * ` + `effectiveInputOverOutputRatioRC, vec3(0.0))`; } else { sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`; } this.userCode = ` const vec3 effectiveInputOverOutputRatioRC = vec3( ${effectiveInSize[0] / effectiveOutSize[0]}, ${effectiveInSize[1] / effectiveOutSize[1]}, ${effectiveInSize[1] / effectiveOutSize[1]}); const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0, ${oldWidth}.0); float getAValue(int b, int r, int c, int d) { return getChannel(getA(b, r, c, d), vec2(c, d)); } void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; ivec3 yRC = coords.yzz + ivec3(0, 0, 1); vec3 sourceFracIndexRC = ${sourceFracIndexRC}; ivec3 sourceNearestRC = ivec3( min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase}))); bool hasNextCol = d < ${depth - 1}; bool hasNextRow = coords.z < ${newWidth - 1}; vec4 newValue = vec4( getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d), hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1) : 0.0, hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d) : 0.0, (hasNextRow && hasNextCol) ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0); setOutput(newValue); } `; } } function resizeNearestNeighbor(args) { const { inputs, backend, attrs } = args; const { images } = inputs; const { alignCorners, halfPixelCenters, size } = attrs; const [newHeight, newWidth] = size; const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) : new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters); return backend.runWebGLProgram(program, [images], images.dtype); } const resizeNearestNeighborConfig = { kernelName: ResizeNearestNeighbor, backendName: 'webgl', kernelFunc: resizeNearestNeighbor }; class ResizeNearestNeigborBackpropProgram { constructor(dyShape, inputShape, alignCorners) { this.variableNames = ['dy']; this.outputShape = []; this.outputShape = inputShape; const [, xHeight, xWidth,] = inputShape; const [, yHeight, yWidth] = dyShape; const effectiveXSize = [ (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth ]; const effectiveYSize = [ (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth ]; const heightScale = effectiveXSize[0] / effectiveYSize[0]; const widthScale = effectiveXSize[1] / effectiveYSize[1]; const invHeightScale = 1 / heightScale; const invWidthScale = 1 / widthScale; const winHeight = (Math.ceil(invHeightScale) * 2) + 2; const winWidth = (Math.ceil(invWidthScale) * 2) + 2; this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int b = coords[0]; int d = coords[3]; int r = coords[1]; int c = coords[2]; float accumulator = 0.0; const float heightScale = float(${heightScale}); const float widthScale = float(${widthScale}); const float invHeightScale = float(${invHeightScale}); const float invWidthScale = float(${invWidthScale}); const int winHeight = int(${winHeight}); const int winWidth = int(${winWidth}); float startRLerp = floor(float(r) * invHeightScale); int startDyR = int(floor(startRLerp - float(winHeight / 2))); float startCLerp = floor(float(c) * invWidthScale); int startDyC = int(floor(startCLerp - float(winWidth / 2))); for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) { int dyR = dyROffset + startDyR; if (dyR < 0 || dyR >= ${yHeight}) { continue; } for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) { int dyC = dyCOffset + startDyC; if (dyC < 0 || dyC >= ${yWidth}) { continue; } float sourceFracRow = float(${effectiveXSize[0]}) * (float(dyR) / float(${effectiveYSize[0]})); float sourceFracCol = float(${effectiveXSize[1]}) * (float(dyC) / float(${effectiveYSize[1]})); int sourceNearestRow = int(min( float(int(${xHeight}) - 1), ${alignCorners} ? float(round(sourceFracRow)) : float(floor(sourceFracRow)))); int sourceNearestCol = int(min( float(int(${xWidth}) - 1), ${alignCorners} ? float(round(sourceFracCol)) : float(floor(sourceFracCol)))); if (r == sourceNearestRow && c == sourceNearestCol) { accumulator += getDy(b, dyR, dyC, d); } } } setOutput(accumulator); } `; } } function resizeNearestNeighborGrad(args) { const { inputs, backend, attrs } = args; const { images, dy } = inputs; const { alignCorners } = attrs; const program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners); return backend.runWebGLProgram(program, [dy], dy.dtype); } const resizeNearestNeighborGradConfig$1 = { kernelName: ResizeNearestNeighborGrad, backendName: 'webgl', kernelFunc: resizeNearestNeighborGrad }; class ReverseProgram { constructor(xShape, axis) { this.variableNames = ['x']; const rank = xShape.length; if (rank > 4) { throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`); } this.outputShape = xShape; if (rank === 1) { this.userCode = ` void main() { int coord = getOutputCoords(); setOutput(getX(${xShape[0]} - coord - 1)); } `; return; } const getInCoord = (i) => { if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { return `${xShape[i]} - coords[${i}] - 1`; } return `coords[${i}]`; }; const inCoords = xShape.map((_, i) => getInCoord(i)).join(','); const type = getCoordsDataType(rank); this.userCode = ` void main() { ${type} coords = getOutputCoords(); setOutput(getX(${inCoords})); } `; } } class ReversePackedProgram { constructor(xShape, axis) { this.variableNames = ['x']; this.packedInputs = true; this.packedOutput = true; const rank = xShape.length; if (rank > 4) { throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`); } this.outputShape = xShape; const channels = getChannels('rc', rank); const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`; const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`; const type = getCoordsDataType(rank); if (rank === 1) { this.userCode = ` void main(){ int rc = getOutputCoords(); vec4 result = vec4(0.); result.r = getChannel(getX(${xShape[0]} - rc - 1), ${xShape[0]} - rc - 1); if(${nextColumn}){ result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1), ${xShape[0]} - (rc + 1) - 1); } setOutput(result); } `; } else { this.userCode = ` void main() { ${type} rc = getOutputCoords(); vec4 result = vec4(0.); result.r = ${getR(channels.slice())}; if(${nextColumn}){ result.g = ${getG(channels.slice())}; } if(${nextRow}) { result.b = ${getB(channels.slice())}; if(${nextColumn}) { result.a = ${getA(channels.slice())}; } } setOutput(result); } `; } function getR(channels) { return getChannel(channels); } function getG(channels) { channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`; return getChannel(channels); } function getB(channels) { channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`; return getChannel(channels); } function getA(channels) { channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`; channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`; return getChannel(channels); } function getChannel(channels) { const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels)); const inCoords = inCoordsArray.join(','); const innerDims = inCoordsArray.slice(-2).join(','); return `getChannel(getX(${inCoords}), vec2(${innerDims}))`; } function getInCoord(i, channels1) { if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { return `${xShape[i]} - ${channels1[i]} - 1`; } else { return `${channels1[i]}`; } } } } function reverse(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { dims } = attrs; const xRank = x.shape.length; const $dims = parseAxisParam(dims, x.shape); if (xRank === 0) { return identity({ inputs: { x }, backend }); } const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new ReversePackedProgram(x.shape, $dims) : new ReverseProgram(x.shape, $dims); return backend.runWebGLProgram(program, [x], x.dtype); } const reverseConfig = { kernelName: Reverse, backendName: 'webgl', kernelFunc: reverse }; class RotateProgram { constructor(imageShape, fillValue) { this.variableNames = ['Image']; this.outputShape = []; this.customUniforms = [{ name: 'params', type: 'vec4' }]; const imageHeight = imageShape[1]; const imageWidth = imageShape[2]; this.outputShape = imageShape; let fillSnippet = ''; if (typeof fillValue === 'number') { fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`; } else { fillSnippet = ` vec3 fill = vec3(${fillValue.join(',')}); float outputValue = fill[coords[3]];`; } this.userCode = ` void main() { ivec4 coords = getOutputCoords(); int x = coords[2]; int y = coords[1]; float coordXFloat = (float(x) - params[0]) * params[3] - (float(y) - params[1]) * params[2]; float coordYFloat = (float(x) - params[0]) * params[2] + (float(y) - params[1]) * params[3]; int coordX = int(round(coordXFloat + params[0])); int coordY = int(round(coordYFloat + params[1])); ${fillSnippet} if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) { outputValue = getImage(coords[0], coordY, coordX, coords[3]); } setOutput(outputValue); } `; } } const rotateWithOffsetConfig = { kernelName: RotateWithOffset, backendName: 'webgl', kernelFunc: ({ inputs, attrs, backend }) => { const { image } = inputs; const { radians, fillValue, center } = attrs; const webglBackend = backend; const program = new RotateProgram(image.shape, fillValue); const [centerX, centerY] = getImageCenter(center, image.shape[1], image.shape[2]); const customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]]; const output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues); return output; } }; const ROUND = ` float base = floor(x); if ((x - base) < 0.5) { return floor(x); } else if ((x - base) > 0.5) { return ceil(x); } else { if (mod(base, 2.0) == 0.0) { return base; } else { return base + 1.0; } } `; const round = unaryKernelFunc({ opSnippet: ROUND }); const roundConfig = { kernelName: Round, backendName: 'webgl', kernelFunc: round, }; const RSQRT = `return inversesqrt(x);`; const rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU }); const rsqrtConfig = { kernelName: Rsqrt, backendName: 'webgl', kernelFunc: rsqrt }; class ScatterProgram { constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) { this.variableNames = ['updates', 'indices', 'defaultValue']; this.outputShape = shape; const stridesType = getCoordsDataType(strides.length); const dtype = getCoordsDataType(shape.length); let indicesString = ''; if (indicesRank === 1) { indicesString = 'i'; } else if (indicesRank === 2) { indicesString = 'i, j'; } const indicesSnippet = `getIndices(${indicesString})`; let updatesString = ''; if (updatesRank === 1) { updatesString = 'i'; } else if (updatesRank === 2) { updatesString = 'i, coords[1]'; } const updatesSnippet = `getUpdates(${updatesString})`; let defaultValuesString = ''; if (defaultIsTensor) { defaultValuesString = 'coords[0], coords[1]'; } const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`; const strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; this.userCode = ` ${stridesType} strides = ${stridesType}(${strides}); void main() { ${dtype} coords = getOutputCoords(); float sum = 0.0; bool found = false; for (int i = 0; i < ${updateSize}; i++) { int flattenedIndex = 0; for (int j = 0; j < ${sliceDim}; j++) { int index = round(${indicesSnippet}); flattenedIndex += index * ${strideString}; } if (flattenedIndex == coords[0]) { sum += ${updatesSnippet}; found = true; } } setOutput(mix(${defaultValueSnippet}, sum, float(found))); } `; } } class ScatterPackedProgram { constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) { this.variableNames = ['updates', 'indices', 'defaultValue']; this.packedInputs = true; this.packedOutput = true; this.outputShape = shape; const stridesType = getCoordsDataType(strides.length); const dtype = getCoordsDataType(shape.length); let indicesString = ''; if (indicesRank === 1) { indicesString = 'i'; } else if (indicesRank === 2) { indicesString = 'i, j'; } const indicesSnippet = `getIndices(${indicesString})`; let updatesString = ''; if (updatesRank === 1) { updatesString = 'i'; } else if (updatesRank === 2) { updatesString = 'i, coords[1]'; } const updatesSnippet = `getUpdates(${updatesString})`; let defaultValuesString = ''; if (defaultIsTensor) { defaultValuesString = 'coords[0], coords[1]'; } const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`; const strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; const strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides'; this.userCode = ` ${stridesType} strides = ${stridesType}(${strides}); void main() { ${dtype} coords = getOutputCoords(); vec4 sum = vec4(0.); vec4 found = vec4(0.); for (int i = 0; i < ${updateSize}; i+=2) { ivec2 flattenedIndex = ivec2(0); for (int j = 0; j < ${sliceDim}; j+=2) { ivec4 index = round(${indicesSnippet}); flattenedIndex += index.xz * ${strideString}; if (j + 1 < ${sliceDim}) { flattenedIndex += index.yw * ${strideString2}; } } if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] || flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) { vec4 updVals = ${updatesSnippet}; if (flattenedIndex[0] == coords[0]) { sum.xy += updVals.xy; found.xy = vec2(1.); } else if (flattenedIndex[0] == coords[0] + 1) { sum.zw += updVals.xy; found.zw = vec2(1.); } if (flattenedIndex[1] == coords[0]) { sum.xy += updVals.zw; found.xy = vec2(1.); } else if (flattenedIndex[1] == coords[0] + 1) { sum.zw += updVals.zw; found.zw = vec2(1.); } } } setOutput(mix(${defaultValueSnippet}, sum, found)); } `; } } function scatterNd(args) { const { inputs, backend, attrs } = args; const { indices, updates } = inputs; const { shape } = attrs; const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape); const flattenShape = [outputSize / sliceSize, sliceSize]; if (outputSize === 0) { return backend.makeTensorInfo(shape, indices.dtype); } const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } }); const flattenX = reshape({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } }); const defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); let program; if (env().getBool('WEBGL_PACK')) { program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape); } else { program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape); } const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(res); backend.disposeIntermediateTensorInfo(defaultValue); return reshaped; } const scatterNdConfig = { kernelName: ScatterNd, backendName: 'webgl', kernelFunc: scatterNd }; class SearchSortedProgram { constructor(batchSize, numInputs, numValues, side) { this.variableNames = ['sortedSequence', 'values']; this.customUniforms = [{ name: 'numInputs', type: 'int' }]; this.outputShape = [batchSize, numValues]; const webGL2LoopHead = 'while (left < right) {'; const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`; const loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead : webGL1LoopHead; const boundComparator = side === 'left' ? '<' : '<='; this.userCode = ` int findBound(int batch, float value) { int left = 0; int right = numInputs; int mid; ${loopHead} mid = (left + right) / 2; if (getSortedSequence(batch, mid) ${boundComparator} value) { left = mid + 1; } else { right = mid; } } return right; } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int valueIndex = coords[1]; float value = getValues(batch, valueIndex); setOutput(float(findBound(batch, value))); } `; } } function searchSorted(args) { const { inputs, backend, attrs } = args; const { sortedSequence, values } = inputs; const { side } = attrs; const program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side); const customValues = [[sortedSequence.shape[1]]]; return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues); } const searchSortedConfig = { kernelName: SearchSorted, backendName: 'webgl', kernelFunc: searchSorted, }; class SelectProgram { constructor(cRank, shape, rank) { this.variableNames = ['c', 'a', 'b']; this.outputShape = shape; let cCoords; let abCoords; if (rank > 4) { throw Error(`Where for rank ${rank} is not yet supported`); } if (rank === 1) { abCoords = `resRC`; cCoords = `resRC`; } else { const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; const cCoordVars = []; const abCoordVars = []; for (let i = 0; i < shape.length; i++) { abCoordVars.push(`${currentCoords[i]}`); if (i < cRank) { cCoordVars.push(`${currentCoords[i]}`); } } cCoords = cCoordVars.join(); abCoords = abCoordVars.join(); } const dtype = getCoordsDataType(rank); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); float cVal = getC(${cCoords}); if (cVal >= 1.0) { setOutput(getA(${abCoords})); } else { setOutput(getB(${abCoords})); } } `; } } function select(args) { const { inputs, backend } = args; const { condition, t, e } = inputs; const program = new SelectProgram(condition.shape.length, t.shape, t.shape.length); return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype)); } const selectConfig = { kernelName: Select, backendName: 'webgl', kernelFunc: select }; const SELU = ` float scaleAlpha = ${SELU_SCALEALPHA}; float scale = ${SELU_SCALE}; return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0); `; const selu = unaryKernelFunc({ opSnippet: SELU }); const seluConfig = { kernelName: Selu$1, backendName: 'webgl', kernelFunc: selu, }; const SIGMOID = CHECK_NAN_SNIPPET_UNARY + ` return 1.0 / (1.0 + exp(-1.0 * x)); `; const SIGMOID_PACKED = ` vec4 result = 1.0 / (1.0 + exp(-1.0 * x)); bvec4 isNaN = isnan(x); result.r = isNaN.r ? x.r : result.r; result.g = isNaN.g ? x.g : result.g; result.b = isNaN.b ? x.b : result.b; result.a = isNaN.a ? x.a : result.a; return result; `; const sigmoid = unaryKernelFunc({ opSnippet: SIGMOID, packedOpSnippet: SIGMOID_PACKED, cpuKernelImpl: sigmoidImplCPU }); const sigmoidConfig = { kernelName: Sigmoid$1, backendName: 'webgl', kernelFunc: sigmoid, }; const SIGN = ` if (isnan(x)) { return 0.0; } return sign(x); `; const sign = unaryKernelFunc({ opSnippet: SIGN }); const signConfig = { kernelName: Sign, backendName: 'webgl', kernelFunc: sign, }; const SIN = CHECK_NAN_SNIPPET_UNARY + ` return sin(x); `; const SIN_PACKED = ` vec4 result = sin(x); bvec4 isNaN = isnan(x); ${CHECK_NAN_SNIPPET_PACKED} return result; `; const sin = unaryKernelFunc({ opSnippet: SIN, packedOpSnippet: SIN_PACKED }); const sinConfig = { kernelName: Sin, backendName: 'webgl', kernelFunc: sin, }; const SINH = ` float e2x = exp(x); return (e2x - 1.0 / e2x) / 2.0; `; const sinh = unaryKernelFunc({ opSnippet: SINH }); const sinhConfig = { kernelName: Sinh, backendName: 'webgl', kernelFunc: sinh, }; const SOFTPLUS = ` float epsilon = 1.1920928955078125e-7; float threshold = log(epsilon) + 2.0; bool too_large = x > -threshold; bool too_small = x < threshold; float result; float exp_x = exp(x); if (too_large){ result = x; } else if (too_small){ result = exp_x; } else{ result = log(exp_x + 1.0); } return result; `; const softplus = unaryKernelFunc({ opSnippet: SOFTPLUS }); const softplusConfig = { kernelName: Softplus$1, backendName: 'webgl', kernelFunc: softplus, }; const spaceToBatchND = (args) => { const { inputs, backend, attrs } = args; const { x } = inputs; const { blockShape, paddings } = attrs; assert$1(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' + 'implemented yet'); const prod = blockShape.reduce((a, b) => a * b); const completePaddings = [[0, 0]]; completePaddings.push(...paddings); for (let i = 1 + blockShape.length; i < x.shape.length; ++i) { completePaddings.push([0, 0]); } const toDispose = []; const paddedX = padV2({ inputs: { x }, backend, attrs: { paddings: completePaddings, constantValue: 0 } }); const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false); const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false); const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false); const reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } }); const paddedXT = transpose({ inputs: { x: reshapedPaddedX }, backend, attrs: { perm: permutedReshapedPaddedPermutation } }); const result = reshape({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } }); toDispose.push(paddedX); toDispose.push(reshapedPaddedX); toDispose.push(paddedXT); toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; }; const spaceToBatchNDConfig = { kernelName: SpaceToBatchND, backendName: 'webgl', kernelFunc: spaceToBatchND }; function sparseFillEmptyRows(args) { const { inputs, backend } = args; const { indices, values, denseShape, defaultValue } = inputs; if (denseShape.shape.length !== 1) { throw new Error(`Dense shape must be a vector, saw: ${denseShape.shape}`); } if (indices.shape.length !== 2) { throw new Error(`Indices must be a matrix, saw: ${indices.shape}`); } if (values.shape.length !== 1) { throw new Error(`Values must be a vector, saw: ${values.shape}`); } if (defaultValue.shape.length !== 0) { throw new Error(`Default value must be a scalar, saw: ${defaultValue.shape}`); } const $indices = backend.readSync(indices.dataId); const $values = backend.readSync(values.dataId); const $denseShape = backend.readSync(denseShape.dataId); const $defaultValue = backend.readSync(defaultValue.dataId)[0]; const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue); return [ backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices), backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues), backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))), backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)), ]; } const sparseFillEmptyRowsConfig = { kernelName: SparseFillEmptyRows, backendName: 'webgl', kernelFunc: sparseFillEmptyRows, }; function sparseReshape(args) { const { inputs, backend } = args; const { inputIndices, inputShape, newShape } = inputs; if (inputIndices.shape.length !== 2) { throw new Error(`Input indices should be a matrix but received shape ${inputIndices.shape}`); } if (inputShape.shape.length !== 1) { throw new Error(`Input shape should be a vector but received shape ${inputShape.shape}`); } if (newShape.shape.length !== 1) { throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`); } const $inputShape = Array.from(backend.readSync(inputShape.dataId)); const $inputIndices = backend.readSync(inputIndices.dataId); const targetShape = Array.from(backend.readSync(newShape.dataId)); const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape); return [ backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices), backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)), ]; } const sparseReshapeConfig = { kernelName: SparseReshape, backendName: 'webgl', kernelFunc: sparseReshape, }; function sparseSegmentMean(args) { const { inputs, backend } = args; const { data, indices, segmentIds } = inputs; if (data.shape.length < 1) { throw new Error(`Data should be at least 1 dimensional but received scalar`); } if (indices.shape.length !== 1) { throw new Error(`Indices should be a vector but received shape ${indices.shape}`); } if (segmentIds.shape.length !== 1) { throw new Error(`Segment ids should be a vector but received shape ${segmentIds.shape}`); } const $data = backend.readSync(data.dataId); const $indices = backend.readSync(indices.dataId); const $segmentIds = backend.readSync(segmentIds.dataId); const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true); return backend.makeTensorInfo(outputDataShape, data.dtype, outputData); } const sparseSegmentMeanConfig = { kernelName: SparseSegmentMean, backendName: 'webgl', kernelFunc: sparseSegmentMean, }; function sparseSegmentSum(args) { const { inputs, backend } = args; const { data, indices, segmentIds } = inputs; if (data.shape.length < 1) { throw new Error(`Data should be at least 1 dimensional but received scalar`); } if (indices.shape.length !== 1) { throw new Error(`Indices should be a vector but received shape ${indices.shape}`); } if (segmentIds.shape.length !== 1) { throw new Error(`Segment ids should be a vector but received shape ${segmentIds.shape}`); } const $data = backend.readSync(data.dataId); const $indices = backend.readSync(indices.dataId); const $segmentIds = backend.readSync(segmentIds.dataId); const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds); return backend.makeTensorInfo(outputDataShape, data.dtype, outputData); } const sparseSegmentSumConfig = { kernelName: SparseSegmentSum, backendName: 'webgl', kernelFunc: sparseSegmentSum, }; function sparseToDense(args) { const { inputs, backend, attrs } = args; const { sparseIndices, sparseValues, defaultValue } = inputs; const { outputShape } = attrs; const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape); const sumDupeIndices = false; if (sparseValues.dtype === 'string') { const indicesBuf = backend.bufferSync(sparseIndices); const updatesBuf = backend.bufferSync(sparseValues); const $defaultValue = decodeString(backend.readSync(defaultValue.dataId)[0]); const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values); } const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices); const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: outputShape } }); backend.disposeIntermediateTensorInfo(res); return reshaped; } const sparseToDenseConfig = { kernelName: SparseToDense, backendName: 'webgl', kernelFunc: sparseToDense }; function splitV(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { numOrSizeSplits, axis } = attrs; const $axis = parseAxisParam(axis, x.shape)[0]; const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis); const xRank = x.shape.length; const begin = new Array(xRank).fill(0); const size = x.shape.slice(); return splitSizes.map(s => { const sliceSize = [...size]; sliceSize[$axis] = s; const sliceT = slice({ inputs: { x }, backend, attrs: { begin, size: sliceSize } }); begin[$axis] += s; return sliceT; }); } const splitVConfig = { kernelName: SplitV, backendName: 'webgl', kernelFunc: splitV }; const SQRT = `return sqrt(x);`; const sqrt = unaryKernelFunc({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU }); const sqrtConfig = { kernelName: Sqrt, backendName: 'webgl', kernelFunc: sqrt }; const SQUARE = `return x * x;`; const square$1 = unaryKernelFunc({ opSnippet: SQUARE }); const squareConfig = { kernelName: Square, backendName: 'webgl', kernelFunc: square$1, }; const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; const squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE }); const squaredDifferenceConfig = { kernelName: SquaredDifference, backendName: 'webgl', kernelFunc: squaredDifference, }; function staticRegexReplace(args) { const { inputs, backend, attrs } = args; const { x } = inputs; if (x.dtype !== 'string') { throw new Error('Input must be of datatype string'); } const $x = backend.readSync(x.dataId); const stringInput = fromUint8ToStringArray($x); const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs); return backend.makeTensorInfo(x.shape, 'string', output); } const staticRegexReplaceConfig = { kernelName: StaticRegexReplace, backendName: 'webgl', kernelFunc: staticRegexReplace, }; function step({ inputs, attrs, backend }) { const { x } = inputs; const opSnippet = CHECK_NAN_SNIPPET$1 + ` return x > 0.0 ? 1.0 : float(${attrs.alpha}); `; const program = new UnaryOpProgram(x.shape, opSnippet); return backend.runWebGLProgram(program, [x], x.dtype); } const stepConfig = { kernelName: Step, backendName: 'webgl', kernelFunc: step, }; class StridedSliceProgram { constructor(begin, strides, size) { this.variableNames = ['x']; this.outputShape = size; const rank = size.length; const inputDtype = getCoordsDataType(size.length); const dtype = getCoordsDataType(size.length); let newCoords = ''; if (rank === 1) { newCoords = 'coords * strides + begin'; } else { let outputAxis = 0; newCoords = size.map((_, i) => { outputAxis++; return size.length === 1 ? `coords * strides[${i}] + begin[${i}]` : `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`; }) .join(','); } this.userCode = ` ${inputDtype} begin = ${inputDtype}(${begin}); ${inputDtype} strides = ${inputDtype}(${strides}); void main() { ${dtype} coords = getOutputCoords(); setOutput(getX(${newCoords})); } `; } } function stridedSlice(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs; const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); let result; if (isIdentity) { result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } }); } else if (sliceDim0 || isSimpleSlice) { assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`); const size = computeOutShape$2($begin, $end, $strides); const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } }); result = reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } }); backend.disposeIntermediateTensorInfo(sliced); } else { const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]); if (shouldExecuteOnCPU) { const values = backend.readSync(x.dataId); const xBuf = buffer(x.shape, x.dtype, values); const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin); result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values); } else { const program = new StridedSliceProgram($begin, $strides, finalShapeSparse); result = backend.runWebGLProgram(program, [x], x.dtype); } } const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: finalShape } }); backend.disposeIntermediateTensorInfo(result); return resultReshaped; } const stridedSliceConfig = { kernelName: StridedSlice, backendName: 'webgl', kernelFunc: stridedSlice }; function stringNGrams(args) { const { inputs, backend, attrs } = args; const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs; const { data, dataSplits } = inputs; const $data = backend.readSync(data.dataId); const $dataSplits = backend.readSync(dataSplits.dataId); const [nGrams, nGramsSplits] = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences); return [ backend.makeTensorInfo([nGrams.length], 'string', nGrams), backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits), ]; } const stringNGramsConfig = { kernelName: StringNGrams, backendName: 'webgl', kernelFunc: stringNGrams, }; function stringSplit(args) { const { inputs, backend, attrs } = args; const { skipEmpty } = attrs; const { input, delimiter } = inputs; if (input.dtype !== 'string') { throw new Error('Input must be of datatype string'); } if (input.shape.length !== 1) { throw new Error(`Input must be a vector, got shape: ${input.shape}`); } if (delimiter.shape.length !== 0) { throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`); } const $input = backend.readSync(input.dataId); const $delimiter = backend.readSync(delimiter.dataId)[0]; const [indices, values, shape] = stringSplitImplCPU($input, $delimiter, skipEmpty); const outputSize = values.length; return [ backend.makeTensorInfo([outputSize, 2], 'int32', indices), backend.makeTensorInfo([outputSize], 'string', values), backend.makeTensorInfo([2], 'int32', new Int32Array(shape)) ]; } const stringSplitConfig = { kernelName: StringSplit, backendName: 'webgl', kernelFunc: stringSplit, }; function stringToHashBucketFast(args) { const { inputs, backend, attrs } = args; const { numBuckets } = attrs; const { input } = inputs; if (input.dtype !== 'string') { throw new Error('Input must be of datatype string'); } if (numBuckets <= 0) { throw new Error(`Number of buckets must be at least 1`); } const $input = backend.readSync(input.dataId); const output = stringToHashBucketFastImplCPU($input, numBuckets); return backend.makeTensorInfo(input.shape, 'int32', output); } const stringToHashBucketFastConfig = { kernelName: StringToHashBucketFast, backendName: 'webgl', kernelFunc: stringToHashBucketFast, }; const TAN = `return tan(x);`; const tan = unaryKernelFunc({ opSnippet: TAN }); const tanConfig = { kernelName: Tan, backendName: 'webgl', kernelFunc: tan, }; const TANH = ` float e2x = exp(-2.0 * abs(x)); return sign(x) * (1.0 - e2x) / (1.0 + e2x); `; const tanh = unaryKernelFunc({ opSnippet: TANH }); const tanhConfig = { kernelName: Tanh$1, backendName: 'webgl', kernelFunc: tanh, }; function tensorScatterUpdate(args) { const { inputs, backend} = args; const { tensor, indices, updates } = inputs; const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape); const flattenShape = [outputSize / sliceSize, sliceSize]; if (outputSize === 0) { return backend.makeTensorInfo(tensor.shape, indices.dtype); } const flattenIndices = reshape({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } }); const flattenX = reshape({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } }); const flattenTensor = reshape({ inputs: { x: tensor }, backend, attrs: { shape: flattenShape } }); const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true); const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype); const reshaped = reshape({ inputs: { x: res }, backend, attrs: { shape: tensor.shape } }); backend.disposeIntermediateTensorInfo(flattenIndices); backend.disposeIntermediateTensorInfo(flattenX); backend.disposeIntermediateTensorInfo(flattenTensor); backend.disposeIntermediateTensorInfo(res); return reshaped; } const tensorScatterUpdateConfig = { kernelName: TensorScatterUpdate, backendName: 'webgl', kernelFunc: tensorScatterUpdate }; class TileProgram { constructor(aShape, reps) { this.variableNames = ['A']; const outputShape = new Array(aShape.length); for (let i = 0; i < outputShape.length; i++) { outputShape[i] = aShape[i] * reps[i]; } this.outputShape = outputShape; this.rank = outputShape.length; const dtype = getCoordsDataType(this.rank); const sourceCoords = getSourceCoords(aShape); this.userCode = ` void main() { ${dtype} resRC = getOutputCoords(); setOutput(getA(${sourceCoords})); } `; } } function getSourceCoords(aShape) { const rank = aShape.length; if (rank > 5) { throw Error(`Tile for rank ${rank} is not yet supported`); } if (rank === 1) { return `imod(resRC, ${aShape[0]})`; } const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u']; const sourceCoords = []; for (let i = 0; i < aShape.length; i++) { sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`); } return sourceCoords.join(); } function tile$1(params) { const { inputs, backend, attrs } = params; const { x } = inputs; const { reps } = attrs; if (x.dtype === 'string' || x.shape.length > 5) { const data = backend.readSync(x.dataId); const value = x.dtype === 'string' ? data.map(d => decodeString(d)) : data; const buf = buffer(x.shape, x.dtype, value); const outBuf = tileImplCPU(buf, reps); return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values); } const program = new TileProgram(x.shape, reps); const output = backend.runWebGLProgram(program, [x], x.dtype); return output; } const tileConfig = { kernelName: Tile, backendName: 'webgl', kernelFunc: tile$1, }; class SwapProgram { constructor(shape) { this.variableNames = ['x', 'indices']; this.customUniforms = [ { name: 'n', type: 'int' }, { name: 'firstPass', type: 'int' }, { name: 'negativeInf', type: 'float' }, { name: 'dir', type: 'int' }, { name: 'inc', type: 'int' } ]; this.outputShape = shape; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int elemIdx = coords[1]; bool isFirstInPair = imod(elemIdx, 2 * inc) < inc; int i = isFirstInPair ? elemIdx : elemIdx - inc; int i0 = firstPass == 1 ? i : int(getIndices(batch, i)); int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc)); float x0 = i0 < n ? getX(batch, i0) : negativeInf; float x1 = i1 < n ? getX(batch, i1) : negativeInf; bool reverse = imod(elemIdx, 2 * dir) >= dir; bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0); if (reverse == isGreater) { int iTemp = i0; i0 = i1; i1 = iTemp; } if (isFirstInPair) { setOutput(float(i0)); } else { setOutput(float(i1)); } } `; } } class MergeProgram { constructor(shape) { this.variableNames = ['x', 'indices']; this.customUniforms = [ { name: 'n', type: 'int' }, { name: 'firstPass', type: 'int' }, { name: 'k', type: 'int' } ]; this.outputShape = shape; this.userCode = ` void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int elemIdx = coords[1]; int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k)); int i0 = firstPass == 1 ? i : int(getIndices(batch, i)); int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k)); float x0 = getX(batch, i0); float x1 = i1 < n ? getX(batch, i1) : x0; setOutput(x0 >= x1 ? float(i0) : float(i1)); } `; } } function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) { if (tensorInfo !== null) { backend.disposeIntermediateTensorInfo(tensorInfo); } } function roundUpToPow2(num) { let pow2 = 1; while (pow2 < num) { pow2 *= 2; } return pow2; } function topK(args) { const { inputs, backend, attrs } = args; const { x } = inputs; const { k, sorted } = attrs; const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD'); const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD'); const xShape = x.shape; const lastDim = xShape[xShape.length - 1]; if (backend.shouldExecuteOnCPU([x]) || lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD || k > TOPK_K_CPU_HANDOFF_THRESHOLD) { const xVals = backend.readSync(x.dataId); const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted); return [ backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values), backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values) ]; } if (k === 0) { xShape[xShape.length - 1] = 0; return [ backend.makeTensorInfo(xShape, x.dtype, []), backend.makeTensorInfo(xShape, 'int32', []) ]; } if (lastDim === 1 ) { return [ x, fill({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend }) ]; } const xtexData = backend.texData.get(x.dataId); const xIsPacked = xtexData !== null && xtexData.isPacked; const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x; const xSize = sizeFromShape(xShape); const batch = xSize / lastDim; const x2D = reshape({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend }); if (xIsPacked) { disposeIntermediateTensorInfoOrNull(backend, xUnPacked); } const kPow2 = roundUpToPow2(k); const lastDimPow2 = roundUpToPow2(lastDim); let indices = null; const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices]; const runSwap = (dir, inc, shape) => { const inputs = getInputs(); const program = new SwapProgram(shape); const fistPass = indices === null ? 1 : 0; const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]]; const prevIndices = indices; indices = backend.runWebGLProgram(program, inputs, 'int32', customValues); disposeIntermediateTensorInfoOrNull(backend, prevIndices); }; for (let len = 1; len < kPow2; len *= 2) { const dir = len * 2; for (let inc = len; inc >= 1; inc /= 2) { runSwap(dir, inc, [batch, lastDimPow2]); } } for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) { const inputs = getInputs(); const mergeProgram = new MergeProgram([batch, indicesSize / 2]); const firstPass = indices === null ? 1 : 0; const customValues = [[lastDim], [firstPass], [kPow2]]; const prevIndices = indices; indices = backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues); disposeIntermediateTensorInfoOrNull(backend, prevIndices); const len = kPow2 / 2; const dir = len * 2; for (let inc = len; inc >= 1; inc /= 2) { runSwap(dir, inc, indices.shape); } } let prevIndices = indices; indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } }); disposeIntermediateTensorInfoOrNull(backend, prevIndices); let values = gatherV2({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } }); disposeIntermediateTensorInfoOrNull(backend, x2D); const newShape = xShape.slice(0, -1); newShape.push(k); prevIndices = indices; indices = reshape({ inputs: { x: indices }, attrs: { shape: newShape }, backend }); disposeIntermediateTensorInfoOrNull(backend, prevIndices); const prevValues = values; values = reshape({ inputs: { x: values }, attrs: { shape: newShape }, backend }); disposeIntermediateTensorInfoOrNull(backend, prevValues); return [values, indices]; } const topKConfig = { kernelName: TopK, backendName: 'webgl', kernelFunc: topK }; class TransformProgram { constructor(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) { this.variableNames = ['Image', 'Transforms']; this.outputShape = outShape; const interpolationModeId = interpolation === 'nearest' ? 1 : 2; let fillModeId; switch (fillMode) { case 'constant': fillModeId = 1; break; case 'reflect': fillModeId = 2; break; case 'wrap': fillModeId = 3; break; case 'nearest': fillModeId = 4; break; default: fillModeId = 1; break; } this.userCode = ` float mapCoord(float outCoord, float len) { float inCoord = outCoord; if(${fillModeId} == 2) { if (inCoord < 0.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz2 = 2.0 * len; if (inCoord < sz2) { inCoord = sz2 * float(int(float(-inCoord / sz2))) + inCoord; } inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0; } } else if (inCoord > len - 1.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz2 = 2.0 * len; inCoord -= sz2 * float(int(float(inCoord / sz2))); if (inCoord >= len) { inCoord = sz2 - inCoord - 1.0; } } } return clamp(inCoord, 0.0, len - 1.0); } else if (${fillModeId} == 3) { if (inCoord < 0.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz = len - 1.0; inCoord += len * (float(int(float(-inCoord / sz))) + 1.0); } } else if (inCoord > len - 1.0) { if (len <= 1.0) { inCoord = 0.0; } else { float sz = len - 1.0; inCoord -= len * float(int(float(inCoord / sz))); } } return clamp(inCoord, 0.0, len - 1.0); } else if (${fillModeId} == 4) { return clamp(outCoord, 0.0, len - 1.0); } else { return outCoord; } } float readWithFillValue(int batch, int coordY, int coordX, int channel) { float outputValue; if (0 <= coordY && coordY < ${imageHeight} && 0 <= coordX && coordX < ${imageWidth}) { outputValue = getImage(batch, coordY, coordX, channel); } else { outputValue = float(${fillValue}); } return outputValue; } void main() { ivec4 coords = getOutputCoords(); float outputValue; int batch = coords[0]; int x = coords[2]; int y = coords[1]; int channel = coords[3]; float xf = float(x); float yf = float(y); float a1 = getTransforms(batch, 0); float a2 = getTransforms(batch, 1); float a3 = getTransforms(batch, 2); float b1 = getTransforms(batch, 3); float b2 = getTransforms(batch, 4); float b3 = getTransforms(batch, 5); float c1 = getTransforms(batch, 6); float c2 = getTransforms(batch, 7); float projection = c1 * xf + c2 * yf + 1.0; if (projection == 0.0) { outputValue = float(${fillValue}); } else { float inX = (a1 * xf + a2 * yf + a3) / projection; float inY = (b1 * xf + b2 * yf + b3) / projection; float mapX = mapCoord(inX, float(${imageWidth})); float mapY = mapCoord(inY, float(${imageHeight})); if (${interpolationModeId} == 1) { int coordY = int(round(mapY)); int coordX = int(round(mapX)); outputValue = readWithFillValue(batch, coordY, coordX, channel); } else { float yFloor = floor(mapY); float xFloor = floor(mapX); float yCeil = yFloor + 1.0; float xCeil = xFloor + 1.0; float valueYFloor = (xCeil - mapX) * readWithFillValue(batch, int(yFloor), int(xFloor), channel) + (mapX - xFloor) * readWithFillValue(batch, int(yFloor), int(xCeil), channel); float valueYCeil = (xCeil - mapX) * readWithFillValue(batch, int(yCeil), int(xFloor), channel) + (mapX - xFloor) * readWithFillValue(batch, int(yCeil), int(xCeil), channel); outputValue = (yCeil - mapY) * valueYFloor + (mapY - yFloor) * valueYCeil; } } setOutput(outputValue); } `; } } function transform(args) { const { inputs, backend, attrs } = args; const { image, transforms } = inputs; const { interpolation, fillMode, fillValue, outputShape } = attrs; const [batch, imageHeight, imageWidth, numChannels] = image.shape; const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth]; const outShape = [batch, outHeight, outWidth, numChannels]; const program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape); return backend.runWebGLProgram(program, [image, transforms], 'float32'); } const transformConfig = { kernelName: Transform, backendName: 'webgl', kernelFunc: transform }; function unique$1(args) { const { inputs, attrs, backend } = args; const { axis } = attrs; const { x } = inputs; assertNotComplex(x, 'unique'); console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded'); const values = backend.readSync(x.dataId); const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype); return [ backend.makeTensorInfo(outputShape, x.dtype, outputValues), backend.makeTensorInfo([indices.length], 'int32', indices), ]; } const uniqueConfig = { kernelName: Unique, backendName: 'webgl', kernelFunc: unique$1, }; function unpack(args) { const { inputs, backend, attrs } = args; const { value } = inputs; let { axis } = attrs; if (axis < 0) { axis += value.shape.length; } const x = value; const xRank = x.shape.length; const num = value.shape[axis]; const outShape = new Array(xRank - 1); let outIndex = 0; for (let i = 0; i < xRank; i++) { if (i !== axis) { outShape[outIndex++] = x.shape[i]; } } const toDispose = []; const begin = new Array(xRank).fill(0); const size = x.shape.slice(); size[axis] = 1; const res = new Array(num); for (let i = 0; i < res.length; i++) { begin[axis] = i; const sliced = slice({ inputs: { x }, backend, attrs: { begin, size } }); const reshaped = reshape({ inputs: { x: sliced }, backend, attrs: { shape: outShape } }); res[i] = reshaped; toDispose.push(sliced); } toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return res; } const unpackConfig = { kernelName: Unpack, backendName: 'webgl', kernelFunc: unpack }; class SegmentOpProgram { constructor(segOpInfo, segOpType) { this.variableNames = ['x', 'segmentIds']; const windowSize = segOpInfo.windowSize; const batchSize = segOpInfo.batchSize; const inSize = segOpInfo.inSize; const numSegments = segOpInfo.numSegments; const outSize = numSegments * Math.ceil(inSize / windowSize); this.outputShape = [batchSize, outSize]; const initializationValue = '0.0'; const returnValue = `sumValue`; const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; const windowSizeVec4Remainder = windowSize % 4; const updateSnippet = ` sumValue += dot(values, segFilter); `; let checkValueOutOfBounds = ''; if (inSize % windowSize > 0) { checkValueOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return initializationValue; } `; } let checkSegmentIdOutOfBounds = ''; if (inSize % windowSize > 0) { checkSegmentIdOutOfBounds = ` if (inIdx < 0 || inIdx >= ${inSize}) { return -1.0; } `; } this.userCode = ` const float initializationValue = ${initializationValue}; float getValue(int batch, int inIdx) { ${checkValueOutOfBounds} return getX(batch, inIdx); } float getSegmentIdAtIndex(int inIdx) { ${checkSegmentIdOutOfBounds} return getSegmentIds(inIdx); } void main() { ivec2 coords = getOutputCoords(); int batch = coords[0]; int outIdx = coords[1]; int inOffset = int(floor(float(outIdx) / float( ${numSegments})) * float(${windowSize})); int currentSeg = int(mod(float(outIdx), float(${numSegments}))); float sumValue = 0.0; for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) { int inIdx = inOffset + i; vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), getValue(batch, inIdx + 3) ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0 ); ${updateSnippet} } int inIdx = inOffset + ${windowSizeNearestVec4}; if (${windowSizeVec4Remainder === 1}) { vec4 values = vec4( getValue(batch, inIdx), initializationValue, initializationValue, initializationValue ); int inIdxSeg = int(getSegmentIdAtIndex(inIdx)); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, 0, 0, 0 ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 2}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), initializationValue, initializationValue ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, 0, 0 ); ${updateSnippet} } else if (${windowSizeVec4Remainder === 3}) { vec4 values = vec4( getValue(batch, inIdx), getValue(batch, inIdx + 1), getValue(batch, inIdx + 2), initializationValue ); vec4 segFilter = vec4( int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0, int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0, 0 ); ${updateSnippet} } setOutput(${returnValue}); } `; } } function unsortedSegmentSum(args) { const { inputs, backend, attrs } = args; const { x, segmentIds } = inputs; const { numSegments } = attrs; const xRank = x.shape.length; const toDispose = []; let axis = 0; const permutation = getAxesPermutation([axis], xRank); let permutedX = x; if (permutation != null) { permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } }); toDispose.push(permutedX); axis = getInnerMostAxes(1, xRank)[0]; } const outShape = computeOutShape(permutedX.shape, axis, numSegments); const inSize = sizeFromShape([permutedX.shape[axis]]); const a2D = reshape({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } }); toDispose.push(a2D); const outputDType = sumOutType(x.dtype); const segOpCompute = (x, segOpType, segmentIds, dtype, numSegments) => { const batchSize = x.shape[0]; const inSize = x.shape[1]; const windowSize = segOpComputeOptimalWindowSize(inSize, numSegments); const segOpInfo = { windowSize, inSize, batchSize, numSegments }; const program = new SegmentOpProgram(segOpInfo, segOpType); const output = backend.compileAndRun(program, [x, segmentIds], dtype); toDispose.push(output); if (output.shape[1] === numSegments) { return output; } const rangeInfo = range$1({ backend, attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' } }); const tileInfo = tile$1({ inputs: { x: rangeInfo }, backend, attrs: { reps: [inSize / windowSize] } }); toDispose.push(rangeInfo); toDispose.push(tileInfo); const result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments); return result; }; const segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments); const reshaped = reshape({ inputs: { x: segOpResult }, backend, attrs: { shape: outShape } }); let result = reshaped; if (permutation != null) { toDispose.push(reshaped); const perm = getUndoAxesPermutation(permutation); result = transpose({ inputs: { x: result }, backend, attrs: { perm } }); } toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t)); return result; } const unsortedSegmentSumConfig = { kernelName: UnsortedSegmentSum, backendName: 'webgl', kernelFunc: unsortedSegmentSum }; const kernelConfigs = [ _fusedMatMulConfig, absConfig, acosConfig, acoshConfig, addConfig, addNConfig, allConfig, anyConfig, argMaxConfig, argMinConfig, asinConfig, asinhConfig, atanConfig, atan2Config, atanhConfig, avgPoolConfig, avgPool3DConfig, avgPool3DGradConfig$1, avgPoolGradConfig$1, batchMatMulConfig, batchNormConfig, batchToSpaceNDConfig, bincountConfig, bitwiseAndConfig, broadcastArgsConfig, castConfig, ceilConfig, clipByValueConfig, complexConfig, complexAbsConfig, concatConfig, conv2DConfig, conv2DBackpropFilterConfig, conv2DBackpropInputConfig, conv3DConfig, conv3DBackpropFilterV2Config, conv3DBackpropInputConfig, cosConfig, coshConfig, cropAndResizeConfig, cumprodConfig, cumsumConfig, denseBincountConfig, depthToSpaceConfig, depthwiseConv2dNativeConfig, depthwiseConv2dNativeBackpropFilterConfig, depthwiseConv2dNativeBackpropInputConfig, diagConfig, dilation2DConfig, einsumConfig, eluConfig, eluGradConfig$1, equalConfig, erfConfig, expConfig, expandDimsConfig, expm1Config, fftConfig, fillConfig, flipLeftRightConfig, floorConfig, floorDivConfig, fromPixelsConfig, fusedConv2DConfig, fusedDepthwiseConv2DConfig, gatherNdConfig, gatherV2Config, greaterConfig, greaterEqualConfig, identityConfig, ifftConfig, imagConfig, isFiniteConfig, isInfConfig, isNaNConfig, leakyReluConfig, lessConfig, lessEqualConfig, linSpaceConfig, logConfig, log1pConfig, logicalAndConfig, logicalNotConfig, logicalOrConfig, LRNConfig, LRNGradConfig, maxConfig, maximumConfig, maxPoolConfig, maxPool3DConfig, maxPool3DGradConfig$1, maxPoolGradConfig$1, maxPoolWithArgmaxConfig, meanConfig, minConfig, minimumConfig, mirrorPadConfig, modConfig, multinomialConfig, multiplyConfig, negConfig, nonMaxSuppressionV3Config, nonMaxSuppressionV4Config, nonMaxSuppressionV5Config, notEqualConfig, oneHotConfig, onesLikeConfig, packConfig, padV2Config, powConfig, preluConfig, prodConfig, raggedGatherConfig, raggedRangeConfig, raggedTensorToTensorConfig, rangeConfig, realConfig, realDivConfig, reciprocalConfig, reluConfig, relu6Config, reshapeConfig, resizeBilinearConfig, resizeBilinearGradConfig$1, resizeNearestNeighborConfig, resizeNearestNeighborGradConfig$1, reverseConfig, rotateWithOffsetConfig, roundConfig, rsqrtConfig, scatterNdConfig, searchSortedConfig, selectConfig, seluConfig, sigmoidConfig, signConfig, sinConfig, sinhConfig, sliceConfig, softmaxConfig, softplusConfig, spaceToBatchNDConfig, sparseFillEmptyRowsConfig, sparseReshapeConfig, sparseSegmentMeanConfig, sparseSegmentSumConfig, sparseToDenseConfig, splitVConfig, sqrtConfig, squareConfig, squaredDifferenceConfig, staticRegexReplaceConfig, stepConfig, stridedSliceConfig, stringNGramsConfig, stringSplitConfig, stringToHashBucketFastConfig, subConfig, sumConfig, tanConfig, tanhConfig, tensorScatterUpdateConfig, tileConfig, topKConfig, transformConfig, transposeConfig, uniqueConfig, unpackConfig, unsortedSegmentSumConfig, zerosLikeConfig ]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig); } const absGradConfig = { kernelName: Abs, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(dy, step$1(cast$2(x, 'float32'), -1)) }; } }; const acosGradConfig = { kernelName: Acos, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => { const a = square$2(cast$2(x, 'float32')); const b = sqrt$1(sub$1(scalar(1), a)); return neg$1(div(dy, b)); } }; } }; const acoshGradConfig = { kernelName: Acosh, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => { const a = sqrt$1(sub$1(square$2(cast$2(x, 'float32')), 1)); return div(dy, a); } }; } }; const addGradConfig = { kernelName: Add, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { let res = dy; const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, a.shape); }; const derB = () => { let res = dy; const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, b.shape); }; return { a: derA, b: derB }; } }; const addNGradConfig = { kernelName: AddN, saveAllInputs: true, gradFunc: (dy, saved) => { const ders = {}; saved.forEach((_, i) => { ders[i] = () => dy.clone(); }); return ders; } }; const argMaxGradConfig = { kernelName: ArgMax, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => zerosLike$1(x) }; } }; const argMinGradConfig = { kernelName: ArgMin, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => zerosLike$1(x) }; } }; const asinGradConfig = { kernelName: Asin, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, sqrt$1(sub$1(scalar(1), square$2(cast$2(x, 'float32'))))) }; } }; const asinhGradConfig = { kernelName: Asinh, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => { const a = sqrt$1(add(scalar(1), square$2(cast$2(x, 'float32')))); return div(dy, a); } }; } }; const atan2GradConfig = { kernelName: Atan2, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { const d = add(square$2(a), square$2(b)); let res = mul(dy, div(b, d)); const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, a.shape); }; const derB = () => { const d = add(square$2(a), square$2(b)); let res = neg$1(mul(dy, div(a, d))); const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, b.shape); }; return { a: derA, b: derB }; } }; const atanGradConfig = { kernelName: Atan, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, add(square$2(cast$2(x, 'float32')), 1)) }; } }; const atanhGradConfig = { kernelName: Atanh, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, sub$1(scalar(1), square$2(cast$2(x, 'float32')))) }; } }; function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) { const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad'); const $input = convertToTensor(input, 'input', 'avgPool3dGrad'); let dy5D = $dy; let input5D = $input; let reshapedTo5D = false; if ($input.rank === 4) { reshapedTo5D = true; dy5D = reshape$1($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); input5D = reshape$1($input, [ 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] ]); } assert$1(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` + `${dy5D.rank}.`); assert$1(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` + `${input5D.rank}.`); checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode); const inputs = { dy: dy5D, input: input5D }; const attrs = { filterSize, strides, pad, dimRoundingMode }; const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs); if (reshapedTo5D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } const avgPool3dGrad = op({ avgPool3dGrad_ }); const avgPool3DGradConfig = { kernelName: AvgPool3D, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { filterSize, strides, pad, dimRoundingMode } = attrs; return { x: () => avgPool3dGrad(dy, x, filterSize, strides, pad, dimRoundingMode) }; } }; function avgPoolGrad_(dy, input, filterSize, strides, pad) { const $dy = convertToTensor(dy, 'dy', 'avgPoolGrad'); const $input = convertToTensor(input, 'input', 'avgPoolGrad'); assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`); let input4D = $input; let dy4D = $dy; let reshapedTo4D = false; if ($input.rank === 3) { reshapedTo4D = true; input4D = reshape$1($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]); dy4D = reshape$1($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]); } assert$1(dy4D.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ` + `${dy4D.rank}.`); assert$1(input4D.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ` + `${input4D.rank}.`); const inputs = { dy: dy4D, input: input4D }; const attrs = { filterSize, strides, pad }; const res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs); if (reshapedTo4D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3]]); } return res; } const avgPoolGrad = op({ avgPoolGrad_ }); const avgPoolGradConfig = { kernelName: AvgPool, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { filterSize, strides, pad } = attrs; return { x: () => avgPoolGrad(dy, x, filterSize, strides, pad) }; } }; const batchMatMulGradConfig = { kernelName: BatchMatMul, inputsToSave: ['a', 'b'], gradFunc: (dy, saved, attrs) => { const [a, b] = saved; const { transposeA, transposeB } = attrs; if (!transposeA && !transposeB) { return { a: () => matMul$1(dy, b, false, true), b: () => matMul$1(a, dy, true, false) }; } else if (!transposeA && transposeB) { return { a: () => matMul$1(dy, b, false, false), b: () => matMul$1(dy, a, true, false) }; } else if (transposeA && !transposeB) { return { a: () => matMul$1(b, dy, false, true), b: () => matMul$1(a, dy, false, false) }; } else { return { a: () => matMul$1(b, dy, true, true), b: () => matMul$1(dy, a, true, true) }; } } }; const batchToSpaceNDGradConfig = { kernelName: BatchToSpaceND, gradFunc: (dy, saved, attrs) => { const { blockShape, crops } = attrs; return { x: () => spaceToBatchND$1(dy, blockShape, crops) }; } }; const broadcastToGradConfig = { kernelName: BroadcastTo, gradFunc: (dy, saved, attrs) => { const broadCastToAttrs = attrs; const inputShape = broadCastToAttrs.inputShape; const outputShape = broadCastToAttrs.shape; const reps = Array.from(outputShape); for (let i = inputShape.length - 1; i >= 0; i--) { if (inputShape[i] === outputShape[i]) { reps[i] = 1; } else if (inputShape[i] !== 1) { throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`); } } const axes = []; for (let i = 0; i < reps.length; i++) { if (reps[i] > 1) { axes.push(i); } } return { x: () => sum$1(dy, axes, true ) }; } }; const castGradConfig = { kernelName: Cast, gradFunc: (dy) => { return { x: () => dy.clone() }; } }; const ceilGradConfig = { kernelName: Ceil, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const clipByValueGradConfig = { kernelName: ClipByValue, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { clipValueMin, clipValueMax } = attrs; return { x: () => where(logicalAnd$1(greaterEqual$1(x, clipValueMin), lessEqual$1(x, clipValueMax)), dy, zerosLike$1(dy)), }; } }; const complexAbsGradConfig = { kernelName: ComplexAbs, inputsToSave: ['x'], gradFunc: absGradConfig.gradFunc, }; const concatGradConfig = { kernelName: Concat, saveAllInputs: true, gradFunc: (dy, saved, attrs) => { const shapes = saved.map(t => t.shape); const { axis } = attrs; const $axis = parseAxisParam(axis, saved[0].shape)[0]; const sizeSplits = shapes.map(s => s[$axis]); const derTensors = split$1(dy, sizeSplits, $axis); return derTensors.map(t => () => t); } }; const conv2DGradConfig = { kernelName: Conv2D, inputsToSave: ['x', 'filter'], gradFunc: (dy, saved, attrs) => { const [x4D, $filter] = saved; const { dilations, strides, pad, dataFormat } = attrs; assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv2D: dilation rates greater than 1 ' + `are not yet supported in gradients. Got dilations '${dilations}'`); return { x: () => conv2DBackpropInput$1(x4D.shape, dy, $filter, strides, pad, dataFormat), filter: () => conv2DBackpropFilter$1(x4D, dy, $filter.shape, strides, pad, dataFormat) }; } }; const conv2DBackpropInputGradConfig = { kernelName: Conv2DBackpropInput, inputsToSave: ['dy', 'filter'], gradFunc: (ddx, saved, attrs) => { const [dy, filter] = saved; const { strides, pad, dataFormat, dimRoundingMode } = attrs; return { dy: () => conv2d$1(ddx, filter, strides, pad, dataFormat, 1 , dimRoundingMode), filter: () => conv2DBackpropFilter$1(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode) }; } }; function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) { let x5D = x; if (x.rank === 4) { x5D = reshape$1(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]); } let dy5D = dy; if (dy5D.rank === 4) { dy5D = reshape$1(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]); } assert$1(x5D.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` + `${x5D.shape}.`); assert$1(dy5D.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` + `${dy5D.shape}.`); assert$1(filterShape.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` + `${filterShape}.`); assert$1(x5D.shape[4] === filterShape[3], () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` + `match input depth in filter (${filterShape[3]}.`); assert$1(dy5D.shape[4] === filterShape[4], () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` + `match output depth for filter (${filterShape[4]}).`); const inputs = { x: x5D, dy: dy5D }; const attrs = { strides, pad, filterShape }; return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs); } const conv3DBackpropFilter = op({ conv3DBackpropFilter_ }); const conv3DGradConfig = { kernelName: Conv3D, inputsToSave: ['x', 'filter'], gradFunc: (dy, saved, attrs) => { const { dilations, strides, pad } = attrs; assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv3D: dilation rates greater than 1 are ' + `not yet supported in gradients. Got dilations '${dilations}'`); const [x5D, $filter] = saved; return { x: () => conv3DBackpropInput$1(x5D.shape, dy, $filter, strides, pad), filter: () => conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad) }; } }; const cosGradConfig = { kernelName: Cos, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(neg$1(sin$1(cast$2(x, 'float32'))), dy) }; } }; const coshGradConfig = { kernelName: Cosh, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(sinh$1(cast$2(x, 'float32')), dy) }; } }; const cumsumGradConfig = { kernelName: Cumsum, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { axis, exclusive, reverse } = attrs; return { x: () => { const permutation = getAxesPermutation([axis], x.rank); let out = cumsum$1(dy, axis, exclusive, !reverse); if (permutation != null) { out = transpose$1(out, permutation); } return out; } }; } }; const depthwiseConv2dNativeGradConfig = { kernelName: DepthwiseConv2dNative, inputsToSave: ['x', 'filter'], gradFunc: (dy, saved, attrs) => { const { dilations, strides, pad, dimRoundingMode } = attrs; const $dilations = dilations == null ? [1, 1] : dilations; assert$1(tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' + `greater than 1 are not yet supported. Got dilations ` + `'${$dilations}'`); const [x, filter] = saved; assert$1(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` + `rank 4, but got rank ${x.rank}.`); assert$1(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` + `rank 4, but got rank ${filter.rank}.`); assert$1(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` + `channels (${x.shape[3]}) must match the inChannels dimension ` + `in filter ${filter.shape[2]}.`); assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' + `dilations must be 1. Got strides ${strides} and dilations ` + `'${$dilations}'.`); checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode); return { x: () => depthwiseConv2dNativeBackpropInput$1(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode), filter: () => depthwiseConv2dNativeBackpropFilter$1(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode), }; } }; const dilation2dGradConfig = { kernelName: Dilation2D, inputsToSave: ['x', 'filter'], gradFunc: (dy, saved, attrs) => { const [x, filter] = saved; const inputInputs = { x, filter, dy }; const filterInputs = { x, filter, dy }; return { x: () => ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs), filter: () => ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs) }; } }; const eluGradConfig = { kernelName: Elu$1, outputsToSave: [true], gradFunc: (dy, saved) => { const [y] = saved; const inputs = { dy, y }; return { x: () => ENGINE.runKernel(EluGrad, inputs) }; } }; const erfGradConfig = { kernelName: Erf, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; const a = mul(exp$1(neg$1(square$2(x))), 2 / Math.sqrt(Math.PI)); return { x: () => mul(dy, a) }; } }; const expGradConfig = { kernelName: Exp, outputsToSave: [true], gradFunc: (dy, saved) => { const [y] = saved; return { x: () => mul(dy, y) }; } }; const expandDimsGradConfig = { kernelName: ExpandDims, inputsToSave: ['input'], gradFunc: (dy, saved) => { const [input] = saved; return { input: () => reshape$1(dy, input.shape) }; } }; const expm1GradConfig = { kernelName: Expm1, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(dy, exp$1(x)) }; } }; const floorGradConfig = { kernelName: Floor, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const floorDivGradConfig = { kernelName: FloorDiv, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { const res = div(dy, cast$2(b, 'float32')); const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum$1(res, reduceAxes), a.shape); } return res; }; const derB = () => { let res = mul(dy, cast$2(a, 'float32')); const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = reshape$1(sum$1(res, reduceAxes), b.shape); } const tmp = square$2(b); return neg$1(div(res, cast$2(tmp, 'float32'))); }; return { a: derA, b: derB }; } }; const fusedBatchNormGradConfig = { kernelName: FusedBatchNorm, inputsToSave: ['x', 'mean', 'variance', 'scale'], gradFunc: (dy, saved, attrs) => { const { varianceEpsilon } = attrs; const [x, mean, variance, scale] = saved; const scaleValue = scale == null ? scalar(1) : scale; const reductionAxes = getReductionAxes(mean.shape, x.shape); const tileShape = []; if (mean.rank === 1) { for (let i = 0; i < x.shape.length - 1; ++i) { tileShape.push(x.shape[i]); } tileShape.push(1); } const xMinusMean = sub$1(x, mean); const dyTimesScaleValue = mul(dy, scaleValue); const oneOverSqrtVariance = rsqrt$1(add(variance, scalar(varianceEpsilon))); const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5)); const derX = () => { if (mean.rank === 1) { return reshape$1(mul(mul(dy, tile$2(reshape$1(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape); } else { return reshape$1(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape); } }; const derMean = () => { let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue); if (mean.rank === 1) { meanDer = sum$1(meanDer, reductionAxes); } return reshape$1(meanDer, mean.shape); }; const derVariance = () => { let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue); if (mean.rank === 1) { varianceDer = sum$1(varianceDer, reductionAxes); } return reshape$1(varianceDer, mean.shape); }; const derScale = () => { const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance); let scaleDer = mul(dy, xMinusMean2TimesRsqrt); if (mean.rank === 1) { scaleDer = sum$1(scaleDer, reductionAxes); } return reshape$1(scaleDer, mean.shape); }; const derOffset = () => { let offsetDer = dy; if (mean.rank === 1) { offsetDer = sum$1(offsetDer, reductionAxes); } return reshape$1(offsetDer, mean.shape); }; return { x: derX, mean: derMean, variance: derVariance, scale: derScale, offset: derOffset }; } }; const gatherGradConfig = { kernelName: GatherV2, inputsToSave: ['x', 'indices'], gradFunc: (dy, saved, attrs) => { const [x, indices] = saved; const { axis, batchDims } = attrs; const parsedAxis = parseAxisParam(axis, x.shape)[0]; const derXBatch = (x, indices, dy) => { return () => { const paramsShape = x.shape; const indicesSize = indices.size; const outerShape = paramsShape.slice(0, parsedAxis); const outerDims = outerShape.length; const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1); const innerDims = innerShape.length; const outerAxesIndices = arrayRange(0, outerDims); const innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims); const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]); const values = reshape$1(dy, valuesShape); const reshapedIndices = reshape$1(indices, [indicesSize]); const transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]); const valuesTranspose = transpose$1(values, transposeDims); let paramsGrad = unsortedSegmentSum$1(valuesTranspose, reshapedIndices, x.shape[parsedAxis]); const invertTransposeDims = getUndoAxesPermutation(transposeDims); paramsGrad = transpose$1(paramsGrad, invertTransposeDims); return paramsGrad; }; }; if (batchDims === 1) { const batchSize = x.shape[0]; const xBatch = x.split(batchSize, 0); const derXBatched = () => { const stacked = stack(xBatch.map((x, i) => { return derXBatch(x, indices.slice(i, 1), dy.slice(i, 1))(); })); return stacked.reshape(x.shape); }; return { x: derXBatched, indices: () => indices }; } else { return { x: derXBatch(x, indices, dy), indices: () => indices }; } } }; function arrayRange(start, stop) { const result = []; for (let i = start; i < stop; ++i) { result.push(i); } return result; } function arrayConcat(arrays) { const result = []; for (let i = 0; i < arrays.length; ++i) { for (let j = 0; j < arrays[i].length; ++j) { result.push(arrays[i][j]); } } return result; } const greaterEqualGradConfig = { kernelName: GreaterEqual, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; return { a: () => zerosLike$1(a), b: () => zerosLike$1(b) }; } }; const identityGradConfig = { kernelName: Identity$1, gradFunc: (dy) => { return { x: () => cast$2(dy, 'float32') }; } }; const isFiniteGradConfig = { kernelName: IsFinite, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const isInfGradConfig = { kernelName: IsInf, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const isNanGradConfig = { kernelName: IsNan, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const leakyReluGradConfig = { kernelName: LeakyRelu, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { alpha } = attrs; const mask = greater$1(x, 0); return { x: () => where(mask, dy, mul(dy, alpha)) }; } }; const log1pGradConfig = { kernelName: Log1p, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, add(x, 1)) }; } }; const logGradConfig = { kernelName: Log, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, cast$2(x, 'float32')) }; } }; const logSoftmaxGradConfig = { kernelName: LogSoftmax$1, inputsToSave: [], outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const [value] = saved; const { axis } = attrs; return { logits: () => { const keepDims = true; const softmax = exp$1(value); return sub$1(dy, mul(sum$1(dy, axis, keepDims), softmax)); } }; } }; function localResponseNormalizationBackprop_(x, y, dy, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) { const inputs = { x, y, dy }; const attrs = { depthRadius, bias, alpha, beta }; return ENGINE.runKernel(LRNGrad, inputs, attrs); } const localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_ }); const lrnGradConfig = { kernelName: LRN, inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const [x, y] = saved; const { depthRadius, bias, alpha, beta } = attrs; return { x: () => localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta) }; } }; function gradForMinAndMax(dy, y, xOrig, origAxes) { if (y.rank < xOrig.rank) { y = reshape$1(y, expandShapeToKeepDim(y.shape, origAxes)); } if (dy.rank < xOrig.rank) { dy = reshape$1(dy, expandShapeToKeepDim(dy.shape, origAxes)); } return { x: () => { const dx = mul(dy, cast$2(equal$1(xOrig, y), dy.dtype)); return dx; } }; } const maxGradConfig = { kernelName: Max, inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const maxAttrs = attrs; const { reductionIndices } = maxAttrs; const x = saved[0]; const y = saved[1]; const origAxes = parseAxisParam(reductionIndices, x.shape); const maxGrad = gradForMinAndMax(dy, y, x, origAxes); return { x: () => { return maxGrad['x'](); } }; } }; const maximumGradConfig = { kernelName: Maximum, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const derA = () => mul(dy, cast$2(greaterEqual$1(a, b), 'float32')); const derB = () => mul(dy, cast$2(less$1(a, b), 'float32')); return { a: derA, b: derB }; } }; function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) { const $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad'); const $input = convertToTensor(input, 'input', 'maxPool3dGrad'); const $output = convertToTensor(output, 'output', 'maxPool3dGrad'); let dy5D = $dy; let input5D = $input; let output5D = $output; let reshapedTo5D = false; if ($input.rank === 4) { reshapedTo5D = true; dy5D = reshape$1($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]); input5D = reshape$1($input, [ 1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3] ]); output5D = reshape$1($output, [ 1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3] ]); } assert$1(dy5D.rank === 5, () => `Error in maxPool3dGrad: dy must be rank 5 but got rank ` + `${dy5D.rank}.`); assert$1(input5D.rank === 5, () => `Error in maxPool3dGrad: input must be rank 5 but got rank ` + `${input5D.rank}.`); assert$1(output5D.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ` + `${output5D.rank}.`); checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode); const inputs = { dy: dy5D, input: input5D, output: output5D }; const attrs = { filterSize, strides, pad, dimRoundingMode }; const res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs); if (reshapedTo5D) { return reshape$1(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]); } return res; } const maxPool3dGrad = op({ maxPool3dGrad_ }); const maxPool3DGradConfig = { kernelName: MaxPool3D, inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const [x, y] = saved; const { filterSize, strides, pad, dimRoundingMode } = attrs; return { x: () => maxPool3dGrad(dy, x, y, filterSize, strides, pad, dimRoundingMode) }; } }; function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) { const $dy = convertToTensor(dy, 'dy', 'maxPoolGrad'); const $input = convertToTensor(input, 'input', 'maxPoolGrad'); const $output = convertToTensor(output, 'output', 'maxPoolGrad'); assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy ` + `(${$dy.rank})`); assert$1($dy.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ` + `${$dy.rank}.`); assert$1($input.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ` + `${$input.rank}.`); checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode); const inputs = { dy: $dy, input: $input, output: $output }; const attrs = { filterSize, strides, pad, dimRoundingMode }; return ENGINE.runKernel(MaxPoolGrad, inputs, attrs); } const maxPoolGrad = op({ maxPoolGrad_ }); const maxPoolGradConfig = { kernelName: MaxPool, inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const [x, y] = saved; const { filterSize, strides, pad } = attrs; return { x: () => maxPoolGrad(dy, x, y, filterSize, strides, pad) }; } }; const meanGradConfig = { kernelName: Mean, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { axis } = attrs; const axes = parseAxisParam(axis, x.shape); const shapes = computeOutAndReduceShapes(x.shape, axes); const reduceShape = shapes[1]; const reduceSize = sizeFromShape(reduceShape); const derX = () => { const expandedDyShape = x.shape.slice(); axes.forEach(axis => { expandedDyShape[axis] = 1; }); const expandedDy = reshape$1(dy, expandedDyShape); const res = div(mul(expandedDy, ones(x.shape, 'float32')), reduceSize); return res; }; return { x: derX }; } }; const minGradConfig = { kernelName: Min, inputsToSave: ['x'], outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const minAttrs = attrs; const { axis } = minAttrs; const [x, y] = saved; const origAxes = parseAxisParam(axis, x.shape); const minGrad = gradForMinAndMax(dy, y, x, origAxes); return { x: () => { return minGrad['x'](); } }; } }; const minimumGradConfig = { kernelName: Minimum, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const derA = () => mul(dy, cast$2(lessEqual$1(a, b), 'float32')); const derB = () => mul(dy, cast$2(greater$1(a, b), 'float32')); return { a: derA, b: derB }; } }; const mirrorPadGradConfig = { kernelName: MirrorPad, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const x = saved[0]; const { paddings } = attrs; const begin = paddings.map(p => p[0]); return { x: () => slice$1(dy, begin, x.shape) }; } }; const modGradConfig = { kernelName: Mod, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum$1(dy, reduceAxes), a.shape); } return dy; }; const derB = () => { const res = mul(dy, neg$1(floor$1(div(a, b)))); const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum$1(res, reduceAxes), b.shape); } return res; }; return { a: derA, b: derB }; } }; const multiplyGradConfig = { kernelName: Multiply, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { const res = mul(dy, cast$2(b, 'float32')); const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum$1(res, reduceAxes), a.shape); } return res; }; const derB = () => { const res = mul(dy, cast$2(a, 'float32')); const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum$1(res, reduceAxes), b.shape); } return res; }; return { a: derA, b: derB }; } }; const negGradConfig = { kernelName: Neg, gradFunc: (dy) => { return { x: () => neg$1(dy) }; } }; const oneHotGradConfig = { kernelName: OneHot, inputsToSave: ['indices'], gradFunc: (dy, saved) => { const indices = saved[0]; return { indices: () => zeros(indices.shape, 'float32') }; } }; const onesLikeGradConfig = { kernelName: OnesLike, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const packGradConfig = { kernelName: Pack, saveAllInputs: true, gradFunc: (dy, saved, attrs) => { const { axis } = attrs; const derTensors = unstack(dy, axis); return derTensors.map(t => () => t); } }; const padV2GradConfig = { kernelName: PadV2, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const x = saved[0]; const { paddings } = attrs; const begin = paddings.map(p => p[0]); return { x: () => slice$1(dy, begin, x.shape) }; } }; const powGradConfig = { kernelName: Pow, inputsToSave: ['a', 'b'], outputsToSave: [true], gradFunc: (dy, saved) => { const [a, b, y] = saved; const base = a; const exp = b; const outShape = assertAndGetBroadcastShape(base.shape, exp.shape); const derBase = () => { const expFloat = cast$2(exp, 'float32'); let res = mul(dy, mul(expFloat, pow$1(base, sub$1(expFloat, scalar(1))))); const reduceAxes = getReductionAxes(base.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, base.shape); }; const derExp = () => { const condition = greater$1(base, 0); const logBase = where(condition, log$1(base), zerosLike$1(base)); let res = mul(dy, mul(y, logBase)); const reduceAxes = getReductionAxes(exp.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, exp.shape); }; return { a: derBase, b: derExp }; } }; const preluGradConfig = { kernelName: Prelu, inputsToSave: ['x', 'alpha'], gradFunc: (dy, saved) => { const [x, alpha] = saved; const mask = greater$1(x, 0); return { x: () => where(mask, dy, mul(dy, alpha)), alpha: () => { let res = where(mask, zerosLike$1(dy), mul(dy, x)); const reduceAxes = getReductionAxes(alpha.shape, dy.shape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, alpha.shape); } }; } }; function prodGradFn_(x, dy, axis) { const expandedYShape = x.shape.slice(); expandedYShape[axis] = 1; const expandedDy = reshape$1(dy, expandedYShape); const xCumProd = cumprod$1(x, axis, true, false); const xCumRevProd = cumprod$1(x, axis, true, true); const dx = mul(xCumProd, xCumRevProd); return mul(expandedDy, dx); } function prodsGradFn_(x, dy, axis) { const xRank = x.shape.length; const finalProdAxis = xRank - axis.length; const xPermutation = getAxesPermutation(axis, xRank); let permutedX = x; if (xPermutation != null) { permutedX = transpose$1(x, xPermutation); } const newShape = permutedX.shape.slice(); const removedShape = newShape.splice(xRank - axis.length, axis.length); const endPartShape = removedShape.reduce((p, c) => p * c, 1); newShape.push(endPartShape); const reshapedPermutedX = permutedX.reshape(newShape); let prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis); prodGrad = prodGrad.reshape(permutedX.shape); if (xPermutation != null) { const undoPermutation = getUndoAxesPermutation(xPermutation); prodGrad = transpose$1(prodGrad, undoPermutation); } return prodGrad; } const prodGradConfig = { kernelName: Prod, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { axis } = attrs; let axisArr = []; if (axis === undefined || axis === null) { axisArr = x.shape.map((_, i) => i); } else if (typeof axis === 'number') { axisArr = [axis]; } else { axisArr = axis; } return { x: () => prodsGradFn_(x, dy, axisArr) }; } }; const divGradConfig = { kernelName: RealDiv, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { const res = div(dy, cast$2(b, 'float32')); const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return reshape$1(sum$1(res, reduceAxes), a.shape); } return res; }; const derB = () => { let res = mul(dy, cast$2(a, 'float32')); const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = reshape$1(sum$1(res, reduceAxes), b.shape); } const tmp = square$2(b); return neg$1(div(res, cast$2(tmp, 'float32'))); }; return { a: derA, b: derB }; } }; const reciprocalGradConfig = { kernelName: Reciprocal, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, neg$1(square$2(x))) }; } }; const relu6GradConfig = { kernelName: Relu6$1, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; const mask = mul(lessEqual$1(x, 6), step$1(x)); return { x: () => mul(dy, cast$2(mask, 'float32')) }; } }; const reluGradConfig = { kernelName: Relu$1, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(dy, cast$2(step$1(x), 'float32')) }; } }; const reshapeGradConfig = { kernelName: Reshape$1, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => reshape$1(dy, x.shape) }; } }; const resizeBilinearGradConfig = { kernelName: ResizeBilinear, inputsToSave: ['images'], gradFunc: (dy, saved, attrs) => { const [images] = saved; const inputs = { dy, images }; const imagesDer = () => ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs); return { images: imagesDer }; } }; const resizeNearestNeighborGradConfig = { kernelName: ResizeNearestNeighbor, inputsToSave: ['images'], gradFunc: (dy, saved, attrs) => { const [images] = saved; const inputs = { dy, images }; const imagesDer = () => ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs); return { images: imagesDer }; } }; const reverseGradConfig = { kernelName: Reverse, gradFunc: (dy, saved, attrs) => { const { dims } = attrs; const axes = parseAxisParam(dims, dy.shape); return { x: () => reverse$1(dy, axes) }; } }; const roundGradConfig = { kernelName: Round, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const rsqrtGradConfig = { kernelName: Rsqrt, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => neg$1(div(dy, mul(pow$1(x, 1.5), 2))) }; } }; const selectGradConfig = { kernelName: Select, inputsToSave: ['condition'], gradFunc: (dy, saved) => { const [condition] = saved; return { condition: () => cast$2(zerosLike$1(condition), 'float32'), t: () => mul(dy, cast$2(condition, dy.dtype)), e: () => mul(dy, cast$2(logicalNot$1(condition), dy.dtype)) }; } }; const seluGradConfig = { kernelName: Selu$1, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => { const mask = greater$1(x, scalar(0)); const scaleAlpha = scalar(SELU_SCALEALPHA); const scale = scalar(SELU_SCALE); const greaterThanZeroDer = mul(dy, scale); const lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp$1(cast$2(x, 'float32'))); return where(mask, greaterThanZeroDer, lessEqualZeroDer); } }; } }; const sigmoidGradConfig = { kernelName: Sigmoid$1, outputsToSave: [true], gradFunc: (dy, saved) => { const [y] = saved; return { x: () => mul(dy, mul(y, sub$1(scalar(1), y))) }; } }; const signGradConfig = { kernelName: Sign, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const sinGradConfig = { kernelName: Sin, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(cos$1(cast$2(x, 'float32')), dy) }; } }; const sinhGradConfig = { kernelName: Sinh, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(cosh$1(cast$2(x, 'float32')), dy) }; } }; const sliceGradConfig = { kernelName: Slice, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { begin, size } = attrs; const inputShape = x.shape; const [begin_, size_] = parseSliceParams(x, begin, size); const paddings = []; for (let i = 0; i < dy.rank; i++) { paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]); } return { x: () => pad(dy, paddings) }; } }; const softmaxGradConfig = { kernelName: Softmax$1, outputsToSave: [true], gradFunc: (dy, saved, attrs) => { const [y] = saved; const { dim } = attrs; const keepDims = true; const dyTimesY = mul(dy, y); return { logits: () => sub$1(dyTimesY, mul(sum$1(dyTimesY, [dim], keepDims), y)) }; } }; const softplusGradConfig = { kernelName: Softplus$1, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(dy, sigmoid$1(x)) }; } }; const spaceToBatchNDGradConfig = { kernelName: SpaceToBatchND, gradFunc: (dy, saved, attrs) => { const { blockShape, paddings } = attrs; return { x: () => batchToSpaceND$1(dy, blockShape, paddings) }; } }; const splitVGradConfig = { kernelName: SplitV, gradFunc: (dy, saved, attrs) => { const { axis } = attrs; return { x: () => concat$1(dy, axis) }; } }; const sqrtGradConfig = { kernelName: Sqrt, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, mul(sqrt$1(cast$2(x, 'float32')), 2)) }; } }; const squareGradConfig = { kernelName: Square, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => mul(dy, mul(cast$2(x, 'float32'), 2)) }; } }; const squaredDifferenceGradConfig = { kernelName: SquaredDifference, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const two = scalar(2); const derA = () => mul(dy, mul(two, sub$1(a, b))); const derB = () => mul(dy, mul(two, sub$1(b, a))); return { a: derA, b: derB }; } }; const stepGradConfig = { kernelName: Step, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const subGradConfig = { kernelName: Sub, inputsToSave: ['a', 'b'], gradFunc: (dy, saved) => { const [a, b] = saved; const outShape = assertAndGetBroadcastShape(a.shape, b.shape); const derA = () => { let res = dy; const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(res, a.shape); }; const derB = () => { let res = dy; const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = sum$1(res, reduceAxes); } return reshape$1(neg$1(res), b.shape); }; return { a: derA, b: derB }; } }; const sumGradConfig = { kernelName: Sum, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const expandedDyShape = x.shape.slice(); const { axis } = attrs; const axes = parseAxisParam(axis, x.shape); axes.forEach(axis => { expandedDyShape[axis] = 1; }); const expandedDy = reshape$1(dy, expandedDyShape); const derX = mul(expandedDy, ones(x.shape, 'float32')); return { x: () => derX }; } }; const tanGradConfig = { kernelName: Tan, inputsToSave: ['x'], gradFunc: (dy, saved) => { const [x] = saved; return { x: () => div(dy, square$2(cos$1(x))) }; } }; const tanhGradConfig = { kernelName: Tanh$1, outputsToSave: [true], gradFunc: (dy, saved) => { const [y] = saved; return { x: () => mul(sub$1(scalar(1), square$2(y)), dy) }; } }; const tileGradConfig = { kernelName: Tile, inputsToSave: ['x'], gradFunc: (dy, saved, attrs) => { const [x] = saved; const { reps } = attrs; const derX = () => { let xGrad = zerosLike$1(x); if (x.rank === 1) { for (let i = 0; i < reps[0]; ++i) { xGrad = add(xGrad, slice$1(dy, [i * x.shape[0]], [x.shape[0]])); } } else if (x.rank === 2) { for (let i = 0; i < reps[0]; ++i) { for (let j = 0; j < reps[1]; ++j) { xGrad = add(xGrad, slice$1(dy, [i * x.shape[0], j * x.shape[1]], [ x.shape[0], x.shape[1] ])); } } } else if (x.rank === 3) { for (let i = 0; i < reps[0]; ++i) { for (let j = 0; j < reps[1]; ++j) { for (let k = 0; k < reps[2]; ++k) { xGrad = add(xGrad, slice$1(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]])); } } } } else if (x.rank === 4) { for (let i = 0; i < reps[0]; ++i) { for (let j = 0; j < reps[1]; ++j) { for (let k = 0; k < reps[2]; ++k) { for (let l = 0; l < reps[3]; ++l) { xGrad = add(xGrad, slice$1(dy, [ i * x.shape[0], j * x.shape[1], k * x.shape[2], l * x.shape[3] ], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]])); } } } } } else { throw new Error(`Gradient for tile operation is not implemented for rank-` + `${x.rank} tensors yet.`); } return xGrad; }; return { x: derX }; }, }; const transposeGradConfig = { kernelName: Transpose, gradFunc: (dy, saved, attrs) => { const transposeAttrs = attrs; const { perm } = transposeAttrs; const undoPerm = getUndoAxesPermutation(perm); return { x: () => transpose$1(dy, undoPerm) }; } }; const unpackGradConfig = { kernelName: Unpack, gradFunc: (dy, saved, attrs) => { const unpackAttrs = attrs; const { axis } = unpackAttrs; return { value: () => stack(dy, axis) }; } }; const unsortedSegmentSumGradConfig = { kernelName: UnsortedSegmentSum, inputsToSave: ['segmentIds'], gradFunc: (dy, saved) => { const [segmentIds] = saved; const derX = () => { return gatherDropNegatives(dy, segmentIds); }; return { x: derX }; } }; function gatherDropNegatives(x, indices) { const zeroClippedIndices = maximum$1(indices, zerosLike$1(indices)); const gathered = gather$1(x, zeroClippedIndices); let isPositive = greaterEqual$1(indices, scalar(0, 'int32')); const numIters = gathered.rank - isPositive.rank; for (let i = 0; i < numIters; ++i) { isPositive = expandDims$2(isPositive, i + 1); } isPositive = logicalAnd$1(isPositive, ones(gathered.shape, 'bool')); const zeroSlice = zerosLike$1(gathered); return where(isPositive, gathered, zeroSlice); } const zerosLikeGradConfig = { kernelName: ZerosLike, gradFunc: (dy) => { return { x: () => zerosLike$1(dy) }; } }; const gradConfigs = [ absGradConfig, acosGradConfig, acoshGradConfig, addGradConfig, addNGradConfig, argMaxGradConfig, argMinGradConfig, asinGradConfig, asinhGradConfig, atan2GradConfig, atanGradConfig, atanhGradConfig, avgPool3DGradConfig, avgPoolGradConfig, batchMatMulGradConfig, batchToSpaceNDGradConfig, broadcastToGradConfig, castGradConfig, ceilGradConfig, clipByValueGradConfig, complexAbsGradConfig, concatGradConfig, conv2DBackpropInputGradConfig, conv2DGradConfig, conv3DGradConfig, cosGradConfig, coshGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, dilation2dGradConfig, divGradConfig, eluGradConfig, erfGradConfig, expGradConfig, expandDimsGradConfig, expm1GradConfig, floorDivGradConfig, floorGradConfig, fusedBatchNormGradConfig, gatherGradConfig, greaterEqualGradConfig, identityGradConfig, isFiniteGradConfig, isInfGradConfig, isNanGradConfig, leakyReluGradConfig, log1pGradConfig, logGradConfig, logSoftmaxGradConfig, lrnGradConfig, maxGradConfig, maxGradConfig, maximumGradConfig, maxPool3DGradConfig, maxPoolGradConfig, meanGradConfig, minGradConfig, minimumGradConfig, mirrorPadGradConfig, modGradConfig, multiplyGradConfig, negGradConfig, oneHotGradConfig, onesLikeGradConfig, packGradConfig, padV2GradConfig, padV2GradConfig, powGradConfig, preluGradConfig, prodGradConfig, reciprocalGradConfig, relu6GradConfig, reluGradConfig, reshapeGradConfig, resizeBilinearGradConfig, resizeNearestNeighborGradConfig, reverseGradConfig, roundGradConfig, rsqrtGradConfig, selectGradConfig, seluGradConfig, sigmoidGradConfig, signGradConfig, sinGradConfig, sinhGradConfig, sliceGradConfig, softmaxGradConfig, softplusGradConfig, spaceToBatchNDGradConfig, spaceToBatchNDGradConfig, splitVGradConfig, splitVGradConfig, sqrtGradConfig, squaredDifferenceGradConfig, squareGradConfig, stepGradConfig, subGradConfig, sumGradConfig, tanGradConfig, tanhGradConfig, tileGradConfig, transposeGradConfig, unpackGradConfig, unsortedSegmentSumGradConfig, zerosLikeGradConfig ]; for (const gradientConfig of gradConfigs) { registerGradient(gradientConfig); } class AttributeError extends Error { constructor(message) { super(message); Object.setPrototypeOf(this, AttributeError.prototype); } } class RuntimeError extends Error { constructor(message) { super(message); Object.setPrototypeOf(this, RuntimeError.prototype); } } class ValueError extends Error { constructor(message) { super(message); Object.setPrototypeOf(this, ValueError.prototype); } } class NotImplementedError extends Error { constructor(message) { super(message); Object.setPrototypeOf(this, NotImplementedError.prototype); } } class AssertionError extends Error { constructor(message) { super(message); Object.setPrototypeOf(this, AssertionError.prototype); } } class LruCache { constructor(maxEntries) { this.maxEntries = maxEntries || 100; this.cache = new Map(); } get(key) { let entry; if (this.cache.has(key)) { entry = this.cache.get(key); this.cache.delete(key); this.cache.set(key, entry); } return entry; } put(key, value) { if (this.cache.has(key)) { this.cache.delete(key); } else if (this.cache.size >= this.maxEntries) { const keyToDelete = this.cache.keys().next().value; this.cache.delete(keyToDelete); } this.cache.set(key, value); } getMaxEntries() { return this.maxEntries; } setMaxEntries(maxEntries) { if (maxEntries < 0) { throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`); } if (this.maxEntries > maxEntries) { for (let i = 0; i < this.maxEntries - maxEntries; i++) { const keyToDelete = this.cache.keys().next().value; this.cache.delete(keyToDelete); } } this.maxEntries = maxEntries; } } function pyListRepeat(value, numValues) { if (Array.isArray(value)) { let newArray = []; for (let i = 0; i < numValues; i++) { newArray = newArray.concat(value); } return newArray; } else { const newArray = new Array(numValues); newArray.fill(value); return newArray; } } function assert(val, message) { if (!val) { throw new AssertionError(message); } } function count(array, refernce) { let counter = 0; for (const item of array) { if (item === refernce) { counter++; } } return counter; } function singletonOrArray(xs) { if (xs.length === 1) { return xs[0]; } return xs; } function toList(x) { if (Array.isArray(x)) { return x; } return [x]; } function toSnakeCase(name) { const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2'); const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase(); if (insecure[0] !== '_') { return insecure; } return 'private' + insecure; } function toCamelCase(identifier) { if (identifier.length <= 1) { return identifier; } if (identifier.indexOf('_') === -1) { return identifier; } return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase()); } let _GLOBAL_CUSTOM_OBJECTS = {}; function serializeKerasObject(instance) { if (instance === null || instance === undefined) { return null; } const dict = {}; dict['className'] = instance.getClassName(); dict['config'] = instance.getConfig(); return dict; } function convertNDArrayScalarsInConfig(config) { if (config == null || typeof config !== 'object') { return; } else if (Array.isArray(config)) { config.forEach(configItem => convertNDArrayScalarsInConfig(configItem)); } else { const fields = Object.keys(config); for (const field of fields) { const value = config[field]; if (value != null && typeof value === 'object') { if (!Array.isArray(value) && value['type'] === 'ndarray' && typeof value['value'] === 'number') { config[field] = value['value']; } else { convertNDArrayScalarsInConfig(value); } } } } } function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) { if (typeof identifier === 'string') { const functionName = identifier; let fn; if (functionName in customObjects) { fn = customObjects[functionName]; } else if (functionName in _GLOBAL_CUSTOM_OBJECTS) { fn = _GLOBAL_CUSTOM_OBJECTS[functionName]; } else { fn = moduleObjects[functionName]; if (fn == null) { throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` + `This may be due to one of the following reasons:\n` + `1. The ${printableModuleName} is defined in Python, in which ` + `case it needs to be ported to TensorFlow.js or your JavaScript ` + `code.\n` + `2. The custom ${printableModuleName} is defined in JavaScript, ` + `but is not registered properly with ` + `tf.serialization.registerClass().`); } } return fn; } else { const config = identifier; if (config['className'] == null || config['config'] == null) { throw new ValueError(`${printableModuleName}: Improper config format: ` + `${JSON.stringify(config)}.\n` + `'className' and 'config' must set.`); } const className = config['className']; let cls, fromConfig; if (className in customObjects) { [cls, fromConfig] = customObjects[className]; } else if (className in _GLOBAL_CUSTOM_OBJECTS) { [cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className']; } else if (className in moduleObjects) { [cls, fromConfig] = moduleObjects[className]; } if (cls == null) { throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` + `This may be due to one of the following reasons:\n` + `1. The ${printableModuleName} is defined in Python, in which ` + `case it needs to be ported to TensorFlow.js or your JavaScript ` + `code.\n` + `2. The custom ${printableModuleName} is defined in JavaScript, ` + `but is not registered properly with ` + `tf.serialization.registerClass().`); } if (fromConfig != null) { const customObjectsCombined = {}; for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) { customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key]; } for (const key of Object.keys(customObjects)) { customObjectsCombined[key] = customObjects[key]; } const nestedConfig = config['config']; nestedConfig['customObjects'] = customObjectsCombined; const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS); for (const key of Object.keys(customObjects)) { _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key]; } convertNDArrayScalarsInConfig(config['config']); const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit); _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects); return returnObj; } else { const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS); for (const key of Object.keys(customObjects)) { _GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key]; } const returnObj = new cls(config['config']); _GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects); return returnObj; } } } function numberCompare(a, b) { return (a < b) ? -1 : ((a > b) ? 1 : 0); } function reverseNumberCompare(a, b) { return -1 * numberCompare(a, b); } function unique(xs) { if (xs == null) { return xs; } const out = []; for (const x of xs) { if (out.indexOf(x) === -1) { out.push(x); } } return out; } function isObjectEmpty(obj) { if (obj == null) { throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`); } for (const key in obj) { if (obj.hasOwnProperty(key)) { return false; } } return true; } function checkStringTypeUnionValue(values, label, value) { if (value == null) { return; } if (values.indexOf(value) < 0) { throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`); } } function assertPositiveInteger(value, name) { if (Array.isArray(value)) { assert$1(value.length > 0, () => `${name} is unexpectedly an empty array.`); value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`)); } else { assert$1(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` + `${formatAsFriendlyString(value)}.`); } } function formatAsFriendlyString(value) { if (value === null) { return 'null'; } else if (Array.isArray(value)) { return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']'; } else if (typeof value === 'string') { return `"${value}"`; } else { return `${value}`; } } function debounce(f, waitMs, nowFunc) { let lastTime = nowFunc != null ? nowFunc() : now(); let lastResult; const f2 = (...args) => { const now$1 = nowFunc != null ? nowFunc() : now(); if (now$1 - lastTime < waitMs) { return lastResult; } lastTime = now$1; lastResult = f(...args); return lastResult; }; return f2; } function mapActivationToFusedKernel(activationName) { if (activationName === 'relu') { return 'relu'; } if (activationName === 'linear') { return 'linear'; } if (activationName === 'elu') { return 'elu'; } return null; } let _nextUniqueTensorId = 0; function getNextUniqueTensorId() { return _nextUniqueTensorId++; } const _uidPrefixes = {}; function getUid(prefix = '') { if (!(prefix in _uidPrefixes)) { _uidPrefixes[prefix] = 0; } _uidPrefixes[prefix] += 1; return prefix + _uidPrefixes[prefix].toString(); } const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast']; const nameMap = new Map(); function checkDataFormat(value) { checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value); } const _nameScopeStack = []; const _nameScopeDivider = '/'; function nameScope(name, fn) { _nameScopeStack.push(name); try { const val = fn(); _nameScopeStack.pop(); return val; } catch (e) { _nameScopeStack.pop(); throw e; } } function currentNameScopePrefix() { if (_nameScopeStack.length === 0) { return ''; } else { return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider; } } function getScopedTensorName(tensorName) { if (!isValidTensorName(tensorName)) { throw new Error('Not a valid tensor name: \'' + tensorName + '\''); } return currentNameScopePrefix() + tensorName; } function getUniqueTensorName(scopedName) { if (!isValidTensorName(scopedName)) { throw new Error('Not a valid tensor name: \'' + scopedName + '\''); } if (!nameMap.has(scopedName)) { nameMap.set(scopedName, 0); } const index = nameMap.get(scopedName); nameMap.set(scopedName, nameMap.get(scopedName) + 1); if (index > 0) { const result = `${scopedName}_${index}`; nameMap.set(result, 1); return result; } else { return scopedName; } } const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/); function isValidTensorName(name) { return !!name.match(tensorNameRegex); } function arrayProd(array, begin, end) { if (begin == null) { begin = 0; } if (end == null) { end = array.length; } let prod = 1; for (let i = begin; i < end; ++i) { prod *= array[i]; } return prod; } function range(begin, end) { if (end < begin) { throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`); } const out = []; for (let i = begin; i < end; ++i) { out.push(i); } return out; } let _epsilon; function epsilon() { if (_epsilon == null) { _epsilon = backend().epsilon(); } return _epsilon; } function imageDataFormat() { return 'channelsLast'; } function cast(x, dtype) { return cast$2(x, dtype); } function expandDims(x, axis = -1) { const outShape = x.shape.slice(); if (axis < 0) { axis = outShape.length + axis + 1; } outShape.splice(axis, 0, 1); return reshape$1(x, outShape); } function repeat(x, n) { return tidy(() => { if (x.shape.length !== 2) { throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` + `rank-${x.shape.length} tensor.`); } const y = expandDims(x, 1); return tile(y, [1, n, 1]); }); } function flatten(x) { const newShape = [arrayProd(x.shape)]; return reshape$1(x, newShape); } function batchFlatten(x) { if (x.rank <= 1) { throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`); } const newShape = [x.shape[0], arrayProd(x.shape, 1)]; return reshape$1(x, newShape); } function sliceAlongFirstAxis(array, start, size) { return tidy(() => { switch (array.rank) { case 1: return slice1d(array, start, size); case 2: return slice2d(array, [start, 0], [size, array.shape[1]]); case 3: return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]); case 4: return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]); case 5: return slice$1(array, [start, 0, 0, 0, 0], [ size, array.shape[1], array.shape[2], array.shape[3], array.shape[4] ]); case 6: return slice$1(array, [start, 0, 0, 0, 0, 0], [ size, array.shape[1], array.shape[2], array.shape[3], array.shape[4], array.shape[5] ]); default: throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` + `${array.rank}`); } }); } function tile(x, n) { if (!Array.isArray(n)) { n = [n]; } if (x.rank !== n.length) { throw new ValueError(`The length of input n (${n.length}) does not match ` + `the number of dimensions in input x (${x.rank})`); } return tile$2(x, n); } function randomNormal(shape, mean = 0.0, stddev = 1.0, dtype, seed) { return randomNormal$1(shape, mean, stddev, dtype, seed); } function dot(a, b, activation, bias) { if ((a.rank < 2) || (b.rank < 2)) { throw new NotImplementedError(`dot requires both inputs to be rank >= 2` + ` but got x shape = ${a.shape} and y shape = ${b.shape}`); } if (b.rank >= 3) { const xLastDim = a.shape.slice(-1)[0]; const ySecondLastDim = b.shape.slice(-2)[0]; if (xLastDim !== ySecondLastDim) { throw new NotImplementedError(`If rank y >= 3, then the second last dim` + ` of y must equal the last dim of x but got x shape = ${a.shape} and ` + ` y shape = ${b.shape}`); } } if ((a.rank === 2) && (b.rank === 2)) { const transposeA = false; const transposeB = false; return matMul({ a, b: b, transposeA, transposeB, bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null, activation }); } else { const aFirstDims = a.shape.slice(); const aLastDim = aFirstDims.pop(); a = reshape$1(a, [-1, aLastDim]); const bShape = b.shape.slice(); const bLastDim = bShape.pop(); const ySecondLastDim = bShape.pop(); const yOtherDims = [...bShape, bLastDim]; const perm = Array.from({ length: b.rank }, (_, i) => { if (i === 0) { return b.rank - 2; } else if (i <= b.rank - 2) { return i - 1; } return i; }); b = reshape$1(transpose$1(b, perm), [ySecondLastDim, -1]); const outputShape = [...aFirstDims, ...yOtherDims]; const transposeA = false; const transposeB = false; return reshape$1(matMul({ a, b, transposeA, transposeB, bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null, activation }), outputShape); } } function gather(reference, indices, axis) { return tidy(() => { if (Array.isArray(indices)) { indices = tensor1d(indices, 'int32'); } else { indices = cast$2(indices, 'int32'); } return gather$1(reference, indices, axis); }); } function square(x) { return mul(x, x); } function reshapeBias(xRank, bias, dataFormat) { const biasShape = bias.shape; if (bias.rank !== 1 && bias.rank !== xRank) { throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` + `; expected it to be 1 or ${xRank}`); } if (xRank === 5) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return reshape$1(bias, [1, biasShape[0], 1, 1, 1]); } else { return reshape$1(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return reshape$1(bias, [1, 1, 1, 1, biasShape[0]]); } else { return reshape$1(bias, [1].concat(biasShape)); } } } else if (xRank === 4) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return reshape$1(bias, [1, biasShape[0], 1, 1]); } else { return reshape$1(bias, [1, biasShape[2], biasShape[0], biasShape[1]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return reshape$1(bias, [1, 1, 1, biasShape[0]]); } else { return reshape$1(bias, [1].concat(biasShape)); } } } else if (xRank === 3) { if (dataFormat === 'channelsFirst') { if (biasShape.length === 1) { return reshape$1(bias, [1, biasShape[0], 1]); } else { return reshape$1(bias, [1, biasShape[1], biasShape[0]]); } } else if (dataFormat === 'channelsLast') { if (biasShape.length === 1) { return reshape$1(bias, [1, 1, biasShape[0]]); } else { return reshape$1(bias, [1].concat(biasShape)); } } } else if (xRank < 3) { return bias; } throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`); } function biasAdd(x, bias, dataFormat) { return tidy(() => { if (dataFormat == null) { dataFormat = imageDataFormat(); } checkDataFormat(dataFormat); return add(x, reshapeBias(x.rank, bias, dataFormat)); }); } function elu(x, alpha = 1) { if (alpha !== 1) { throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` + `yet.`); } return elu$2(x); } function softsign(x) { return tidy(() => div(x, add(abs$1(x), 1))); } function dropout$1(x, level, noiseShape, seed) { return tidy(() => dropout$2(x, level, noiseShape, seed)); } function hardSigmoid(x) { return tidy(() => { const y = add(.5, mul(.2, x)); return clipByValue$1(y, 0, 1); }); } function inTrainPhase(x, alt, training = false) { return training ? x() : alt(); } const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg']; const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal']; function checkFanMode(value) { checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value); } function checkDistribution(value) { checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value); } class Initializer extends Serializable { fromConfigUsesCustomObjects() { return false; } getConfig() { return {}; } } class Zeros extends Initializer { apply(shape, dtype) { return zeros(shape, dtype); } } Zeros.className = 'Zeros'; registerClass(Zeros); class Ones extends Initializer { apply(shape, dtype) { return ones(shape, dtype); } } Ones.className = 'Ones'; registerClass(Ones); class Constant extends Initializer { constructor(args) { super(); if (typeof args !== 'object') { throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`); } if (args.value === undefined) { throw new ValueError(`config must have value set but got ${args}`); } this.value = args.value; } apply(shape, dtype) { return tidy(() => mul(scalar(this.value), ones(shape, dtype))); } getConfig() { return { value: this.value, }; } } Constant.className = 'Constant'; registerClass(Constant); class RandomUniform extends Initializer { constructor(args) { super(); this.DEFAULT_MINVAL = -0.05; this.DEFAULT_MAXVAL = 0.05; this.minval = args.minval || this.DEFAULT_MINVAL; this.maxval = args.maxval || this.DEFAULT_MAXVAL; this.seed = args.seed; } apply(shape, dtype) { return randomUniform(shape, this.minval, this.maxval, dtype, this.seed); } getConfig() { return { minval: this.minval, maxval: this.maxval, seed: this.seed }; } } RandomUniform.className = 'RandomUniform'; registerClass(RandomUniform); class RandomNormal extends Initializer { constructor(args) { super(); this.DEFAULT_MEAN = 0.; this.DEFAULT_STDDEV = 0.05; this.mean = args.mean || this.DEFAULT_MEAN; this.stddev = args.stddev || this.DEFAULT_STDDEV; this.seed = args.seed; } apply(shape, dtype) { dtype = dtype || 'float32'; if (dtype !== 'float32' && dtype !== 'int32') { throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`); } return randomNormal(shape, this.mean, this.stddev, dtype, this.seed); } getConfig() { return { mean: this.mean, stddev: this.stddev, seed: this.seed }; } } RandomNormal.className = 'RandomNormal'; registerClass(RandomNormal); class TruncatedNormal extends Initializer { constructor(args) { super(); this.DEFAULT_MEAN = 0.; this.DEFAULT_STDDEV = 0.05; this.mean = args.mean || this.DEFAULT_MEAN; this.stddev = args.stddev || this.DEFAULT_STDDEV; this.seed = args.seed; } apply(shape, dtype) { dtype = dtype || 'float32'; if (dtype !== 'float32' && dtype !== 'int32') { throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`); } return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed); } getConfig() { return { mean: this.mean, stddev: this.stddev, seed: this.seed }; } } TruncatedNormal.className = 'TruncatedNormal'; registerClass(TruncatedNormal); class Identity extends Initializer { constructor(args) { super(); this.gain = args.gain != null ? args.gain : 1.0; } apply(shape, dtype) { return tidy(() => { if (shape.length !== 2 || shape[0] !== shape[1]) { throw new ValueError('Identity matrix initializer can only be used for' + ' 2D square matrices.'); } else { return mul(this.gain, eye(shape[0])); } }); } getConfig() { return { gain: this.gain }; } } Identity.className = 'Identity'; registerClass(Identity); function computeFans(shape, dataFormat = 'channelsLast') { let fanIn; let fanOut; checkDataFormat(dataFormat); if (shape.length === 2) { fanIn = shape[0]; fanOut = shape[1]; } else if ([3, 4, 5].indexOf(shape.length) !== -1) { if (dataFormat === 'channelsFirst') { const receptiveFieldSize = arrayProd(shape, 2); fanIn = shape[1] * receptiveFieldSize; fanOut = shape[0] * receptiveFieldSize; } else if (dataFormat === 'channelsLast') { const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2); fanIn = shape[shape.length - 2] * receptiveFieldSize; fanOut = shape[shape.length - 1] * receptiveFieldSize; } } else { const shapeProd = arrayProd(shape); fanIn = Math.sqrt(shapeProd); fanOut = Math.sqrt(shapeProd); } return [fanIn, fanOut]; } class VarianceScaling extends Initializer { constructor(args) { super(); if (args.scale < 0.0) { throw new ValueError(`scale must be a positive float. Got: ${args.scale}`); } this.scale = args.scale == null ? 1.0 : args.scale; this.mode = args.mode == null ? 'fanIn' : args.mode; checkFanMode(this.mode); this.distribution = args.distribution == null ? 'normal' : args.distribution; checkDistribution(this.distribution); this.seed = args.seed; } apply(shape, dtype) { const fans = computeFans(shape); const fanIn = fans[0]; const fanOut = fans[1]; let scale = this.scale; if (this.mode === 'fanIn') { scale /= Math.max(1, fanIn); } else if (this.mode === 'fanOut') { scale /= Math.max(1, fanOut); } else { scale /= Math.max(1, (fanIn + fanOut) / 2); } if (this.distribution === 'normal') { const stddev = Math.sqrt(scale); dtype = dtype || 'float32'; if (dtype !== 'float32' && dtype !== 'int32') { throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`); } return truncatedNormal(shape, 0, stddev, dtype, this.seed); } else { const limit = Math.sqrt(3 * scale); return randomUniform(shape, -limit, limit, dtype, this.seed); } } getConfig() { return { scale: this.scale, mode: this.mode, distribution: this.distribution, seed: this.seed }; } } VarianceScaling.className = 'VarianceScaling'; registerClass(VarianceScaling); class GlorotUniform extends VarianceScaling { constructor(args) { super({ scale: 1.0, mode: 'fanAvg', distribution: 'uniform', seed: args == null ? null : args.seed }); } getClassName() { return VarianceScaling.className; } } GlorotUniform.className = 'GlorotUniform'; registerClass(GlorotUniform); class GlorotNormal extends VarianceScaling { constructor(args) { super({ scale: 1.0, mode: 'fanAvg', distribution: 'normal', seed: args == null ? null : args.seed }); } getClassName() { return VarianceScaling.className; } } GlorotNormal.className = 'GlorotNormal'; registerClass(GlorotNormal); class HeNormal extends VarianceScaling { constructor(args) { super({ scale: 2.0, mode: 'fanIn', distribution: 'normal', seed: args == null ? null : args.seed }); } getClassName() { return VarianceScaling.className; } } HeNormal.className = 'HeNormal'; registerClass(HeNormal); class HeUniform extends VarianceScaling { constructor(args) { super({ scale: 2.0, mode: 'fanIn', distribution: 'uniform', seed: args == null ? null : args.seed }); } getClassName() { return VarianceScaling.className; } } HeUniform.className = 'HeUniform'; registerClass(HeUniform); class LeCunNormal extends VarianceScaling { constructor(args) { super({ scale: 1.0, mode: 'fanIn', distribution: 'normal', seed: args == null ? null : args.seed }); } getClassName() { return VarianceScaling.className; } } LeCunNormal.className = 'LeCunNormal'; registerClass(LeCunNormal); class LeCunUniform extends VarianceScaling { constructor(args) { super({ scale: 1.0, mode: 'fanIn', distribution: 'uniform', seed: args == null ? null : args.seed }); } getClassName() { return VarianceScaling.className; } } LeCunUniform.className = 'LeCunUniform'; registerClass(LeCunUniform); class Orthogonal extends Initializer { constructor(args) { super(); this.DEFAULT_GAIN = 1; this.ELEMENTS_WARN_SLOW = 2000; this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain; this.seed = args.seed; } apply(shape, dtype) { return tidy(() => { if (shape.length < 2) { throw new NotImplementedError('Shape must be at least 2D.'); } if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) { throw new TypeError(`Unsupported data type ${dtype}.`); } dtype = dtype; const numRows = sizeFromShape(shape.slice(0, -1)); const numCols = shape[shape.length - 1]; const numElements = numRows * numCols; if (numElements > this.ELEMENTS_WARN_SLOW) { console.warn(`Orthogonal initializer is being called on a matrix with more ` + `than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` + `Slowness may result.`); } const flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)]; const randNormalMat = randomNormal(flatShape, 0, 1, dtype, this.seed); const qr = linalg.qr(randNormalMat, false); let qMat = qr[0]; const rMat = qr[1]; const diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]); qMat = mul(qMat, diag.sign()); if (numRows < numCols) { qMat = qMat.transpose(); } return mul(scalar(this.gain), qMat.reshape(shape)); }); } getConfig() { return { gain: this.gain, seed: this.seed, }; } } Orthogonal.className = 'Orthogonal'; registerClass(Orthogonal); const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = { 'constant': 'Constant', 'glorotNormal': 'GlorotNormal', 'glorotUniform': 'GlorotUniform', 'heNormal': 'HeNormal', 'heUniform': 'HeUniform', 'identity': 'Identity', 'leCunNormal': 'LeCunNormal', 'leCunUniform': 'LeCunUniform', 'ones': 'Ones', 'orthogonal': 'Orthogonal', 'randomNormal': 'RandomNormal', 'randomUniform': 'RandomUniform', 'truncatedNormal': 'TruncatedNormal', 'varianceScaling': 'VarianceScaling', 'zeros': 'Zeros' }; function deserializeInitializer(config, customObjects = {}) { return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer'); } function serializeInitializer(initializer) { return serializeKerasObject(initializer); } function getInitializer(identifier) { if (typeof identifier === 'string') { const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier; if (className === 'GlorotNormal') { return new GlorotNormal(); } else if (className === 'GlorotUniform') { return new GlorotUniform(); } else if (className === 'HeNormal') { return new HeNormal(); } else if (className === 'HeUniform') { return new HeUniform(); } else if (className === 'LeCunNormal') { return new LeCunNormal(); } else if (className === 'LeCunUniform') { return new LeCunUniform(); } else { const config = {}; config['className'] = className; config['config'] = {}; return deserializeInitializer(config); } } else if (identifier instanceof Initializer) { return identifier; } else { return deserializeInitializer(identifier); } } function normalizeShapeList(x) { if (x.length === 0) { return []; } if (!Array.isArray(x[0])) { return [x]; } return x; } function getExactlyOneTensor(xs) { let x; if (Array.isArray(xs)) { if (xs.length !== 1) { throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`); } x = xs[0]; } else { x = xs; } return x; } function getExactlyOneShape(shapes) { if (Array.isArray(shapes) && Array.isArray(shapes[0])) { if (shapes.length === 1) { shapes = shapes; return shapes[0]; } else { throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`); } } else { return shapes; } } function countParamsInWeights(weights) { let count = 0; for (const weight of weights) { if (weight.shape.length === 0) { count += 1; } else { count += weight.shape.reduce((a, b) => a * b); } } return count; } const DEFAULT_VARIABLE_NAME_PREFIX = 'Variable'; class LayerVariable { constructor(val, dtype = 'float32', name = DEFAULT_VARIABLE_NAME_PREFIX, trainable = true, constraint = null) { this.dtype = dtype == null ? 'float32' : dtype; this.shape = val.shape; this.id = getNextUniqueTensorId(); name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name; this.originalName = getScopedTensorName(name); this.name = getUniqueTensorName(this.originalName); this.trainable_ = trainable; this.constraint = constraint; this.val = variable(val, this.trainable_, this.name, this.dtype); } read() { this.assertNotDisposed(); return this.val; } write(newVal) { this.assertNotDisposed(); checkShapesMatch(this.val, newVal); if (this.val.id !== newVal.id) { this.val.assign(newVal); if (this.constraint != null) { this.val.assign(this.constraint.apply(this.val)); } } return this; } dispose() { this.assertNotDisposed(); this.val.dispose(); } assertNotDisposed() { if (this.val.isDisposed) { throw new Error(`LayersVariable ${this.name} is already disposed.`); } } get trainable() { return this.trainable_; } set trainable(trainable) { this.trainable_ = trainable; this.val.trainable = trainable; } } function checkShapesMatch(x, y) { if (x.shape.toString() !== y.shape.toString()) { throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' + JSON.stringify(y.shape)); } } function batchGetValue(xs) { return xs.map(x => x.read()); } function batchSetValue(variablesAndValues) { variablesAndValues.forEach(variableAndValue => { const variable = variableAndValue[0]; variable.write(variableAndValue[1]); }); } class InputSpec { constructor(args) { this.dtype = args.dtype; this.shape = args.shape; if (args.shape != null) { this.ndim = args.shape.length; } else { this.ndim = args.ndim; } this.maxNDim = args.maxNDim; this.minNDim = args.minNDim; this.axes = args.axes || {}; } } class SymbolicTensor { constructor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) { this.dtype = dtype; this.shape = shape; this.sourceLayer = sourceLayer; this.inputs = inputs; this.callArgs = callArgs; this.outputTensorIndex = outputTensorIndex; this.id = getNextUniqueTensorId(); if (name != null) { this.originalName = getScopedTensorName(name); this.name = getUniqueTensorName(this.originalName); } this.rank = shape.length; } } let _nextNodeID = 0; class Node { constructor(args, callArgs) { this.callArgs = callArgs; this.id = _nextNodeID++; this.outboundLayer = args.outboundLayer; this.inboundLayers = args.inboundLayers; this.nodeIndices = args.nodeIndices; this.tensorIndices = args.tensorIndices; this.inputTensors = args.inputTensors; this.outputTensors = args.outputTensors; this.inputMasks = args.inputMasks; this.outputMasks = args.outputMasks; this.inputShapes = args.inputShapes; this.outputShapes = args.outputShapes; for (const layer of args.inboundLayers) { if (layer != null) { layer.outboundNodes.push(this); } } args.outboundLayer.inboundNodes.push(this); } getConfig() { const inboundNames = []; for (const layer of this.inboundLayers) { if (layer != null) { inboundNames.push(layer.name); } else { inboundNames.push(null); } } return { outboundLayer: this.outboundLayer ? this.outboundLayer.name : null, inboundLayers: inboundNames, nodeIndices: this.nodeIndices, tensorIndices: this.tensorIndices }; } } let _nextLayerID = 0; class Layer extends Serializable { constructor(args = {}) { super(); this._callHook = null; this._addedWeightNames = []; this._stateful = false; this.id = _nextLayerID++; this.activityRegularizer = null; this.inputSpec = null; this.supportsMasking = false; this._trainableWeights = []; this._nonTrainableWeights = []; this._losses = []; this._updates = []; this._built = false; this.inboundNodes = []; this.outboundNodes = []; let name = args.name; if (!name) { const prefix = this.getClassName(); name = toSnakeCase(prefix) + '_' + getUid(prefix); } this.name = name; this.trainable_ = args.trainable == null ? true : args.trainable; if (args.inputShape != null || args.batchInputShape != null) { let batchInputShape; if (args.batchInputShape != null) { batchInputShape = args.batchInputShape; } else if (args.inputShape != null) { let batchSize = null; if (args.batchSize != null) { batchSize = args.batchSize; } batchInputShape = [batchSize].concat(args.inputShape); } this.batchInputShape = batchInputShape; let dtype = args.dtype; if (dtype == null) { dtype = args.inputDType; } if (dtype == null) { dtype = 'float32'; } this.dtype = dtype; } if (args.weights != null) { this.initialWeights = args.weights; } else { this.initialWeights = null; } this._refCount = null; this.fastWeightInitDuringBuild = false; } static nodeKey(layer, nodeIndex) { return layer.name + '_ib-' + nodeIndex.toString(); } getNodeAtIndex(nodeIndex, attrName) { if (this.inboundNodes.length === 0) { throw new RuntimeError('The layer has never been called ' + `and thus has no defined ${attrName}.`); } if (this.inboundNodes.length <= nodeIndex) { throw new ValueError(`Asked to get ${attrName} at node ${nodeIndex}, ` + `but the layer has only ${this.inboundNodes.length} inbound nodes.`); } return this.inboundNodes[nodeIndex]; } getInputAt(nodeIndex) { return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors); } getOutputAt(nodeIndex) { return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors); } get input() { if (this.inboundNodes.length > 1) { throw new AttributeError(`Layer ${this.name}` + ' has multiple inbound nodes, ' + 'hence the notion of "layer input" ' + 'is ill-defined. ' + 'Use `getInputAt(nodeIndex)` instead.'); } else if (this.inboundNodes.length === 0) { throw new AttributeError(`Layer ${this.name}` + ' is not connected, no input to return.'); } return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors); } get output() { if (this.inboundNodes.length === 0) { throw new AttributeError(`Layer ${this.name}` + ' has no inbound nodes.'); } if (this.inboundNodes.length > 1) { throw new AttributeError(`Layer ${this.name}` + ' has multiple inbound nodes, ' + 'hence the notion of "layer output" ' + 'is ill-defined. ' + 'Use `getOutputAt(nodeIndex)` instead.'); } return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors); } get losses() { return this._losses; } calculateLosses() { return this.losses.map(lossFn => lossFn()); } get updates() { return this._updates; } get built() { return this._built; } set built(built) { this._built = built; } get trainable() { return this.trainable_; } set trainable(trainable) { this._trainableWeights.forEach(w => w.trainable = trainable); this.trainable_ = trainable; } get trainableWeights() { if (this.trainable_) { return this._trainableWeights.filter(w => w.trainable); } else { return []; } } set trainableWeights(weights) { this._trainableWeights = weights; } get nonTrainableWeights() { if (this.trainable) { return this._trainableWeights.filter(w => !w.trainable) .concat(this._nonTrainableWeights); } else { return this._trainableWeights.concat(this._nonTrainableWeights); } } set nonTrainableWeights(weights) { this._nonTrainableWeights = weights; } get weights() { return this.trainableWeights.concat(this.nonTrainableWeights); } get stateful() { return this._stateful; } resetStates() { if (!this.stateful) { throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' + 'object.'); } } assertInputCompatibility(inputs) { const inputsList = toList(inputs); if (this.inputSpec == null || this.inputSpec.length === 0) { return; } const inputSpec = toList(this.inputSpec); if (inputsList.length !== inputSpec.length) { throw new ValueError(`Layer ${this.name} expects ${inputSpec.length} inputs, ` + `but it received ${inputsList.length} input tensors. ` + `Input received: ${inputs}`); } for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) { const x = inputsList[inputIndex]; const spec = inputSpec[inputIndex]; if (spec == null) { continue; } const ndim = x.rank; if (spec.ndim != null) { if (ndim !== spec.ndim) { throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}: ` + `expected ndim=${spec.ndim}, found ndim=${ndim}`); } } if (spec.maxNDim != null) { if (ndim > spec.maxNDim) { throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` + `: expected max_ndim=${spec.maxNDim}, found ndim=${ndim}`); } } if (spec.minNDim != null) { if (ndim < spec.minNDim) { throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` + `: expected min_ndim=${spec.minNDim}, found ndim=${ndim}.`); } } if (spec.dtype != null) { if (x.dtype !== spec.dtype) { throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name} ` + `: expected dtype=${spec.dtype}, found dtype=${x.dtype}.`); } } if (spec.axes) { const xShape = x.shape; for (const key in spec.axes) { const axis = Number(key); const value = spec.axes[key]; const xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis]; if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) { throw new ValueError(`Input ${inputIndex} is incompatible with layer ` + `${this.name}: expected axis ${axis} of input shape to ` + `have value ${value} but got shape ${xShape}.`); } } } if (spec.shape != null) { for (let i = 0; i < spec.shape.length; ++i) { const specDim = spec.shape[i]; const dim = x.shape[i]; if (specDim != null && dim != null) { if (specDim !== dim) { throw new ValueError(`Input ${inputIndex} is incompatible with layer ` + `${this.name}: expected shape=${spec.shape}, ` + `found shape=${x.shape}.`); } } } } } } call(inputs, kwargs) { return inputs; } invokeCallHook(inputs, kwargs) { if (this._callHook != null) { this._callHook(inputs, kwargs); } } setCallHook(callHook) { this._callHook = callHook; } clearCallHook() { this._callHook = null; } apply(inputs, kwargs) { kwargs = kwargs || {}; this.assertNotDisposed(); const inputsList = toList(inputs); const allAreSymbolic = checkAllSymbolic(inputs); const noneAreSymbolic = checkNoneSymbolic(inputs); if (allAreSymbolic === noneAreSymbolic) { throw new ValueError('Arguments to apply() must be all ' + 'SymbolicTensors or all Tensors'); } return nameScope(this.name, () => { if (!this.built) { this.assertInputCompatibility(inputs); const inputShapes = []; for (const xElem of toList(inputs)) { inputShapes.push(xElem.shape); } this.build(singletonOrArray(inputShapes)); this.built = true; if (this.initialWeights) { this.setWeights(this.initialWeights); } if (this._refCount === null && noneAreSymbolic) { this._refCount = 1; } } this.assertInputCompatibility(inputs); if (noneAreSymbolic) { let output = this.call(inputs, kwargs); if (this.supportsMasking) { this.setMaskMetadata(inputs, output); } const outputList = toList(output); const outputListCopy = []; for (let x of outputList) { if (inputsList.indexOf(x) !== -1) { x = x.clone(); } outputListCopy.push(x); } output = singletonOrArray(outputListCopy); if (this.activityRegularizer != null) { throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.'); } return output; } else { const inputShape = collectInputShape(inputs); const outputShape = this.computeOutputShape(inputShape); let output; const outputDType = guessOutputDType(); this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] : inputShape); if (outputShape != null && outputShape.length > 0 && Array.isArray(outputShape[0])) { output = outputShape .map((shape, index) => new SymbolicTensor(outputDType, shape, this, toList(inputs), kwargs, this.name, index)); } else { output = new SymbolicTensor(outputDType, outputShape, this, toList(inputs), kwargs, this.name); } this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs); this._refCount++; if (this.activityRegularizer != null) { throw new NotImplementedError('Layer invocation in the presence of activity ' + 'regularizer(s) is not supported yet.'); } return output; } }); } warnOnIncompatibleInputShape(inputShape) { if (this.batchInputShape == null) { return; } else if (inputShape.length !== this.batchInputShape.length) { console.warn(`The rank of the input tensor provided (shape: ` + `${JSON.stringify(inputShape)}) does not match that of the ` + `batchInputShape (${JSON.stringify(this.batchInputShape)}) ` + `of the layer ${this.name}`); } else { let dimMismatch = false; this.batchInputShape.forEach((dimension, i) => { if (dimension != null && inputShape[i] != null && inputShape[i] !== dimension) { dimMismatch = true; } }); if (dimMismatch) { console.warn(`The shape of the input tensor ` + `(${JSON.stringify(inputShape)}) does not ` + `match the expectation of layer ${this.name}: ` + `${JSON.stringify(this.batchInputShape)}`); } } } get outputShape() { if (this.inboundNodes == null || this.inboundNodes.length === 0) { throw new AttributeError(`The layer ${this.name} has never been called and thus has no ` + `defined output shape.`); } const allOutputShapes = []; for (const node of this.inboundNodes) { const shapeString = JSON.stringify(node.outputShapes); if (allOutputShapes.indexOf(shapeString) === -1) { allOutputShapes.push(shapeString); } } if (allOutputShapes.length === 1) { const outputShapes = this.inboundNodes[0].outputShapes; if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) && outputShapes.length === 1) { return outputShapes[0]; } else { return outputShapes; } } else { throw new AttributeError(`The layer ${this.name} has multiple inbound nodes with different ` + `output shapes. Hence the notion of "output shape" is ill-defined ` + `for the layer.`); } } countParams() { if (!this.built) { throw new RuntimeError(`You tried to call countParams() on ${this.name}, ` + `but the layer is not built yet. Build it first by calling ` + `build(batchInputShape).`); } return countParamsInWeights(this.weights); } build(inputShape) { this.built = true; } getWeights(trainableOnly = false) { return batchGetValue(trainableOnly ? this.trainableWeights : this.weights); } setWeights(weights) { tidy(() => { const params = this.weights; if (params.length !== weights.length) { throw new ValueError(`You called setWeights(weights) on layer "${this.name}" ` + `with a weight list of length ${weights.length}, ` + `but the layer was expecting ${params.length} weights. ` + `Provided weights: ${weights}...`); } if (params.length === 0) { return; } const weightValueTuples = []; const paramValues = batchGetValue(params); for (let i = 0; i < paramValues.length; ++i) { const pv = paramValues[i]; const p = params[i]; const w = weights[i]; if (!arraysEqual(pv.shape, w.shape)) { throw new ValueError(`Layer weight shape ${pv.shape} ` + `not compatible with provided weight shape ${w.shape}`); } weightValueTuples.push([p, w]); } batchSetValue(weightValueTuples); }); } addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) { if (this._addedWeightNames.indexOf(name) !== -1) { throw new ValueError(`Duplicate weight name ${name} for layer ${this.name}`); } this._addedWeightNames.push(name); if (dtype == null) { dtype = 'float32'; } if (this.fastWeightInitDuringBuild) { initializer = getInitializerFunc != null ? getInitializerFunc() : getInitializer('zeros'); } const initValue = initializer.apply(shape, dtype); const weight = new LayerVariable(initValue, dtype, name, trainable, constraint); initValue.dispose(); if (regularizer != null) { this.addLoss(() => regularizer.apply(weight.read())); } if (trainable == null) { trainable = true; } if (trainable) { this._trainableWeights.push(weight); } else { this._nonTrainableWeights.push(weight); } return weight; } setFastWeightInitDuringBuild(value) { this.fastWeightInitDuringBuild = value; } addLoss(losses) { if (losses == null || Array.isArray(losses) && losses.length === 0) { return; } losses = toList(losses); if (this._losses !== undefined && this._losses !== null) { this.losses.push(...losses); } } computeOutputShape(inputShape) { return inputShape; } computeMask(inputs, mask) { if (!this.supportsMasking) { if (mask != null) { if (Array.isArray(mask)) { mask.forEach(maskElement => { if (maskElement != null) { throw new TypeError(`Layer ${this.name} does not support masking, ` + 'but was passed an inputMask.'); } }); } else { throw new TypeError(`Layer ${this.name} does not support masking, ` + 'but was passed an inputMask.'); } } return null; } return mask; } setMaskMetadata(inputs, outputs, previousMask) { if (!this.supportsMasking) { return; } const outputMasks = this.computeMask(inputs, previousMask); const outputsList = toList(outputs); const outputMasksList = toList(outputMasks); if (outputsList.length !== outputMasksList.length) { throw new Error(`${this.name} outputs ${outputsList.length} tensors ` + `but ${outputsList.length} masks for those tensors`); } for (let i = 0; i < outputsList.length; i++) { outputsList[i].kerasMask = outputMasksList[i]; } } addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs = null) { const inputTensorList = toList(inputTensors); outputTensors = toList(outputTensors); inputMasks = toList(inputMasks); outputMasks = toList(outputMasks); inputShapes = normalizeShapeList(inputShapes); outputShapes = normalizeShapeList(outputShapes); const inboundLayers = []; const nodeIndices = []; const tensorIndices = []; for (const x of inputTensorList) { inboundLayers.push(x.sourceLayer); nodeIndices.push(x.nodeIndex); tensorIndices.push(x.tensorIndex); } new Node({ outboundLayer: this, inboundLayers, nodeIndices, tensorIndices, inputTensors: inputTensorList, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes }, kwargs); for (let i = 0; i < outputTensors.length; i++) { outputTensors[i].sourceLayer = this; outputTensors[i].nodeIndex = this.inboundNodes.length - 1; outputTensors[i].tensorIndex = i; } } getConfig() { const config = { name: this.name, trainable: this.trainable }; if (this.batchInputShape != null) { config['batchInputShape'] = this.batchInputShape; } if (this.dtype != null) { config['dtype'] = this.dtype; } return config; } disposeWeights() { this.weights.forEach(weight => weight.dispose()); return this.weights.length; } assertNotDisposed() { if (this._refCount === 0) { throw new Error(`Layer '${this.name}' is already disposed.`); } } dispose() { if (!this.built) { throw new Error(`Cannot dispose Layer ${this.name} because it has not been ` + `built yet.`); } if (this._refCount === null) { throw new Error(`Cannot dispose Layer ${this.name} because it has not been used ` + `yet.`); } this.assertNotDisposed(); let numDisposedVariables = 0; if (--this._refCount === 0) { numDisposedVariables = this.disposeWeights(); } return { refCountAfterDispose: this._refCount, numDisposedVariables }; } } function collectInputShape(inputTensors) { inputTensors = toList(inputTensors); const shapes = []; for (const x of inputTensors) { shapes.push(x.shape); } return singletonOrArray(shapes); } function guessOutputDType(inputTensors) { return 'float32'; } function getSourceInputs(tensor, layer, nodeIndex) { if (layer == null || (nodeIndex != null && nodeIndex > 0)) { layer = tensor.sourceLayer; nodeIndex = tensor.nodeIndex; } if (layer.inboundNodes.length === 0) { return [tensor]; } else { const node = layer.inboundNodes[nodeIndex]; if (node.inboundLayers.length === 0) { return node.inputTensors; } else { const sourceTensors = []; for (let i = 0; i < node.inboundLayers.length; i++) { const x = node.inputTensors[i]; const layer = node.inboundLayers[i]; const nodeIndex = node.nodeIndices[i]; const previousSources = getSourceInputs(x, layer, nodeIndex); for (const x of previousSources) { if (sourceTensors.indexOf(x) === -1) { sourceTensors.push(x); } } } return sourceTensors; } } } function checkAllSymbolic(tensors) { let allAreSymbolic = true; for (const tensor of toList(tensors)) { if (!(tensor instanceof SymbolicTensor)) { allAreSymbolic = false; break; } } return allAreSymbolic; } function checkNoneSymbolic(tensors) { let noneAreSymbolic = true; for (const tensor of toList(tensors)) { if (tensor instanceof SymbolicTensor) { noneAreSymbolic = false; break; } } return noneAreSymbolic; } class InputLayer extends Layer { constructor(args) { super({ dtype: args.dtype, name: args.name != null ? args.name : getUid('input').toString() }); if (args.batchSize == null) { args.batchSize = null; } if (args.sparse == null) { args.sparse = false; } this.trainable = false; this.built = true; this.sparse = args.sparse; if (args.inputShape != null && args.batchInputShape != null) { throw new ValueError('Only provide the inputShape OR ' + 'batchInputShape argument to inputLayer, not both at the same time.'); } let batchInputShape = args.batchInputShape; if (batchInputShape == null) { if (args.inputShape == null) { throw new ValueError('An InputLayer should be passed either a ' + '`batchInputShape` or an `inputShape`.'); } else { batchInputShape = [args.batchSize].concat(args.inputShape); } } else { if (args.batchSize != null) { throw new ValueError('Cannot specify batchSize if batchInputShape is ' + 'specified when creating an InputLayer.'); } } const dtype = args.dtype || 'float32'; this.batchInputShape = batchInputShape; this.dtype = dtype; this.inputSpec = [{ shape: batchInputShape }]; const inputTensor = new SymbolicTensor(this.dtype, this.batchInputShape, this, [], {}, this.name); inputTensor.nodeIndex = 0; inputTensor.tensorIndex = 0; new Node({ outboundLayer: this, inboundLayers: [], nodeIndices: [], tensorIndices: [], inputTensors: [inputTensor], outputTensors: [inputTensor], inputMasks: [null], outputMasks: [null], inputShapes: [batchInputShape], outputShapes: [batchInputShape] }); } apply(inputs, kwargs) { throw new ValueError('Cannot pass any input to an ' + `InputLayer's apply() method. InputLayer name: ${this.name}`); } dispose() { return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 }; } getConfig() { return { batchInputShape: this.batchInputShape, dtype: this.dtype, sparse: this.sparse, name: this.name }; } } InputLayer.className = 'InputLayer'; registerClass(InputLayer); function Input(config) { if (config.batchShape == null && config.shape == null) { throw new Error('Please provide to Input either a `shape`' + ' or a `batchShape` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.'); } if (config.batchShape != null && config.shape != null) { throw new ValueError('Please provide either a `shape` or `batchShape` ' + 'argument to Input, but not both.'); } let batchShape = config.batchShape; if (config.shape != null && batchShape == null) { batchShape = [null].concat(config.shape); } let dtype = config.dtype; if (dtype == null) { dtype = 'float32'; } const inputLayer = new InputLayer({ batchInputShape: batchShape, name: config.name, dtype, sparse: config.sparse }); const outputs = inputLayer.inboundNodes[0].outputTensors; return outputs[0]; } function assertFeedCompatibility(key, val) { if (key.dtype == null || key.dtype === val.dtype) { return val; } try { return cast$2(val, key.dtype); } catch (err) { throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` + `of the key '${key.name}' (${key.dtype}).`); } } class FeedDict { constructor(feeds) { this.id2Value = {}; this.id2Mask = {}; this.name2Id = {}; if (feeds instanceof FeedDict) { for (const id in feeds.id2Value) { this.id2Value[id] = feeds.id2Value[id]; if (id in feeds.id2Mask) { this.id2Mask[id] = feeds.id2Mask[id]; } } } else { if (feeds == null) { return; } for (const feed of feeds) { this.add(feed.key, feed.value); } } } add(key, value, mask) { if (this.id2Value[key.id] == null) { this.id2Value[key.id] = assertFeedCompatibility(key, value); this.name2Id[key.name] = key.id; if (mask != null) { this.id2Mask[key.id] = mask; } } else { throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`); } return this; } addFeed(feed) { this.add(feed.key, feed.value); } hasKey(key) { return this.id2Value[key.id] != null; } names() { return Object.keys(this.name2Id); } getValue(key) { if (key instanceof SymbolicTensor) { if (this.id2Value[key.id] == null) { throw new ValueError(`Nonexistent key: ${key.name}`); } else { return this.id2Value[key.id]; } } else { const id = this.name2Id[key]; if (id == null) { throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`); } return this.id2Value[id]; } } getMask(key) { if (key instanceof SymbolicTensor) { if (this.id2Value[key.id] == null) { throw new ValueError(`Nonexistent key: ${key.name}`); } else { return this.id2Mask[key.id]; } } else { const id = this.name2Id[key]; if (id == null) { throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`); } return this.id2Mask[id]; } } disposeMasks() { if (this.id2Mask != null) { dispose(this.id2Mask); } } } const cachedSorted = new LruCache(); const cachedRecipientCounts = new LruCache(); function execute(fetches, feedDict, kwargs, probe) { const training = kwargs == null ? false : kwargs['training']; const arrayFetches = Array.isArray(fetches); const fetchArray = arrayFetches ? fetches : [fetches]; const outputNames = fetchArray.map(t => t.name); const finalOutputs = []; const feedNames = feedDict.names(); for (const outputName of outputNames) { if (feedNames.indexOf(outputName) !== -1) { finalOutputs.push(feedDict.getValue(outputName)); } else { finalOutputs.push(null); } } const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(','); let sorted = cachedSorted.get(fetchAndFeedKey); let recipientCounts; if (sorted == null) { const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict); sorted = out.sorted; recipientCounts = out.recipientCounts; cachedSorted.put(fetchAndFeedKey, sorted); cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts); } recipientCounts = {}; if (!training) { Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey)); } const internalFeedDict = new FeedDict(feedDict); for (let i = 0; i < sorted.length; ++i) { const symbolic = sorted[i]; const srcLayer = symbolic.sourceLayer; if (srcLayer instanceof InputLayer) { continue; } const inputValues = []; const inputMasks = []; const tensorsToDispose = []; let maskExists = false; for (const input of symbolic.inputs) { const value = internalFeedDict.getValue(input); const mask = internalFeedDict.getMask(input); inputValues.push(value); inputMasks.push(mask); if (mask != null) { maskExists = true; } if (!training) { recipientCounts[input.name]--; if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) && outputNames.indexOf(input.name) === -1 && !value.isDisposed && input.sourceLayer.stateful !== true) { tensorsToDispose.push(value); } } } if (maskExists) { kwargs = kwargs || {}; kwargs['mask'] = inputMasks[0]; } const outputTensors = toList(srcLayer.apply(inputValues, kwargs)); let outputMask = null; if (srcLayer.supportsMasking) { outputMask = srcLayer.computeMask(inputValues, inputMasks); } const layerOutputs = getNodeOutputs(symbolic); const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs]; for (let i = 0; i < outputSymbolicTensors.length; ++i) { if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) { internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask); } const index = outputNames.indexOf(outputSymbolicTensors[i].name); if (index !== -1) { finalOutputs[index] = outputTensors[i]; } } if (!training) { dispose(tensorsToDispose); } } internalFeedDict.disposeMasks(); return arrayFetches ? finalOutputs : finalOutputs[0]; } function getTopologicalSortAndRecipientCounts(fetches, feedDict) { assert$1(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`); let finalSorted = []; let finalRecipientMap = {}; if (fetches.length === 1) { const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict); finalSorted = out.sorted; finalRecipientMap = out.recipientMap; } else { const visited = new Set(); for (const fetch of fetches) { const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict); for (const symbolicTensor of sorted) { if (!visited.has(symbolicTensor.name)) { finalSorted.push(symbolicTensor); visited.add(symbolicTensor.name); } } for (const name in recipientMap) { if (finalRecipientMap[name] == null) { finalRecipientMap[name] = new Set(); } recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient)); } } } return { sorted: finalSorted, recipientCounts: recipientMap2Counts(finalRecipientMap) }; } function recipientMap2Counts(recipientMap) { const recipientCounts = {}; for (const name in recipientMap) { recipientCounts[name] = recipientMap[name].size; } return recipientCounts; } function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) { const visited = new Set(); const sorted = []; const recipientMap = {}; for (const key of feedDict.names()) { visited.add(key); } const stack = []; const marks = []; stack.push(fetch); while (stack.length > 0) { const top = stack[stack.length - 1]; if (visited.has(top.name)) { stack.pop(); continue; } const topIsMarked = marks[marks.length - 1] === stack.length - 1; if (top.inputs.length === 0 || topIsMarked) { stack.pop(); sorted.push(top); visited.add(top.name); if (topIsMarked) { marks.pop(); } } else { marks.push(stack.length - 1); for (const input of top.inputs) { if (recipientMap[input.name] == null) { recipientMap[input.name] = new Set(); } recipientMap[input.name].add(top.name); if (visited.has(input.name)) { continue; } stack.push(input); } } } return { sorted, recipientMap }; } function getNodeOutputs(fetch) { let layerOutputs; if (fetch.sourceLayer.inboundNodes.length === 1) { layerOutputs = fetch.sourceLayer.output; } else { let nodeIndex = null; for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) { for (const outputTensor of fetch.sourceLayer.inboundNodes[i] .outputTensors) { if (outputTensor.id === fetch.id) { nodeIndex = i; break; } } } layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex); } return layerOutputs; } function calcL2Norms(w, axis) { return tidy(() => sqrt$1(sum$1(mul(w, w), axis, true))); } class Constraint extends Serializable { getConfig() { return {}; } } class MaxNorm extends Constraint { constructor(args) { super(); this.defaultMaxValue = 2; this.defaultAxis = 0; this.maxValue = args.maxValue != null ? args.maxValue : this.defaultMaxValue; this.axis = args.axis != null ? args.axis : this.defaultAxis; } apply(w) { return tidy(() => { const norms = calcL2Norms(w, this.axis); const desired = clipByValue$1(norms, 0, this.maxValue); return mul(w, div(desired, add(epsilon(), norms))); }); } getConfig() { return { maxValue: this.maxValue, axis: this.axis }; } } MaxNorm.className = 'MaxNorm'; registerClass(MaxNorm); class UnitNorm extends Constraint { constructor(args) { super(); this.defaultAxis = 0; this.axis = args.axis != null ? args.axis : this.defaultAxis; } apply(w) { return tidy(() => div(w, add(epsilon(), calcL2Norms(w, this.axis)))); } getConfig() { return { axis: this.axis }; } } UnitNorm.className = 'UnitNorm'; registerClass(UnitNorm); class NonNeg extends Constraint { apply(w) { return relu$1(w); } } NonNeg.className = 'NonNeg'; registerClass(NonNeg); class MinMaxNorm extends Constraint { constructor(args) { super(); this.defaultMinValue = 0.0; this.defaultMaxValue = 1.0; this.defaultRate = 1.0; this.defaultAxis = 0; this.minValue = args.minValue != null ? args.minValue : this.defaultMinValue; this.maxValue = args.maxValue != null ? args.maxValue : this.defaultMaxValue; this.rate = args.rate != null ? args.rate : this.defaultRate; this.axis = args.axis != null ? args.axis : this.defaultAxis; } apply(w) { return tidy(() => { const norms = calcL2Norms(w, this.axis); const desired = add(mul(this.rate, clipByValue$1(norms, this.minValue, this.maxValue)), mul(1.0 - this.rate, norms)); return mul(w, div(desired, add(epsilon(), norms))); }); } getConfig() { return { minValue: this.minValue, maxValue: this.maxValue, rate: this.rate, axis: this.axis }; } } MinMaxNorm.className = 'MinMaxNorm'; registerClass(MinMaxNorm); const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = { 'maxNorm': 'MaxNorm', 'minMaxNorm': 'MinMaxNorm', 'nonNeg': 'NonNeg', 'unitNorm': 'UnitNorm' }; function serializeConstraint(constraint) { return serializeKerasObject(constraint); } function deserializeConstraint(config, customObjects = {}) { return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint'); } function getConstraint(identifier) { if (identifier == null) { return null; } if (typeof identifier === 'string') { const className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ? CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier; const config = { className, config: {} }; return deserializeConstraint(config); } else if (identifier instanceof Constraint) { return identifier; } else { return deserializeConstraint(identifier); } } function glorotUniform(args) { return new GlorotUniform(args); } async function resolveScalarsInLogs(logs) { if (logs == null) { return; } const promises = []; const keys = []; const scalarsToDispose = []; for (const key in logs) { const value = logs[key]; if (typeof value !== 'number') { const valueScalar = value; promises.push(valueScalar.data()); keys.push(key); scalarsToDispose.push(valueScalar); } } if (promises.length > 0) { const values = await Promise.all(promises); for (let i = 0; i < values.length; ++i) { logs[keys[i]] = values[i][0]; } dispose(scalarsToDispose); } } function disposeTensorsInLogs(logs) { if (logs == null) { return; } for (const key in logs) { const value = logs[key]; if (typeof value !== 'number') { value.dispose(); } } } var ModelLoggingVerbosity; (function (ModelLoggingVerbosity) { ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT"; ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE"; })(ModelLoggingVerbosity || (ModelLoggingVerbosity = {})); const DEFAULT_YIELD_EVERY_MS = 125; class BaseCallback { constructor() { this.validationData = null; } setParams(params) { this.params = params; } async onEpochBegin(epoch, logs) { } async onEpochEnd(epoch, logs) { } async onBatchBegin(batch, logs) { } async onBatchEnd(batch, logs) { } async onTrainBegin(logs) { } async onTrainEnd(logs) { } setModel(model) { } } class CallbackList { constructor(callbacks, queueLength = 10) { if (callbacks == null) { callbacks = []; } this.callbacks = callbacks; this.queueLength = queueLength; } append(callback) { this.callbacks.push(callback); } setParams(params) { for (const callback of this.callbacks) { callback.setParams(params); } } setModel(model) { for (const callback of this.callbacks) { callback.setModel(model); } } async onEpochBegin(epoch, logs) { if (logs == null) { logs = {}; } for (const callback of this.callbacks) { await callback.onEpochBegin(epoch, logs); } } async onEpochEnd(epoch, logs) { if (logs == null) { logs = {}; } for (const callback of this.callbacks) { await callback.onEpochEnd(epoch, logs); } } async onBatchBegin(batch, logs) { if (logs == null) { logs = {}; } for (const callback of this.callbacks) { await callback.onBatchBegin(batch, logs); } } async onBatchEnd(batch, logs) { if (logs == null) { logs = {}; } for (const callback of this.callbacks) { await callback.onBatchEnd(batch, logs); } } async onTrainBegin(logs) { if (logs == null) { logs = {}; } for (const callback of this.callbacks) { await callback.onTrainBegin(logs); } } async onTrainEnd(logs) { if (logs == null) { logs = {}; } for (const callback of this.callbacks) { await callback.onTrainEnd(logs); } } } class BaseLogger extends BaseCallback { constructor() { super(); } async onEpochBegin(epoch) { this.seen = 0; this.totals = {}; } async onBatchEnd(batch, logs) { if (logs == null) { logs = {}; } const batchSize = logs['size'] == null ? 0 : logs['size']; this.seen += batchSize; for (const key in logs) { const value = logs[key]; if (typeof value === 'number') { if (!this.totals.hasOwnProperty(key)) { this.totals[key] = 0; } this.totals[key] = this.totals[key] + value * batchSize; } else { let oldTotalsToDispose; if (key in this.totals) { oldTotalsToDispose = this.totals[key]; } else { this.totals[key] = 0; } const total = tidy(() => add((this.totals[key]), mul(value, batchSize))); this.totals[key] = total; if (oldTotalsToDispose != null) { oldTotalsToDispose.dispose(); } } } } async onEpochEnd(epoch, logs) { if (logs != null) { for (const key of this.params['metrics']) { if (this.totals[key] == null) { continue; } if (typeof this.totals[key] === 'number') { logs[key] = this.totals[key] / this.seen; } else { tidy(() => { const log = mul(div(1, this.seen), this.totals[key]); logs[key] = log; this.totals[key].dispose(); keep(logs[key]); }); } } } } } class History extends BaseCallback { async onTrainBegin(logs) { this.epoch = []; this.history = {}; } async onEpochEnd(epoch, logs) { if (logs == null) { logs = {}; } this.epoch.push(epoch); for (const key in logs) { if (this.history[key] == null) { this.history[key] = []; } this.history[key].push(logs[key]); } } async syncData() { const promises = []; const keys = []; const indices = []; for (const key in this.history) { const valueArray = this.history[key]; for (let i = 0; i < valueArray.length; ++i) { if (typeof valueArray[i] !== 'number') { const valueScalar = valueArray[i]; promises.push(valueScalar.data()); keys.push(key); indices.push(i); } } } const values = await Promise.all(promises); for (let n = 0; n < values.length; ++n) { const tensorToDispose = this.history[keys[n]][indices[n]]; tensorToDispose.dispose(); this.history[keys[n]][indices[n]] = values[n][0]; } } } class CustomCallback extends BaseCallback { constructor(args, yieldEvery) { super(); this.currentEpoch = 0; this.nowFunc = args.nowFunc; this.nextFrameFunc = args.nextFrameFunc || nextFrame; this.yieldEvery = yieldEvery || 'auto'; if (this.yieldEvery === 'auto') { this.yieldEvery = DEFAULT_YIELD_EVERY_MS; } if (this.yieldEvery === 'never' && args.onYield != null) { throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' + 'Either change `yieldEvery` or remove the callback'); } if (isNumber(this.yieldEvery)) { this.maybeWait = debounce(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc); } this.trainBegin = args.onTrainBegin; this.trainEnd = args.onTrainEnd; this.epochBegin = args.onEpochBegin; this.epochEnd = args.onEpochEnd; this.batchBegin = args.onBatchBegin; this.batchEnd = args.onBatchEnd; this.yield = args.onYield; } async maybeWait(epoch, batch, logs) { const ps = []; if (this.yield != null) { await resolveScalarsInLogs(logs); ps.push(this.yield(epoch, batch, logs)); } ps.push(this.nextFrameFunc()); await Promise.all(ps); } async onEpochBegin(epoch, logs) { this.currentEpoch = epoch; if (this.epochBegin != null) { await resolveScalarsInLogs(logs); await this.epochBegin(epoch, logs); } } async onEpochEnd(epoch, logs) { const ps = []; if (this.epochEnd != null) { await resolveScalarsInLogs(logs); ps.push(this.epochEnd(epoch, logs)); } if (this.yieldEvery === 'epoch') { ps.push(this.nextFrameFunc()); } await Promise.all(ps); } async onBatchBegin(batch, logs) { if (this.batchBegin != null) { await resolveScalarsInLogs(logs); await this.batchBegin(batch, logs); } } async onBatchEnd(batch, logs) { const ps = []; if (this.batchEnd != null) { await resolveScalarsInLogs(logs); ps.push(this.batchEnd(batch, logs)); } if (this.yieldEvery === 'batch') { ps.push(this.nextFrameFunc()); } else if (isNumber(this.yieldEvery)) { ps.push(this.maybeWait(this.currentEpoch, batch, logs)); } await Promise.all(ps); } async onTrainBegin(logs) { if (this.trainBegin != null) { await resolveScalarsInLogs(logs); await this.trainBegin(logs); } } async onTrainEnd(logs) { if (this.trainEnd != null) { await resolveScalarsInLogs(logs); await this.trainEnd(logs); } } } function standardizeCallbacks(callbacks, yieldEvery) { if (callbacks == null) { callbacks = {}; } if (callbacks instanceof BaseCallback) { return [callbacks]; } if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) { return callbacks; } const callbackConfigs = toList(callbacks); return callbackConfigs.map(callbackConfig => new CustomCallback(callbackConfig, yieldEvery)); } class CallbackConstructorRegistry { constructor() { } static registerCallbackConstructor(verbosityLevel, callbackConstructor) { assert$1(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), () => `Verbosity level is expected to be an integer >= 0, ` + `but got ${verbosityLevel}`); CallbackConstructorRegistry.checkForDuplicate(callbackConstructor); if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) { CallbackConstructorRegistry.constructors[verbosityLevel] = []; } CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor); } static checkForDuplicate(callbackConstructor) { for (const levelName in CallbackConstructorRegistry.constructors) { const constructors = CallbackConstructorRegistry.constructors[+levelName]; constructors.forEach(ctor => { if (ctor === callbackConstructor) { throw new ValueError('Duplicate callback constructor.'); } }); } } static clear() { CallbackConstructorRegistry.constructors = {}; } static createCallbacks(verbosityLevel) { const constructors = []; for (const levelName in CallbackConstructorRegistry.constructors) { const level = +levelName; if (verbosityLevel >= level) { constructors.push(...CallbackConstructorRegistry.constructors[level]); } } return constructors.map(ctor => new ctor()); } } CallbackConstructorRegistry.constructors = {}; function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) { const history = new History(); const actualCallbacks = [ new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose) ]; if (callbacks != null) { actualCallbacks.push(...callbacks); } actualCallbacks.push(history); const callbackList = new CallbackList(actualCallbacks); callbackList.setParams({ epochs, initialEpoch, samples: numTrainSamples, steps: stepsPerEpoch, batchSize, verbose, doValidation, metrics: callbackMetrics, }); return { callbackList, history }; } function deserialize(config, customObjects = {}, fastWeightInit = false) { return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit); } function l2Normalize(x, axis) { return tidy(() => { if (x.dtype !== 'float32') { x = cast$2(x, 'float32'); } const squareSum = sum$1(square(x), axis, true); const epsilonTensor = fill$1(squareSum.shape, epsilon()); const norm = sqrt$1(maximum$1(squareSum, epsilonTensor)); return div(x, norm); }); } function meanSquaredError(yTrue, yPred) { return tidy(() => mean(square(sub$1(yPred, yTrue)), -1)); } function meanAbsoluteError(yTrue, yPred) { return tidy(() => mean(abs$1(sub$1(yPred, yTrue)), -1)); } function meanAbsolutePercentageError(yTrue, yPred) { return tidy(() => { const diff = sub$1(yTrue, yPred); const clippedTrue = clipByValue$1(abs$1(yTrue), epsilon(), Number.MAX_VALUE); const absResult = abs$1(div(diff, clippedTrue)); return mul(100, mean(absResult, -1)); }); } function meanSquaredLogarithmicError(yTrue, yPred) { return tidy(() => { const clippedPred = clipByValue$1(yPred, epsilon(), Number.MAX_VALUE); const firstLog = log$1(add(1, clippedPred)); const clippedTrue = clipByValue$1(yTrue, epsilon(), Number.MAX_VALUE); const secondLog = log$1(add(1, clippedTrue)); return mean(square(sub$1(firstLog, secondLog)), -1); }); } function squaredHinge(yTrue, yPred) { return tidy(() => { const maxResult = maximum$1(0, sub$1(1, mul(yTrue, yPred))); return mean(square(maxResult), -1); }); } function hinge(yTrue, yPred) { return tidy(() => { const maxResult = maximum$1(0, sub$1(1, mul(yTrue, yPred))); return mean(maxResult, -1); }); } function categoricalHinge(yTrue, yPred) { return tidy(() => { const pos = sum$1(mul(yTrue, yPred), -1); const neg = max$1(mul(sub$1(1, yTrue), yPred), -1); return maximum$1(0, add(1, sub$1(neg, pos))); }); } function logcosh(yTrue, yPred) { return tidy(() => { const log2 = Math.log(2); const predictionDiff = sub$1(yPred, yTrue); const logcoshResult = sub$1(add(predictionDiff, softplus$1(mul(-2, predictionDiff))), log2); return mean(logcoshResult, -1); }); } function categoricalCrossentropy$1(target, output, fromLogits = false) { return tidy(() => { if (fromLogits) { output = softmax$1(output); } else { const outputSum = sum$1(output, output.shape.length - 1, true); output = div(output, outputSum); } output = clipByValue$1(output, epsilon(), 1 - epsilon()); return neg$1(sum$1(mul(cast$2(target, 'float32'), log$1(output)), output.shape.length - 1)); }); } function sparseCategoricalCrossentropy$1(target, output, fromLogits = false) { return tidy(() => { const flatTarget = cast$2(floor$1(flatten(target)), 'int32'); output = clipByValue$1(output, epsilon(), 1 - epsilon()); const outputShape = output.shape; const oneHotTarget = reshape$1(oneHot$1(flatTarget, outputShape[outputShape.length - 1]), outputShape); return categoricalCrossentropy$1(oneHotTarget, output, fromLogits); }); } function sigmoidCrossEntropyWithLogits(labels, logits) { if (!arraysEqual(labels.shape, logits.shape)) { throw new ValueError(`logits and labels must have the same shape, but got shapes ` + `${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`); } return tidy(() => { const reluLogits = relu$1(logits); const negAbsLogits = neg$1(abs$1(logits)); return add(sub$1(reluLogits, mul(logits, labels)), log1p$1(exp$1(negAbsLogits))); }); } function binaryCrossentropy$1(yTrue, yPred) { return tidy(() => { let y; y = clipByValue$1(yPred, epsilon(), 1 - epsilon()); y = log$1(div(y, sub$1(1, y))); return mean(sigmoidCrossEntropyWithLogits(yTrue, y), -1); }); } function kullbackLeiblerDivergence(yTrue, yPred) { return tidy(() => { const clippedTrue = clipByValue$1(yTrue, epsilon(), 1); const clippedPred = clipByValue$1(yPred, epsilon(), 1); return sum$1(mul(yTrue, log$1(div(clippedTrue, clippedPred))), -1); }); } function poisson(yTrue, yPred) { return tidy(() => { const logPred = log$1(add(epsilon(), yPred)); return mean(sub$1(yPred, mul(yTrue, logPred)), -1); }); } function cosineProximity(yTrue, yPred) { return tidy(() => { const trueNormalized = l2Normalize(yTrue, -1); const predNormalized = l2Normalize(yPred, -1); const trueXPred = mul(trueNormalized, predNormalized); return neg$1(sum$1(trueXPred, -1)); }); } const lossesMap = { meanSquaredError, meanAbsoluteError, meanAbsolutePercentageError, meanSquaredLogarithmicError, squaredHinge, hinge, categoricalHinge, logcosh, categoricalCrossentropy: categoricalCrossentropy$1, sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1, binaryCrossentropy: binaryCrossentropy$1, kullbackLeiblerDivergence, poisson, cosineProximity }; function get$1(identifierOrFn) { if (typeof identifierOrFn === 'string') { if (identifierOrFn in lossesMap) { return lossesMap[identifierOrFn]; } let errMsg = `Unknown loss ${identifierOrFn}`; if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) { errMsg = `Unknown loss ${identifierOrFn}. ` + 'Use "categoricalCrossentropy" as the string name for ' + 'tf.losses.softmaxCrossEntropy'; } throw new ValueError(errMsg); } else { return identifierOrFn; } } function binaryAccuracy(yTrue, yPred) { return tidy(() => { const threshold = mul(.5, onesLike$1(yPred)); const yPredThresholded = cast(greater$1(yPred, threshold), yTrue.dtype); return mean(equal$1(yTrue, yPredThresholded), -1); }); } function categoricalAccuracy(yTrue, yPred) { return tidy(() => cast(equal$1(argMax$1(yTrue, -1), argMax$1(yPred, -1)), 'float32')); } function truePositives(yTrue, yPred) { return tidy(() => { return cast$2(sum$1(logicalAnd$1(equal$1(yTrue, 1), equal$1(yPred, 1))), 'float32'); }); } function falsePositives(yTrue, yPred) { return tidy(() => { return cast$2(sum$1(logicalAnd$1(equal$1(yTrue, 0), equal$1(yPred, 1))), 'float32'); }); } function precision(yTrue, yPred) { return tidy(() => { const tp = truePositives(yTrue, yPred); const fp = falsePositives(yTrue, yPred); const denominator = add(tp, fp); return cast$2(where(greater$1(denominator, 0), div(tp, denominator), 0), 'float32'); }); } function binaryCrossentropy(yTrue, yPred) { return binaryCrossentropy$1(yTrue, yPred); } function sparseCategoricalAccuracy(yTrue, yPred) { if (yTrue.rank === yPred.rank) { yTrue = squeeze(yTrue, [yTrue.rank - 1]); } yPred = argMax$1(yPred, -1); if (yPred.dtype !== yTrue.dtype) { yPred = cast$2(yPred, yTrue.dtype); } return cast$2(equal$1(yTrue, yPred), 'float32'); } const mse = meanSquaredError; const MSE = meanSquaredError; const mae = meanAbsoluteError; const MAE = meanAbsoluteError; const mape = meanAbsolutePercentageError; const MAPE = meanAbsolutePercentageError; const categoricalCrossentropy = categoricalCrossentropy$1; const cosine = cosineProximity; const sparseCategoricalCrossentropy = sparseCategoricalCrossentropy$1; const metricsMap = { binaryAccuracy, categoricalAccuracy, precision, categoricalCrossentropy, sparseCategoricalCrossentropy, mse, MSE, mae, MAE, mape, MAPE, cosine }; function get(identifier) { if (typeof identifier === 'string' && identifier in metricsMap) { return metricsMap[identifier]; } else if (typeof identifier !== 'string' && identifier != null) { return identifier; } else { throw new ValueError(`Unknown metric ${identifier}`); } } function getLossOrMetricName(fn) { assert(fn !== null, `Unknown LossOrMetricFn ${fn}`); if (typeof fn === 'string') { return fn; } else { let fnName; for (const key of Object.keys(lossesMap)) { if (lossesMap[key] === fn) { fnName = key; break; } } if (fnName !== undefined) { return fnName; } for (const key of Object.keys(metricsMap)) { if (metricsMap[key] === fn) { fnName = key; break; } } if (fnName !== undefined) { return fnName; } return fn.name; } } function getOptimizer(identifier) { const optimizerMap = { 'Adagrad': () => train.adagrad(0.01), 'Adadelta': () => train.adadelta(1, 0.95, epsilon()), 'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()), 'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0), 'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()), 'SGD': () => train.sgd(0.01) }; optimizerMap['adagrad'] = optimizerMap['Adagrad']; optimizerMap['adadelta'] = optimizerMap['Adadelta']; optimizerMap['adam'] = optimizerMap['Adam']; optimizerMap['adamax'] = optimizerMap['Adamax']; optimizerMap['rmsprop'] = optimizerMap['RMSProp']; optimizerMap['sgd'] = optimizerMap['SGD']; if (identifier in optimizerMap) { return optimizerMap[identifier](); } throw new ValueError(`Unknown Optimizer ${identifier}`); } const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024; function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize = false) { if (userDefinedMetadata == null || typeof userDefinedMetadata !== 'object' || Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype || !plainObjectCheck(userDefinedMetadata)) { throw new Error('User-defined metadata is expected to be a JSON object, but is not.'); } if (checkSize) { const out = JSON.stringify(userDefinedMetadata); if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) { console.warn(`User-defined metadata of model "${modelName}" is too large in ` + `size (length=${out.length} when serialized). It is not ` + `recommended to store such large objects in user-defined metadata. ` + `Please make sure its serialized length is <= ` + `${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`); } } } function plainObjectCheck(x) { if (x === null) { return true; } else if (typeof x === 'object') { if (Object.getPrototypeOf(x) === Object.prototype) { const keys = Object.keys(x); for (const key of keys) { if (typeof key !== 'string') { return false; } if (!plainObjectCheck(x[key])) { return false; } } return true; } else { if (Array.isArray(x)) { for (const item of x) { if (!plainObjectCheck(item)) { return false; } } return true; } else { return false; } } } else { const xType = typeof x; return xType === 'string' || xType === 'number' || xType === 'boolean'; } } function printSummary(model, lineLength, positions, printFn = console.log) { const sequentialLike = isModelSequentialLike(model); const toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #']; if (sequentialLike) { lineLength = lineLength || 90; positions = positions || [0.32, 0.61, 0.89, 1]; } else { lineLength = lineLength || 115; positions = positions || [0.24, 0.48, 0.70, 0.80, 1]; } if (positions[positions.length - 1] <= 1) { positions = positions.map(p => Math.floor(lineLength * p)); } let relevantNodes; if (!sequentialLike) { toDisplay.push('Receives inputs'); relevantNodes = []; for (const depth in model.nodesByDepth) { relevantNodes.push(...model.nodesByDepth[depth]); } } printFn('_'.repeat(lineLength)); printRow(toDisplay, positions, printFn); printFn('='.repeat(lineLength)); const layers = model.layers; for (let i = 0; i < layers.length; ++i) { if (sequentialLike) { printLayerSummary(layers[i], positions, printFn); } else { printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn); } printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength)); } model.checkTrainableWeightsConsistency(); const trainableCount = countTrainableParams(model); const nonTrainableCount = countParamsInWeights(model.nonTrainableWeights); printFn(`Total params: ${trainableCount + nonTrainableCount}`); printFn(`Trainable params: ${trainableCount}`); printFn(`Non-trainable params: ${nonTrainableCount}`); printFn('_'.repeat(lineLength)); } function countTrainableParams(model) { let trainableCount; if (model.collectedTrainableWeights != null) { trainableCount = countParamsInWeights(model.collectedTrainableWeights); } else { trainableCount = countParamsInWeights(model.trainableWeights); } return trainableCount; } function isModelSequentialLike(model) { let sequentialLike = true; const nodesByDepth = []; const nodes = []; for (const depth in model.nodesByDepth) { nodesByDepth.push(model.nodesByDepth[depth]); } for (const depthNodes of nodesByDepth) { if (depthNodes.length > 1 || depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) { sequentialLike = false; break; } nodes.push(...depthNodes); } if (sequentialLike) { for (const layer of model.layers) { let flag = false; for (const node of layer.inboundNodes) { if (nodes.indexOf(node) !== -1) { if (flag) { sequentialLike = false; break; } else { flag = true; } } } if (!sequentialLike) { break; } } } return sequentialLike; } function printRow(fields, positions, printFn = console.log) { let line = ''; for (let i = 0; i < fields.length; ++i) { if (i > 0) { line = line.slice(0, line.length - 1) + ' '; } line += fields[i]; line = line.slice(0, positions[i]); line += ' '.repeat(positions[i] - line.length); } printFn(line); } function printLayerSummary(layer, positions, printFn) { let outputShape; let inputShape; try { inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(','); } catch (err) { inputShape = 'multiple'; } try { outputShape = JSON.stringify(layer.outputShape); } catch (err) { outputShape = 'multiple'; } const name = layer.name; const className = layer.getClassName(); const fields = [`${name} (${className})`, inputShape, outputShape, layer.countParams().toString()]; printRow(fields, positions, printFn); } function printLayerSummaryWithConnections(layer, positions, relevantNodes, printFn) { let outputShape; let inputShape; try { inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(','); } catch (err) { inputShape = 'multiple'; } try { outputShape = JSON.stringify(layer.outputShape); } catch (err) { outputShape = 'multiple'; } const connections = []; for (const node of layer.inboundNodes) { if (relevantNodes != null && relevantNodes.length > 0 && relevantNodes.indexOf(node) === -1) { continue; } for (let i = 0; i < node.inboundLayers.length; ++i) { const inboundLayer = node.inboundLayers[i].name; const inboundLayerIndex = node.nodeIndices[i]; const inboundTensorIndex = node.tensorIndices[i]; connections.push(`${inboundLayer}[${inboundLayerIndex}][${inboundTensorIndex}]`); } } const name = layer.name; const className = layer.getClassName(); const firstConnection = connections.length === 0 ? '' : connections[0]; const fields = [ `${name} (${className})`, inputShape, outputShape, layer.countParams().toString(), firstConnection ]; printRow(fields, positions, printFn); for (let i = 1; i < connections.length; ++i) { printRow(['', '', '', '', connections[i]], positions, printFn); } } function isArrayItemInputOrOutputName(key, index, value) { return (key === 'inboundNodes' || key === 'outputLayers' || key === 'inputLayers') && index === 0 && typeof value === 'string'; } function convertPythonicToTs(pythonicConfig, key) { if (pythonicConfig === null) { return null; } else if (typeof pythonicConfig === 'string') { return toCamelCase(pythonicConfig); } else if ((typeof pythonicConfig === 'number') || (typeof pythonicConfig === 'boolean')) { return pythonicConfig; } else if (pythonicConfig instanceof Array) { const tsArray = []; const arrayLength = pythonicConfig.length; for (let i = 0; i < arrayLength; ++i) { const item = pythonicConfig[i]; if (isArrayItemInputOrOutputName(key, i, item)) { tsArray.push(item); } else { tsArray.push(convertPythonicToTs(item, key)); } } return tsArray; } else { const tsDict = {}; for (const pythonicKey of Object.keys(pythonicConfig)) { const pythonicValue = pythonicConfig[pythonicKey]; if (pythonicKey === 'name' && typeof pythonicValue === 'string') { tsDict[pythonicKey] = pythonicValue; } else { const tsKey = toCamelCase(pythonicKey); tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey); } } return tsDict; } } function convertTsToPythonic(tsConfig, key) { if (tsConfig === null || tsConfig === undefined) { return null; } else if (typeof tsConfig === 'string') { return toSnakeCase(tsConfig); } else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) { return tsConfig; } else if (tsConfig instanceof Array) { const pyArray = []; const arrayLength = tsConfig.length; for (let i = 0; i < arrayLength; ++i) { const item = tsConfig[i]; if (isArrayItemInputOrOutputName(key, i, item)) { pyArray.push(item); } else { pyArray.push(convertTsToPythonic(item, key)); } } return pyArray; } else { const pyDict = {}; for (const tsKey of Object.keys(tsConfig)) { const tsValue = tsConfig[tsKey]; const pyKey = toSnakeCase(tsKey); if ((tsKey === 'name' || tsKey === 'className') && typeof tsValue === 'string') { pyDict[pyKey] = tsValue; } else { pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey); } } return pyDict; } } const version = '4.22.0'; const isKerasSavedModelFormat = (weights) => { const keys = Object.keys(weights); if (keys.length === 0) { return false; } const key = keys[0].split('/'); return !isNaN(parseInt(key[key.length - 1], 10)); }; class Container extends Layer { constructor(args) { super({}); this.containerNodes = new Set(); this.name = args.name; if (this.name == null) { const prefix = this.getClassName().toLowerCase(); this.name = getUid(prefix); } this.supportsMasking = false; this.trainable_ = true; if (Array.isArray(args.inputs)) { this.inputs = args.inputs.slice(); } else { this.inputs = [args.inputs]; } if (Array.isArray(args.outputs)) { this.outputs = args.outputs.slice(); } else { this.outputs = [args.outputs]; } if (unique(this.inputs).length !== this.inputs.length) { throw new ValueError('The list of inputs passed to the model is ' + 'redundant. All inputs should only appear once. Found: ' + `${this.inputs.map(x => x.name)}`); } if (unique(this.outputs).length !== this.outputs.length) { console.warn('The list of outputs passed to the model is redundant. ' + 'All outputs should only appear once. Found: ' + `${this.outputs.map(x => x.name)}`); } this.inputLayers = []; this.inputLayersNodeIndices = []; this.inputLayersTensorIndices = []; this.outputLayers = []; this.outputLayersNodeIndices = []; this.outputLayersTensorIndices = []; this.layers = []; this.internalContainerRefs = []; for (const x of this.outputs) { const layer = x.sourceLayer; const nodeIndex = x.nodeIndex; const tensorIndex = x.tensorIndex; this.outputLayers.push(layer); this.outputLayersNodeIndices.push(nodeIndex); this.outputLayersTensorIndices.push(tensorIndex); } for (const x of this.inputs) { const layer = x.sourceLayer; const nodeIndex = x.nodeIndex; const tensorIndex = x.tensorIndex; assert(nodeIndex === 0, 'input layer has >1 nodes'); assert(tensorIndex === 0, 'input layer has >1 tensors'); this.inputLayers.push(layer); this.inputLayersNodeIndices.push(nodeIndex); this.inputLayersTensorIndices.push(tensorIndex); } this.inputNames = []; this.outputNames = []; this.feedInputShapes = []; this.feedInputNames = []; this.feedOutputNames = []; for (let i = 0; i < this.inputLayers.length; i++) { const layer = this.inputLayers[i]; if (!(layer instanceof InputLayer)) { throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' + `Received inputs: ${args.inputs}. ` + `Input ${i} (0-based) originates ` + `from layer type ${layer.getClassName()}.`); } this.inputNames.push(layer.name); this.feedInputShapes.push(layer.batchInputShape); this.feedInputNames.push(layer.name); } for (const layer of this.outputLayers) { this.outputNames.push(layer.name); } this.internalInputShapes = this.inputs.map(x => x.shape); this.internalOutputShapes = this.outputs.map(x => x.shape); const nodesDepths = {}; const nodeIDToNode = {}; const layersDepths = {}; const layerIDToLayer = {}; const layerIndices = {}; const nodesInDecreasingDepth = []; const buildMapOfGraph = (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) => { if (layer == null || nodeIndex == null || tensorIndex == null) { layer = tensor.sourceLayer; nodeIndex = tensor.nodeIndex; tensorIndex = tensor.tensorIndex; } const node = layer.inboundNodes[nodeIndex]; if (nodesInProgress.indexOf(node) !== -1) { throw new RuntimeError(`The tensor ${tensor.name} at layer "${layer.name}" ` + 'is part of a cycle.'); } if (finishedNodes.indexOf(node) !== -1) { return; } this.containerNodes.add(Container.nodeKey(layer, nodeIndex)); if (!(layer.id in layerIndices)) { layerIndices[layer.id] = Object.keys(layerIndices).length; } if (nodesInProgress.indexOf(node) === -1) { nodesInProgress.push(node); } const numInboundLayers = node.inboundLayers.length; for (let i = 0; i < numInboundLayers; i++) { const x = node.inputTensors[i]; const layer = node.inboundLayers[i]; const nodeIndex = node.nodeIndices[i]; const tensorIndex = node.tensorIndices[i]; buildMapOfGraph(x, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex); } finishedNodes.push(node); while (nodesInProgress.indexOf(node) >= 0) { nodesInProgress.splice(nodesInProgress.indexOf(node), 1); } nodesInDecreasingDepth.push(node); }; const finishedNodes = []; const nodesInProgress = []; for (const x of this.outputs) { buildMapOfGraph(x, finishedNodes, nodesInProgress); } const reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse(); for (const node of reversedNodesInDecreasingDepth) { nodeIDToNode[node.id] = node; if (!(node.id in nodesDepths)) { nodesDepths[node.id] = 0; } let depth = nodesDepths[node.id]; const previousDepth = (layersDepths[node.outboundLayer.id] == null ? 0 : layersDepths[node.outboundLayer.id]); depth = Math.max(depth, previousDepth); layersDepths[node.outboundLayer.id] = depth; layerIDToLayer[node.outboundLayer.id] = node.outboundLayer; nodesDepths[node.id] = depth; for (let i = 0; i < node.inboundLayers.length; i++) { const inboundLayer = node.inboundLayers[i]; const nodeIndex = node.nodeIndices[i]; const inboundNode = inboundLayer.inboundNodes[nodeIndex]; const previousDepth = (nodesDepths[inboundNode.id] == null ? 0 : nodesDepths[inboundNode.id]); nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth); nodeIDToNode[inboundNode.id] = inboundNode; } } const nodesByDepth = {}; for (const nodeID in nodesDepths) { const depth = nodesDepths[nodeID]; if (!(depth in nodesByDepth)) { nodesByDepth[depth] = []; } nodesByDepth[depth].push(nodeIDToNode[nodeID]); } const layersByDepth = {}; for (const layerID in layersDepths) { const depth = layersDepths[layerID]; if (!(depth in layersByDepth)) { layersByDepth[depth] = []; } layersByDepth[depth].push(layerIDToLayer[layerID]); } let depthKeys = Object.keys(layersByDepth) .map(x => parseInt(x, 10)) .sort(reverseNumberCompare); this.layers = []; for (const depth of depthKeys) { const layersForDepth = layersByDepth[depth]; layersForDepth.sort((a, b) => { const aIndex = layerIndices[a.id]; const bIndex = layerIndices[b.id]; if (aIndex < bIndex) { return -1; } if (aIndex > bIndex) { return 1; } return 0; }); for (const layer of layersForDepth) { if (layer instanceof Container) { this.internalContainerRefs.push(layer); } this.layers.push(layer); } } this.layersByDepth = layersByDepth; depthKeys = Object.keys(nodesByDepth) .map(x => parseInt(x, 10)) .sort(reverseNumberCompare); const computableTensors = this.inputs.slice(); const layersWithCompleteInput = []; for (const depth of depthKeys) { for (const node of nodesByDepth[depth]) { const layer = node.outboundLayer; if (layer != null) { for (const x of node.inputTensors) { if (computableTensors.indexOf(x) === -1) { throw new RuntimeError(`Graph disconnected: cannot obtain value for tensor ${x}` + ` at layer "${layer.name}". ` + 'The following previous layers were accessed without ' + `issue: ${layersWithCompleteInput}`); } } for (const x of node.outputTensors) { computableTensors.push(x); } layersWithCompleteInput.push(layer.name); } } } this.nodesByDepth = nodesByDepth; const allNames = this.layers.map(x => x.name); for (const name of allNames) { const numOccurrences = allNames.filter(x => x === name).length; if (numOccurrences !== 1) { throw new RuntimeError(`The name "${name}" is used ${numOccurrences} times ` + 'in the model. All layer names should be unique. Layer names: ' + JSON.stringify(allNames)); } } this.outboundNodes = []; this.inboundNodes = []; new Node({ outboundLayer: this, inboundLayers: [], nodeIndices: [], tensorIndices: [], inputTensors: this.inputs, outputTensors: this.outputs, inputMasks: this.inputs.map(x => null), outputMasks: this.outputs.map(x => null), inputShapes: this.inputs.map(x => x.shape), outputShapes: this.outputs.map(x => x.shape) }); this.built = true; this._refCount = 1; } assertNotDisposed() { if (this._refCount === 0) { throw new Error(`Container '${this.name}' is already disposed.`); } } dispose() { this.assertNotDisposed(); const result = { refCountAfterDispose: null, numDisposedVariables: 0 }; if (--this._refCount === 0) { for (const layer of this.layers) { result.numDisposedVariables += layer.dispose().numDisposedVariables; } for (const container of this.internalContainerRefs) { result.numDisposedVariables += container.dispose().numDisposedVariables; } } result.refCountAfterDispose = this._refCount; return result; } get trainable() { return this.trainable_; } set trainable(trainable) { this.layers.forEach(layer => { layer._trainableWeights .forEach(w => w.trainable = trainable); }); this.trainable_ = trainable; } get trainableWeights() { if (this._trainableWeights.length > 0) { throw new ValueError('Container instance unexpectedly contains _trainableWeights.' + 'The trainable weights of a Container are a union of the ' + 'trainable weights of its consituent Layers. Its own ' + '_trainableWeights must remain an empty Array.'); } if (!this.trainable) { return []; } let weights = []; for (const layer of this.layers) { weights = weights.concat(layer.trainableWeights); } return weights; } get nonTrainableWeights() { const weights = []; for (const layer of this.layers) { weights.push(...layer.nonTrainableWeights); } if (!this.trainable) { const trainableWeights = []; for (const layer of this.layers) { trainableWeights.push(...layer.trainableWeights); } return trainableWeights.concat(weights); } return weights; } get weights() { return this.trainableWeights.concat(this.nonTrainableWeights); } loadWeights(weights, strict = true) { const nameToWeight = {}; let totalWeightsCount = 0; const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights); if (modelIsKerasSavedModelFormat) { this.parseWeights(weights); } for (const layer of this.layers) { for (const [index, weight] of layer.weights.entries()) { const parsedName = modelIsKerasSavedModelFormat ? `${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` : weight.originalName; if (nameToWeight[parsedName] != null) { throw new ValueError(`Duplicate weight name: ${parsedName}`); } nameToWeight[parsedName] = weight; totalWeightsCount++; } } const weightValueTuples = []; for (const name in weights) { let validatedName = name; if (nameToWeight[name] == null) { const tokens = name.split('/'); const shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]); validatedName = shortenNameArray.join('/'); } if (nameToWeight[validatedName] != null) { weightValueTuples.push([nameToWeight[validatedName], weights[name]]); } else if (strict) { throw new ValueError(`Provided weight data has no target variable: ${name}`); } delete nameToWeight[validatedName]; } if (strict) { const unsetNames = []; for (const name in nameToWeight) { unsetNames.push(name); } if (unsetNames.length > 0) { throw new ValueError(`${unsetNames.length} of ${totalWeightsCount} weights are not set: ` + `${unsetNames}`); } } batchSetValue(weightValueTuples); } parseWeights(weights) { for (const key in Object.keys(weights)) { const listParts = key.split('/'); const list = ['vars', 'layer_checkpoint_dependencies']; const newKey = listParts .map(str => { if (str.startsWith('_')) { return str.slice(1); } return str; }) .filter(str => !list.includes(str)) .join('/'); if (newKey !== key) { weights[newKey] = weights[key]; delete weights[key]; } } } updatedConfig() { const theConfig = this.getConfig(); const modelConfig = {}; modelConfig['className'] = this.getClassName(); modelConfig['config'] = theConfig; modelConfig['kerasVersion'] = `tfjs-layers ${version}`; modelConfig['backend'] = 'TensorFlow.js'; return modelConfig; } toJSON(unused, returnString = true) { const modelConfig = convertTsToPythonic(this.updatedConfig()); return returnString ? JSON.stringify(modelConfig) : modelConfig; } call(inputs, kwargs) { return tidy(() => { inputs = toList(inputs); const feedDict = new FeedDict(); for (let i = 0; i < this.inputs.length; ++i) { feedDict.add(this.inputs[i], inputs[i]); } return execute(this.outputs, feedDict, kwargs); }); } computeMask(inputs, mask) { return tidy(() => { inputs = toList(inputs); let masks; if (mask == null) { masks = pyListRepeat(null, inputs.length); } else { masks = toList(mask); } return this.runInternalGraph(inputs, masks)[1]; }); } computeOutputShape(inputShape) { const inputShapes = normalizeShapeList(inputShape); if (inputShapes.length !== this.inputLayers.length) { throw new ValueError(`Invalid inputShape argument ${inputShape}: ` + `model has ${this.inputLayers.length} tensor inputs.`); } const layersToOutputShapes = {}; for (let i = 0; i < inputShapes.length; i++) { const layer = this.inputLayers[i]; const inputShape = inputShapes[i]; const shapeKey = layer.name + '_0_0'; layersToOutputShapes[shapeKey] = inputShape; } const depthKeys = Object.keys(this.nodesByDepth) .map(x => parseInt(x, 10)) .sort(reverseNumberCompare); if (depthKeys.length > 1) { for (const depth of depthKeys) { const nodes = this.nodesByDepth[depth]; for (const node of nodes) { const layer = node.outboundLayer; if (this.inputLayers.map(x => x.id).indexOf(layer.id) !== -1) { continue; } const inputShapes = []; for (let j = 0; j < node.inboundLayers.length; j++) { const inboundLayer = node.inboundLayers[j]; const nodeIndex = node.nodeIndices[j]; const tensorIndex = node.tensorIndices[j]; const shapeKey = `${inboundLayer.name}_${nodeIndex}_${tensorIndex}`; const inputShape = layersToOutputShapes[shapeKey]; inputShapes.push(inputShape); } const outputShape = layer.computeOutputShape(singletonOrArray(inputShapes)); const outputShapes = normalizeShapeList(outputShape); const nodeIndex = layer.inboundNodes.indexOf(node); for (let j = 0; j < outputShapes.length; j++) { const shapeKey = `${layer.name}_${nodeIndex}_${j}`; layersToOutputShapes[shapeKey] = outputShapes[j]; } } } } const outputShapes = []; const outputShapeKeys = []; for (let i = 0; i < this.outputLayers.length; i++) { const layer = this.outputLayers[i]; const nodeIndex = this.outputLayersNodeIndices[i]; const tensorIndex = this.outputLayersTensorIndices[i]; const shapeKey = `${layer.name}_${nodeIndex}_${tensorIndex}`; outputShapeKeys.push(shapeKey); } for (let i = 0; i < outputShapeKeys.length; i++) { const key = outputShapeKeys[i]; assert(key in layersToOutputShapes); outputShapes.push(layersToOutputShapes[key]); } return singletonOrArray(outputShapes); } runInternalGraph(inputs, masks) { if (masks == null) { masks = pyListRepeat(null, inputs.length); } const tensorMap = {}; for (let i = 0; i < this.inputs.length; ++i) { const x = this.inputs[i]; const y = inputs[i]; const mask = masks[i]; tensorMap[x.id] = [y, mask]; } const depthKeys = Object.keys(this.nodesByDepth) .map(x => parseInt(x, 10)) .sort(reverseNumberCompare); for (const depth of depthKeys) { const nodes = this.nodesByDepth[depth]; for (const node of nodes) { const layer = node.outboundLayer; const referenceInputTensors = node.inputTensors; const referenceOutputTensors = node.outputTensors; const computedData = new Array(); for (const x of referenceInputTensors) { if (x.id in tensorMap) { computedData.push(tensorMap[x.id]); } } if (computedData.length === referenceInputTensors.length) { let kwargs = {}; let computedTensors; let computedMasks; let outputTensors; let outputMasks; if (node.callArgs != null) { kwargs = node.callArgs; } if (computedData.length === 1) { const [computedTensor, computedMask] = computedData[0]; if (kwargs['mask'] == null) { kwargs['mask'] = computedMask; } outputTensors = toList(layer.call(computedTensor, kwargs)); outputMasks = toList(layer.computeMask(computedTensor, computedMask)); computedTensors = [computedTensor]; computedMasks = [computedMask]; } else { computedTensors = computedData.map(x => x[0]); computedMasks = computedData.map(x => x[1]); if (kwargs['mask'] == null) { kwargs['mask'] = computedMasks; } outputTensors = toList(layer.call(computedTensors, kwargs)); outputMasks = toList(layer.computeMask(computedTensors, computedMasks)); } if (layer.activityRegularizer) { throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' + 'presence of activity regularizer(s) is not supported yet.'); } for (let i = 0; i < referenceOutputTensors.length; ++i) { const x = referenceOutputTensors[i]; const y = outputTensors[i]; const mask = outputMasks[i]; tensorMap[x.id] = [y, mask]; } } } } const outputTensors = []; const outputMasks = []; const outputShapes = []; for (const x of this.outputs) { assert(x.id in tensorMap, `Could not compute output ${x.name} : ${x.id}`); const [tensor, mask] = tensorMap[x.id]; outputShapes.push(tensor.shape); outputTensors.push(tensor); outputMasks.push(mask); } return [outputTensors, outputMasks, outputShapes]; } buildNodeConversionMap(layers) { const nodeConversionMap = {}; let keptNodes; for (const layer of this.layers) { keptNodes = layer instanceof Container ? 1 : 0; for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) { const nodeKey = Container.nodeKey(layer, originalNodeIndex); if (this.containerNodes.has(nodeKey)) { nodeConversionMap[nodeKey] = keptNodes; keptNodes += 1; } } } return nodeConversionMap; } getLayer(nameOrIndex, index) { if (index != null) { return this.findLayer(index); } else { if (nameOrIndex == null) { throw new ValueError('Provide either a layer name or layer index'); } if (typeof nameOrIndex === 'number') { return this.findLayer(nameOrIndex); } } for (const layer of this.layers) { if (layer.name === nameOrIndex) { return layer; } } throw new ValueError(`No such layer: ${nameOrIndex}`); } findLayer(index) { if (this.layers.length <= index) { throw new ValueError(`Was asked to retrieve layer at index ${index}, but model only ` + `has ${this.layers.length} layer(s).`); } else { return this.layers[index]; } } calculateLosses() { return tidy(() => { const losses = []; for (const layer of this.layers) { for (let nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) { const nodeKey = Container.nodeKey(layer, nodeIndex); if (this.containerNodes.has(nodeKey)) { losses.push(...layer.calculateLosses()); } } } return losses; }); } getConfig() { const config = { name: this.name }; const nodeConversionMap = this.buildNodeConversionMap(this.layers); const layerConfigs = []; for (const layer of this.layers) { const layerClassName = layer.getClassName(); const layerConfig = layer.getConfig(); const filteredInboundNodes = []; for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) { const node = layer.inboundNodes[originalNodeIndex]; const nodeKey = Container.nodeKey(layer, originalNodeIndex); let kwargs = {}; if (this.containerNodes.has(nodeKey)) { if (node.callArgs) { try { JSON.stringify(node.callArgs); kwargs = node.callArgs; } catch (err) { console.warn(`Layer ${layer.name} was passed ` + `non-serializable keyword arguments: ` + `${node.callArgs}. They will not be included ` + `in the serialized model (and thus will be ` + `missing at deserialization time).`); kwargs = {}; } } if (node.inboundLayers.length > 0) { const nodeData = []; for (let i = 0; i < node.inboundLayers.length; i++) { const inboundLayer = node.inboundLayers[i]; const nodeIndex = node.nodeIndices[i]; const tensorIndex = node.tensorIndices[i]; const nodeKey = Container.nodeKey(inboundLayer, nodeIndex); let newNodeIndex = nodeConversionMap[nodeKey]; if (newNodeIndex == null) { newNodeIndex = 0; } nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]); } filteredInboundNodes.push(nodeData); } } } const dict = {}; dict['name'] = layer.name; dict['className'] = layerClassName; dict['config'] = layerConfig; dict['inboundNodes'] = filteredInboundNodes; layerConfigs.push(dict); } config['layers'] = layerConfigs; const modelInputs = []; for (let i = 0; i < this.inputLayers.length; i++) { const layer = this.inputLayers[i]; const nodeIndex = this.inputLayersNodeIndices[i]; const nodeKey = Container.nodeKey(layer, nodeIndex); if (!this.containerNodes.has(nodeKey)) { continue; } let newNodeIndex = nodeConversionMap[nodeKey]; if (newNodeIndex === null || newNodeIndex === undefined) { newNodeIndex = 0; } const tensorIndex = this.inputLayersTensorIndices[i]; modelInputs.push([layer.name, newNodeIndex, tensorIndex]); } config['inputLayers'] = modelInputs; const modelOutputs = []; for (let i = 0; i < this.outputLayers.length; i++) { const layer = this.outputLayers[i]; const nodeIndex = this.outputLayersNodeIndices[i]; const nodeKey = Container.nodeKey(layer, nodeIndex); if (!this.containerNodes.has(nodeKey)) { continue; } let newNodeIndex = nodeConversionMap[nodeKey]; if (newNodeIndex === null || newNodeIndex === undefined) { newNodeIndex = 0; } const tensorIndex = this.outputLayersTensorIndices[i]; modelOutputs.push([layer.name, newNodeIndex, tensorIndex]); } config['outputLayers'] = modelOutputs; return config; } static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) { const createdLayers = {}; const unprocessedNodes = {}; function addUnprocessedNode(layer, nodeData) { if (!(layer.name in unprocessedNodes)) { unprocessedNodes[layer.name] = [nodeData]; } else { unprocessedNodes[layer.name].push(nodeData); } } function processNode(layer, nodeData) { const inputTensors = []; let kwargs; for (const inputData of nodeData) { const inboundLayerName = inputData[0]; const inboundNodeIndex = inputData[1]; const inboundTensorIndex = inputData[2]; kwargs = inputData[3] == null ? {} : inputData[3]; if (!(inboundLayerName in createdLayers)) { addUnprocessedNode(layer, nodeData); return; } const inboundLayer = createdLayers[inboundLayerName]; if (inboundLayer.inboundNodes.length <= inboundNodeIndex) { addUnprocessedNode(layer, nodeData); return; } const inboundNode = inboundLayer.inboundNodes[inboundNodeIndex]; inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]); } if (inputTensors.length > 0) { layer.apply(singletonOrArray(inputTensors), kwargs); } } function processLayer(layerData) { const layerName = layerData['name']; const layer = deserialize(layerData, config['customObjects'] != null ? config['customObjects'] : {}); layer.setFastWeightInitDuringBuild(fastWeightInit); createdLayers[layerName] = layer; const inboundNodesData = layerData['inboundNodes']; inboundNodesData.forEach(nodeData => { if (!(nodeData instanceof Array)) { throw new ValueError(`Corrupted configuration, expected array for nodeData: ${nodeData}`); } addUnprocessedNode(layer, nodeData); }); } const name = config['name']; const layersFromConfig = config['layers']; for (const layerData of layersFromConfig) { processLayer(layerData); } while (!isObjectEmpty(unprocessedNodes)) { for (const layerData of layersFromConfig) { const layer = createdLayers[layerData['name']]; if (layer.name in unprocessedNodes) { const currentUnprocessedNodesForLayer = unprocessedNodes[layer.name]; delete unprocessedNodes[layer.name]; for (const nodeData of currentUnprocessedNodesForLayer) { processNode(layer, nodeData); } } } } const inputTensors = []; const outputTensors = []; const inputLayersFromConfig = config['inputLayers']; for (const layerData of inputLayersFromConfig) { const layerName = layerData[0]; const nodeIndex = layerData[1]; const tensorIndex = layerData[2]; assert(layerName in createdLayers); const layer = createdLayers[layerName]; const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors; inputTensors.push(layerOutputTensors[tensorIndex]); } const outputLayersFromConfig = config['outputLayers']; for (const layerData of outputLayersFromConfig) { const layerName = layerData[0]; const nodeIndex = layerData[1]; const tensorIndex = layerData[2]; assert(layerName in createdLayers); const layer = createdLayers[layerName]; const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors; outputTensors.push(layerOutputTensors[tensorIndex]); } return new cls({ inputs: inputTensors, outputs: outputTensors, name }); } get stateful() { if (this._stateful) { throw new ValueError('Container instance unexpectedly has _stateful = true. The ' + 'statefulness of a Container is determined by the Layers it ' + 'contains. Its _stateful property must remain the default false.'); } for (const layer of this.layers) { if (layer.stateful) { return true; } } return false; } resetStates() { tidy(() => { this.layers.forEach(layer => { if (layer.stateful) { layer.resetStates(); } }); }); } } function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) { const numOutputs = outputNames.length; if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) { return outputNames.map(name => null); } if (numOutputs === 1) { if (Array.isArray(xWeight) && xWeight.length === 1) { return xWeight; } else if (typeof xWeight === 'object' && outputNames[0] in xWeight) { return [xWeight[outputNames[0]]]; } else { return [xWeight]; } } if (Array.isArray(xWeight)) { if (xWeight.length !== numOutputs) { throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` + `element(s), but the model has ${numOutputs} outputs. ` + `Make sure a set of weights is provided for each model output.`); } return xWeight; } else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 && typeof xWeight[Object.keys(xWeight)[0]] === 'object') { const output = []; outputNames.forEach(outputName => { if (outputName in xWeight) { output.push(xWeight[outputName]); } else { output.push(null); } }); return output; } else { throw new Error(`The model has multiple (${numOutputs}) outputs, ` + `so ${weightType} must be either an array with ` + `${numOutputs} elements or an object with ${outputNames} keys. ` + `Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`); } } function standardizeClassWeights(classWeight, outputNames) { return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight'); } async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) { if (classWeight != null) { const yClasses = tidy(() => { if (y.shape.length === 1) { return clone(y); } else if (y.shape.length === 2) { if (y.shape[1] > 1) { const axis = 1; return argMax$1(y, axis); } else if (y.shape[1] === 1) { return reshape$1(y, [y.shape[0]]); } else { throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` + `during handling of class weights. The size is expected to be ` + `>= 1.`); } } else { throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` + `handling of class weights. The rank is expected to be 1 or 2.`); } }); const yClassIndices = Array.from(await yClasses.data()); dispose(yClasses); const classSampleWeight = []; yClassIndices.forEach(classIndex => { if (classWeight[classIndex] == null) { throw new Error(`classWeight must contain all classes in the training data. ` + `The class ${classIndex} exists in the data but not in ` + `classWeight`); } else { classSampleWeight.push(classWeight[classIndex]); } }); return tensor1d(classSampleWeight, 'float32'); } else { return null; } } function computeWeightedLoss(losses, sampleWeights) { return mul(losses, sampleWeights); } const DEFAULT_VALIDATION_BATCH_SIZE = 32; function standardizeDataIteratorOutput( model, iteratorOut) { let xs; let ys; const iteratorOutObj = iteratorOut; xs = iteratorOutObj['xs']; ys = iteratorOutObj['ys']; assert$1(xs != null && ys != null, () => 'A Dataset iterator for fitDataset() is expected to generate ' + 'objects of the form `{xs: xVal, ys: yVal}`, where the two ' + 'values may be `tf.Tensor`, an array of Tensors, or a map of ' + 'string to Tensor. The provided Dataset instead generates ' + `${iteratorOut}`); const flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs); const flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys); const batchSize = flattenedXs[0].shape[0]; assert$1(flattenedXs.length === model.inputs.length, () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` + `provides ${flattenedXs.length} inputs. (Expected input keys: ` + `${JSON.stringify(model.inputNames)})`); assert$1(flattenedYs.length === model.outputs.length, () => `LayersModel has ${model.outputs.length} outputs, but the dataset ` + `provides ${flattenedYs.length} outputs. (Expected output keys: ` + `${JSON.stringify(model.outputNames)})`); for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) { assert$1(flattenedXs[xIndex].shape[0] === batchSize, () => `Batch size mismatch: input ` + `${model.inputNames[xIndex]} has ${flattenedXs[xIndex].shape[0]}; ` + `expected ${batchSize} based on input ${model.inputNames[0]}.`); } for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) { assert$1(flattenedYs[yIndex].shape[0] === batchSize, () => `Batch size mismatch: output ` + `${model.outputNames[yIndex]} has ${flattenedYs[yIndex].shape[0]}; ` + `expected ${batchSize} based on input ${model.inputNames[0]}.`); } return { xs: flattenedXs, ys: flattenedYs }; } function flattenTensorOrArrayOrMap(inputOrOutput, names, values) { if (values instanceof Tensor) { return [values]; } else if (Array.isArray(values)) { assert$1(values.length === names.length, () => `Received an array of ${values.length} Tensors, but expected ${names.length} to match the ${inputOrOutput} keys ${names}.`); return values; } else { const result = []; for (const name of names) { if (values[name] == null) { throw new ValueError(`The feature data generated by the dataset lacks the required ` + `${inputOrOutput} key '${name}'.`); } result.push(values[name]); } return result; } } function standardizeTensorValidationData(data) { if (data.length === 3) { throw new NotImplementedError('Validation with sample weights is not implemented yet.'); } return { xs: data[0], ys: data[1] }; } async function fitDataset( model, dataset, args) { const hasBatchesPerEpoch = args.batchesPerEpoch != null; assert$1(model.optimizer != null, () => 'You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileConfig).'); assert$1(args != null, () => `For fitDataset(), the 2nd argument (config) is required, ` + `but it is not provided in this call.`); assert$1(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), () => `For fitDataset(), config.epochs is expected to be a positive ` + `integer, but got ${args.epochs}`); assert$1(!hasBatchesPerEpoch || (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` + `positive integer if specified, but got ${args.batchesPerEpoch}`); assert$1( args['validationSplit'] == null, () => '`validationSplit` is not supported by `fitDataset()`. ' + 'Use validationData instead.'); if (model.isTraining) { throw new Error('Cannot start training because another fit() call is ongoing.'); } model.isTraining = true; try { const doValidation = args.validationData != null; let valXs; let valYs; if (doValidation) { if (isDatasetObject(args.validationData)) { assert$1(args.validationBatches == null || (args.validationBatches > 0 && Number.isInteger(args.validationBatches)), () => `For fitDataset() with dataset-based validation, ` + `config.validationBatches is expected not to be provided, ` + `or to be a positive integer, ` + `but got ${args.validationBatches}`); } else { const validationData = standardizeTensorValidationData(args.validationData); valXs = validationData.xs; valYs = validationData.ys; } } const trainFunction = model.makeTrainFunction(); const outLabels = model.getDedupedMetricsNames(); let callbackMetrics; if (doValidation) { callbackMetrics = outLabels.slice().concat(outLabels.map(n => 'val_' + n)); } else { callbackMetrics = outLabels.slice(); } const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery); const verbose = args.verbose == null ? 1 : args.verbose; const { callbackList, history } = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null, doValidation, callbackMetrics); callbackList.setModel(model); model.history = history; await callbackList.onTrainBegin(); model.stopTraining_ = false; let epoch = args.initialEpoch == null ? 0 : args.initialEpoch; let dataIterator = await dataset.iterator(); while (epoch < args.epochs) { const epochLogs = {}; await callbackList.onEpochBegin(epoch); let stepsDone = 0; let batchIndex = 0; if (!hasBatchesPerEpoch) { dataIterator = await dataset.iterator(); } while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) { const iteratorOut = await dataIterator.next(); if (hasBatchesPerEpoch && iteratorOut.done) { console.warn('You provided `batchesPerEpoch` as ' + `${args.batchesPerEpoch}, ` + 'but your dataset iterator ran out of data after ' + `${stepsDone} batches; ` + 'interrupting training. Make sure that your ' + 'dataset can generate at least `batchesPerEpoch * epochs` ' + 'batches (in this case, ' + `${args.batchesPerEpoch * args.epochs} batches). ` + 'You may need to use the repeat() function when building ' + 'your dataset.'); break; } if (iteratorOut.value != null) { const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value); const batchLogs = {}; batchLogs['batch'] = batchIndex; batchLogs['size'] = xs[0].shape[0]; await callbackList.onBatchBegin(batchIndex, batchLogs); const sampleWeights = []; if (args.classWeight != null) { const standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames); for (let i = 0; i < standardClassWeights.length; ++i) { sampleWeights.push(await standardizeWeights(ys[i], null, standardClassWeights[i])); } } const ins = xs.concat(ys).concat(sampleWeights); const outs = trainFunction(ins); dispose(ins); for (let i = 0; i < outLabels.length; ++i) { const label = outLabels[i]; const out = outs[i]; batchLogs[label] = out; keep(out); } await callbackList.onBatchEnd(batchIndex, batchLogs); disposeTensorsInLogs(batchLogs); batchIndex++; stepsDone++; } if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch : iteratorOut.done) { if (doValidation) { let valOuts; if (isDatasetObject(args.validationData)) { valOuts = toList(await model.evaluateDataset(args.validationData, { batches: args.validationBatches })); } else { valOuts = toList(model.evaluate(valXs, valYs, { batchSize: args.validationBatchSize == null ? DEFAULT_VALIDATION_BATCH_SIZE : args.validationBatchSize, verbose: 0 })); } for (let i = 0; i < model.metricsNames.length; ++i) { epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i]; } } break; } if (model.stopTraining_) { break; } } await callbackList.onEpochEnd(epoch, epochLogs); epoch++; if (model.stopTraining_) { break; } } await callbackList.onTrainEnd(); await model.history.syncData(); return model.history; } finally { model.isTraining = false; } } function getStepsPerEpoch(dataset, args) { let stepsPerEpoch = null; if (args.batchesPerEpoch != null) { stepsPerEpoch = args.batchesPerEpoch; } else if (Number.isFinite(dataset.size)) { stepsPerEpoch = dataset.size; } return stepsPerEpoch; } function isDatasetObject(dataset) { return (typeof dataset.iterator === 'function'); } function isLazyIteratorObject(iterator) { return (typeof iterator.next === 'function'); } async function evaluateDataset( model, dataset, args) { args = args || {}; const hasBatches = args.batches != null; const f = model.testFunction; let outs = []; if (args.verbose > 0) { throw new NotImplementedError('Verbose mode is not implemented yet.'); } assert$1(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), () => 'Test loop expects `batches` to be a positive integer, but ' + `received ${JSON.stringify(args.batches)}`); const dataIterator = isLazyIteratorObject(dataset) ? dataset : await dataset.iterator(); let numExamples = 0; let batch = 0; while (hasBatches ? batch < args.batches : true) { const iteratorOut = await dataIterator.next(); outs = tidy(() => { if (iteratorOut.value) { const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value); const xsAndYs = xs.concat(ys); const batchOuts = tidy(() => f(xsAndYs)); dispose(xsAndYs); if (batch === 0) { for (let i = 0; i < batchOuts.length; ++i) { outs.push(scalar(0)); } } const batchSize = xsAndYs[0].shape[0]; for (let i = 0; i < batchOuts.length; ++i) { const batchOut = batchOuts[i]; const oldScalar = outs[i]; outs[i] = tidy(() => add(outs[i], mul(batchSize, batchOut))); if (batch > 0) { dispose(oldScalar); } } dispose(batchOuts); numExamples += batchSize; ++batch; } return outs; }); if (iteratorOut.done) { if (hasBatches) { console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' + 'Interrupting evalution. Make sure that your ' + 'dataset can generate at least `batches` ' + `batches (in this case, ${args.batches} batches). ` + 'You may need to use the repeat() function when building ' + 'your dataset.'); } break; } } for (let i = 0; i < outs.length; ++i) { const oldScalar = outs[i]; outs[i] = div(outs[i], numExamples); dispose(oldScalar); } return singletonOrArray(outs); } function checkBatchSize(batchSize) { assert$1(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`); } function sliceArrays(arrays, start, stop) { if (arrays == null) { return [null]; } else if (Array.isArray(arrays)) { return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start)); } else { return sliceAlongFirstAxis(arrays, start, stop - start); } } function sliceArraysByIndices(arrays, indices) { return tidy(() => { if (arrays == null) { return null; } else if (Array.isArray(arrays)) { return arrays.map(array => sliceArraysByIndices(array, indices)); } else { return gather(arrays, indices.dtype === 'int32' ? indices : cast$2(indices, 'int32')); } }); } function makeBatches(size, batchSize) { const output = []; let batchStart = 0; let batchEnd = null; while (batchStart < size) { batchEnd = batchStart + batchSize; if (batchEnd >= size) { batchEnd = size; } output.push([batchStart, batchEnd]); batchStart = batchEnd; } return output; } function ensureTensorsRank2OrHigher(tensors) { const outs = []; if (tensors instanceof Tensor) { tensors = [tensors]; } for (let i = 0; i < tensors.length; ++i) { const tensor = tensors[i]; if (tensor.rank === 1) { outs.push(expandDims(tensor, 1)); } else if (tensor.rank === 0) { throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' + '(scalar).'); } else { outs.push(tensor); } } return outs; } function disposeNewTensors(tensors, refTensors) { if (tensors == null) { return; } const oldTensorIds = []; if (refTensors instanceof Tensor) { oldTensorIds.push(refTensors.id); } else if (Array.isArray(refTensors)) { refTensors.forEach(t => oldTensorIds.push(t.id)); } else if (refTensors != null) { for (const name in refTensors) { const oldTensor = refTensors[name]; oldTensorIds.push(oldTensor.id); } } const tensorsToDispose = []; if (tensors instanceof Tensor) { if (oldTensorIds.indexOf(tensors.id) === -1) { tensorsToDispose.push(tensors); } } else if (Array.isArray(tensors)) { tensors.forEach(t => { if (oldTensorIds.indexOf(t.id) === -1) { tensorsToDispose.push(t); } }); } else if (tensors != null) { for (const name in tensors) { const tensor = tensors[name]; if (oldTensorIds.indexOf(tensor.id) === -1) { tensorsToDispose.push(tensor); } } } tensorsToDispose.forEach(t => { if (!t.isDisposed) { t.dispose(); } }); } function isDataTensor(x) { return x instanceof Tensor; } function isDataArray(x) { return Array.isArray(x); } function isDataDict(x) { return !isDataTensor(x) && !isDataArray(x); } function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') { if (names == null || names.length === 0) { if (data != null) { let gotUnexpectedData = false; if (isDataArray(data) && data.length > 0) { gotUnexpectedData = true; } else if (isDataDict(data)) { for (const key in data) { if (data.hasOwnProperty(key)) { gotUnexpectedData = true; break; } } } else { gotUnexpectedData = true; } if (gotUnexpectedData) { throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` + `but got ${data}`); } } return []; } if (data == null) { return names.map(name => null); } let arrays; if (isDataDict(data)) { data = data; arrays = []; for (const name of names) { if (data[name] == null) { throw new ValueError(`No data provided for "${name}". Need data for each key in: ` + `${names}`); } arrays.push(data[name]); } } else if (isDataArray(data)) { data = data; if (data.length !== names.length) { throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` + `Tensors that you are passing to your model is not the size the ` + `model expected. Expected to see ${names.length} Tensor(s), but ` + `instead got the following list of Tensor(s): ${data}`); } arrays = data; } else { data = data; if (names.length > 1) { throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` + `but only received one Tensor. Found: Tensor with shape ${data.shape}`); } arrays = [data]; } arrays = ensureTensorsRank2OrHigher(arrays); if (shapes != null) { for (let i = 0; i < names.length; ++i) { if (shapes[i] == null) { continue; } const array = arrays[i]; if (array.shape.length !== shapes[i].length) { throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` + `to have ${shapes[i].length} dimension(s). but got array with ` + `shape ${array.shape}`); } for (let j = 0; j < shapes[i].length; ++j) { if (j === 0 && !checkBatchAxis) { continue; } const dim = array.shape[j]; const refDim = shapes[i][j]; if (refDim != null && refDim >= 0 && dim !== refDim) { throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` + `example has shape [${shapes[i].slice(1, shapes[i].length)}] ` + `(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` + ` but the ${exceptionPrefix} received an input with ${array.shape[0]}` + ` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` + ` (tensor shape [${array.shape}])`); } } } } return arrays; } function checkArrayLengths(inputs, targets, weights) { const setX = unique(inputs.map(input => input.shape[0])); setX.sort(); const setY = unique(targets.map(target => target.shape[0])); setY.sort(); if (setX.length > 1) { throw new ValueError(`All input Tensors (x) should have the same number of samples. ` + `Got array shapes: ` + `${JSON.stringify(inputs.map(input => input.shape))}`); } if (setY.length > 1) { throw new ValueError(`All target Tensors (y) should have the same number of samples. ` + `Got array shapes: ` + `${JSON.stringify(targets.map(target => target.shape))}`); } if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) { throw new ValueError(`Input Tensors should have the same number of samples as target ` + `Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` + `sample(s).`); } } function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) { const keyLosses = [ meanSquaredError, binaryCrossentropy$1, categoricalCrossentropy$1 ]; for (let i = 0; i < targets.length; ++i) { const y = targets[i]; const loss = lossFns[i]; const shape = outputShapes[i]; if (loss == null) { continue; } if (loss === categoricalCrossentropy$1) { if (y.shape[y.shape.length - 1] === 1) { throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` + `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` + `expects targets to be binary matrices (1s and 0s) of shape ` + `[samples, classes].`); } } if (keyLosses.indexOf(loss) !== -1) { const slicedYShape = y.shape.slice(1); const slicedShape = shape.slice(1); for (let j = 0; j < slicedYShape.length; ++j) { const targetDim = slicedYShape[j]; const outDim = slicedShape[j]; if (outDim != null && targetDim !== outDim) { throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` + `output of shape ${shape}, while using a loss function that ` + `expects targets to have the same shape as the output.`); } } } } } function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') { let arrays; if (Array.isArray(data)) { if (data.length !== names.length) { throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` + `Tensors that you are passing to your model is not the size the ` + `the model expected. Expected to see ${names.length} Tensor(s),` + ` but instead got ${data.length} Tensors(s).`); } arrays = data; } else { if (names.length > 1) { throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` + `but only received one Tensor. Found: array with shape ` + `${JSON.stringify(data.shape)}.`); } arrays = [data]; } if (shapes != null) { for (let i = 0; i < names.length; ++i) { if (shapes[i] == null) { continue; } const array = arrays[i]; if (array.shape.length !== shapes[i].length) { throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` + `to have ${shapes[i].length} dimension(s), but got array with ` + `shape ${JSON.stringify(array.shape)}`); } for (let j = 0; j < shapes[i].length; ++j) { if (j === 0 && !checkBatchAxis) { continue; } const dim = array.shape[j]; const refDim = shapes[i][j]; if (refDim != null) { if (refDim !== dim) { throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` + `${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` + `got array with shape ${JSON.stringify(array.shape)}.`); } } } } } } function collectMetrics(metrics, outputNames) { if (metrics == null || Array.isArray(metrics) && metrics.length === 0) { return outputNames.map(name => []); } let wrappedMetrics; if (typeof metrics === 'string' || typeof metrics === 'function') { wrappedMetrics = [metrics]; } else if (Array.isArray(metrics) || typeof metrics === 'object') { wrappedMetrics = metrics; } else { throw new TypeError('Type of metrics argument not understood. Expected an string,' + `function, Array, or Object, found: ${metrics}`); } if (Array.isArray(wrappedMetrics)) { return outputNames.map(name => wrappedMetrics); } else { const nestedMetrics = []; for (const name of outputNames) { let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : []; if (!Array.isArray(outputMetrics)) { outputMetrics = [outputMetrics]; } nestedMetrics.push(outputMetrics); } return nestedMetrics; } } const LAYERS_MODEL_FORMAT_NAME = 'layers-model'; class LayersModel extends Container { constructor(args) { super(args); this.isTraining = false; } summary(lineLength, positions, printFn = console.log) { if (!this.built) { throw new ValueError(`This model has never been called, thus its weights have not been ` + `created yet. So no summary can be displayed. Build the model ` + `first (e.g., by calling it on some test data).`); } printSummary(this, lineLength, positions, printFn); } compile(args) { if (args.loss == null) { args.loss = []; } this.loss = args.loss; if (typeof args.optimizer === 'string') { this.optimizer_ = getOptimizer(args.optimizer); this.isOptimizerOwned = true; } else { if (!(args.optimizer instanceof Optimizer)) { throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`); } this.optimizer_ = args.optimizer; this.isOptimizerOwned = false; } let lossFunctions = []; if (!Array.isArray(args.loss) && typeof args.loss !== 'string' && typeof args.loss !== 'function') { args.loss = args.loss; for (const name in args.loss) { if (this.outputNames.indexOf(name) === -1) { throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` + `Only expected the following keys: ${this.outputNames}`); } } for (const name of this.outputNames) { if (args.loss[name] == null) { console.warn(`Output "${name}" is missing from loss dictionary. We assume ` + `this was done on purpose, and we will not be expecting data ` + `to be passed to ${name} during training`); } lossFunctions.push(get$1(args.loss[name])); } } else if (Array.isArray(args.loss)) { if (args.loss.length !== this.outputs.length) { throw new ValueError(`When passing an Array as loss, it should have one entry per ` + `model output. The model has ${this.outputs.length} output(s), ` + `but you passed loss=${args.loss}.`); } const theLosses = args.loss; lossFunctions = theLosses.map(l => get$1(l)); } else { const lossFunction = get$1(args.loss); this.outputs.forEach(_ => { lossFunctions.push(lossFunction); }); } this.lossFunctions = lossFunctions; this.feedOutputNames = []; this.feedOutputShapes = []; this.feedLossFns = []; for (let i = 0; i < this.outputs.length; ++i) { const shape = this.internalOutputShapes[i]; const name = this.outputNames[i]; this.feedOutputNames.push(name); this.feedOutputShapes.push(shape); this.feedLossFns.push(this.lossFunctions[i]); } const skipTargetIndices = []; this.metrics = args.metrics; this.metricsNames = ['loss']; this.metricsTensors = []; nameScope('loss', () => { for (let i = 0; i < this.outputs.length; ++i) { if (skipTargetIndices.indexOf(i) !== -1) { continue; } const weightedLoss = this.lossFunctions[i]; if (this.outputs.length > 1) { this.metricsTensors.push([weightedLoss, i]); this.metricsNames.push(this.outputNames[i] + '_loss'); } } }); const nestedMetrics = collectMetrics(args.metrics, this.outputNames); const appendMetric = (outputIndex, metricName, metricTensor) => { if (this.outputNames.length > 1) { metricName = this.outputNames[outputIndex] + '_' + metricName; } this.metricsNames.push(metricName); this.metricsTensors.push([metricTensor, outputIndex]); }; nameScope('metric', () => { for (let i = 0; i < this.outputs.length; ++i) { if (skipTargetIndices.indexOf(i) !== -1) { continue; } const outputMetrics = nestedMetrics[i]; const handleMetrics = (metrics) => { const metricNamePrefix = ''; let metricName; let accFn; let weightedMetricFn; for (const metric of metrics) { if (typeof metric === 'string' && ['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !== -1) { const outputShape = this.internalOutputShapes[i]; if (outputShape[outputShape.length - 1] === 1 || this.lossFunctions[i] === binaryCrossentropy$1) { if (['accuracy', 'acc'].indexOf(metric) !== -1) { accFn = binaryAccuracy; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { accFn = binaryCrossentropy; } } else if (this.lossFunctions[i] === sparseCategoricalCrossentropy$1) { if (['accuracy', 'acc'].indexOf(metric) !== -1) { accFn = sparseCategoricalAccuracy; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { accFn = sparseCategoricalCrossentropy; } } else { if (['accuracy', 'acc'].indexOf(metric) !== -1) { accFn = categoricalAccuracy; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { accFn = categoricalCrossentropy; } } let suffix; if (['accuracy', 'acc'].indexOf(metric) !== -1) { suffix = 'acc'; } else if (['crossentropy', 'ce'].indexOf(metric) !== -1) { suffix = 'ce'; } weightedMetricFn = accFn; metricName = metricNamePrefix + suffix; } else { const metricFn = get(metric); weightedMetricFn = metricFn; metricName = metricNamePrefix + getLossOrMetricName(metric); } let metricResult; nameScope(metricName, () => { metricResult = weightedMetricFn; }); appendMetric(i, metricName, metricResult); } }; handleMetrics(outputMetrics); } }); this.collectedTrainableWeights = this.trainableWeights; } checkTrainableWeightsConsistency() { if (this.collectedTrainableWeights == null) { return; } if (this.trainableWeights.length !== this.collectedTrainableWeights.length) { console.warn('Discrepancy between trainableweights and collected trainable ' + 'weights. Did you set `model.trainable` without calling ' + '`model.compile()` afterwards?'); } } evaluate(x, y, args = {}) { const batchSize = args.batchSize == null ? 32 : args.batchSize; checkBatchSize(batchSize); const checkBatchAxis = true; const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize); try { const ins = standardizedOuts[0].concat(standardizedOuts[1]); this.makeTestFunction(); const f = this.testFunction; const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps); return singletonOrArray(testOuts); } finally { disposeNewTensors(standardizedOuts[0], x); disposeNewTensors(standardizedOuts[1], y); } } async evaluateDataset(dataset, args) { this.makeTestFunction(); return evaluateDataset(this, dataset, args); } checkNumSamples(ins, batchSize, steps, stepsName = 'steps') { let numSamples; if (steps != null) { numSamples = null; if (batchSize != null) { throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` + `Got batchSize = ${batchSize}`); } } else if (ins != null) { if (Array.isArray(ins)) { numSamples = ins[0].shape[0]; } else { numSamples = ins.shape[0]; } } else { throw new ValueError(`Either the input data should have a defined shape, or ` + `${stepsName} shoud be specified.`); } return numSamples; } execute(inputs, outputs) { if (Array.isArray(outputs) && outputs.length === 0) { throw new ValueError('`outputs` is an empty Array, which is not allowed.'); } const outputsIsArray = Array.isArray(outputs); const outputNames = (outputsIsArray ? outputs : [outputs]); const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames); const feedDict = new FeedDict(); if (inputs instanceof Tensor) { inputs = [inputs]; } if (Array.isArray(inputs)) { if (inputs.length !== this.inputs.length) { throw new ValueError(`The number of inputs provided (${inputs.length}) ` + `does not match the number of inputs of this model ` + `(${this.inputs.length}).`); } for (let i = 0; i < this.inputs.length; ++i) { feedDict.add(this.inputs[i], inputs[i]); } } else { for (const input of this.inputs) { const tensorValue = inputs[input.name]; if (tensorValue == null) { throw new ValueError(`No value is provided for the model's input ${input.name}`); } feedDict.add(input, tensorValue); } } const executeOutputs = execute(outputSymbolicTensors, feedDict); return outputsIsArray ? executeOutputs : executeOutputs[0]; } retrieveSymbolicTensors(symbolicTensorNames) { const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length); let outputsRemaining = symbolicTensorNames.length; for (const layer of this.layers) { const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output]; const layerOutputNames = layerOutputs.map(output => output.name); for (let i = 0; i < symbolicTensorNames.length; ++i) { const index = layerOutputNames.indexOf(symbolicTensorNames[i]); if (index !== -1) { outputSymbolicTensors[i] = layerOutputs[index]; outputsRemaining--; } if (outputsRemaining === 0) { break; } } if (outputsRemaining === 0) { break; } } if (outputsRemaining > 0) { const remainingNames = []; outputSymbolicTensors.forEach((tensor, i) => { if (tensor == null) { remainingNames.push(symbolicTensorNames[i]); } }); throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` + `${JSON.stringify(remainingNames)}`); } return outputSymbolicTensors; } predictLoop(ins, batchSize = 32, verbose = false) { return tidy(() => { const numSamples = this.checkNumSamples(ins); if (verbose) { throw new NotImplementedError('Verbose predictLoop() is not implemented yet.'); } const batches = makeBatches(numSamples, batchSize); const outsBatches = this.outputs.map(output => []); for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) { const batchOuts = tidy(() => { const batchStart = batches[batchIndex][0]; const batchEnd = batches[batchIndex][1]; const insBatch = sliceArrays(ins, batchStart, batchEnd); const feeds = []; if (Array.isArray(insBatch)) { for (let i = 0; i < insBatch.length; ++i) { feeds.push({ key: this.inputs[i], value: insBatch[i] }); } } else { feeds.push({ key: this.inputs[0], value: insBatch }); } const feedDict = new FeedDict(feeds); return execute(this.outputs, feedDict); }); batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut)); } return singletonOrArray(outsBatches.map(batches => concat$1(batches, 0))); }); } predict(x, args = {}) { const xsRank2OrHigher = ensureTensorsRank2OrHigher(x); checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false); try { const batchSize = args.batchSize == null ? 32 : args.batchSize; checkBatchSize(batchSize); return this.predictLoop(xsRank2OrHigher, batchSize); } finally { disposeNewTensors(xsRank2OrHigher, x); } } predictOnBatch(x) { checkInputData(x, this.inputNames, this.feedInputShapes, true); const batchSize = (Array.isArray(x) ? x[0] : x).shape[0]; return this.predictLoop(x, batchSize); } standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) { if (this.optimizer_ == null) { throw new RuntimeError('You must compile a model before training/testing. Use ' + 'LayersModel.compile(modelCompileArgs).'); } const outputShapes = []; for (let i = 0; i < this.feedOutputShapes.length; ++i) { const outputShape = this.feedOutputShapes[i]; const lossFn = this.feedLossFns[i]; if (lossFn === sparseCategoricalCrossentropy$1) { outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1])); } else { outputShapes.push(outputShape); } } x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input'); y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target'); checkArrayLengths(x, y); checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes); if (this.stateful && batchSize != null && batchSize > 0) { if (x[0].shape[0] % batchSize !== 0) { throw new ValueError(`In a stateful network, you should only pass inputs with a ` + `number of samples that is divisible by the batch size ` + `${batchSize}. Found: ${x[0].shape[0]} sample(s).`); } } return [x, y]; } async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) { const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize); if (sampleWeight != null) { throw new Error('sample weight is not supported yet.'); } let standardSampleWeights = null; if (classWeight != null) { const classWeights = standardizeClassWeights(classWeight, this.outputNames); standardSampleWeights = []; for (let i = 0; i < classWeights.length; ++i) { standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i])); } } return [standardXs, standardYs, standardSampleWeights]; } testLoop(f, ins, batchSize, verbose = 0, steps) { return tidy(() => { const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps'); const outs = []; if (verbose > 0) { throw new NotImplementedError('Verbose mode is not implemented yet.'); } if (steps != null) { throw new NotImplementedError('steps mode in testLoop() is not implemented yet'); } else { const batches = makeBatches(numSamples, batchSize); const indexArray = tensor1d(range(0, numSamples)); for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) { const batchStart = batches[batchIndex][0]; const batchEnd = batches[batchIndex][1]; const batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart); const insBatch = sliceArraysByIndices(ins, batchIds); const batchOuts = f(insBatch); if (batchIndex === 0) { for (let i = 0; i < batchOuts.length; ++i) { outs.push(scalar(0)); } } for (let i = 0; i < batchOuts.length; ++i) { const batchOut = batchOuts[i]; outs[i] = add(outs[i], mul(batchEnd - batchStart, batchOut)); } } for (let i = 0; i < outs.length; ++i) { outs[i] = div(outs[i], numSamples); } } return outs; }); } getDedupedMetricsNames() { const outLabels = this.metricsNames; const dedupedOutLabels = []; for (let i = 0; i < outLabels.length; ++i) { const label = outLabels[i]; let newLabel = label; if (count(outLabels, label) > 1) { const dupIndex = count(outLabels.slice(0, i), label); newLabel += `_${dupIndex}`; } dedupedOutLabels.push(newLabel); } return dedupedOutLabels; } makeTrainFunction() { return (data) => { const lossValues = []; const inputs = data.slice(0, this.inputs.length); const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length); const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2); const metricsValues = []; const totalLossFunction = () => { const feeds = []; for (let i = 0; i < this.inputs.length; ++i) { feeds.push({ key: this.inputs[i], value: inputs[i] }); } const feedDict = new FeedDict(feeds); const outputs = execute(this.outputs, feedDict, { 'training': true }); let totalLoss; for (let i = 0; i < this.lossFunctions.length; ++i) { const lossFunction = this.lossFunctions[i]; let loss = lossFunction(targets[i], outputs[i]); if (sampleWeights[i] != null) { loss = computeWeightedLoss(loss, sampleWeights[i]); } const meanLoss = mean(loss); lossValues.push(meanLoss); if (i === 0) { totalLoss = loss; } else { totalLoss = add(totalLoss, loss); } } for (let i = 0; i < this.metricsTensors.length; ++i) { let weightedMetric; if (this.outputs.length > 1 && i < this.outputs.length) { weightedMetric = lossValues[i]; } else { const metric = this.metricsTensors[i][0]; const outputIndex = this.metricsTensors[i][1]; weightedMetric = mean(metric(targets[outputIndex], outputs[outputIndex])); } keep(weightedMetric); metricsValues.push(weightedMetric); } totalLoss = mean(totalLoss); this.calculateLosses().forEach(regularizerLoss => { totalLoss = add(totalLoss, regularizerLoss); }); return totalLoss; }; const variables = this.collectedTrainableWeights.map(param => param.read()); const returnCost = true; const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables); return [totalLossValue].concat(metricsValues); }; } makeTestFunction() { this.testFunction = (data) => { return tidy(() => { const valOutputs = []; let totalLoss; const inputs = data.slice(0, this.inputs.length); const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length); const feeds = []; for (let i = 0; i < this.inputs.length; ++i) { feeds.push({ key: this.inputs[i], value: inputs[i] }); } const feedDict = new FeedDict(feeds); const outputs = execute(this.outputs, feedDict); for (let i = 0; i < this.lossFunctions.length; ++i) { const lossFunction = this.lossFunctions[i]; const loss = mean(lossFunction(targets[i], outputs[i])); if (i === 0) { totalLoss = loss; } else { totalLoss = add(totalLoss, loss); } valOutputs.push(totalLoss); } for (let i = 0; i < this.metricsTensors.length; ++i) { const metric = this.metricsTensors[i][0]; const outputIndex = this.metricsTensors[i][1]; const meanMetric = mean(metric(targets[outputIndex], outputs[outputIndex])); valOutputs.push(meanMetric); } return valOutputs; }); }; } async fit(x, y, args = {}) { if (this.isTraining) { throw new Error('Cannot start training because another fit() call is ongoing.'); } this.isTraining = true; let inputs; let targets; let originalInputs; let originalTargets; let inputValX; let inputValY; let valX; let valY; let sampleWeights; try { const batchSize = args.batchSize == null ? 32 : args.batchSize; checkBatchSize(batchSize); const checkBatchAxis = false; const standardizedOuts = await this.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize); inputs = standardizedOuts[0]; targets = standardizedOuts[1]; sampleWeights = standardizedOuts[2]; let doValidation = false; let valIns; if (args.validationData != null && args.validationData.length > 0) { doValidation = true; if (args.validationData.length === 2) { inputValX = args.validationData[0]; inputValY = args.validationData[1]; } else if (args.validationData.length === 3) { throw new NotImplementedError('validationData including sample weights is not supported yet.'); } else { throw new ValueError(`When passing validation data, it must contain 2 (valX, valY) ` + `or 3 (valX, valY, valSampleWeight) items; ` + `${args.validationData} is invalid.`); } const checkBatchAxis = true; const valStandardized = await this.standardizeUserData(inputValX, inputValY, null, null, checkBatchAxis, batchSize); valX = valStandardized[0]; valY = valStandardized[1]; valIns = valX.concat(valY); } else if (args.validationSplit != null && args.validationSplit > 0 && args.validationSplit < 1) { doValidation = true; const splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit)); const originalBatchSize = inputs[0].shape[0]; valX = sliceArrays(inputs, splitAt, originalBatchSize); originalInputs = inputs; inputs = sliceArrays(inputs, 0, splitAt); valY = sliceArrays(targets, splitAt, originalBatchSize); originalTargets = targets; targets = sliceArrays(targets, 0, splitAt); valIns = valX.concat(valY); } else if (args.validationSteps != null) { doValidation = true; } const ins = inputs.concat(targets).concat(sampleWeights); this.checkTrainableWeightsConsistency(); const trainFunction = this.makeTrainFunction(); const outLabels = this.getDedupedMetricsNames(); let valFunction; let callbackMetrics; if (doValidation) { this.makeTestFunction(); valFunction = this.testFunction; callbackMetrics = outLabels.slice().concat(outLabels.map(n => 'val_' + n)); } else { valFunction = null; valIns = []; callbackMetrics = outLabels.slice(); } const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery); const out = await this.fitLoop(trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null); return out; } finally { this.isTraining = false; disposeNewTensors(inputs, x); disposeNewTensors(targets, y); disposeNewTensors(originalInputs, x); disposeNewTensors(originalTargets, y); disposeNewTensors(valX, inputValX); disposeNewTensors(valY, inputValY); if (sampleWeights != null) { dispose(sampleWeights); } } } async fitLoop(f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) { if (batchSize == null) { batchSize = 32; } if (epochs == null) { epochs = 1; } if (shuffle$1 == null) { shuffle$1 = true; } if (initialEpoch == null) { initialEpoch = 0; } let doValidation = false; if (valF != null && valIns != null) { doValidation = true; } if (validationSteps != null) { doValidation = true; if (stepsPerEpoch == null) { throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' + 'i.e., `stepsPerEpoch` must be set.'); } } const numTrainSamples = this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch'); let indexArray; if (numTrainSamples != null) { indexArray = range(0, numTrainSamples); } if (verbose == null) { verbose = 1; } const { callbackList, history } = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics); callbackList.setModel(this); this.history = history; await callbackList.onTrainBegin(); this.stopTraining_ = false; for (let epoch = initialEpoch; epoch < epochs; ++epoch) { await callbackList.onEpochBegin(epoch); const epochLogs = {}; if (stepsPerEpoch != null) { throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.'); } else { if (shuffle$1 === 'batch') { throw new NotImplementedError('batch shuffling is not implemneted' + ' yet'); } else if (shuffle$1) { shuffle(indexArray); } const epochIndexArray1D = tensor1d(indexArray); const batches = makeBatches(numTrainSamples, batchSize); for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) { const batchLogs = {}; await callbackList.onBatchBegin(batchIndex, batchLogs); tidy(() => { const batchStart = batches[batchIndex][0]; const batchEnd = batches[batchIndex][1]; const batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart); batchLogs['batch'] = batchIndex; batchLogs['size'] = batchEnd - batchStart; const insBatch = sliceArraysByIndices(ins, batchIds); const outs = f(insBatch); for (let i = 0; i < outLabels.length; ++i) { const label = outLabels[i]; const out = outs[i]; batchLogs[label] = out; keep(out); } if (batchIndex === batches.length - 1) { if (doValidation) { const valOuts = this.testLoop(valF, valIns, batchSize); for (let i = 0; i < outLabels.length; ++i) { const label = outLabels[i]; const out = valOuts[i]; keep(out); epochLogs['val_' + label] = out; } } } }); await callbackList.onBatchEnd(batchIndex, batchLogs); disposeTensorsInLogs(batchLogs); if (this.stopTraining_) { break; } } epochIndexArray1D.dispose(); } await callbackList.onEpochEnd(epoch, epochLogs); if (this.stopTraining_) { break; } } await callbackList.onTrainEnd(); await this.history.syncData(); return this.history; } async fitDataset(dataset, args) { return fitDataset(this, dataset, args); } async trainOnBatch(x, y) { const standardizeOut = await this.standardizeUserData(x, y); const inputs = standardizeOut[0]; const targets = standardizeOut[1]; const trainFunction = this.makeTrainFunction(); const losses = trainFunction(inputs.concat(targets)); const lossValues = []; for (const loss of losses) { const v = await loss.data(); lossValues.push(v[0]); } dispose(losses); disposeNewTensors(standardizeOut[0], x); disposeNewTensors(standardizeOut[1], y); return singletonOrArray(lossValues); } getNamedWeights(config) { const namedWeights = []; const trainableOnly = config != null && config.trainableOnly; const weights = trainableOnly ? this.trainableWeights : this.weights; const weightValues = this.getWeights(trainableOnly); for (let i = 0; i < weights.length; ++i) { if (trainableOnly && !weights[i].trainable) { continue; } namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] }); } return namedWeights; } set stopTraining(stop) { this.stopTraining_ = stop; } get stopTraining() { return this.stopTraining_; } get optimizer() { return this.optimizer_; } set optimizer(optimizer) { if (this.optimizer_ !== optimizer) { this.optimizer_ = optimizer; this.isOptimizerOwned = false; } } dispose() { const result = super.dispose(); if (result.refCountAfterDispose === 0 && this.optimizer != null && this.isOptimizerOwned) { const numTensorsBeforeOptmizerDisposal = memory().numTensors; this.optimizer_.dispose(); result.numDisposedVariables += numTensorsBeforeOptmizerDisposal - memory().numTensors; } return result; } getLossIdentifiers() { let lossNames; if (typeof this.loss === 'string') { lossNames = toSnakeCase(this.loss); } else if (Array.isArray(this.loss)) { for (const loss of this.loss) { if (typeof loss !== 'string') { throw new Error('Serialization of non-string loss is not supported.'); } } lossNames = this.loss.map(name => toSnakeCase(name)); } else { const outputNames = Object.keys(this.loss); lossNames = {}; const losses = this.loss; for (const outputName of outputNames) { if (typeof losses[outputName] === 'string') { lossNames[outputName] = toSnakeCase(losses[outputName]); } else { throw new Error('Serialization of non-string loss is not supported.'); } } } return lossNames; } getMetricIdentifiers() { if (typeof this.metrics === 'string' || typeof this.metrics === 'function') { return [toSnakeCase(getLossOrMetricName(this.metrics))]; } else if (Array.isArray(this.metrics)) { return this.metrics.map(metric => toSnakeCase(getLossOrMetricName(metric))); } else { const metricsIdentifiers = {}; for (const key in this.metrics) { metricsIdentifiers[key] = toSnakeCase(getLossOrMetricName(this.metrics[key])); } return metricsIdentifiers; } } getTrainingConfig() { return { loss: this.getLossIdentifiers(), metrics: this.getMetricIdentifiers(), optimizer_config: { class_name: this.optimizer.getClassName(), config: this.optimizer.getConfig() } }; } loadTrainingConfig(trainingConfig) { if (trainingConfig.weighted_metrics != null) { throw new Error('Loading weight_metrics is not supported yet.'); } if (trainingConfig.loss_weights != null) { throw new Error('Loading loss_weights is not supported yet.'); } if (trainingConfig.sample_weight_mode != null) { throw new Error('Loading sample_weight_mode is not supported yet.'); } const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config); const optimizer = deserialize(tsConfig); let loss; if (typeof trainingConfig.loss === 'string') { loss = toCamelCase(trainingConfig.loss); } else if (Array.isArray(trainingConfig.loss)) { loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry)); } else if (trainingConfig.loss != null) { loss = {}; for (const key in trainingConfig.loss) { loss[key] = toCamelCase(trainingConfig.loss[key]); } } let metrics; if (Array.isArray(trainingConfig.metrics)) { metrics = trainingConfig.metrics.map(metric => toCamelCase(metric)); } else if (trainingConfig.metrics != null) { metrics = {}; for (const key in trainingConfig.metrics) { metrics[key] = toCamelCase(trainingConfig.metrics[key]); } } this.compile({ loss, metrics, optimizer }); } async save(handlerOrURL, config) { if (typeof handlerOrURL === 'string') { const handlers = getSaveHandlers(handlerOrURL); if (handlers.length === 0) { throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`); } else if (handlers.length > 1) { throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` + `URL '${handlerOrURL}'`); } handlerOrURL = handlers[0]; } if (handlerOrURL.save == null) { throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' + 'provided does not have the `save` attribute defined.'); } const weightDataAndSpecs = await encodeWeights(this.getNamedWeights(config)); const returnString = false; const unusedArg = null; const modelConfig = this.toJSON(unusedArg, returnString); const modelArtifacts = { modelTopology: modelConfig, format: LAYERS_MODEL_FORMAT_NAME, generatedBy: `TensorFlow.js tfjs-layers v${version}`, convertedBy: null, }; const includeOptimizer = config == null ? false : config.includeOptimizer; if (includeOptimizer && this.optimizer != null) { modelArtifacts.trainingConfig = this.getTrainingConfig(); const weightType = 'optimizer'; const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await encodeWeights(await this.optimizer.getWeights(), weightType); weightDataAndSpecs.specs.push(...optimizerWeightSpecs); weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]); } if (this.userDefinedMetadata != null) { const checkSize = true; checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize); modelArtifacts.userDefinedMetadata = this.userDefinedMetadata; } modelArtifacts.weightData = weightDataAndSpecs.data; modelArtifacts.weightSpecs = weightDataAndSpecs.specs; return handlerOrURL.save(modelArtifacts); } setUserDefinedMetadata(userDefinedMetadata) { checkUserDefinedMetadata(userDefinedMetadata, this.name); this.userDefinedMetadata = userDefinedMetadata; } getUserDefinedMetadata() { return this.userDefinedMetadata; } } LayersModel.className = 'Model'; registerClass(LayersModel); class Functional extends LayersModel { } Functional.className = 'Functional'; registerClass(Functional); async function loadLayersModelFromIOHandler(handler, customObjects, options) { if (options == null) { options = {}; } if (handler.load == null) { throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' + 'does not have the `load` method implemented.'); } const artifacts = await handler.load(); let modelTopology = artifacts.modelTopology; if (modelTopology['model_config'] != null) { modelTopology = modelTopology['model_config']; } const strict = options.strict == null ? true : options.strict; const fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict; const model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit); const trainingConfig = artifacts.trainingConfig; if (trainingConfig != null) { model.loadTrainingConfig(trainingConfig); } if (artifacts.userDefinedMetadata != null) { model.setUserDefinedMetadata(artifacts.userDefinedMetadata); } if (artifacts.weightData != null) { if (artifacts.weightSpecs == null) { throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' + 'Therefore loading of weights cannot proceed.'); } const { modelWeights, optimizerWeights } = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs); model.loadWeights(modelWeights, strict); if (model.optimizer != null && optimizerWeights.length > 0) { await model.optimizer.setWeights(optimizerWeights); } dispose(modelWeights); dispose(optimizerWeights.map(w => w.tensor)); } return model; } function decodeModelAndOptimizerWeights(weightData, specs) { const name2Tensor = decodeWeights(weightData, specs); const modelWeights = {}; const optimizerWeights = []; specs.forEach(spec => { if (spec.group === 'optimizer') { optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] }); } else { modelWeights[spec.name] = name2Tensor[spec.name]; } }); return { modelWeights, optimizerWeights }; } class Sequential extends LayersModel { constructor(args) { super({ inputs: [], outputs: [] }); args = args || {}; this.trainable = true; this.built = false; this.name = (args.name != null) ? args.name : getUid('sequential_'); if (args.layers != null) { for (const layer of args.layers) { this.add(layer); } } } checkShape(layer) { const shape = layer.inboundNodes[0].outputTensors[0].shape; if (shape.some(x => x < 0)) { throw new ValueError('Negative dimension size caused by adding layer ' + `${layer.name} with input shape [` + `${layer.inboundNodes[0].inputTensors[0].shape}]`); } } add(layer) { const isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel; let modelLayer; if (isLayerModelInstance) { modelLayer = layer; if (modelLayer.outputs.length !== 1) { throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.'); } if (modelLayer.inputs.length !== 1) { throw new ValueError('All layers in a Sequential model ' + 'should have a single input tensor. ' + 'For multi-input layers, ' + 'use the functional API.'); } } if (this.outputs.length === 0) { if (layer.inboundNodes.length === 0) { if (layer.batchInputShape == null) { throw new ValueError('The first layer in a Sequential model must ' + 'get an `inputShape` or `batchInputShape` argument.'); } const x = Input({ batchShape: layer.batchInputShape, dtype: layer.dtype, name: layer.name + '_input' }); layer.apply(x); } if (isLayerModelInstance) { this.outputs = modelLayer.outputs; this.inputs = modelLayer.inputs; } else { if (layer.inboundNodes.length !== 1) { throw new ValueError('A layer added to a Sequential model must not already be ' + `connected somewhere else. LayersModel received layer ${layer.name} ` + `which has ${layer.inboundNodes.length} pre-existing inbound ` + 'connections.'); } if (layer.inboundNodes[0].outputTensors.length !== 1) { throw new ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.'); } this.checkShape(layer); this.outputs = [layer.inboundNodes[0].outputTensors[0]]; this.inputs = getSourceInputs(this.outputs[0]); } this.inboundNodes = []; new Node({ outboundLayer: this, inboundLayers: [], nodeIndices: [], tensorIndices: [], inputTensors: this.inputs, outputTensors: this.outputs, inputMasks: pyListRepeat(null, this.inputs.length), outputMasks: [null], inputShapes: this.inputs.map(x => x.shape), outputShapes: this.outputs[0].shape }); } else { const outputTensor = layer.apply(this.outputs[0]); if (Array.isArray(outputTensor)) { throw new TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.'); } this.checkShape(layer); this.outputs = [outputTensor]; this.inboundNodes[0].outputTensors = this.outputs; this.inboundNodes[0].outputShapes = [this.outputs[0].shape]; } this.layers.push(layer); this.built = false; } pop() { if (this.layers.length === 0) { throw new TypeError('There are no layers in the model.'); } this.layers.pop(); if (this.layers.length === 0) { this.outputs = []; this.inboundNodes = []; this.outboundNodes = []; } else { const lastLayerIndex = this.layers.length - 1; this.layers[lastLayerIndex].outboundNodes = []; this.outputs = [this.layers[lastLayerIndex].output]; this.inboundNodes[0].outputTensors = this.outputs; this.inboundNodes[0].outputShapes = [this.outputs[0].shape]; } } call(inputs, kwargs) { if (this.model == null) { this.build(); } return this.model.call(inputs, kwargs); } build(inputShape) { getExactlyOneShape(inputShape); if (this.inputs.length === 0 || this.outputs.length === 0) { throw new TypeError('Sequential model cannot be built: model is empty.' + ' Add some layers first.'); } this.model = new LayersModel({ inputs: this.inputs, outputs: this.outputs[0], name: this.name + '_model' }); this.model.trainable = this.trainable; this.supportsMasking = this.model.supportsMasking; this.inputLayers = this.model.inputLayers; this.inputLayersNodeIndices = this.model.inputLayersNodeIndices; this.inputLayersTensorIndices = this.model.inputLayersTensorIndices; this.outputLayers = this.model.outputLayers; this.outputLayersNodeIndices = this.model.outputLayersNodeIndices; this.outputLayersTensorIndices = this.model.outputLayersTensorIndices; this.nodesByDepth = this.model.nodesByDepth; this.containerNodes = this.model.containerNodes; this.outputNames = this.model.outputNames; this.inputNames = this.model.inputNames; this.built = true; } countParams() { if (!this.built) { this.build(); } return super.countParams(); } summary(lineLength, positions, printFn = console.log) { if (!this.built) { this.build(); } super.summary(lineLength, positions, printFn); } setWeights(weights) { if (this.model == null) { this.build(); } this.model.setWeights(weights); } evaluate(x, y, args = {}) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before being used.'); } return this.model.evaluate(x, y, args); } async evaluateDataset(dataset, args) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before being used.'); } return this.model.evaluateDataset(dataset, args); } predict(x, args = {}) { if (this.model == null) { this.build(); } return this.model.predict(x, args); } predictOnBatch(x) { if (this.model == null) { this.build(); } return this.model.predictOnBatch(x); } compile(args) { this.build(); this.model.compile(args); this.optimizer_ = this.model.optimizer; this.isOptimizerOwned = this.model.isOptimizerOwned; this.loss = this.model.loss; this.metrics = this.model.metrics; this.metricsTensors = this.model.metricsTensors; this.metricsNames = this.model.metricsNames; } get optimizer() { return this.model == null ? undefined : this.model.optimizer; } set optimizer(optimizer) { this.model.optimizer = optimizer; } async fit(x, y, args = {}) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before ' + 'being used.'); } return this.model.fit(x, y, args); } async fitDataset(dataset, args) { if (!this.built) { throw new RuntimeError('The model needs to be compiled before ' + 'being used.'); } return this.model.fitDataset(dataset, args); } async trainOnBatch(x, y) { return this.model.trainOnBatch(x, y); } static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) { let configArray; let extraModelConfig = {}; if (config instanceof Array) { if (!(config[0].className != null) || config[0]['className'] === 'Merge') { throw new ValueError('Legacy serialization format not supported yet.'); } configArray = config; } else { assert$1(config['layers'] != null, () => `When the config data for a Sequential model is not an Array, ` + `it must be an Object that contains the 'layers' field.`); configArray = config['layers']; delete config['layers']; extraModelConfig = config; } const model = new cls(extraModelConfig); if (!(model instanceof Sequential)) { throw new NotImplementedError(`Sequential.fromConfig called on non-Sequential input: ${model}`); } for (const conf of configArray) { const customObjects = undefined; const layer = deserialize(conf, customObjects, fastWeightInit); if (fastWeightInit) { layer.setFastWeightInitDuringBuild(true); } model.add(layer); } return model; } set stopTraining(stop) { if (this.model == null) { throw new ValueError('Cannot set the stopTraining property of a sequential model before ' + 'it is compiled.'); } this.model.stopTraining = stop; } get stopTraining() { if (this.model == null) { throw new ValueError('Cannot get the stopTraining property of a sequential model before ' + 'it is compiled.'); } return this.model.stopTraining; } getConfig() { const layers = []; for (const layer of this.layers) { const dict = {}; dict['className'] = layer.getClassName(); dict['config'] = layer.getConfig(); layers.push(dict); } return { name: this.name, layers }; } } Sequential.className = 'Sequential'; registerClass(Sequential); function sequential(config) { return new Sequential(config); } let Activation$1 = class Activation extends Serializable { getConfig() { return {}; } }; class Elu extends Activation$1 { apply(x, alpha = 1) { return elu(x, alpha); } } Elu.className = 'elu'; registerClass(Elu); class Selu extends Activation$1 { apply(x) { return selu$1(x); } } Selu.className = 'selu'; registerClass(Selu); class Relu extends Activation$1 { apply(x) { return relu$1(x); } } Relu.className = 'relu'; registerClass(Relu); class Relu6 extends Activation$1 { apply(x) { return tidy(() => minimum$1(6.0, relu$1(x))); } } Relu6.className = 'relu6'; registerClass(Relu6); class Linear extends Activation$1 { apply(x) { return x; } } Linear.className = 'linear'; registerClass(Linear); class Sigmoid extends Activation$1 { apply(x) { return sigmoid$1(x); } } Sigmoid.className = 'sigmoid'; registerClass(Sigmoid); class HardSigmoid extends Activation$1 { apply(x) { return hardSigmoid(x); } } HardSigmoid.className = 'hardSigmoid'; registerClass(HardSigmoid); class Softplus extends Activation$1 { apply(x) { return softplus$1(x); } } Softplus.className = 'softplus'; registerClass(Softplus); class Softsign extends Activation$1 { apply(x) { return softsign(x); } } Softsign.className = 'softsign'; registerClass(Softsign); class Tanh extends Activation$1 { apply(x) { return tanh$1(x); } } Tanh.className = 'tanh'; registerClass(Tanh); class Softmax extends Activation$1 { apply(x, axis = (-1)) { return softmax$1(x, axis); } } Softmax.className = 'softmax'; registerClass(Softmax); class LogSoftmax extends Activation$1 { apply(x, axis = (-1)) { return logSoftmax(x, axis); } } LogSoftmax.className = 'logSoftmax'; registerClass(LogSoftmax); class Gelu extends Activation$1 { apply(x) { return tidy(() => { return tidy(() => { const sqrtTwo = Math.sqrt(2); const cdf = mul(0.5, add(1, erf$1(div(x, sqrtTwo)))); return mul(x, cdf); }); }); } } Gelu.className = 'gelu'; registerClass(Gelu); class GeluNew extends Activation$1 { apply(x) { return tidy(() => { return mul(0.5, mul(x, add(1, tanh$1(mul(sqrt$1(div(2, Math.PI)), add(x, mul(0.044715, pow$1(x, 3)))))))); }); } } GeluNew.className = 'gelu_new'; registerClass(GeluNew); class Mish extends Activation$1 { apply(x) { return tidy(() => mul(x, tanh$1(softplus$1(x)))); } } Mish.className = 'mish'; registerClass(Mish); class Swish extends Activation$1 { apply(x, alpha = 1) { return tidy(() => mul(sigmoid$1(mul(x, alpha)), x)); } } Swish.className = 'swish'; registerClass(Swish); function serializeActivation(activation) { return activation.getClassName(); } function deserializeActivation(config, customObjects = {}) { return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation'); } function getActivation(identifier) { if (identifier == null) { const config = {}; config['className'] = 'linear'; config['config'] = {}; return deserializeActivation(config); } if (typeof identifier === 'string') { const config = {}; config['className'] = identifier; config['config'] = {}; return deserializeActivation(config); } else if (identifier instanceof Activation$1) { return identifier; } else { return deserializeActivation(identifier); } } function assertObjectArgs(args) { if (args != null && typeof args !== 'object') { throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` + `object, but received: ${args}`); } } class Regularizer extends Serializable { } class L1L2 extends Regularizer { constructor(args) { super(); assertObjectArgs(args); this.l1 = args == null || args.l1 == null ? 0.01 : args.l1; this.l2 = args == null || args.l2 == null ? 0.01 : args.l2; this.hasL1 = this.l1 !== 0; this.hasL2 = this.l2 !== 0; } apply(x) { return tidy(() => { let regularization = zeros([1]); if (this.hasL1) { regularization = add(regularization, sum$1(mul(this.l1, abs$1(x)))); } if (this.hasL2) { regularization = add(regularization, sum$1(mul(this.l2, square(x)))); } return reshape$1(regularization, []); }); } getConfig() { return { 'l1': this.l1, 'l2': this.l2 }; } static fromConfig(cls, config) { return new cls({ l1: config['l1'], l2: config['l2'] }); } } L1L2.className = 'L1L2'; registerClass(L1L2); const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = { 'l1l2': 'L1L2' }; function serializeRegularizer(constraint) { return serializeKerasObject(constraint); } function deserializeRegularizer(config, customObjects = {}) { return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer'); } function getRegularizer(identifier) { if (identifier == null) { return null; } if (typeof identifier === 'string') { const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ? REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] : identifier; const config = { className, config: {} }; return deserializeRegularizer(config); } else if (identifier instanceof Regularizer) { return identifier; } else { return deserializeRegularizer(identifier); } } class Dropout extends Layer { constructor(args) { super(args); this.rate = Math.max(Math.min(args.rate, 1), 0); this.noiseShape = args.noiseShape; this.seed = args.seed; this.supportsMasking = true; } getNoiseShape(input) { if (this.noiseShape == null) { return this.noiseShape; } const inputShape = input.shape; const noiseShape = []; for (let i = 0; i < this.noiseShape.length; ++i) { noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]); } return noiseShape; } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); const input = getExactlyOneTensor(inputs); if (0 < this.rate && this.rate < 1) { const training = kwargs['training'] == null ? false : kwargs['training']; const noiseShape = this.getNoiseShape(input); const output = inTrainPhase(() => dropout$1(input, this.rate, noiseShape, this.seed), () => input, training); return output; } return inputs; }); } getConfig() { const config = { rate: this.rate, noiseShape: this.noiseShape, seed: this.seed, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } dispose() { return super.dispose(); } } Dropout.className = 'Dropout'; registerClass(Dropout); class SpatialDropout1D extends Dropout { constructor(args) { super(args); this.inputSpec = [{ ndim: 3 }]; } getNoiseShape(input) { const inputShape = input.shape; return [inputShape[0], 1, inputShape[2]]; } } SpatialDropout1D.className = 'SpatialDropout1D'; registerClass(SpatialDropout1D); class Dense extends Layer { constructor(args) { super(args); this.activation = null; this.useBias = true; this.kernel = null; this.bias = null; this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal'; this.DEFAULT_BIAS_INITIALIZER = 'zeros'; if (args.batchInputShape == null && args.inputShape == null && args.inputDim != null) { let batchSize = null; if (args.batchSize != null) { batchSize = args.batchSize; } this.batchInputShape = [batchSize, args.inputDim]; } this.units = args.units; assertPositiveInteger(this.units, 'units'); this.activation = getActivation(args.activation); if (args.useBias != null) { this.useBias = args.useBias; } this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER); this.biasInitializer = getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER); this.kernelConstraint = getConstraint(args.kernelConstraint); this.biasConstraint = getConstraint(args.biasConstraint); this.kernelRegularizer = getRegularizer(args.kernelRegularizer); this.biasRegularizer = getRegularizer(args.biasRegularizer); this.activityRegularizer = getRegularizer(args.activityRegularizer); this.supportsMasking = true; this.inputSpec = [{ minNDim: 2 }]; } build(inputShape) { inputShape = getExactlyOneShape(inputShape); const inputLastDim = inputShape[inputShape.length - 1]; if (this.kernel == null) { this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint); if (this.useBias) { this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint); } } this.inputSpec = [{ minNDim: 2, axes: { [-1]: inputLastDim } }]; this.built = true; } computeOutputShape(inputShape) { inputShape = getExactlyOneShape(inputShape); const outputShape = inputShape.slice(); outputShape[outputShape.length - 1] = this.units; return outputShape; } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); const input = getExactlyOneTensor(inputs); const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName()); let output; if (fusedActivationName != null) { output = dot(input, this.kernel.read(), fusedActivationName, this.bias ? this.bias.read() : null); } else { output = dot(input, this.kernel.read()); if (this.bias != null) { output = biasAdd(output, this.bias.read()); } if (this.activation != null) { output = this.activation.apply(output); } } return output; }); } getConfig() { const config = { units: this.units, activation: serializeActivation(this.activation), useBias: this.useBias, kernelInitializer: serializeInitializer(this.kernelInitializer), biasInitializer: serializeInitializer(this.biasInitializer), kernelRegularizer: serializeRegularizer(this.kernelRegularizer), biasRegularizer: serializeRegularizer(this.biasRegularizer), activityRegularizer: serializeRegularizer(this.activityRegularizer), kernelConstraint: serializeConstraint(this.kernelConstraint), biasConstraint: serializeConstraint(this.biasConstraint) }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } Dense.className = 'Dense'; registerClass(Dense); class Flatten extends Layer { constructor(args) { args = args || {}; super(args); this.inputSpec = [{ minNDim: 3 }]; this.dataFormat = args.dataFormat; } computeOutputShape(inputShape) { inputShape = getExactlyOneShape(inputShape); for (const dim of inputShape.slice(1)) { if (dim == null) { throw new ValueError(`The shape of the input to "Flatten" is not fully defined ` + `(got ${inputShape.slice(1)}). Make sure to pass a complete ` + `"input_shape" or "batch_input_shape" argument to the first ` + `layer in your model.`); } } return [inputShape[0], arrayProd(inputShape, 1)]; } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); let input = getExactlyOneTensor(inputs); if (this.dataFormat === 'channelsFirst' && input.rank > 1) { const permutation = [0]; for (let i = 2; i < input.rank; ++i) { permutation.push(i); } permutation.push(1); input = transpose$1(input, permutation); } return batchFlatten(input); }); } getConfig() { const config = {}; if (this.dataFormat != null) { config['dataFormat'] = this.dataFormat; } const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } Flatten.className = 'Flatten'; registerClass(Flatten); class Activation extends Layer { constructor(args) { super(args); this.supportsMasking = true; this.activation = getActivation(args.activation); } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); const input = getExactlyOneTensor(inputs); return this.activation.apply(input); }); } getConfig() { const config = { activation: serializeActivation(this.activation) }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } Activation.className = 'Activation'; registerClass(Activation); class RepeatVector extends Layer { constructor(args) { super(args); this.n = args.n; this.inputSpec = [{ ndim: 2 }]; } computeOutputShape(inputShape) { return [inputShape[0], this.n, inputShape[1]]; } call(inputs, kwargs) { return tidy(() => { inputs = getExactlyOneTensor(inputs); return repeat(inputs, this.n); }); } getConfig() { const config = { n: this.n, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } RepeatVector.className = 'RepeatVector'; registerClass(RepeatVector); class Reshape extends Layer { constructor(args) { super(args); this.targetShape = args.targetShape; for (let i = 0; i < this.targetShape.length; ++i) { if (this.isUnknown(this.targetShape[i])) { this.targetShape[i] = null; } } } isUnknown(dim) { return dim < 0 || dim == null; } fixUnknownDimension(inputShape, outputShape) { const errorMsg = 'Total size of new array must be unchanged.'; const finalShape = outputShape.slice(); let known = 1; let unknown = null; for (let i = 0; i < finalShape.length; ++i) { const dim = finalShape[i]; if (this.isUnknown(dim)) { if (unknown === null) { unknown = i; } else { throw new ValueError('Can only specifiy one unknown dimension.'); } } else { known *= dim; } } const originalSize = arrayProd(inputShape); if (unknown !== null) { if (known === 0 || originalSize % known !== 0) { throw new ValueError(errorMsg); } finalShape[unknown] = originalSize / known; } else if (originalSize !== known) { throw new ValueError(errorMsg); } return finalShape; } computeOutputShape(inputShape) { let anyUnknownDims = false; for (let i = 0; i < inputShape.length; ++i) { if (this.isUnknown(inputShape[i])) { anyUnknownDims = true; break; } } if (anyUnknownDims) { return inputShape.slice(0, 1).concat(this.targetShape); } else { return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape)); } } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); const input = getExactlyOneTensor(inputs); const inputShape = input.shape; const outputShape = inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape)); return reshape$1(input, outputShape); }); } getConfig() { const config = { targetShape: this.targetShape, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } Reshape.className = 'Reshape'; registerClass(Reshape); class Permute extends Layer { constructor(args) { super(args); if (args.dims == null) { throw new Error('Required configuration field `dims` is missing during Permute ' + 'constructor call.'); } if (!Array.isArray(args.dims)) { throw new Error('Permute constructor requires `dims` to be an Array, but received ' + `${args.dims} instead.`); } const expectedSortedIndices = range(1, args.dims.length + 1); if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) { throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) + ' `dims` must contain consecutive integers starting from 1.'); } this.dims = args.dims; this.dimsIncludingBatch = [0].concat(this.dims); this.inputSpec = [new InputSpec({ ndim: this.dims.length + 1 })]; } computeOutputShape(inputShape) { inputShape = getExactlyOneShape(inputShape); const outputShape = inputShape.slice(); this.dims.forEach((dim, i) => { outputShape[i + 1] = inputShape[dim]; }); return outputShape; } call(inputs, kwargs) { return transpose$1(getExactlyOneTensor(inputs), this.dimsIncludingBatch); } getConfig() { const config = { dims: this.dims, }; const baseConfig = super.getConfig(); Object.assign(config, baseConfig); return config; } } Permute.className = 'Permute'; registerClass(Permute); class Masking extends Layer { constructor(args) { super(args == null ? {} : args); this.supportsMasking = true; if (args != null) { this.maskValue = args.maskValue == null ? 0 : args.maskValue; } else { this.maskValue = 0; } } computeOutputShape(inputShape) { return inputShape; } getConfig() { const baseConfig = super.getConfig(); const config = { maskValue: this.maskValue }; Object.assign(config, baseConfig); return config; } computeMask(inputs, mask) { const input = getExactlyOneTensor(inputs); const axis = -1; return any$1(notEqual$1(input, this.maskValue), axis); } call(inputs, kwargs) { return tidy(() => { this.invokeCallHook(inputs, kwargs); const input = getExactlyOneTensor(inputs); const axis = -1; const keepDims = true; const booleanMask = any$1(notEqual$1(input, this.maskValue), axis, keepDims); const output = mul(input, cast$2(booleanMask, input.dtype)); return output; }); } } Masking.className = 'Masking'; registerClass(Masking); function dense(args) { return new Dense(args); } function dropout(args) { return new Dropout(args); } export { LayersModel, dense, dropout, enableProdMode, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$1 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler };