mirror of
https://github.com/tutao/tutanota.git
synced 2025-12-07 13:49:47 +00:00
When the webgl backend is not available or unsupported, we fall back to the tensorflow cpu backend. Tensorflow cpu backend library review by abp and jhm. Co-authored-by: jomapp <17314077+jomapp@users.noreply.github.com>
45262 lines
1.5 MiB
Vendored
45262 lines
1.5 MiB
Vendored
function _mergeNamespaces(n, m) {
|
|
m.forEach(function (e) {
|
|
e && typeof e !== 'string' && !Array.isArray(e) && Object.keys(e).forEach(function (k) {
|
|
if (k !== 'default' && !(k in n)) {
|
|
var d = Object.getOwnPropertyDescriptor(e, k);
|
|
Object.defineProperty(n, k, d.get ? d : {
|
|
enumerable: true,
|
|
get: function () { return e[k]; }
|
|
});
|
|
}
|
|
});
|
|
});
|
|
return Object.freeze(n);
|
|
}
|
|
|
|
|
|
const EPSILON_FLOAT32$1 = 1e-7;
|
|
const EPSILON_FLOAT16$1 = 1e-4;
|
|
|
|
class DataStorage {
|
|
constructor(backend, dataMover) {
|
|
this.backend = backend;
|
|
this.dataMover = dataMover;
|
|
this.data = new WeakMap();
|
|
this.dataIdsCount = 0;
|
|
}
|
|
get(dataId) {
|
|
if (!this.data.has(dataId)) {
|
|
this.dataMover.moveData(this.backend, dataId);
|
|
}
|
|
return this.data.get(dataId);
|
|
}
|
|
set(dataId, value) {
|
|
this.dataIdsCount++;
|
|
this.data.set(dataId, value);
|
|
}
|
|
has(dataId) {
|
|
return this.data.has(dataId);
|
|
}
|
|
delete(dataId) {
|
|
this.dataIdsCount--;
|
|
return this.data.delete(dataId);
|
|
}
|
|
numDataIds() {
|
|
return this.dataIdsCount;
|
|
}
|
|
}
|
|
|
|
class KernelBackend {
|
|
refCount(dataId) {
|
|
return notYetImplemented('refCount');
|
|
}
|
|
incRef(dataId) {
|
|
return notYetImplemented('incRef');
|
|
}
|
|
timerAvailable() {
|
|
return true;
|
|
}
|
|
time(f) {
|
|
return notYetImplemented('time');
|
|
}
|
|
read(dataId) {
|
|
return notYetImplemented('read');
|
|
}
|
|
readSync(dataId) {
|
|
return notYetImplemented('readSync');
|
|
}
|
|
readToGPU(dataId, options) {
|
|
return notYetImplemented('readToGPU');
|
|
}
|
|
numDataIds() {
|
|
return notYetImplemented('numDataIds');
|
|
}
|
|
disposeData(dataId, force) {
|
|
return notYetImplemented('disposeData');
|
|
}
|
|
write(values, shape, dtype) {
|
|
return notYetImplemented('write');
|
|
}
|
|
move(dataId, values, shape, dtype, refCount) {
|
|
return notYetImplemented('move');
|
|
}
|
|
createTensorFromGPUData(values, shape, dtype) {
|
|
return notYetImplemented('createTensorFromGPUData');
|
|
}
|
|
memory() {
|
|
return notYetImplemented('memory');
|
|
}
|
|
|
|
floatPrecision() {
|
|
return notYetImplemented('floatPrecision');
|
|
}
|
|
|
|
epsilon() {
|
|
return this.floatPrecision() === 32 ? EPSILON_FLOAT32$1 : EPSILON_FLOAT16$1;
|
|
}
|
|
dispose() {
|
|
return notYetImplemented('dispose');
|
|
}
|
|
}
|
|
function notYetImplemented(kernelName) {
|
|
throw new Error(`'${kernelName}' not yet implemented or not found in the registry. ` +
|
|
`This kernel may not be supported by the tfjs backend you have chosen`);
|
|
}
|
|
|
|
|
|
|
|
|
|
function shuffle(array) {
|
|
let counter = array.length;
|
|
let index = 0;
|
|
|
|
while (counter > 0) {
|
|
|
|
index = (Math.random() * counter) | 0;
|
|
|
|
counter--;
|
|
|
|
swap(array, counter, index);
|
|
}
|
|
}
|
|
|
|
function clamp(min, x, max) {
|
|
return Math.max(min, Math.min(x, max));
|
|
}
|
|
function nearestLargerEven(val) {
|
|
return val % 2 === 0 ? val : val + 1;
|
|
}
|
|
function swap(object, left, right) {
|
|
const temp = object[left];
|
|
object[left] = object[right];
|
|
object[right] = temp;
|
|
}
|
|
function sum$3(arr) {
|
|
let sum = 0;
|
|
for (let i = 0; i < arr.length; i++) {
|
|
sum += arr[i];
|
|
}
|
|
return sum;
|
|
}
|
|
|
|
function assert$1(expr, msg) {
|
|
if (!expr) {
|
|
throw new Error(typeof msg === 'string' ? msg : msg());
|
|
}
|
|
}
|
|
function assertShapesMatch(shapeA, shapeB, errorMessagePrefix = '') {
|
|
assert$1(arraysEqual(shapeA, shapeB), () => errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
|
|
}
|
|
function assertNonNull(a) {
|
|
assert$1(a != null, () => `The input to the tensor constructor must be a non-null value.`);
|
|
}
|
|
|
|
function sizeFromShape(shape) {
|
|
if (shape.length === 0) {
|
|
|
|
return 1;
|
|
}
|
|
let size = shape[0];
|
|
for (let i = 1; i < shape.length; i++) {
|
|
size *= shape[i];
|
|
}
|
|
return size;
|
|
}
|
|
function arraysEqual(n1, n2) {
|
|
if (n1 === n2) {
|
|
return true;
|
|
}
|
|
if (n1 == null || n2 == null) {
|
|
return false;
|
|
}
|
|
if (n1.length !== n2.length) {
|
|
return false;
|
|
}
|
|
for (let i = 0; i < n1.length; i++) {
|
|
if (n1[i] !== n2[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
function isInt(a) {
|
|
return a % 1 === 0;
|
|
}
|
|
function sizeToSquarishShape(size) {
|
|
const width = Math.ceil(Math.sqrt(size));
|
|
return [width, Math.ceil(size / width)];
|
|
}
|
|
function rightPad(a, size) {
|
|
if (size <= a.length) {
|
|
return a;
|
|
}
|
|
return a + ' '.repeat(size - a.length);
|
|
}
|
|
function repeatedTry(checkFn, delayFn = (counter) => 0, maxCounter, scheduleFn) {
|
|
return new Promise((resolve, reject) => {
|
|
let tryCount = 0;
|
|
const tryFn = () => {
|
|
if (checkFn()) {
|
|
resolve();
|
|
return;
|
|
}
|
|
tryCount++;
|
|
const nextBackoff = delayFn(tryCount);
|
|
if (maxCounter != null && tryCount >= maxCounter) {
|
|
reject();
|
|
return;
|
|
}
|
|
if (scheduleFn != null) {
|
|
scheduleFn(tryFn, nextBackoff);
|
|
}
|
|
else {
|
|
|
|
|
|
setTimeout(tryFn, nextBackoff);
|
|
}
|
|
};
|
|
tryFn();
|
|
});
|
|
}
|
|
|
|
function inferFromImplicitShape(shape, size) {
|
|
let shapeProd = 1;
|
|
let implicitIdx = -1;
|
|
for (let i = 0; i < shape.length; ++i) {
|
|
if (shape[i] >= 0) {
|
|
shapeProd *= shape[i];
|
|
}
|
|
else if (shape[i] === -1) {
|
|
if (implicitIdx !== -1) {
|
|
throw Error(`Shapes can only have 1 implicit size. ` +
|
|
`Found -1 at dim ${implicitIdx} and dim ${i}`);
|
|
}
|
|
implicitIdx = i;
|
|
}
|
|
else if (shape[i] < 0) {
|
|
throw Error(`Shapes can not be < 0. Found ${shape[i]} at dim ${i}`);
|
|
}
|
|
}
|
|
if (implicitIdx === -1) {
|
|
if (size > 0 && size !== shapeProd) {
|
|
throw Error(`Size(${size}) must match the product of shape ${shape}`);
|
|
}
|
|
return shape;
|
|
}
|
|
if (shapeProd === 0) {
|
|
throw Error(`Cannot infer the missing size in [${shape}] when ` +
|
|
`there are 0 elements`);
|
|
}
|
|
if (size % shapeProd !== 0) {
|
|
throw Error(`The implicit shape can't be a fractional number. ` +
|
|
`Got ${size} / ${shapeProd}`);
|
|
}
|
|
const newShape = shape.slice();
|
|
newShape[implicitIdx] = size / shapeProd;
|
|
return newShape;
|
|
}
|
|
function parseAxisParam(axis, shape) {
|
|
const rank = shape.length;
|
|
|
|
axis = axis == null ? shape.map((s, i) => i) : [].concat(axis);
|
|
|
|
assert$1(axis.every(ax => ax >= -rank && ax < rank), () => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
|
|
`got axis ${axis}`);
|
|
|
|
assert$1(axis.every(ax => isInt(ax)), () => `All values in axis param must be integers but ` +
|
|
`got axis ${axis}`);
|
|
|
|
return axis.map(a => a < 0 ? rank + a : a);
|
|
}
|
|
|
|
function squeezeShape(shape, axis) {
|
|
const newShape = [];
|
|
const keptDims = [];
|
|
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
|
|
const axes = (axis == null || isEmptyArray) ?
|
|
null :
|
|
parseAxisParam(axis, shape).sort();
|
|
let j = 0;
|
|
for (let i = 0; i < shape.length; ++i) {
|
|
if (axes != null) {
|
|
if (axes[j] === i && shape[i] !== 1) {
|
|
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
|
|
}
|
|
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
|
|
newShape.push(shape[i]);
|
|
keptDims.push(i);
|
|
}
|
|
if (axes[j] <= i) {
|
|
j++;
|
|
}
|
|
}
|
|
if (shape[i] !== 1) {
|
|
newShape.push(shape[i]);
|
|
keptDims.push(i);
|
|
}
|
|
}
|
|
return { newShape, keptDims };
|
|
}
|
|
function getTypedArrayFromDType(dtype, size) {
|
|
return getArrayFromDType(dtype, size);
|
|
}
|
|
function getArrayFromDType(dtype, size) {
|
|
let values = null;
|
|
if (dtype == null || dtype === 'float32') {
|
|
values = new Float32Array(size);
|
|
}
|
|
else if (dtype === 'int32') {
|
|
values = new Int32Array(size);
|
|
}
|
|
else if (dtype === 'bool') {
|
|
values = new Uint8Array(size);
|
|
}
|
|
else if (dtype === 'string') {
|
|
values = new Array(size);
|
|
}
|
|
else {
|
|
throw new Error(`Unknown data type ${dtype}`);
|
|
}
|
|
return values;
|
|
}
|
|
function checkConversionForErrors(vals, dtype) {
|
|
for (let i = 0; i < vals.length; i++) {
|
|
const num = vals[i];
|
|
if (isNaN(num) || !isFinite(num)) {
|
|
throw Error(`A tensor of type ${dtype} being uploaded contains ${num}.`);
|
|
}
|
|
}
|
|
}
|
|
|
|
function isValidDtype(dtype) {
|
|
return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' ||
|
|
dtype === 'int32' || dtype === 'string';
|
|
}
|
|
|
|
function hasEncodingLoss(oldType, newType) {
|
|
if (newType === 'complex64') {
|
|
return false;
|
|
}
|
|
if (newType === 'float32' && oldType !== 'complex64') {
|
|
return false;
|
|
}
|
|
if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') {
|
|
return false;
|
|
}
|
|
if (newType === 'bool' && oldType === 'bool') {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
function bytesPerElement(dtype) {
|
|
if (dtype === 'float32' || dtype === 'int32') {
|
|
return 4;
|
|
}
|
|
else if (dtype === 'complex64') {
|
|
return 8;
|
|
}
|
|
else if (dtype === 'bool') {
|
|
return 1;
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dtype ${dtype}`);
|
|
}
|
|
}
|
|
|
|
function bytesFromStringArray(arr) {
|
|
if (arr == null) {
|
|
return 0;
|
|
}
|
|
let bytes = 0;
|
|
arr.forEach(x => bytes += x.length);
|
|
return bytes;
|
|
}
|
|
|
|
function isString(value) {
|
|
return typeof value === 'string' || value instanceof String;
|
|
}
|
|
function isBoolean(value) {
|
|
return typeof value === 'boolean';
|
|
}
|
|
function isNumber(value) {
|
|
return typeof value === 'number';
|
|
}
|
|
function inferDtype(values) {
|
|
if (Array.isArray(values)) {
|
|
return inferDtype(values[0]);
|
|
}
|
|
if (values instanceof Float32Array) {
|
|
return 'float32';
|
|
}
|
|
else if (values instanceof Int32Array || values instanceof Uint8Array ||
|
|
values instanceof Uint8ClampedArray) {
|
|
return 'int32';
|
|
}
|
|
else if (isNumber(values)) {
|
|
return 'float32';
|
|
}
|
|
else if (isString(values)) {
|
|
return 'string';
|
|
}
|
|
else if (isBoolean(values)) {
|
|
return 'bool';
|
|
}
|
|
return 'float32';
|
|
}
|
|
function isFunction(f) {
|
|
return !!(f && f.constructor && f.call && f.apply);
|
|
}
|
|
function nearestDivisor(size, start) {
|
|
for (let i = start; i < size; ++i) {
|
|
if (size % i === 0) {
|
|
return i;
|
|
}
|
|
}
|
|
return size;
|
|
}
|
|
function computeStrides(shape) {
|
|
const rank = shape.length;
|
|
if (rank < 2) {
|
|
return [];
|
|
}
|
|
|
|
|
|
const strides = new Array(rank - 1);
|
|
strides[rank - 2] = shape[rank - 1];
|
|
for (let i = rank - 3; i >= 0; --i) {
|
|
strides[i] = strides[i + 1] * shape[i + 1];
|
|
}
|
|
return strides;
|
|
}
|
|
function createNestedArray(offset, shape, a, isComplex = false) {
|
|
const ret = new Array();
|
|
if (shape.length === 1) {
|
|
const d = shape[0] * (isComplex ? 2 : 1);
|
|
for (let i = 0; i < d; i++) {
|
|
ret[i] = a[offset + i];
|
|
}
|
|
}
|
|
else {
|
|
const d = shape[0];
|
|
const rest = shape.slice(1);
|
|
const len = rest.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
|
|
for (let i = 0; i < d; i++) {
|
|
ret[i] = createNestedArray(offset + i * len, rest, a, isComplex);
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
function toNestedArray(shape, a, isComplex = false) {
|
|
if (shape.length === 0) {
|
|
|
|
return a[0];
|
|
}
|
|
const size = shape.reduce((acc, c) => acc * c) * (isComplex ? 2 : 1);
|
|
if (size === 0) {
|
|
|
|
return [];
|
|
}
|
|
if (size !== a.length) {
|
|
throw new Error(`[${shape}] does not match the input size ${a.length}${isComplex ? ' for a complex tensor' : ''}.`);
|
|
}
|
|
return createNestedArray(0, shape, a, isComplex);
|
|
}
|
|
function convertBackendValuesAndArrayBuffer(data, dtype) {
|
|
|
|
if (Array.isArray(data)) {
|
|
return data;
|
|
}
|
|
if (dtype === 'float32') {
|
|
return data instanceof Float32Array ? data : new Float32Array(data);
|
|
}
|
|
else if (dtype === 'int32') {
|
|
return data instanceof Int32Array ? data : new Int32Array(data);
|
|
}
|
|
else if (dtype === 'bool' || dtype === 'string') {
|
|
return Uint8Array.from(new Int32Array(data));
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dtype ${dtype}`);
|
|
}
|
|
}
|
|
function makeOnesTypedArray(size, dtype) {
|
|
const array = makeZerosTypedArray(size, dtype);
|
|
for (let i = 0; i < array.length; i++) {
|
|
array[i] = 1;
|
|
}
|
|
return array;
|
|
}
|
|
function makeZerosTypedArray(size, dtype) {
|
|
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
|
|
return new Float32Array(size);
|
|
}
|
|
else if (dtype === 'int32') {
|
|
return new Int32Array(size);
|
|
}
|
|
else if (dtype === 'bool') {
|
|
return new Uint8Array(size);
|
|
}
|
|
else {
|
|
throw new Error(`Unknown data type ${dtype}`);
|
|
}
|
|
}
|
|
|
|
function makeZerosNestedTypedArray(shape, dtype) {
|
|
const size = shape.reduce((prev, curr) => prev * curr, 1);
|
|
if (dtype == null || dtype === 'float32') {
|
|
return toNestedArray(shape, new Float32Array(size));
|
|
}
|
|
else if (dtype === 'int32') {
|
|
return toNestedArray(shape, new Int32Array(size));
|
|
}
|
|
else if (dtype === 'bool') {
|
|
return toNestedArray(shape, new Uint8Array(size));
|
|
}
|
|
else {
|
|
throw new Error(`Unknown data type ${dtype}`);
|
|
}
|
|
}
|
|
function assertNonNegativeIntegerDimensions(shape) {
|
|
shape.forEach(dimSize => {
|
|
assert$1(Number.isInteger(dimSize) && dimSize >= 0, () => `Tensor must have a shape comprised of positive integers but got ` +
|
|
`shape [${shape}].`);
|
|
});
|
|
}
|
|
|
|
function locToIndex(locs, rank, strides) {
|
|
if (rank === 0) {
|
|
return 0;
|
|
}
|
|
else if (rank === 1) {
|
|
return locs[0];
|
|
}
|
|
let index = locs[locs.length - 1];
|
|
for (let i = 0; i < locs.length - 1; ++i) {
|
|
index += strides[i] * locs[i];
|
|
}
|
|
return index;
|
|
}
|
|
|
|
function indexToLoc(index, rank, strides) {
|
|
if (rank === 0) {
|
|
return [];
|
|
}
|
|
else if (rank === 1) {
|
|
return [index];
|
|
}
|
|
const locs = new Array(rank);
|
|
for (let i = 0; i < locs.length - 1; ++i) {
|
|
locs[i] = Math.floor(index / strides[i]);
|
|
index -= locs[i] * strides[i];
|
|
}
|
|
locs[locs.length - 1] = index;
|
|
return locs;
|
|
}
|
|
|
|
|
|
function isPromise(object) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return object && object.then && typeof object.then === 'function';
|
|
}
|
|
|
|
|
|
|
|
const TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags';
|
|
|
|
class Environment {
|
|
|
|
constructor(global) {
|
|
this.global = global;
|
|
this.flags = {};
|
|
this.flagRegistry = {};
|
|
this.urlFlags = {};
|
|
|
|
this.getQueryParams = getQueryParams;
|
|
this.populateURLFlags();
|
|
}
|
|
setPlatform(platformName, platform) {
|
|
if (this.platform != null) {
|
|
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
|
|
console.warn(`Platform ${this.platformName} has already been set. ` +
|
|
`Overwriting the platform with ${platformName}.`);
|
|
}
|
|
}
|
|
this.platformName = platformName;
|
|
this.platform = platform;
|
|
}
|
|
registerFlag(flagName, evaluationFn, setHook) {
|
|
this.flagRegistry[flagName] = { evaluationFn, setHook };
|
|
|
|
|
|
if (this.urlFlags[flagName] != null) {
|
|
const flagValue = this.urlFlags[flagName];
|
|
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
|
|
console.warn(`Setting feature override from URL ${flagName}: ${flagValue}.`);
|
|
}
|
|
this.set(flagName, flagValue);
|
|
}
|
|
}
|
|
async getAsync(flagName) {
|
|
if (flagName in this.flags) {
|
|
return this.flags[flagName];
|
|
}
|
|
this.flags[flagName] = await this.evaluateFlag(flagName);
|
|
return this.flags[flagName];
|
|
}
|
|
get(flagName) {
|
|
if (flagName in this.flags) {
|
|
return this.flags[flagName];
|
|
}
|
|
const flagValue = this.evaluateFlag(flagName);
|
|
if (isPromise(flagValue)) {
|
|
throw new Error(`Flag ${flagName} cannot be synchronously evaluated. ` +
|
|
`Please use getAsync() instead.`);
|
|
}
|
|
this.flags[flagName] = flagValue;
|
|
return this.flags[flagName];
|
|
}
|
|
getNumber(flagName) {
|
|
return this.get(flagName);
|
|
}
|
|
getBool(flagName) {
|
|
return this.get(flagName);
|
|
}
|
|
getString(flagName) {
|
|
return this.get(flagName);
|
|
}
|
|
getFlags() {
|
|
return this.flags;
|
|
}
|
|
|
|
get features() {
|
|
return this.flags;
|
|
}
|
|
set(flagName, value) {
|
|
if (this.flagRegistry[flagName] == null) {
|
|
throw new Error(`Cannot set flag ${flagName} as it has not been registered.`);
|
|
}
|
|
this.flags[flagName] = value;
|
|
if (this.flagRegistry[flagName].setHook != null) {
|
|
this.flagRegistry[flagName].setHook(value);
|
|
}
|
|
}
|
|
evaluateFlag(flagName) {
|
|
if (this.flagRegistry[flagName] == null) {
|
|
throw new Error(`Cannot evaluate flag '${flagName}': no evaluation function found.`);
|
|
}
|
|
return this.flagRegistry[flagName].evaluationFn();
|
|
}
|
|
setFlags(flags) {
|
|
this.flags = Object.assign({}, flags);
|
|
}
|
|
reset() {
|
|
this.flags = {};
|
|
this.urlFlags = {};
|
|
this.populateURLFlags();
|
|
}
|
|
populateURLFlags() {
|
|
if (typeof this.global === 'undefined' ||
|
|
typeof this.global.location === 'undefined' ||
|
|
typeof this.global.location.search === 'undefined') {
|
|
return;
|
|
}
|
|
const urlParams = this.getQueryParams(this.global.location.search);
|
|
if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) {
|
|
const keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(',');
|
|
keyValues.forEach(keyValue => {
|
|
const [key, value] = keyValue.split(':');
|
|
this.urlFlags[key] = parseValue(key, value);
|
|
});
|
|
}
|
|
}
|
|
}
|
|
function getQueryParams(queryString) {
|
|
const params = {};
|
|
queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, (s, ...t) => {
|
|
decodeParam(params, t[0], t[1]);
|
|
return t.join('=');
|
|
});
|
|
return params;
|
|
}
|
|
function decodeParam(params, name, value) {
|
|
params[decodeURIComponent(name)] = decodeURIComponent(value || '');
|
|
}
|
|
function parseValue(flagName, value) {
|
|
const lowerCaseValue = value.toLowerCase();
|
|
if (lowerCaseValue === 'true' || lowerCaseValue === 'false') {
|
|
return lowerCaseValue === 'true';
|
|
}
|
|
else if (`${+lowerCaseValue}` === lowerCaseValue) {
|
|
return +lowerCaseValue;
|
|
}
|
|
else {
|
|
return value;
|
|
}
|
|
}
|
|
|
|
function env() {
|
|
return ENV$2;
|
|
}
|
|
let ENV$2 = null;
|
|
function setEnvironmentGlobal(environment) {
|
|
ENV$2 = environment;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let globalNameSpace;
|
|
|
|
function getGlobalNamespace() {
|
|
if (globalNameSpace == null) {
|
|
|
|
let ns;
|
|
if (typeof (window) !== 'undefined') {
|
|
ns = window;
|
|
}
|
|
else if (typeof (global) !== 'undefined') {
|
|
ns = global;
|
|
}
|
|
else if (typeof (process) !== 'undefined') {
|
|
ns = process;
|
|
}
|
|
else if (typeof (self) !== 'undefined') {
|
|
ns = self;
|
|
}
|
|
else {
|
|
throw new Error('Could not find a global object');
|
|
}
|
|
globalNameSpace = ns;
|
|
}
|
|
return globalNameSpace;
|
|
}
|
|
|
|
function getGlobalMap() {
|
|
const ns = getGlobalNamespace();
|
|
if (ns._tfGlobals == null) {
|
|
ns._tfGlobals = new Map();
|
|
}
|
|
return ns._tfGlobals;
|
|
}
|
|
|
|
function getGlobal(key, init) {
|
|
const globalMap = getGlobalMap();
|
|
if (globalMap.has(key)) {
|
|
return globalMap.get(key);
|
|
}
|
|
else {
|
|
const singleton = init();
|
|
globalMap.set(key, singleton);
|
|
return globalMap.get(key);
|
|
}
|
|
}
|
|
|
|
const Abs = 'Abs';
|
|
const Acos = 'Acos';
|
|
const Acosh = 'Acosh';
|
|
const Add = 'Add';
|
|
const AddN = 'AddN';
|
|
const All = 'All';
|
|
const Any = 'Any';
|
|
const ArgMax = 'ArgMax';
|
|
const ArgMin = 'ArgMin';
|
|
const Asin = 'Asin';
|
|
const Asinh = 'Asinh';
|
|
const Atan = 'Atan';
|
|
const Atanh = 'Atanh';
|
|
const Atan2 = 'Atan2';
|
|
const AvgPool = 'AvgPool';
|
|
const AvgPoolGrad = 'AvgPoolGrad';
|
|
const AvgPool3D = 'AvgPool3D';
|
|
const AvgPool3DGrad = 'AvgPool3DGrad';
|
|
const BatchMatMul = 'BatchMatMul';
|
|
const BatchToSpaceND = 'BatchToSpaceND';
|
|
const Bincount = 'Bincount';
|
|
const BitwiseAnd = 'BitwiseAnd';
|
|
const BroadcastTo = 'BroadcastTo';
|
|
const BroadcastArgs = 'BroadcastArgs';
|
|
const Cast = 'Cast';
|
|
const Ceil = 'Ceil';
|
|
const ClipByValue = 'ClipByValue';
|
|
const Complex = 'Complex';
|
|
const ComplexAbs = 'ComplexAbs';
|
|
const Concat = 'Concat';
|
|
const Conv2D = 'Conv2D';
|
|
const Conv2DBackpropFilter = 'Conv2DBackpropFilter';
|
|
const Conv2DBackpropInput = 'Conv2DBackpropInput';
|
|
const Conv3D = 'Conv3D';
|
|
const Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2';
|
|
const Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2';
|
|
const Cos = 'Cos';
|
|
const Cosh = 'Cosh';
|
|
const Cumprod = 'Cumprod';
|
|
const Cumsum = 'Cumsum';
|
|
const CropAndResize = 'CropAndResize';
|
|
const DenseBincount = 'DenseBincount';
|
|
const DepthToSpace = 'DepthToSpace';
|
|
const DepthwiseConv2dNative = 'DepthwiseConv2dNative';
|
|
const DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter';
|
|
const DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput';
|
|
const Diag = 'Diag';
|
|
const Dilation2D = 'Dilation2D';
|
|
const Dilation2DBackpropInput = 'Dilation2DBackpropInput';
|
|
const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter';
|
|
const Draw = 'Draw';
|
|
const RealDiv = 'RealDiv';
|
|
const Einsum = 'Einsum';
|
|
const Elu$1 = 'Elu';
|
|
const EluGrad = 'EluGrad';
|
|
const Erf = 'Erf';
|
|
const Equal = 'Equal';
|
|
const Exp = 'Exp';
|
|
const ExpandDims = 'ExpandDims';
|
|
const Expm1 = 'Expm1';
|
|
const FFT = 'FFT';
|
|
const Fill = 'Fill';
|
|
const FlipLeftRight = 'FlipLeftRight';
|
|
const Floor = 'Floor';
|
|
const FloorDiv = 'FloorDiv';
|
|
const FusedBatchNorm = 'FusedBatchNorm';
|
|
const GatherV2 = 'GatherV2';
|
|
const GatherNd = 'GatherNd';
|
|
const Greater = 'Greater';
|
|
const GreaterEqual = 'GreaterEqual';
|
|
const Identity$1 = 'Identity';
|
|
const IFFT = 'IFFT';
|
|
const Imag = 'Imag';
|
|
const IsFinite = 'IsFinite';
|
|
const IsInf = 'IsInf';
|
|
const IsNan = 'IsNan';
|
|
const LeakyRelu = 'LeakyRelu';
|
|
const Less = 'Less';
|
|
const LessEqual = 'LessEqual';
|
|
const LinSpace = 'LinSpace';
|
|
const Log = 'Log';
|
|
const Log1p = 'Log1p';
|
|
const LogicalAnd = 'LogicalAnd';
|
|
const LogicalNot = 'LogicalNot';
|
|
const LogicalOr = 'LogicalOr';
|
|
const LogSoftmax$1 = 'LogSoftmax';
|
|
const LRN = 'LRN';
|
|
const LRNGrad = 'LRNGrad';
|
|
const Max = 'Max';
|
|
const Maximum = 'Maximum';
|
|
const MaxPool = 'MaxPool';
|
|
const MaxPoolGrad = 'MaxPoolGrad';
|
|
const MaxPool3D = 'MaxPool3D';
|
|
const MaxPool3DGrad = 'MaxPool3DGrad';
|
|
const MaxPoolWithArgmax = 'MaxPoolWithArgmax';
|
|
const Mean = 'Mean';
|
|
const Min = 'Min';
|
|
const Minimum = 'Minimum';
|
|
const MirrorPad = 'MirrorPad';
|
|
const Mod = 'Mod';
|
|
const Multinomial = 'Multinomial';
|
|
const Multiply = 'Multiply';
|
|
const Neg = 'Neg';
|
|
const NotEqual = 'NotEqual';
|
|
const NonMaxSuppressionV3 = 'NonMaxSuppressionV3';
|
|
const NonMaxSuppressionV4 = 'NonMaxSuppressionV4';
|
|
const NonMaxSuppressionV5 = 'NonMaxSuppressionV5';
|
|
const OnesLike = 'OnesLike';
|
|
const OneHot = 'OneHot';
|
|
const Pack = 'Pack';
|
|
const PadV2 = 'PadV2';
|
|
const Pow = 'Pow';
|
|
const Prelu = 'Prelu';
|
|
const Prod = 'Prod';
|
|
const RaggedGather = 'RaggedGather';
|
|
const RaggedRange = 'RaggedRange';
|
|
const RaggedTensorToTensor = 'RaggedTensorToTensor';
|
|
const Range = 'Range';
|
|
const Real = 'Real';
|
|
const Reciprocal = 'Reciprocal';
|
|
const Relu$1 = 'Relu';
|
|
const Reshape$1 = 'Reshape';
|
|
const ResizeNearestNeighbor = 'ResizeNearestNeighbor';
|
|
const ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad';
|
|
const ResizeBilinear = 'ResizeBilinear';
|
|
const ResizeBilinearGrad = 'ResizeBilinearGrad';
|
|
const Relu6$1 = 'Relu6';
|
|
const Reverse = 'Reverse';
|
|
const Round = 'Round';
|
|
const Rsqrt = 'Rsqrt';
|
|
const ScatterNd = 'ScatterNd';
|
|
const TensorScatterUpdate = 'TensorScatterUpdate';
|
|
const SearchSorted = 'SearchSorted';
|
|
const Select = 'Select';
|
|
const Selu$1 = 'Selu';
|
|
const Slice = 'Slice';
|
|
const Sin = 'Sin';
|
|
const Sinh = 'Sinh';
|
|
const Sign = 'Sign';
|
|
const Sigmoid$1 = 'Sigmoid';
|
|
const Softplus$1 = 'Softplus';
|
|
const Sqrt = 'Sqrt';
|
|
const Sum = 'Sum';
|
|
const SpaceToBatchND = 'SpaceToBatchND';
|
|
const SplitV = 'SplitV';
|
|
const Softmax$1 = 'Softmax';
|
|
const SparseFillEmptyRows = 'SparseFillEmptyRows';
|
|
const SparseReshape = 'SparseReshape';
|
|
const SparseSegmentMean = 'SparseSegmentMean';
|
|
const SparseSegmentSum = 'SparseSegmentSum';
|
|
const SparseToDense = 'SparseToDense';
|
|
const SquaredDifference = 'SquaredDifference';
|
|
const Square = 'Square';
|
|
const StaticRegexReplace = 'StaticRegexReplace';
|
|
const StridedSlice = 'StridedSlice';
|
|
const StringNGrams = 'StringNGrams';
|
|
const StringSplit = 'StringSplit';
|
|
const StringToHashBucketFast = 'StringToHashBucketFast';
|
|
const Sub = 'Sub';
|
|
const Tan = 'Tan';
|
|
const Tanh$1 = 'Tanh';
|
|
const Tile = 'Tile';
|
|
const TopK = 'TopK';
|
|
const Transform = 'Transform';
|
|
const Transpose = 'Transpose';
|
|
const Unique = 'Unique';
|
|
const Unpack = 'Unpack';
|
|
const UnsortedSegmentSum = 'UnsortedSegmentSum';
|
|
const ZerosLike = 'ZerosLike';
|
|
|
|
const Step = 'Step';
|
|
const FromPixels = 'FromPixels';
|
|
const RotateWithOffset = 'RotateWithOffset';
|
|
const _FusedMatMul = '_FusedMatMul';
|
|
const FusedConv2D = 'FusedConv2D';
|
|
const FusedDepthwiseConv2D = 'FusedDepthwiseConv2D';
|
|
|
|
|
|
function warn(...msg) {
|
|
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
|
|
console.warn(...msg);
|
|
}
|
|
}
|
|
function log$3(...msg) {
|
|
if (!(env().getBool('IS_TEST') || env().getBool('PROD'))) {
|
|
console.log(...msg);
|
|
}
|
|
}
|
|
|
|
|
|
const kernelRegistry = getGlobal('kernelRegistry', () => new Map());
|
|
const gradRegistry = getGlobal('gradRegistry', () => new Map());
|
|
|
|
function getKernel(kernelName, backendName) {
|
|
const key = makeKey(kernelName, backendName);
|
|
return kernelRegistry.get(key);
|
|
}
|
|
|
|
function getGradient(kernelName) {
|
|
return gradRegistry.get(kernelName);
|
|
}
|
|
function getKernelsForBackend(backendName) {
|
|
const it = kernelRegistry.entries();
|
|
const result = [];
|
|
while (true) {
|
|
const { done, value } = it.next();
|
|
if (done) {
|
|
break;
|
|
}
|
|
const [key, config] = value;
|
|
const [backend,] = key.split('_');
|
|
if (backend === backendName) {
|
|
result.push(config);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
function registerKernel(config) {
|
|
const { kernelName, backendName } = config;
|
|
const key = makeKey(kernelName, backendName);
|
|
if (kernelRegistry.has(key)) {
|
|
warn(`The kernel '${kernelName}' for backend ` +
|
|
`'${backendName}' is already registered`);
|
|
}
|
|
kernelRegistry.set(key, config);
|
|
}
|
|
|
|
function registerGradient(config) {
|
|
const { kernelName } = config;
|
|
if (gradRegistry.has(kernelName)) {
|
|
|
|
|
|
if (env().getBool('DEBUG')) {
|
|
warn(`Overriding the gradient for '${kernelName}'`);
|
|
}
|
|
}
|
|
gradRegistry.set(kernelName, config);
|
|
}
|
|
function makeKey(kernelName, backendName) {
|
|
return `${backendName}_${kernelName}`;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
setTimeout(() => env().setPlatform('browser', new PlatformStub()));
|
|
|
|
|
|
function isTypedArrayBrowser(a) {
|
|
return a instanceof Float32Array || a instanceof Int32Array ||
|
|
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
|
|
}
|
|
|
|
class PlatformStub {
|
|
constructor() {
|
|
}
|
|
|
|
fetch(path, init) {
|
|
throw new Error("fetch is not supported in this build.");
|
|
}
|
|
|
|
now() {
|
|
return performance.now();
|
|
}
|
|
|
|
encode(text, encoding) {
|
|
if (encoding !== 'utf-8' && encoding !== 'utf8') {
|
|
throw new Error(`Browser's encoder only supports utf-8, but got ${encoding}`);
|
|
}
|
|
if (this.textEncoder == null) {
|
|
this.textEncoder = new TextEncoder();
|
|
}
|
|
return this.textEncoder.encode(text);
|
|
}
|
|
|
|
decode(bytes, encoding) {
|
|
return new TextDecoder(encoding).decode(bytes);
|
|
}
|
|
|
|
setTimeoutCustom(functionRef, delay) {
|
|
if (typeof window === 'undefined' ||
|
|
!env().getBool('USE_SETTIMEOUTCUSTOM')) {
|
|
setTimeout(functionRef, delay);
|
|
return;
|
|
}
|
|
this.functionRefs.push(functionRef);
|
|
setTimeout(() => {
|
|
window.postMessage({name: this.messageName, index: this.functionRefs.length - 1}, location.origin);
|
|
}, delay);
|
|
if (!this.hasEventListener) {
|
|
this.hasEventListener = true;
|
|
window.addEventListener('message', (event) => {
|
|
if (event.source === window && event.data.name === this.messageName) {
|
|
event.stopPropagation();
|
|
const functionRef = this.functionRefs[event.data.index];
|
|
functionRef();
|
|
this.handledMessageCount++;
|
|
if (this.handledMessageCount === this.functionRefs.length) {
|
|
this.functionRefs = [];
|
|
this.handledMessageCount = 0;
|
|
}
|
|
}
|
|
}, true);
|
|
}
|
|
}
|
|
|
|
isTypedArray(a) {
|
|
return isTypedArrayBrowser(a)
|
|
}
|
|
}
|
|
|
|
var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {};
|
|
|
|
function getDefaultExportFromCjs (x) {
|
|
return x && x.__esModule && Object.prototype.hasOwnProperty.call(x, 'default') ? x['default'] : x;
|
|
}
|
|
|
|
var long = Long$1;
|
|
|
|
|
|
var wasm = null;
|
|
|
|
try {
|
|
wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([
|
|
0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11
|
|
])), {}).exports;
|
|
} catch (e) {
|
|
|
|
}
|
|
|
|
|
|
function Long$1(low, high, unsigned) {
|
|
|
|
|
|
this.low = low | 0;
|
|
|
|
|
|
this.high = high | 0;
|
|
|
|
|
|
this.unsigned = !!unsigned;
|
|
}
|
|
|
|
Object.defineProperty(Long$1.prototype, "__isLong__", { value: true });
|
|
|
|
|
|
function isLong(obj) {
|
|
return (obj && obj["__isLong__"]) === true;
|
|
}
|
|
|
|
|
|
Long$1.isLong = isLong;
|
|
|
|
|
|
var INT_CACHE = {};
|
|
|
|
|
|
var UINT_CACHE = {};
|
|
|
|
|
|
function fromInt(value, unsigned) {
|
|
var obj, cachedObj, cache;
|
|
if (unsigned) {
|
|
value >>>= 0;
|
|
if (cache = (0 <= value && value < 256)) {
|
|
cachedObj = UINT_CACHE[value];
|
|
if (cachedObj)
|
|
return cachedObj;
|
|
}
|
|
obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true);
|
|
if (cache)
|
|
UINT_CACHE[value] = obj;
|
|
return obj;
|
|
} else {
|
|
value |= 0;
|
|
if (cache = (-128 <= value && value < 128)) {
|
|
cachedObj = INT_CACHE[value];
|
|
if (cachedObj)
|
|
return cachedObj;
|
|
}
|
|
obj = fromBits(value, value < 0 ? -1 : 0, false);
|
|
if (cache)
|
|
INT_CACHE[value] = obj;
|
|
return obj;
|
|
}
|
|
}
|
|
|
|
|
|
Long$1.fromInt = fromInt;
|
|
|
|
|
|
function fromNumber(value, unsigned) {
|
|
if (isNaN(value))
|
|
return unsigned ? UZERO : ZERO;
|
|
if (unsigned) {
|
|
if (value < 0)
|
|
return UZERO;
|
|
if (value >= TWO_PWR_64_DBL)
|
|
return MAX_UNSIGNED_VALUE;
|
|
} else {
|
|
if (value <= -TWO_PWR_63_DBL)
|
|
return MIN_VALUE;
|
|
if (value + 1 >= TWO_PWR_63_DBL)
|
|
return MAX_VALUE;
|
|
}
|
|
if (value < 0)
|
|
return fromNumber(-value, unsigned).neg();
|
|
return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned);
|
|
}
|
|
|
|
|
|
Long$1.fromNumber = fromNumber;
|
|
|
|
|
|
function fromBits(lowBits, highBits, unsigned) {
|
|
return new Long$1(lowBits, highBits, unsigned);
|
|
}
|
|
|
|
|
|
Long$1.fromBits = fromBits;
|
|
|
|
|
|
var pow_dbl = Math.pow;
|
|
|
|
|
|
function fromString(str, unsigned, radix) {
|
|
if (str.length === 0)
|
|
throw Error('empty string');
|
|
if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity")
|
|
return ZERO;
|
|
if (typeof unsigned === 'number') {
|
|
|
|
radix = unsigned,
|
|
unsigned = false;
|
|
} else {
|
|
unsigned = !! unsigned;
|
|
}
|
|
radix = radix || 10;
|
|
if (radix < 2 || 36 < radix)
|
|
throw RangeError('radix');
|
|
|
|
var p;
|
|
if ((p = str.indexOf('-')) > 0)
|
|
throw Error('interior hyphen');
|
|
else if (p === 0) {
|
|
return fromString(str.substring(1), unsigned, radix).neg();
|
|
}
|
|
|
|
|
|
|
|
var radixToPower = fromNumber(pow_dbl(radix, 8));
|
|
|
|
var result = ZERO;
|
|
for (var i = 0; i < str.length; i += 8) {
|
|
var size = Math.min(8, str.length - i),
|
|
value = parseInt(str.substring(i, i + size), radix);
|
|
if (size < 8) {
|
|
var power = fromNumber(pow_dbl(radix, size));
|
|
result = result.mul(power).add(fromNumber(value));
|
|
} else {
|
|
result = result.mul(radixToPower);
|
|
result = result.add(fromNumber(value));
|
|
}
|
|
}
|
|
result.unsigned = unsigned;
|
|
return result;
|
|
}
|
|
|
|
|
|
Long$1.fromString = fromString;
|
|
|
|
|
|
function fromValue(val, unsigned) {
|
|
if (typeof val === 'number')
|
|
return fromNumber(val, unsigned);
|
|
if (typeof val === 'string')
|
|
return fromString(val, unsigned);
|
|
|
|
return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned);
|
|
}
|
|
|
|
|
|
Long$1.fromValue = fromValue;
|
|
|
|
|
|
|
|
|
|
|
|
var TWO_PWR_16_DBL = 1 << 16;
|
|
|
|
|
|
var TWO_PWR_24_DBL = 1 << 24;
|
|
|
|
|
|
var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL;
|
|
|
|
|
|
var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL;
|
|
|
|
|
|
var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2;
|
|
|
|
|
|
var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL);
|
|
|
|
|
|
var ZERO = fromInt(0);
|
|
|
|
|
|
Long$1.ZERO = ZERO;
|
|
|
|
|
|
var UZERO = fromInt(0, true);
|
|
|
|
|
|
Long$1.UZERO = UZERO;
|
|
|
|
|
|
var ONE = fromInt(1);
|
|
|
|
|
|
Long$1.ONE = ONE;
|
|
|
|
|
|
var UONE = fromInt(1, true);
|
|
|
|
|
|
Long$1.UONE = UONE;
|
|
|
|
|
|
var NEG_ONE = fromInt(-1);
|
|
|
|
|
|
Long$1.NEG_ONE = NEG_ONE;
|
|
|
|
|
|
var MAX_VALUE = fromBits(0xFFFFFFFF|0, 0x7FFFFFFF|0, false);
|
|
|
|
|
|
Long$1.MAX_VALUE = MAX_VALUE;
|
|
|
|
|
|
var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF|0, 0xFFFFFFFF|0, true);
|
|
|
|
|
|
Long$1.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE;
|
|
|
|
|
|
var MIN_VALUE = fromBits(0, 0x80000000|0, false);
|
|
|
|
|
|
Long$1.MIN_VALUE = MIN_VALUE;
|
|
|
|
|
|
var LongPrototype = Long$1.prototype;
|
|
|
|
|
|
LongPrototype.toInt = function toInt() {
|
|
return this.unsigned ? this.low >>> 0 : this.low;
|
|
};
|
|
|
|
|
|
LongPrototype.toNumber = function toNumber() {
|
|
if (this.unsigned)
|
|
return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0);
|
|
return this.high * TWO_PWR_32_DBL + (this.low >>> 0);
|
|
};
|
|
|
|
|
|
LongPrototype.toString = function toString(radix) {
|
|
radix = radix || 10;
|
|
if (radix < 2 || 36 < radix)
|
|
throw RangeError('radix');
|
|
if (this.isZero())
|
|
return '0';
|
|
if (this.isNegative()) {
|
|
if (this.eq(MIN_VALUE)) {
|
|
|
|
|
|
var radixLong = fromNumber(radix),
|
|
div = this.div(radixLong),
|
|
rem1 = div.mul(radixLong).sub(this);
|
|
return div.toString(radix) + rem1.toInt().toString(radix);
|
|
} else
|
|
return '-' + this.neg().toString(radix);
|
|
}
|
|
|
|
|
|
|
|
var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned),
|
|
rem = this;
|
|
var result = '';
|
|
while (true) {
|
|
var remDiv = rem.div(radixToPower),
|
|
intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0,
|
|
digits = intval.toString(radix);
|
|
rem = remDiv;
|
|
if (rem.isZero())
|
|
return digits + result;
|
|
else {
|
|
while (digits.length < 6)
|
|
digits = '0' + digits;
|
|
result = '' + digits + result;
|
|
}
|
|
}
|
|
};
|
|
|
|
|
|
LongPrototype.getHighBits = function getHighBits() {
|
|
return this.high;
|
|
};
|
|
|
|
|
|
LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() {
|
|
return this.high >>> 0;
|
|
};
|
|
|
|
|
|
LongPrototype.getLowBits = function getLowBits() {
|
|
return this.low;
|
|
};
|
|
|
|
|
|
LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() {
|
|
return this.low >>> 0;
|
|
};
|
|
|
|
|
|
LongPrototype.getNumBitsAbs = function getNumBitsAbs() {
|
|
if (this.isNegative())
|
|
return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs();
|
|
var val = this.high != 0 ? this.high : this.low;
|
|
for (var bit = 31; bit > 0; bit--)
|
|
if ((val & (1 << bit)) != 0)
|
|
break;
|
|
return this.high != 0 ? bit + 33 : bit + 1;
|
|
};
|
|
|
|
|
|
LongPrototype.isZero = function isZero() {
|
|
return this.high === 0 && this.low === 0;
|
|
};
|
|
|
|
|
|
LongPrototype.eqz = LongPrototype.isZero;
|
|
|
|
|
|
LongPrototype.isNegative = function isNegative() {
|
|
return !this.unsigned && this.high < 0;
|
|
};
|
|
|
|
|
|
LongPrototype.isPositive = function isPositive() {
|
|
return this.unsigned || this.high >= 0;
|
|
};
|
|
|
|
|
|
LongPrototype.isOdd = function isOdd() {
|
|
return (this.low & 1) === 1;
|
|
};
|
|
|
|
|
|
LongPrototype.isEven = function isEven() {
|
|
return (this.low & 1) === 0;
|
|
};
|
|
|
|
|
|
LongPrototype.equals = function equals(other) {
|
|
if (!isLong(other))
|
|
other = fromValue(other);
|
|
if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1)
|
|
return false;
|
|
return this.high === other.high && this.low === other.low;
|
|
};
|
|
|
|
|
|
LongPrototype.eq = LongPrototype.equals;
|
|
|
|
|
|
LongPrototype.notEquals = function notEquals(other) {
|
|
return !this.eq( other);
|
|
};
|
|
|
|
|
|
LongPrototype.neq = LongPrototype.notEquals;
|
|
|
|
|
|
LongPrototype.ne = LongPrototype.notEquals;
|
|
|
|
|
|
LongPrototype.lessThan = function lessThan(other) {
|
|
return this.comp( other) < 0;
|
|
};
|
|
|
|
|
|
LongPrototype.lt = LongPrototype.lessThan;
|
|
|
|
|
|
LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) {
|
|
return this.comp( other) <= 0;
|
|
};
|
|
|
|
|
|
LongPrototype.lte = LongPrototype.lessThanOrEqual;
|
|
|
|
|
|
LongPrototype.le = LongPrototype.lessThanOrEqual;
|
|
|
|
|
|
LongPrototype.greaterThan = function greaterThan(other) {
|
|
return this.comp( other) > 0;
|
|
};
|
|
|
|
|
|
LongPrototype.gt = LongPrototype.greaterThan;
|
|
|
|
|
|
LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) {
|
|
return this.comp( other) >= 0;
|
|
};
|
|
|
|
|
|
LongPrototype.gte = LongPrototype.greaterThanOrEqual;
|
|
|
|
|
|
LongPrototype.ge = LongPrototype.greaterThanOrEqual;
|
|
|
|
|
|
LongPrototype.compare = function compare(other) {
|
|
if (!isLong(other))
|
|
other = fromValue(other);
|
|
if (this.eq(other))
|
|
return 0;
|
|
var thisNeg = this.isNegative(),
|
|
otherNeg = other.isNegative();
|
|
if (thisNeg && !otherNeg)
|
|
return -1;
|
|
if (!thisNeg && otherNeg)
|
|
return 1;
|
|
|
|
if (!this.unsigned)
|
|
return this.sub(other).isNegative() ? -1 : 1;
|
|
|
|
return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1;
|
|
};
|
|
|
|
|
|
LongPrototype.comp = LongPrototype.compare;
|
|
|
|
|
|
LongPrototype.negate = function negate() {
|
|
if (!this.unsigned && this.eq(MIN_VALUE))
|
|
return MIN_VALUE;
|
|
return this.not().add(ONE);
|
|
};
|
|
|
|
|
|
LongPrototype.neg = LongPrototype.negate;
|
|
|
|
|
|
LongPrototype.add = function add(addend) {
|
|
if (!isLong(addend))
|
|
addend = fromValue(addend);
|
|
|
|
|
|
|
|
var a48 = this.high >>> 16;
|
|
var a32 = this.high & 0xFFFF;
|
|
var a16 = this.low >>> 16;
|
|
var a00 = this.low & 0xFFFF;
|
|
|
|
var b48 = addend.high >>> 16;
|
|
var b32 = addend.high & 0xFFFF;
|
|
var b16 = addend.low >>> 16;
|
|
var b00 = addend.low & 0xFFFF;
|
|
|
|
var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
|
|
c00 += a00 + b00;
|
|
c16 += c00 >>> 16;
|
|
c00 &= 0xFFFF;
|
|
c16 += a16 + b16;
|
|
c32 += c16 >>> 16;
|
|
c16 &= 0xFFFF;
|
|
c32 += a32 + b32;
|
|
c48 += c32 >>> 16;
|
|
c32 &= 0xFFFF;
|
|
c48 += a48 + b48;
|
|
c48 &= 0xFFFF;
|
|
return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.subtract = function subtract(subtrahend) {
|
|
if (!isLong(subtrahend))
|
|
subtrahend = fromValue(subtrahend);
|
|
return this.add(subtrahend.neg());
|
|
};
|
|
|
|
|
|
LongPrototype.sub = LongPrototype.subtract;
|
|
|
|
|
|
LongPrototype.multiply = function multiply(multiplier) {
|
|
if (this.isZero())
|
|
return ZERO;
|
|
if (!isLong(multiplier))
|
|
multiplier = fromValue(multiplier);
|
|
|
|
|
|
if (wasm) {
|
|
var low = wasm.mul(this.low,
|
|
this.high,
|
|
multiplier.low,
|
|
multiplier.high);
|
|
return fromBits(low, wasm.get_high(), this.unsigned);
|
|
}
|
|
|
|
if (multiplier.isZero())
|
|
return ZERO;
|
|
if (this.eq(MIN_VALUE))
|
|
return multiplier.isOdd() ? MIN_VALUE : ZERO;
|
|
if (multiplier.eq(MIN_VALUE))
|
|
return this.isOdd() ? MIN_VALUE : ZERO;
|
|
|
|
if (this.isNegative()) {
|
|
if (multiplier.isNegative())
|
|
return this.neg().mul(multiplier.neg());
|
|
else
|
|
return this.neg().mul(multiplier).neg();
|
|
} else if (multiplier.isNegative())
|
|
return this.mul(multiplier.neg()).neg();
|
|
|
|
|
|
if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24))
|
|
return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned);
|
|
|
|
|
|
|
|
|
|
var a48 = this.high >>> 16;
|
|
var a32 = this.high & 0xFFFF;
|
|
var a16 = this.low >>> 16;
|
|
var a00 = this.low & 0xFFFF;
|
|
|
|
var b48 = multiplier.high >>> 16;
|
|
var b32 = multiplier.high & 0xFFFF;
|
|
var b16 = multiplier.low >>> 16;
|
|
var b00 = multiplier.low & 0xFFFF;
|
|
|
|
var c48 = 0, c32 = 0, c16 = 0, c00 = 0;
|
|
c00 += a00 * b00;
|
|
c16 += c00 >>> 16;
|
|
c00 &= 0xFFFF;
|
|
c16 += a16 * b00;
|
|
c32 += c16 >>> 16;
|
|
c16 &= 0xFFFF;
|
|
c16 += a00 * b16;
|
|
c32 += c16 >>> 16;
|
|
c16 &= 0xFFFF;
|
|
c32 += a32 * b00;
|
|
c48 += c32 >>> 16;
|
|
c32 &= 0xFFFF;
|
|
c32 += a16 * b16;
|
|
c48 += c32 >>> 16;
|
|
c32 &= 0xFFFF;
|
|
c32 += a00 * b32;
|
|
c48 += c32 >>> 16;
|
|
c32 &= 0xFFFF;
|
|
c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48;
|
|
c48 &= 0xFFFF;
|
|
return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.mul = LongPrototype.multiply;
|
|
|
|
|
|
LongPrototype.divide = function divide(divisor) {
|
|
if (!isLong(divisor))
|
|
divisor = fromValue(divisor);
|
|
if (divisor.isZero())
|
|
throw Error('division by zero');
|
|
|
|
|
|
if (wasm) {
|
|
|
|
|
|
|
|
if (!this.unsigned &&
|
|
this.high === -2147483648 &&
|
|
divisor.low === -1 && divisor.high === -1) {
|
|
|
|
return this;
|
|
}
|
|
var low = (this.unsigned ? wasm.div_u : wasm.div_s)(
|
|
this.low,
|
|
this.high,
|
|
divisor.low,
|
|
divisor.high
|
|
);
|
|
return fromBits(low, wasm.get_high(), this.unsigned);
|
|
}
|
|
|
|
if (this.isZero())
|
|
return this.unsigned ? UZERO : ZERO;
|
|
var approx, rem, res;
|
|
if (!this.unsigned) {
|
|
|
|
|
|
if (this.eq(MIN_VALUE)) {
|
|
if (divisor.eq(ONE) || divisor.eq(NEG_ONE))
|
|
return MIN_VALUE;
|
|
else if (divisor.eq(MIN_VALUE))
|
|
return ONE;
|
|
else {
|
|
|
|
var halfThis = this.shr(1);
|
|
approx = halfThis.div(divisor).shl(1);
|
|
if (approx.eq(ZERO)) {
|
|
return divisor.isNegative() ? ONE : NEG_ONE;
|
|
} else {
|
|
rem = this.sub(divisor.mul(approx));
|
|
res = approx.add(rem.div(divisor));
|
|
return res;
|
|
}
|
|
}
|
|
} else if (divisor.eq(MIN_VALUE))
|
|
return this.unsigned ? UZERO : ZERO;
|
|
if (this.isNegative()) {
|
|
if (divisor.isNegative())
|
|
return this.neg().div(divisor.neg());
|
|
return this.neg().div(divisor).neg();
|
|
} else if (divisor.isNegative())
|
|
return this.div(divisor.neg()).neg();
|
|
res = ZERO;
|
|
} else {
|
|
|
|
|
|
if (!divisor.unsigned)
|
|
divisor = divisor.toUnsigned();
|
|
if (divisor.gt(this))
|
|
return UZERO;
|
|
if (divisor.gt(this.shru(1)))
|
|
return UONE;
|
|
res = UZERO;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rem = this;
|
|
while (rem.gte(divisor)) {
|
|
|
|
|
|
approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber()));
|
|
|
|
|
|
|
|
var log2 = Math.ceil(Math.log(approx) / Math.LN2),
|
|
delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48),
|
|
|
|
|
|
|
|
approxRes = fromNumber(approx),
|
|
approxRem = approxRes.mul(divisor);
|
|
while (approxRem.isNegative() || approxRem.gt(rem)) {
|
|
approx -= delta;
|
|
approxRes = fromNumber(approx, this.unsigned);
|
|
approxRem = approxRes.mul(divisor);
|
|
}
|
|
|
|
|
|
|
|
if (approxRes.isZero())
|
|
approxRes = ONE;
|
|
|
|
res = res.add(approxRes);
|
|
rem = rem.sub(approxRem);
|
|
}
|
|
return res;
|
|
};
|
|
|
|
|
|
LongPrototype.div = LongPrototype.divide;
|
|
|
|
|
|
LongPrototype.modulo = function modulo(divisor) {
|
|
if (!isLong(divisor))
|
|
divisor = fromValue(divisor);
|
|
|
|
|
|
if (wasm) {
|
|
var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)(
|
|
this.low,
|
|
this.high,
|
|
divisor.low,
|
|
divisor.high
|
|
);
|
|
return fromBits(low, wasm.get_high(), this.unsigned);
|
|
}
|
|
|
|
return this.sub(this.div(divisor).mul(divisor));
|
|
};
|
|
|
|
|
|
LongPrototype.mod = LongPrototype.modulo;
|
|
|
|
|
|
LongPrototype.rem = LongPrototype.modulo;
|
|
|
|
|
|
LongPrototype.not = function not() {
|
|
return fromBits(~this.low, ~this.high, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.and = function and(other) {
|
|
if (!isLong(other))
|
|
other = fromValue(other);
|
|
return fromBits(this.low & other.low, this.high & other.high, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.or = function or(other) {
|
|
if (!isLong(other))
|
|
other = fromValue(other);
|
|
return fromBits(this.low | other.low, this.high | other.high, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.xor = function xor(other) {
|
|
if (!isLong(other))
|
|
other = fromValue(other);
|
|
return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.shiftLeft = function shiftLeft(numBits) {
|
|
if (isLong(numBits))
|
|
numBits = numBits.toInt();
|
|
if ((numBits &= 63) === 0)
|
|
return this;
|
|
else if (numBits < 32)
|
|
return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned);
|
|
else
|
|
return fromBits(0, this.low << (numBits - 32), this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.shl = LongPrototype.shiftLeft;
|
|
|
|
|
|
LongPrototype.shiftRight = function shiftRight(numBits) {
|
|
if (isLong(numBits))
|
|
numBits = numBits.toInt();
|
|
if ((numBits &= 63) === 0)
|
|
return this;
|
|
else if (numBits < 32)
|
|
return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned);
|
|
else
|
|
return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned);
|
|
};
|
|
|
|
|
|
LongPrototype.shr = LongPrototype.shiftRight;
|
|
|
|
|
|
LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) {
|
|
if (isLong(numBits))
|
|
numBits = numBits.toInt();
|
|
numBits &= 63;
|
|
if (numBits === 0)
|
|
return this;
|
|
else {
|
|
var high = this.high;
|
|
if (numBits < 32) {
|
|
var low = this.low;
|
|
return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned);
|
|
} else if (numBits === 32)
|
|
return fromBits(high, 0, this.unsigned);
|
|
else
|
|
return fromBits(high >>> (numBits - 32), 0, this.unsigned);
|
|
}
|
|
};
|
|
|
|
|
|
LongPrototype.shru = LongPrototype.shiftRightUnsigned;
|
|
|
|
|
|
LongPrototype.shr_u = LongPrototype.shiftRightUnsigned;
|
|
|
|
|
|
LongPrototype.toSigned = function toSigned() {
|
|
if (!this.unsigned)
|
|
return this;
|
|
return fromBits(this.low, this.high, false);
|
|
};
|
|
|
|
|
|
LongPrototype.toUnsigned = function toUnsigned() {
|
|
if (this.unsigned)
|
|
return this;
|
|
return fromBits(this.low, this.high, true);
|
|
};
|
|
|
|
|
|
LongPrototype.toBytes = function toBytes(le) {
|
|
return le ? this.toBytesLE() : this.toBytesBE();
|
|
};
|
|
|
|
|
|
LongPrototype.toBytesLE = function toBytesLE() {
|
|
var hi = this.high,
|
|
lo = this.low;
|
|
return [
|
|
lo & 0xff,
|
|
lo >>> 8 & 0xff,
|
|
lo >>> 16 & 0xff,
|
|
lo >>> 24 ,
|
|
hi & 0xff,
|
|
hi >>> 8 & 0xff,
|
|
hi >>> 16 & 0xff,
|
|
hi >>> 24
|
|
];
|
|
};
|
|
|
|
|
|
LongPrototype.toBytesBE = function toBytesBE() {
|
|
var hi = this.high,
|
|
lo = this.low;
|
|
return [
|
|
hi >>> 24 ,
|
|
hi >>> 16 & 0xff,
|
|
hi >>> 8 & 0xff,
|
|
hi & 0xff,
|
|
lo >>> 24 ,
|
|
lo >>> 16 & 0xff,
|
|
lo >>> 8 & 0xff,
|
|
lo & 0xff
|
|
];
|
|
};
|
|
|
|
|
|
Long$1.fromBytes = function fromBytes(bytes, unsigned, le) {
|
|
return le ? Long$1.fromBytesLE(bytes, unsigned) : Long$1.fromBytesBE(bytes, unsigned);
|
|
};
|
|
|
|
|
|
Long$1.fromBytesLE = function fromBytesLE(bytes, unsigned) {
|
|
return new Long$1(
|
|
bytes[0] |
|
|
bytes[1] << 8 |
|
|
bytes[2] << 16 |
|
|
bytes[3] << 24,
|
|
bytes[4] |
|
|
bytes[5] << 8 |
|
|
bytes[6] << 16 |
|
|
bytes[7] << 24,
|
|
unsigned
|
|
);
|
|
};
|
|
|
|
|
|
Long$1.fromBytesBE = function fromBytesBE(bytes, unsigned) {
|
|
return new Long$1(
|
|
bytes[4] << 24 |
|
|
bytes[5] << 16 |
|
|
bytes[6] << 8 |
|
|
bytes[7],
|
|
bytes[0] << 24 |
|
|
bytes[1] << 16 |
|
|
bytes[2] << 8 |
|
|
bytes[3],
|
|
unsigned
|
|
);
|
|
};
|
|
|
|
var long$1 = getDefaultExportFromCjs(long);
|
|
|
|
var LongExports = _mergeNamespaces({
|
|
__proto__: null,
|
|
default: long$1
|
|
}, [long]);
|
|
|
|
|
|
|
|
|
|
|
|
const Long =
|
|
|
|
long$1 || LongExports;
|
|
function hexToLong(hex) {
|
|
return Long.fromString(hex, true, 16);
|
|
}
|
|
|
|
|
|
const k0 = hexToLong('c3a5c85c97cb3127');
|
|
|
|
const k1 = hexToLong('b492b66fbe98f273');
|
|
|
|
const k2 = hexToLong('9ae16a3b2f90404f');
|
|
function shiftMix(val) {
|
|
return val.xor(val.shru(47));
|
|
}
|
|
function fetch(s, offset, numBytes) {
|
|
const bytes = s.slice(offset, offset + numBytes);
|
|
return Long.fromBytes(Array.from(bytes), true, true);
|
|
}
|
|
function fetch64(s, offset) {
|
|
return fetch(s, offset, 8);
|
|
}
|
|
function fetch32(s, offset) {
|
|
return fetch(s, offset, 4);
|
|
}
|
|
function rotate64(val, shift) {
|
|
|
|
return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift));
|
|
}
|
|
function hashLen16(u, v, mul = hexToLong('9ddfea08eb382d69')) {
|
|
|
|
let a = u.xor(v).mul(mul);
|
|
a = a.xor(a.shru(47));
|
|
let b = v.xor(a).mul(mul);
|
|
b = b.xor(b.shru(47));
|
|
b = b.mul(mul);
|
|
return b;
|
|
}
|
|
|
|
|
|
function weakHashLen32WithSeeds(w, x, y, z, a, b) {
|
|
a = a.add(w);
|
|
b = rotate64(b.add(a).add(z), 21);
|
|
const c = a;
|
|
a = a.add(x);
|
|
a = a.add(y);
|
|
b = b.add(rotate64(a, 44));
|
|
return [a.add(z), b.add(c)];
|
|
}
|
|
function weakHashLen32WithSeedsStr(s, offset, a, b) {
|
|
return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b);
|
|
}
|
|
function hashLen0to16(s, len = s.length) {
|
|
if (len >= 8) {
|
|
const mul = k2.add(len * 2);
|
|
const a = fetch64(s, 0).add(k2);
|
|
const b = fetch64(s, len - 8);
|
|
const c = rotate64(b, 37).mul(mul).add(a);
|
|
const d = rotate64(a, 25).add(b).mul(mul);
|
|
return hashLen16(c, d, mul);
|
|
}
|
|
if (len >= 4) {
|
|
const mul = k2.add(len * 2);
|
|
const a = fetch32(s, 0);
|
|
return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul);
|
|
}
|
|
if (len > 0) {
|
|
const a = s[0];
|
|
const b = s[len >> 1];
|
|
const c = s[len - 1];
|
|
const y = a + (b << 8);
|
|
const z = len + (c << 2);
|
|
return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2);
|
|
}
|
|
return k2;
|
|
}
|
|
function hashLen17to32(s, len = s.length) {
|
|
const mul = k2.add(len * 2);
|
|
const a = fetch64(s, 0).mul(k1);
|
|
const b = fetch64(s, 8);
|
|
const c = fetch64(s, len - 8).mul(mul);
|
|
const d = fetch64(s, len - 16).mul(k2);
|
|
return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul);
|
|
}
|
|
function hashLen33to64(s, len = s.length) {
|
|
const mul = k2.add(len * 2);
|
|
const a = fetch64(s, 0).mul(k2);
|
|
const b = fetch64(s, 8);
|
|
const c = fetch64(s, len - 8).mul(mul);
|
|
const d = fetch64(s, len - 16).mul(k2);
|
|
const y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d);
|
|
const z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul);
|
|
const e = fetch64(s, 16).mul(mul);
|
|
const f = fetch64(s, 24);
|
|
const g = y.add(fetch64(s, len - 32)).mul(mul);
|
|
const h = z.add(fetch64(s, len - 24)).mul(mul);
|
|
return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul);
|
|
}
|
|
function fingerPrint64(s, len = s.length) {
|
|
const seed = Long.fromNumber(81, true);
|
|
if (len <= 32) {
|
|
if (len <= 16) {
|
|
return hashLen0to16(s, len);
|
|
}
|
|
else {
|
|
return hashLen17to32(s, len);
|
|
}
|
|
}
|
|
else if (len <= 64) {
|
|
return hashLen33to64(s, len);
|
|
}
|
|
|
|
|
|
let x = seed;
|
|
let y = seed.mul(k1).add(113);
|
|
let z = shiftMix(y.mul(k2).add(113)).mul(k2);
|
|
let v = [Long.UZERO, Long.UZERO];
|
|
let w = [Long.UZERO, Long.UZERO];
|
|
x = x.mul(k2).add(fetch64(s, 0));
|
|
let offset = 0;
|
|
|
|
const end = ((len - 1) >> 6) * 64;
|
|
const last64 = end + ((len - 1) & 63) - 63;
|
|
do {
|
|
x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1);
|
|
y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1);
|
|
x = x.xor(w[1]);
|
|
y = y.add(v[0]).add(fetch64(s, offset + 40));
|
|
z = rotate64(z.add(w[0]), 33).mul(k1);
|
|
v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0]));
|
|
w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
|
|
[z, x] = [x, z];
|
|
offset += 64;
|
|
} while (offset !== end);
|
|
const mul = k1.add(z.and(0xff).shl(1));
|
|
|
|
offset = last64;
|
|
w[0] = w[0].add((len - 1) & 63);
|
|
v[0] = v[0].add(w[0]);
|
|
w[0] = w[0].add(v[0]);
|
|
x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul);
|
|
y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul);
|
|
x = x.xor(w[1].mul(9));
|
|
y = y.add(v[0].mul(9).add(fetch64(s, offset + 40)));
|
|
z = rotate64(z.add(w[0]), 33).mul(mul);
|
|
v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0]));
|
|
w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16)));
|
|
[z, x] = [x, z];
|
|
return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul);
|
|
}
|
|
|
|
|
|
|
|
function createScalarValue(value, dtype) {
|
|
if (dtype === 'string') {
|
|
return encodeString(value);
|
|
}
|
|
return toTypedArray([value], dtype);
|
|
}
|
|
function noConversionNeeded(a, dtype) {
|
|
return (a instanceof Float32Array && dtype === 'float32') ||
|
|
(a instanceof Int32Array && dtype === 'int32') ||
|
|
(a instanceof Uint8Array && dtype === 'bool');
|
|
}
|
|
function toTypedArray(a, dtype) {
|
|
if (dtype === 'string') {
|
|
throw new Error('Cannot convert a string[] to a TypedArray');
|
|
}
|
|
if (Array.isArray(a)) {
|
|
a = flatten$1(a);
|
|
}
|
|
if (env().getBool('DEBUG')) {
|
|
checkConversionForErrors(a, dtype);
|
|
}
|
|
if (noConversionNeeded(a, dtype)) {
|
|
return a;
|
|
}
|
|
if (dtype == null || dtype === 'float32' || dtype === 'complex64') {
|
|
return new Float32Array(a);
|
|
}
|
|
else if (dtype === 'int32') {
|
|
return new Int32Array(a);
|
|
}
|
|
else if (dtype === 'bool') {
|
|
const bool = new Uint8Array(a.length);
|
|
for (let i = 0; i < bool.length; ++i) {
|
|
if (Math.round(a[i]) !== 0) {
|
|
bool[i] = 1;
|
|
}
|
|
}
|
|
return bool;
|
|
}
|
|
else {
|
|
throw new Error(`Unknown data type ${dtype}`);
|
|
}
|
|
}
|
|
|
|
function now() {
|
|
return env().platform.now();
|
|
}
|
|
|
|
function encodeString(s, encoding = 'utf-8') {
|
|
encoding = encoding || 'utf-8';
|
|
return env().platform.encode(s, encoding);
|
|
}
|
|
|
|
function decodeString(bytes, encoding = 'utf-8') {
|
|
encoding = encoding || 'utf-8';
|
|
return env().platform.decode(bytes, encoding);
|
|
}
|
|
function isTypedArray(a) {
|
|
|
|
if (env().platform.isTypedArray != null) {
|
|
return env().platform.isTypedArray(a);
|
|
}
|
|
else {
|
|
return isTypedArrayBrowser(a);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
function flatten$1(arr, result = [], skipTypedArray = false) {
|
|
if (result == null) {
|
|
result = [];
|
|
}
|
|
if (typeof arr === 'boolean' || typeof arr === 'number' ||
|
|
typeof arr === 'string' || isPromise(arr) || arr == null ||
|
|
isTypedArray(arr) && skipTypedArray) {
|
|
result.push(arr);
|
|
}
|
|
else if (Array.isArray(arr) || isTypedArray(arr)) {
|
|
for (let i = 0; i < arr.length; ++i) {
|
|
flatten$1(arr[i], result, skipTypedArray);
|
|
}
|
|
}
|
|
else {
|
|
let maxIndex = -1;
|
|
for (const key of Object.keys(arr)) {
|
|
|
|
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
|
|
maxIndex = Math.max(maxIndex, Number(key));
|
|
}
|
|
}
|
|
for (let i = 0; i <= maxIndex; i++) {
|
|
|
|
flatten$1(arr[i], result, skipTypedArray);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
class Profiler {
|
|
constructor(backendTimer, logger) {
|
|
this.backendTimer = backendTimer;
|
|
this.logger = logger;
|
|
if (logger == null) {
|
|
this.logger = new Logger();
|
|
}
|
|
}
|
|
profileKernel(kernelName, inputs, f) {
|
|
let outputs;
|
|
const holdResultWrapperFn = () => {
|
|
outputs = f();
|
|
};
|
|
let timer;
|
|
const start = now();
|
|
if (this.backendTimer.timerAvailable()) {
|
|
timer = this.backendTimer.time(holdResultWrapperFn);
|
|
}
|
|
else {
|
|
holdResultWrapperFn();
|
|
for (const output of outputs) {
|
|
output.dataSync();
|
|
}
|
|
timer = Promise.resolve({ kernelMs: now() - start });
|
|
}
|
|
if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) {
|
|
for (let i = 0; i < outputs.length; i++) {
|
|
const output = outputs[i];
|
|
|
|
|
|
output.data().then(tensorVals => {
|
|
checkComputationForErrors(tensorVals, output.dtype, kernelName);
|
|
});
|
|
}
|
|
}
|
|
const kernelProfile = {
|
|
kernelName,
|
|
outputs,
|
|
inputs,
|
|
timeMs: timer.then(timing => timing.kernelMs),
|
|
extraInfo: timer.then(timing => timing.getExtraProfileInfo != null ?
|
|
timing.getExtraProfileInfo() :
|
|
'')
|
|
};
|
|
return kernelProfile;
|
|
}
|
|
logKernelProfile(kernelProfile) {
|
|
const { kernelName, outputs, timeMs, inputs, extraInfo } = kernelProfile;
|
|
outputs.forEach(result => {
|
|
Promise.all([result.data(), timeMs, extraInfo]).then(valueContainer => {
|
|
this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]);
|
|
});
|
|
});
|
|
}
|
|
}
|
|
function checkComputationForErrors(vals, dtype, kernelName) {
|
|
if (dtype !== 'float32') {
|
|
|
|
return false;
|
|
}
|
|
for (let i = 0; i < vals.length; i++) {
|
|
const num = vals[i];
|
|
if (isNaN(num) || !isFinite(num)) {
|
|
|
|
console.warn(`Found ${num} in the result of '${kernelName}'`);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
class Logger {
|
|
logKernelProfile(name, result, vals, timeMs, inputs, extraInfo) {
|
|
const time = typeof timeMs === 'number' ? rightPad(`${timeMs}ms`, 9) :
|
|
timeMs['error'];
|
|
const paddedName = rightPad(name, 25);
|
|
const rank = result.rank;
|
|
const size = result.size;
|
|
const shape = rightPad(result.shape.toString(), 14);
|
|
let inputShapesDescription = '';
|
|
for (const name in inputs) {
|
|
const input = inputs[name];
|
|
if (input != null) {
|
|
|
|
|
|
const inputShape = input.shape || result.shape;
|
|
const inputRank = inputShape.length;
|
|
inputShapesDescription +=
|
|
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
|
|
}
|
|
}
|
|
console.log(`%c${paddedName}\t%c${time}\t%c${rank}D ${shape}\t%c${size}\t%c${inputShapesDescription}\t%c${extraInfo}`, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue');
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function getFilteredNodesXToY(tape, xs, y) {
|
|
|
|
|
|
const tensorsFromX = {};
|
|
const nodesFromX = {};
|
|
for (let i = 0; i < xs.length; i++) {
|
|
tensorsFromX[xs[i].id] = true;
|
|
}
|
|
for (let i = 0; i < tape.length; i++) {
|
|
const node = tape[i];
|
|
const nodeInputs = node.inputs;
|
|
for (const inputName in nodeInputs) {
|
|
const input = nodeInputs[inputName];
|
|
let anyInputFromX = false;
|
|
for (let j = 0; j < xs.length; j++) {
|
|
if (tensorsFromX[input.id]) {
|
|
node.outputs.forEach(output => tensorsFromX[output.id] = true);
|
|
anyInputFromX = true;
|
|
nodesFromX[node.id] = true;
|
|
break;
|
|
}
|
|
}
|
|
if (anyInputFromX) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
const tensorsLeadToY = {};
|
|
tensorsLeadToY[y.id] = true;
|
|
const nodesToY = {};
|
|
for (let i = tape.length - 1; i >= 0; i--) {
|
|
const node = tape[i];
|
|
const nodeInputs = node.inputs;
|
|
|
|
for (let j = 0; j < node.outputs.length; j++) {
|
|
if (tensorsLeadToY[node.outputs[j].id]) {
|
|
for (const inputName in nodeInputs) {
|
|
tensorsLeadToY[nodeInputs[inputName].id] = true;
|
|
nodesToY[node.id] = true;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
const filteredTape = [];
|
|
for (let i = 0; i < tape.length; i++) {
|
|
const node = tape[i];
|
|
if (nodesFromX[node.id] && nodesToY[node.id]) {
|
|
|
|
const prunedInputs = {};
|
|
for (const inputName in node.inputs) {
|
|
const nodeInput = node.inputs[inputName];
|
|
if (tensorsFromX[nodeInput.id]) {
|
|
prunedInputs[inputName] = nodeInput;
|
|
}
|
|
}
|
|
|
|
const prunedNode = Object.assign({}, node);
|
|
prunedNode.inputs = prunedInputs;
|
|
prunedNode.outputs = node.outputs;
|
|
filteredTape.push(prunedNode);
|
|
}
|
|
}
|
|
return filteredTape;
|
|
}
|
|
|
|
function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) {
|
|
|
|
for (let i = filteredTape.length - 1; i >= 0; i--) {
|
|
const node = filteredTape[i];
|
|
const dys = [];
|
|
node.outputs.forEach(o => {
|
|
const gradTensor = tensorAccumulatedGradientMap[o.id];
|
|
if (gradTensor != null) {
|
|
dys.push(gradTensor);
|
|
}
|
|
else {
|
|
|
|
|
|
dys.push(null);
|
|
}
|
|
});
|
|
if (node.gradient == null) {
|
|
throw new Error(`Cannot compute gradient: gradient function not found ` +
|
|
`for ${node.kernelName}.`);
|
|
}
|
|
|
|
const inputGradients = node.gradient(dys);
|
|
for (const inputName in node.inputs) {
|
|
if (!(inputName in inputGradients)) {
|
|
throw new Error(`Cannot backprop through input ${inputName}. ` +
|
|
`Available gradients found: ${Object.keys(inputGradients)}.`);
|
|
}
|
|
|
|
const dx = tidy(() => inputGradients[inputName]());
|
|
if (dx.dtype !== 'float32') {
|
|
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
|
|
`${inputName} must have 'float32' dtype, but has '${dx.dtype}'`);
|
|
}
|
|
const x = node.inputs[inputName];
|
|
if (!arraysEqual(dx.shape, x.shape)) {
|
|
throw new Error(`Error in gradient for op ${node.kernelName}. The gradient of input ` +
|
|
`'${inputName}' has shape '${dx.shape}', which does not match ` +
|
|
`the shape of the input '${x.shape}'`);
|
|
}
|
|
if (tensorAccumulatedGradientMap[x.id] == null) {
|
|
tensorAccumulatedGradientMap[x.id] = dx;
|
|
}
|
|
else {
|
|
const curGradient = tensorAccumulatedGradientMap[x.id];
|
|
tensorAccumulatedGradientMap[x.id] = add(curGradient, dx);
|
|
curGradient.dispose();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const FORMAT_LIMIT_NUM_VALS = 20;
|
|
|
|
const FORMAT_NUM_FIRST_LAST_VALS = 3;
|
|
|
|
const FORMAT_NUM_SIG_DIGITS = 7;
|
|
function tensorToString(vals, shape, dtype, verbose) {
|
|
const strides = computeStrides(shape);
|
|
const padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides);
|
|
const rank = shape.length;
|
|
const valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol);
|
|
const lines = ['Tensor'];
|
|
if (verbose) {
|
|
lines.push(` dtype: ${dtype}`);
|
|
lines.push(` rank: ${rank}`);
|
|
lines.push(` shape: [${shape}]`);
|
|
lines.push(` values:`);
|
|
}
|
|
lines.push(valsLines.map(l => ' ' + l).join('\n'));
|
|
return lines.join('\n');
|
|
}
|
|
function computeMaxSizePerColumn(vals, shape, dtype, strides) {
|
|
const n = sizeFromShape(shape);
|
|
const numCols = strides[strides.length - 1];
|
|
const padPerCol = new Array(numCols).fill(0);
|
|
const rank = shape.length;
|
|
const valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals;
|
|
if (rank > 1) {
|
|
for (let row = 0; row < n / numCols; row++) {
|
|
const offset = row * numCols;
|
|
for (let j = 0; j < numCols; j++) {
|
|
padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length);
|
|
}
|
|
}
|
|
}
|
|
return padPerCol;
|
|
}
|
|
function valToString(val, pad, dtype) {
|
|
let valStr;
|
|
if (Array.isArray(val)) {
|
|
valStr = `${parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS))} + ` +
|
|
`${parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS))}j`;
|
|
}
|
|
else if (isString(val)) {
|
|
valStr = `'${val}'`;
|
|
}
|
|
else if (dtype === 'bool') {
|
|
valStr = boolNumToString(val);
|
|
}
|
|
else {
|
|
valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString();
|
|
}
|
|
return rightPad(valStr, pad);
|
|
}
|
|
function boolNumToString(v) {
|
|
return v === 0 ? 'false' : 'true';
|
|
}
|
|
function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast = true) {
|
|
const storagePerElement = dtype === 'complex64' ? 2 : 1;
|
|
const size = shape[0];
|
|
const rank = shape.length;
|
|
if (rank === 0) {
|
|
if (dtype === 'complex64') {
|
|
const complexTuple = createComplexTuples(vals);
|
|
return [valToString(complexTuple[0], 0, dtype)];
|
|
}
|
|
if (dtype === 'bool') {
|
|
return [boolNumToString(vals[0])];
|
|
}
|
|
return [vals[0].toString()];
|
|
}
|
|
if (rank === 1) {
|
|
if (size > FORMAT_LIMIT_NUM_VALS) {
|
|
const firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement;
|
|
let firstVals = Array.from(vals.slice(0, firstValsSize));
|
|
let lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement));
|
|
if (dtype === 'complex64') {
|
|
firstVals = createComplexTuples(firstVals);
|
|
lastVals = createComplexTuples(lastVals);
|
|
}
|
|
return [
|
|
'[' +
|
|
firstVals.map((x, i) => valToString(x, padPerCol[i], dtype))
|
|
.join(', ') +
|
|
', ..., ' +
|
|
lastVals
|
|
.map((x, i) => valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype))
|
|
.join(', ') +
|
|
']'
|
|
];
|
|
}
|
|
const displayVals = dtype === 'complex64' ? createComplexTuples(vals) :
|
|
Array.from(vals);
|
|
return [
|
|
'[' +
|
|
displayVals.map((x, i) => valToString(x, padPerCol[i], dtype))
|
|
.join(', ') +
|
|
']'
|
|
];
|
|
}
|
|
|
|
const subshape = shape.slice(1);
|
|
const substrides = strides.slice(1);
|
|
const stride = strides[0] * storagePerElement;
|
|
const lines = [];
|
|
if (size > FORMAT_LIMIT_NUM_VALS) {
|
|
for (let i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) {
|
|
const start = i * stride;
|
|
const end = start + stride;
|
|
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false ));
|
|
}
|
|
lines.push('...');
|
|
for (let i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) {
|
|
const start = i * stride;
|
|
const end = start + stride;
|
|
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 ));
|
|
}
|
|
}
|
|
else {
|
|
for (let i = 0; i < size; i++) {
|
|
const start = i * stride;
|
|
const end = start + stride;
|
|
lines.push(...subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 ));
|
|
}
|
|
}
|
|
const sep = rank === 2 ? ',' : '';
|
|
lines[0] = '[' + (size > 0 ? lines[0] + sep : '');
|
|
for (let i = 1; i < lines.length - 1; i++) {
|
|
lines[i] = ' ' + lines[i] + sep;
|
|
}
|
|
let newLineSep = ',\n';
|
|
for (let i = 2; i < rank; i++) {
|
|
newLineSep += '\n';
|
|
}
|
|
lines[lines.length - 1] =
|
|
' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep);
|
|
return lines;
|
|
}
|
|
function createComplexTuples(vals) {
|
|
const complexTuples = [];
|
|
for (let i = 0; i < vals.length; i += 2) {
|
|
complexTuples.push([vals[i], vals[i + 1]]);
|
|
}
|
|
return complexTuples;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class TensorBuffer {
|
|
constructor(shape, dtype, values) {
|
|
this.dtype = dtype;
|
|
this.shape = shape.slice();
|
|
this.size = sizeFromShape(shape);
|
|
if (values != null) {
|
|
const n = values.length;
|
|
assert$1(n === this.size, () => `Length of values '${n}' does not match the size ` +
|
|
`inferred by the shape '${this.size}'.`);
|
|
}
|
|
if (dtype === 'complex64') {
|
|
throw new Error(`complex64 dtype TensorBuffers are not supported. Please create ` +
|
|
`a TensorBuffer for the real and imaginary parts separately and ` +
|
|
`call tf.complex(real, imag).`);
|
|
}
|
|
this.values = values || getArrayFromDType(dtype, this.size);
|
|
this.strides = computeStrides(shape);
|
|
}
|
|
|
|
set(value, ...locs) {
|
|
if (locs.length === 0) {
|
|
locs = [0];
|
|
}
|
|
assert$1(locs.length === this.rank, () => `The number of provided coordinates (${locs.length}) must ` +
|
|
`match the rank (${this.rank})`);
|
|
const index = this.locToIndex(locs);
|
|
this.values[index] = value;
|
|
}
|
|
|
|
get(...locs) {
|
|
if (locs.length === 0) {
|
|
locs = [0];
|
|
}
|
|
let i = 0;
|
|
for (const loc of locs) {
|
|
if (loc < 0 || loc >= this.shape[i]) {
|
|
const msg = `Requested out of range element at ${locs}. ` +
|
|
` Buffer shape=${this.shape}`;
|
|
throw new Error(msg);
|
|
}
|
|
i++;
|
|
}
|
|
let index = locs[locs.length - 1];
|
|
for (let i = 0; i < locs.length - 1; ++i) {
|
|
index += this.strides[i] * locs[i];
|
|
}
|
|
return this.values[index];
|
|
}
|
|
locToIndex(locs) {
|
|
if (this.rank === 0) {
|
|
return 0;
|
|
}
|
|
else if (this.rank === 1) {
|
|
return locs[0];
|
|
}
|
|
let index = locs[locs.length - 1];
|
|
for (let i = 0; i < locs.length - 1; ++i) {
|
|
index += this.strides[i] * locs[i];
|
|
}
|
|
return index;
|
|
}
|
|
indexToLoc(index) {
|
|
if (this.rank === 0) {
|
|
return [];
|
|
}
|
|
else if (this.rank === 1) {
|
|
return [index];
|
|
}
|
|
const locs = new Array(this.shape.length);
|
|
for (let i = 0; i < locs.length - 1; ++i) {
|
|
locs[i] = Math.floor(index / this.strides[i]);
|
|
index -= locs[i] * this.strides[i];
|
|
}
|
|
locs[locs.length - 1] = index;
|
|
return locs;
|
|
}
|
|
get rank() {
|
|
return this.shape.length;
|
|
}
|
|
|
|
toTensor() {
|
|
return trackerFn().makeTensor(this.values, this.shape, this.dtype);
|
|
}
|
|
}
|
|
|
|
let trackerFn = null;
|
|
|
|
let opHandler$1 = null;
|
|
|
|
function setTensorTracker(fn) {
|
|
trackerFn = fn;
|
|
}
|
|
|
|
function setOpHandler(handler) {
|
|
opHandler$1 = handler;
|
|
}
|
|
|
|
class Tensor {
|
|
constructor(shape, dtype, dataId, id) {
|
|
|
|
this.kept = false;
|
|
this.isDisposedInternal = false;
|
|
this.shape = shape.slice();
|
|
this.dtype = dtype || 'float32';
|
|
this.size = sizeFromShape(shape);
|
|
this.strides = computeStrides(shape);
|
|
this.dataId = dataId;
|
|
this.id = id;
|
|
this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher');
|
|
}
|
|
get rank() {
|
|
return this.shape.length;
|
|
}
|
|
|
|
async buffer() {
|
|
const vals = await this.data();
|
|
return opHandler$1.buffer(this.shape, this.dtype, vals);
|
|
}
|
|
|
|
bufferSync() {
|
|
return opHandler$1.buffer(this.shape, this.dtype, this.dataSync());
|
|
}
|
|
|
|
async array() {
|
|
const vals = await this.data();
|
|
return toNestedArray(this.shape, vals, this.dtype === 'complex64');
|
|
}
|
|
|
|
arraySync() {
|
|
return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64');
|
|
}
|
|
|
|
async data() {
|
|
this.throwIfDisposed();
|
|
const data = trackerFn().read(this.dataId);
|
|
if (this.dtype === 'string') {
|
|
const bytes = await data;
|
|
try {
|
|
return bytes.map(b => decodeString(b));
|
|
}
|
|
catch (_a) {
|
|
throw new Error('Failed to decode the string bytes into utf-8. ' +
|
|
'To get the original bytes, call tensor.bytes().');
|
|
}
|
|
}
|
|
return data;
|
|
}
|
|
|
|
dataToGPU(options) {
|
|
this.throwIfDisposed();
|
|
return trackerFn().readToGPU(this.dataId, options);
|
|
}
|
|
|
|
dataSync() {
|
|
this.throwIfDisposed();
|
|
const data = trackerFn().readSync(this.dataId);
|
|
if (this.dtype === 'string') {
|
|
try {
|
|
return data.map(b => decodeString(b));
|
|
}
|
|
catch (_a) {
|
|
throw new Error('Failed to decode the string bytes into utf-8. ' +
|
|
'To get the original bytes, call tensor.bytes().');
|
|
}
|
|
}
|
|
return data;
|
|
}
|
|
|
|
async bytes() {
|
|
this.throwIfDisposed();
|
|
const data = await trackerFn().read(this.dataId);
|
|
if (this.dtype === 'string') {
|
|
return data;
|
|
}
|
|
else {
|
|
return new Uint8Array(data.buffer);
|
|
}
|
|
}
|
|
|
|
dispose() {
|
|
if (this.isDisposed) {
|
|
return;
|
|
}
|
|
if (this.kerasMask) {
|
|
this.kerasMask.dispose();
|
|
}
|
|
trackerFn().disposeTensor(this);
|
|
this.isDisposedInternal = true;
|
|
}
|
|
get isDisposed() {
|
|
return this.isDisposedInternal;
|
|
}
|
|
throwIfDisposed() {
|
|
if (this.isDisposed) {
|
|
throw new Error(`Tensor is disposed.`);
|
|
}
|
|
}
|
|
|
|
print(verbose = false) {
|
|
return opHandler$1.print(this, verbose);
|
|
}
|
|
|
|
clone() {
|
|
this.throwIfDisposed();
|
|
return opHandler$1.clone(this);
|
|
}
|
|
|
|
toString(verbose = false) {
|
|
const vals = this.dataSync();
|
|
return tensorToString(vals, this.shape, this.dtype, verbose);
|
|
}
|
|
cast(dtype) {
|
|
this.throwIfDisposed();
|
|
return opHandler$1.cast(this, dtype);
|
|
}
|
|
variable(trainable = true, name, dtype) {
|
|
this.throwIfDisposed();
|
|
return trackerFn().makeVariable(this, trainable, name, dtype);
|
|
}
|
|
}
|
|
Object.defineProperty(Tensor, Symbol.hasInstance, {
|
|
value: (instance) => {
|
|
|
|
|
|
|
|
|
|
|
|
return !!instance && instance.data != null && instance.dataSync != null &&
|
|
instance.throwIfDisposed != null;
|
|
}
|
|
});
|
|
function getGlobalTensorClass() {
|
|
|
|
|
|
|
|
return getGlobal('Tensor', () => {
|
|
return Tensor;
|
|
});
|
|
}
|
|
|
|
getGlobalTensorClass();
|
|
|
|
class Variable extends Tensor {
|
|
constructor(initialValue, trainable, name, tensorId) {
|
|
super(initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId);
|
|
this.trainable = trainable;
|
|
this.name = name;
|
|
}
|
|
|
|
assign(newValue) {
|
|
if (newValue.dtype !== this.dtype) {
|
|
throw new Error(`dtype of the new value (${newValue.dtype}) and ` +
|
|
`previous value (${this.dtype}) must match`);
|
|
}
|
|
if (!arraysEqual(newValue.shape, this.shape)) {
|
|
throw new Error(`shape of the new value (${newValue.shape}) and ` +
|
|
`previous value (${this.shape}) must match`);
|
|
}
|
|
trackerFn().disposeTensor(this);
|
|
this.dataId = newValue.dataId;
|
|
trackerFn().incRef(this, null );
|
|
}
|
|
dispose() {
|
|
trackerFn().disposeVariable(this);
|
|
this.isDisposedInternal = true;
|
|
}
|
|
}
|
|
Object.defineProperty(Variable, Symbol.hasInstance, {
|
|
value: (instance) => {
|
|
return instance instanceof Tensor && instance.assign != null &&
|
|
instance.assign instanceof Function;
|
|
}
|
|
});
|
|
|
|
|
|
var Rank;
|
|
(function (Rank) {
|
|
Rank["R0"] = "R0";
|
|
Rank["R1"] = "R1";
|
|
Rank["R2"] = "R2";
|
|
Rank["R3"] = "R3";
|
|
Rank["R4"] = "R4";
|
|
Rank["R5"] = "R5";
|
|
Rank["R6"] = "R6";
|
|
})(Rank || (Rank = {}));
|
|
|
|
|
|
var UpcastInt32AndMap;
|
|
(function (UpcastInt32AndMap) {
|
|
UpcastInt32AndMap["float32"] = "float32";
|
|
UpcastInt32AndMap["int32"] = "int32";
|
|
UpcastInt32AndMap["bool"] = "int32";
|
|
UpcastInt32AndMap["complex64"] = "complex64";
|
|
})(UpcastInt32AndMap || (UpcastInt32AndMap = {}));
|
|
var UpcastBoolAndMap;
|
|
(function (UpcastBoolAndMap) {
|
|
UpcastBoolAndMap["float32"] = "float32";
|
|
UpcastBoolAndMap["int32"] = "int32";
|
|
UpcastBoolAndMap["bool"] = "bool";
|
|
UpcastBoolAndMap["complex64"] = "complex64";
|
|
})(UpcastBoolAndMap || (UpcastBoolAndMap = {}));
|
|
var UpcastFloat32AndMap;
|
|
(function (UpcastFloat32AndMap) {
|
|
UpcastFloat32AndMap["float32"] = "float32";
|
|
UpcastFloat32AndMap["int32"] = "float32";
|
|
UpcastFloat32AndMap["bool"] = "float32";
|
|
UpcastFloat32AndMap["complex64"] = "complex64";
|
|
})(UpcastFloat32AndMap || (UpcastFloat32AndMap = {}));
|
|
var UpcastComplex64AndMap;
|
|
(function (UpcastComplex64AndMap) {
|
|
UpcastComplex64AndMap["float32"] = "complex64";
|
|
UpcastComplex64AndMap["int32"] = "complex64";
|
|
UpcastComplex64AndMap["bool"] = "complex64";
|
|
UpcastComplex64AndMap["complex64"] = "complex64";
|
|
})(UpcastComplex64AndMap || (UpcastComplex64AndMap = {}));
|
|
const upcastTypeMap = {
|
|
'float32': UpcastFloat32AndMap,
|
|
'int32': UpcastInt32AndMap,
|
|
'bool': UpcastBoolAndMap,
|
|
'complex64': UpcastComplex64AndMap
|
|
};
|
|
function upcastType(typeA, typeB) {
|
|
if (typeA === 'string' || typeB === 'string') {
|
|
if (typeA === 'string' && typeB === 'string') {
|
|
return 'string';
|
|
}
|
|
throw new Error(`Can not upcast ${typeA} with ${typeB}`);
|
|
}
|
|
return upcastTypeMap[typeA][typeB];
|
|
}
|
|
|
|
function sumOutType(type) {
|
|
return upcastType(type, 'int32');
|
|
}
|
|
function isWebGLData(values) {
|
|
return values != null && typeof values === 'object' && 'texture' in values &&
|
|
values.texture instanceof WebGLTexture;
|
|
}
|
|
function isWebGPUData(values) {
|
|
return typeof GPUBuffer !== 'undefined' && values != null &&
|
|
typeof values === 'object' && 'buffer' in values &&
|
|
values.buffer instanceof GPUBuffer;
|
|
}
|
|
|
|
|
|
function makeTypesMatch(a, b) {
|
|
if (a.dtype === b.dtype) {
|
|
return [a, b];
|
|
}
|
|
const dtype = upcastType(a.dtype, b.dtype);
|
|
return [a.cast(dtype), b.cast(dtype)];
|
|
}
|
|
|
|
function getTensorsInContainer(result) {
|
|
const list = [];
|
|
const seen = new Set();
|
|
walkTensorContainer(result, list, seen);
|
|
return list;
|
|
}
|
|
function walkTensorContainer(container, list, seen) {
|
|
if (container == null) {
|
|
return;
|
|
}
|
|
if (container instanceof Tensor) {
|
|
list.push(container);
|
|
return;
|
|
}
|
|
if (!isIterable(container)) {
|
|
return;
|
|
}
|
|
|
|
const iterable = container;
|
|
for (const k in iterable) {
|
|
const val = iterable[k];
|
|
if (!seen.has(val)) {
|
|
seen.add(val);
|
|
walkTensorContainer(val, list, seen);
|
|
}
|
|
}
|
|
}
|
|
|
|
function isIterable(obj) {
|
|
return Array.isArray(obj) || typeof obj === 'object';
|
|
}
|
|
|
|
|
|
function isRegisteredKernelInvocation(kernelInvocation) {
|
|
return kernelInvocation.kernelName != null;
|
|
}
|
|
class EngineState {
|
|
constructor() {
|
|
|
|
this.registeredVariables = {};
|
|
this.nextTapeNodeId = 0;
|
|
this.numBytes = 0;
|
|
this.numTensors = 0;
|
|
this.numStringTensors = 0;
|
|
this.numDataBuffers = 0;
|
|
|
|
|
|
|
|
this.gradientDepth = 0;
|
|
|
|
|
|
this.kernelDepth = 0;
|
|
this.scopeStack = [];
|
|
|
|
this.numDataMovesStack = [];
|
|
this.nextScopeId = 0;
|
|
this.tensorInfo = new WeakMap();
|
|
this.profiling = false;
|
|
this.activeProfile = {
|
|
newBytes: 0,
|
|
newTensors: 0,
|
|
peakBytes: 0,
|
|
kernels: [],
|
|
result: null,
|
|
get kernelNames() {
|
|
return Array.from(new Set(this.kernels.map(k => k.name)));
|
|
}
|
|
};
|
|
}
|
|
dispose() {
|
|
for (const variableName in this.registeredVariables) {
|
|
this.registeredVariables[variableName].dispose();
|
|
}
|
|
}
|
|
}
|
|
class Engine {
|
|
constructor(ENV) {
|
|
this.ENV = ENV;
|
|
this.registry = {};
|
|
this.registryFactory = {};
|
|
this.pendingBackendInitId = 0;
|
|
this.state = new EngineState();
|
|
}
|
|
async ready() {
|
|
if (this.pendingBackendInit != null) {
|
|
return this.pendingBackendInit.then(() => { });
|
|
}
|
|
if (this.backendInstance != null) {
|
|
return;
|
|
}
|
|
const sortedBackends = this.getSortedBackends();
|
|
for (let i = 0; i < sortedBackends.length; i++) {
|
|
const backendName = sortedBackends[i];
|
|
const success = await this.initializeBackend(backendName).success;
|
|
if (success) {
|
|
await this.setBackend(backendName);
|
|
return;
|
|
}
|
|
}
|
|
throw new Error(`Could not initialize any backends, all backend initializations ` +
|
|
`failed.`);
|
|
}
|
|
get backend() {
|
|
if (this.pendingBackendInit != null) {
|
|
throw new Error(`Backend '${this.backendName}' has not yet been initialized. Make ` +
|
|
`sure to await tf.ready() or await tf.setBackend() before calling ` +
|
|
`other methods`);
|
|
}
|
|
if (this.backendInstance == null) {
|
|
const { name, asyncInit } = this.initializeBackendsAndReturnBest();
|
|
if (asyncInit) {
|
|
throw new Error(`The highest priority backend '${name}' has not yet been ` +
|
|
`initialized. Make sure to await tf.ready() or ` +
|
|
`await tf.setBackend() before calling other methods`);
|
|
}
|
|
this.setBackend(name);
|
|
}
|
|
return this.backendInstance;
|
|
}
|
|
backendNames() {
|
|
return Object.keys(this.registryFactory);
|
|
}
|
|
findBackend(backendName) {
|
|
if (!(backendName in this.registry)) {
|
|
|
|
|
|
if (backendName in this.registryFactory) {
|
|
const { asyncInit } = this.initializeBackend(backendName);
|
|
if (asyncInit) {
|
|
|
|
return null;
|
|
}
|
|
}
|
|
else {
|
|
return null;
|
|
}
|
|
}
|
|
return this.registry[backendName];
|
|
}
|
|
findBackendFactory(backendName) {
|
|
if (!(backendName in this.registryFactory)) {
|
|
return null;
|
|
}
|
|
return this.registryFactory[backendName].factory;
|
|
}
|
|
registerBackend(backendName, factory, priority = 1) {
|
|
if (backendName in this.registryFactory) {
|
|
warn(`${backendName} backend was already registered. ` +
|
|
`Reusing existing backend factory.`);
|
|
return false;
|
|
}
|
|
this.registryFactory[backendName] = { factory, priority };
|
|
return true;
|
|
}
|
|
async setBackend(backendName) {
|
|
if (this.registryFactory[backendName] == null) {
|
|
throw new Error(`Backend name '${backendName}' not found in registry`);
|
|
}
|
|
this.backendName = backendName;
|
|
if (this.registry[backendName] == null) {
|
|
this.backendInstance = null;
|
|
const { success, asyncInit } = this.initializeBackend(backendName);
|
|
const result = asyncInit ? await success : success;
|
|
if (!result) {
|
|
return false;
|
|
}
|
|
}
|
|
this.backendInstance = this.registry[backendName];
|
|
this.setupRegisteredKernels();
|
|
|
|
this.profiler = new Profiler(this.backendInstance);
|
|
return true;
|
|
}
|
|
setupRegisteredKernels() {
|
|
const kernels = getKernelsForBackend(this.backendName);
|
|
kernels.forEach(kernel => {
|
|
if (kernel.setupFunc != null) {
|
|
kernel.setupFunc(this.backendInstance);
|
|
}
|
|
});
|
|
}
|
|
disposeRegisteredKernels(backendName) {
|
|
const kernels = getKernelsForBackend(backendName);
|
|
kernels.forEach(kernel => {
|
|
if (kernel.disposeFunc != null) {
|
|
kernel.disposeFunc(this.registry[backendName]);
|
|
}
|
|
});
|
|
}
|
|
|
|
initializeBackend(backendName) {
|
|
const registryFactoryEntry = this.registryFactory[backendName];
|
|
if (registryFactoryEntry == null) {
|
|
throw new Error(`Cannot initialize backend ${backendName}, no registration found.`);
|
|
}
|
|
try {
|
|
const backend = registryFactoryEntry.factory();
|
|
|
|
if (backend && !(backend instanceof KernelBackend) &&
|
|
typeof backend.then === 'function') {
|
|
const promiseId = ++this.pendingBackendInitId;
|
|
const success = backend
|
|
.then(backendInstance => {
|
|
|
|
if (promiseId < this.pendingBackendInitId) {
|
|
return false;
|
|
}
|
|
this.registry[backendName] = backendInstance;
|
|
this.pendingBackendInit = null;
|
|
return true;
|
|
})
|
|
.catch(err => {
|
|
|
|
if (promiseId < this.pendingBackendInitId) {
|
|
return false;
|
|
}
|
|
this.pendingBackendInit = null;
|
|
warn(`Initialization of backend ${backendName} failed`);
|
|
warn(err.stack || err.message);
|
|
return false;
|
|
});
|
|
this.pendingBackendInit = success;
|
|
return { success, asyncInit: true };
|
|
}
|
|
else {
|
|
this.registry[backendName] = backend;
|
|
return { success: true, asyncInit: false };
|
|
}
|
|
}
|
|
catch (err) {
|
|
warn(`Initialization of backend ${backendName} failed`);
|
|
warn(err.stack || err.message);
|
|
return { success: false, asyncInit: false };
|
|
}
|
|
}
|
|
removeBackend(backendName) {
|
|
if (!(backendName in this.registryFactory)) {
|
|
throw new Error(`${backendName} backend not found in registry`);
|
|
}
|
|
if (this.backendName === backendName && this.pendingBackendInit != null) {
|
|
|
|
|
|
this.pendingBackendInitId++;
|
|
}
|
|
if (backendName in this.registry) {
|
|
this.disposeRegisteredKernels(backendName);
|
|
this.registry[backendName].dispose();
|
|
delete this.registry[backendName];
|
|
}
|
|
delete this.registryFactory[backendName];
|
|
|
|
if (this.backendName === backendName) {
|
|
this.pendingBackendInit = null;
|
|
this.backendName = null;
|
|
this.backendInstance = null;
|
|
}
|
|
}
|
|
getSortedBackends() {
|
|
if (Object.keys(this.registryFactory).length === 0) {
|
|
throw new Error('No backend found in registry.');
|
|
}
|
|
return Object.keys(this.registryFactory).sort((a, b) => {
|
|
|
|
return this.registryFactory[b].priority -
|
|
this.registryFactory[a].priority;
|
|
});
|
|
}
|
|
initializeBackendsAndReturnBest() {
|
|
const sortedBackends = this.getSortedBackends();
|
|
for (let i = 0; i < sortedBackends.length; i++) {
|
|
const backendName = sortedBackends[i];
|
|
const { success, asyncInit } = this.initializeBackend(backendName);
|
|
if (asyncInit || success) {
|
|
return { name: backendName, asyncInit };
|
|
}
|
|
}
|
|
throw new Error(`Could not initialize any backends, all backend initializations ` +
|
|
`failed.`);
|
|
}
|
|
moveData(backend, dataId) {
|
|
const info = this.state.tensorInfo.get(dataId);
|
|
const srcBackend = info.backend;
|
|
const values = this.readSync(dataId);
|
|
const refCount = srcBackend.refCount(dataId);
|
|
|
|
|
|
srcBackend.disposeData(dataId, true);
|
|
info.backend = backend;
|
|
backend.move(dataId, values, info.shape, info.dtype, refCount);
|
|
if (this.shouldCheckForMemLeaks()) {
|
|
|
|
|
|
this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
|
|
}
|
|
}
|
|
tidy(nameOrFn, fn) {
|
|
let name = null;
|
|
if (fn == null) {
|
|
|
|
if (typeof nameOrFn !== 'function') {
|
|
throw new Error('Please provide a function to tidy()');
|
|
}
|
|
fn = nameOrFn;
|
|
}
|
|
else {
|
|
|
|
if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) {
|
|
throw new Error('When calling with two arguments, the first argument ' +
|
|
'to tidy() must be a string');
|
|
}
|
|
if (typeof fn !== 'function') {
|
|
throw new Error('When calling with two arguments, the 2nd argument ' +
|
|
'to tidy() must be a function');
|
|
}
|
|
name = nameOrFn;
|
|
|
|
|
|
}
|
|
let result;
|
|
return this.scopedRun(() => this.startScope(name), () => this.endScope(result), () => {
|
|
result = fn();
|
|
if (result instanceof Promise) {
|
|
console.error('Cannot return a Promise inside of tidy.');
|
|
}
|
|
return result;
|
|
});
|
|
}
|
|
scopedRun(start, end, f) {
|
|
start();
|
|
try {
|
|
const res = f();
|
|
end();
|
|
return res;
|
|
}
|
|
catch (ex) {
|
|
end();
|
|
throw ex;
|
|
}
|
|
}
|
|
nextTensorId() {
|
|
return Engine.nextTensorId++;
|
|
}
|
|
nextVariableId() {
|
|
return Engine.nextVariableId++;
|
|
}
|
|
|
|
clone(x) {
|
|
const y = ENGINE.runKernel(Identity$1, { x });
|
|
const inputs = { x };
|
|
const grad = (dy) => ({
|
|
x: () => {
|
|
const dtype = 'float32';
|
|
const gradInputs = { x: dy };
|
|
const attrs = { dtype };
|
|
return ENGINE.runKernel(Cast, gradInputs,
|
|
|
|
attrs);
|
|
}
|
|
});
|
|
const saved = [];
|
|
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
|
|
return y;
|
|
}
|
|
|
|
runKernel(kernelName, inputs, attrs) {
|
|
const hasKernel = getKernel(kernelName, this.backendName) != null;
|
|
if (!hasKernel) {
|
|
throw new Error(`Kernel '${kernelName}' not registered for backend '${this.backendName}'`);
|
|
}
|
|
return this.runKernelFunc({ kernelName, inputs, attrs });
|
|
}
|
|
shouldCheckForMemLeaks() {
|
|
return this.ENV.getBool('IS_TEST');
|
|
}
|
|
checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos) {
|
|
const numDataIdsAfter = this.backend.numDataIds();
|
|
|
|
let numOutputDataIds = 0;
|
|
outInfos.forEach(info => {
|
|
|
|
|
|
numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1);
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
const numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1];
|
|
const dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves;
|
|
if (dataIdsLeaked > 0) {
|
|
throw new Error(`Backend '${this.backendName}' has an internal memory leak ` +
|
|
`(${dataIdsLeaked} data ids) after running '${kernelName}'`);
|
|
}
|
|
}
|
|
|
|
runKernelFunc(kernelParams) {
|
|
let outputs;
|
|
let saved = [];
|
|
const isTapeOn = this.isTapeOn();
|
|
const startingBytecount = this.state.numBytes;
|
|
const startingNumTensors = this.state.numTensors;
|
|
if (this.shouldCheckForMemLeaks()) {
|
|
this.state.numDataMovesStack.push(0);
|
|
}
|
|
let kernelFunc;
|
|
let out;
|
|
const kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ?
|
|
kernelParams.kernelName :
|
|
this.state.activeScope != null ? this.state.activeScope.name : '';
|
|
|
|
|
|
|
|
if (isRegisteredKernelInvocation(kernelParams)) {
|
|
const { kernelName, inputs, attrs } = kernelParams;
|
|
const kernel = getKernel(kernelName, this.backendName);
|
|
assert$1(kernel != null, () => `Cannot find registered kernel '${kernelName}' for backend '${this.backendName}'`);
|
|
kernelFunc = () => {
|
|
const numDataIdsBefore = this.backend.numDataIds();
|
|
out = kernel.kernelFunc({ inputs, attrs, backend: this.backend });
|
|
const outInfos = Array.isArray(out) ? out : [out];
|
|
if (this.shouldCheckForMemLeaks()) {
|
|
this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
|
|
}
|
|
const outTensors = outInfos.map((outInfo) => {
|
|
|
|
|
|
|
|
if (outInfo.rank != null) {
|
|
return outInfo;
|
|
}
|
|
return this.makeTensorFromTensorInfo(outInfo);
|
|
});
|
|
|
|
|
|
|
|
|
|
if (isTapeOn) {
|
|
const tensorsToSave = this.getTensorsForGradient(kernelName, inputs, outTensors);
|
|
saved = this.saveTensorsForBackwardMode(tensorsToSave);
|
|
}
|
|
return outTensors;
|
|
};
|
|
}
|
|
else {
|
|
const { forwardFunc } = kernelParams;
|
|
|
|
const saveFunc = (tensors) => {
|
|
|
|
|
|
|
|
if (!isTapeOn) {
|
|
return;
|
|
}
|
|
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
|
|
};
|
|
kernelFunc = () => {
|
|
const numDataIdsBefore = this.backend.numDataIds();
|
|
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
|
|
const outs = (Array.isArray(out) ? out : [out]);
|
|
if (this.shouldCheckForMemLeaks()) {
|
|
|
|
this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs);
|
|
}
|
|
return outs;
|
|
};
|
|
}
|
|
|
|
|
|
|
|
const { inputs, attrs } = kernelParams;
|
|
const backwardsFunc = isRegisteredKernelInvocation(kernelParams) ?
|
|
null :
|
|
kernelParams.backwardsFunc;
|
|
let kernelProfile;
|
|
this.scopedRun(
|
|
|
|
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
|
|
if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
|
|
outputs = kernelFunc();
|
|
}
|
|
else {
|
|
kernelProfile = this.profiler.profileKernel(kernelOrScopeName, inputs, () => kernelFunc());
|
|
if (this.ENV.getBool('DEBUG')) {
|
|
this.profiler.logKernelProfile(kernelProfile);
|
|
}
|
|
outputs = kernelProfile.outputs;
|
|
}
|
|
});
|
|
if (isTapeOn) {
|
|
this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs);
|
|
}
|
|
if (this.state.profiling) {
|
|
this.state.activeProfile.kernels.push({
|
|
name: kernelOrScopeName,
|
|
bytesAdded: this.state.numBytes - startingBytecount,
|
|
totalBytesSnapshot: this.state.numBytes,
|
|
tensorsAdded: this.state.numTensors - startingNumTensors,
|
|
totalTensorsSnapshot: this.state.numTensors,
|
|
inputShapes: Object.keys(inputs).map(key => inputs[key] != null ? inputs[key].shape : null),
|
|
outputShapes: outputs.map(item => item.shape),
|
|
kernelTimeMs: kernelProfile.timeMs,
|
|
extraInfo: kernelProfile.extraInfo
|
|
});
|
|
}
|
|
return (Array.isArray(out) ? outputs : outputs[0]);
|
|
}
|
|
|
|
saveTensorsForBackwardMode(tensors) {
|
|
const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
|
|
return saved;
|
|
}
|
|
|
|
getTensorsForGradient(kernelName, inputs, outputs) {
|
|
const gradConfig = getGradient(kernelName);
|
|
if (gradConfig != null) {
|
|
const inputsToSave = gradConfig.inputsToSave || [];
|
|
const outputsToSave = gradConfig.outputsToSave || [];
|
|
|
|
|
|
let inputTensorsToSave;
|
|
if (gradConfig.saveAllInputs) {
|
|
assert$1(Array.isArray(inputs), () => 'saveAllInputs is true, expected inputs to be an array.');
|
|
inputTensorsToSave = Object.keys(inputs).map((key) => inputs[key]);
|
|
}
|
|
else {
|
|
inputTensorsToSave = inputsToSave.map((inputName) => inputs[inputName]);
|
|
}
|
|
const outputTensorsToSave = outputs.filter((_, i) => outputsToSave[i]);
|
|
return inputTensorsToSave.concat(outputTensorsToSave);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [];
|
|
}
|
|
|
|
makeTensor(values, shape, dtype, backend) {
|
|
if (values == null) {
|
|
throw new Error('Values passed to engine.makeTensor() are null');
|
|
}
|
|
dtype = dtype || 'float32';
|
|
backend = backend || this.backend;
|
|
let backendVals = values;
|
|
if (dtype === 'string' && isString(values[0])) {
|
|
backendVals = values.map(d => encodeString(d));
|
|
}
|
|
const dataId = backend.write(backendVals, shape, dtype);
|
|
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
|
|
this.trackTensor(t, backend);
|
|
|
|
if (dtype === 'string') {
|
|
const info = this.state.tensorInfo.get(dataId);
|
|
const newBytes = bytesFromStringArray(backendVals);
|
|
this.state.numBytes += newBytes - info.bytes;
|
|
info.bytes = newBytes;
|
|
}
|
|
return t;
|
|
}
|
|
|
|
makeTensorFromDataId(dataId, shape, dtype, backend) {
|
|
dtype = dtype || 'float32';
|
|
const tensorInfo = { dataId, shape, dtype };
|
|
return this.makeTensorFromTensorInfo(tensorInfo, backend);
|
|
}
|
|
|
|
makeTensorFromTensorInfo(tensorInfo, backend) {
|
|
const { dataId, shape, dtype } = tensorInfo;
|
|
const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
|
|
this.trackTensor(t, backend);
|
|
return t;
|
|
}
|
|
makeVariable(initialValue, trainable = true, name, dtype) {
|
|
name = name || this.nextVariableId().toString();
|
|
if (dtype != null && dtype !== initialValue.dtype) {
|
|
initialValue = initialValue.cast(dtype);
|
|
}
|
|
const v = new Variable(initialValue, trainable, name, this.nextTensorId());
|
|
if (this.state.registeredVariables[v.name] != null) {
|
|
throw new Error(`Variable with name ${v.name} was already registered`);
|
|
}
|
|
this.state.registeredVariables[v.name] = v;
|
|
this.incRef(v, this.backend);
|
|
return v;
|
|
}
|
|
trackTensor(a, backend) {
|
|
this.state.numTensors++;
|
|
if (a.dtype === 'string') {
|
|
this.state.numStringTensors++;
|
|
}
|
|
|
|
|
|
let bytes = 0;
|
|
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
|
|
bytes = a.size * bytesPerElement(a.dtype);
|
|
}
|
|
this.state.numBytes += bytes;
|
|
if (!this.state.tensorInfo.has(a.dataId)) {
|
|
this.state.numDataBuffers++;
|
|
this.state.tensorInfo.set(a.dataId, {
|
|
backend: backend || this.backend,
|
|
dtype: a.dtype,
|
|
shape: a.shape,
|
|
bytes
|
|
});
|
|
}
|
|
if (!(a instanceof Variable)) {
|
|
this.track(a);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
incRef(a, backend) {
|
|
this.trackTensor(a, backend);
|
|
this.backend.incRef(a.dataId);
|
|
}
|
|
removeDataId(dataId, backend) {
|
|
if (this.state.tensorInfo.has(dataId) &&
|
|
this.state.tensorInfo.get(dataId).backend === backend) {
|
|
this.state.tensorInfo.delete(dataId);
|
|
this.state.numDataBuffers--;
|
|
}
|
|
}
|
|
disposeTensor(a) {
|
|
if (!this.state.tensorInfo.has(a.dataId)) {
|
|
return;
|
|
}
|
|
const info = this.state.tensorInfo.get(a.dataId);
|
|
this.state.numTensors--;
|
|
if (a.dtype === 'string') {
|
|
this.state.numStringTensors--;
|
|
this.state.numBytes -= info.bytes;
|
|
}
|
|
|
|
|
|
if (a.dtype !== 'complex64' && a.dtype !== 'string') {
|
|
const bytes = a.size * bytesPerElement(a.dtype);
|
|
this.state.numBytes -= bytes;
|
|
}
|
|
|
|
if (info.backend.disposeData(a.dataId)) {
|
|
this.removeDataId(a.dataId, info.backend);
|
|
}
|
|
|
|
|
|
|
|
}
|
|
disposeVariables() {
|
|
for (const varName in this.state.registeredVariables) {
|
|
const v = this.state.registeredVariables[varName];
|
|
this.disposeVariable(v);
|
|
}
|
|
}
|
|
disposeVariable(v) {
|
|
this.disposeTensor(v);
|
|
if (this.state.registeredVariables[v.name] != null) {
|
|
delete this.state.registeredVariables[v.name];
|
|
}
|
|
}
|
|
memory() {
|
|
const info = this.backend.memory();
|
|
info.numTensors = this.state.numTensors;
|
|
info.numDataBuffers = this.state.numDataBuffers;
|
|
info.numBytes = this.state.numBytes;
|
|
if (this.state.numStringTensors > 0) {
|
|
info.unreliable = true;
|
|
if (info.reasons == null) {
|
|
info.reasons = [];
|
|
}
|
|
info.reasons.push('Memory usage by string tensors is approximate ' +
|
|
'(2 bytes per character)');
|
|
}
|
|
return info;
|
|
}
|
|
async profile(query) {
|
|
this.state.profiling = true;
|
|
const startBytes = this.state.numBytes;
|
|
const startNumTensors = this.state.numTensors;
|
|
this.state.activeProfile.kernels = [];
|
|
this.state.activeProfile.result = await query();
|
|
this.state.profiling = false;
|
|
this.state.activeProfile.peakBytes = Math.max(...this.state.activeProfile.kernels.map(d => d.totalBytesSnapshot));
|
|
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
|
|
this.state.activeProfile.newTensors =
|
|
this.state.numTensors - startNumTensors;
|
|
for (const kernel of this.state.activeProfile.kernels) {
|
|
kernel.kernelTimeMs = await kernel.kernelTimeMs;
|
|
kernel.extraInfo = await kernel.extraInfo;
|
|
}
|
|
return this.state.activeProfile;
|
|
}
|
|
isTapeOn() {
|
|
return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
|
|
}
|
|
addTapeNode(kernelName, inputs, outputs, gradientsFunc, saved, attrs) {
|
|
const tapeNode = { id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved };
|
|
const gradConfig = getGradient(kernelName);
|
|
if (gradConfig != null) {
|
|
gradientsFunc = gradConfig.gradFunc;
|
|
}
|
|
if (gradientsFunc != null) {
|
|
tapeNode.gradient = (dys) => {
|
|
|
|
|
|
dys = dys.map((dy, i) => {
|
|
if (dy == null) {
|
|
const output = outputs[i];
|
|
const vals = makeZerosTypedArray(output.size, output.dtype);
|
|
return this.makeTensor(vals, output.shape, output.dtype);
|
|
}
|
|
return dy;
|
|
});
|
|
|
|
|
|
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
|
|
};
|
|
}
|
|
this.state.activeTape.push(tapeNode);
|
|
}
|
|
keep(result) {
|
|
result.kept = true;
|
|
return result;
|
|
}
|
|
startTape() {
|
|
if (this.state.gradientDepth === 0) {
|
|
this.state.activeTape = [];
|
|
}
|
|
this.state.gradientDepth++;
|
|
}
|
|
endTape() {
|
|
this.state.gradientDepth--;
|
|
}
|
|
|
|
startScope(name) {
|
|
const scopeInfo = {
|
|
track: [],
|
|
name: 'unnamed scope',
|
|
id: this.state.nextScopeId++
|
|
};
|
|
if (name) {
|
|
scopeInfo.name = name;
|
|
}
|
|
this.state.scopeStack.push(scopeInfo);
|
|
this.state.activeScope = scopeInfo;
|
|
}
|
|
|
|
endScope(result) {
|
|
const tensorsToTrackInParent = getTensorsInContainer(result);
|
|
const tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(t => t.id));
|
|
|
|
for (let i = 0; i < this.state.activeScope.track.length; i++) {
|
|
const tensor = this.state.activeScope.track[i];
|
|
if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) {
|
|
tensor.dispose();
|
|
}
|
|
}
|
|
const oldScope = this.state.scopeStack.pop();
|
|
this.state.activeScope = this.state.scopeStack.length === 0 ?
|
|
null :
|
|
this.state.scopeStack[this.state.scopeStack.length - 1];
|
|
|
|
tensorsToTrackInParent.forEach(tensor => {
|
|
|
|
|
|
if (!tensor.kept && tensor.scopeId === oldScope.id) {
|
|
this.track(tensor);
|
|
}
|
|
});
|
|
}
|
|
|
|
gradients(f, xs, dy, allowNoGradients = false) {
|
|
assert$1(xs.length > 0, () => 'gradients() received an empty list of xs.');
|
|
if (dy != null && dy.dtype !== 'float32') {
|
|
throw new Error(`dy must have 'float32' dtype, but has '${dy.dtype}'`);
|
|
}
|
|
const y = this.scopedRun(() => this.startTape(), () => this.endTape(), () => this.tidy('forward', f));
|
|
assert$1(y instanceof Tensor, () => 'The result y returned by f() must be a tensor.');
|
|
|
|
const filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y);
|
|
if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) {
|
|
throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
|
|
'that the f you passed encloses all operations that lead from x ' +
|
|
'to y.');
|
|
}
|
|
return this.tidy('backward', () => {
|
|
const accumulatedGradientMap = {};
|
|
accumulatedGradientMap[y.id] = (dy == null) ? ones$1(y.shape) : dy;
|
|
|
|
backpropagateGradients(accumulatedGradientMap, filteredTape,
|
|
|
|
f => this.tidy(f),
|
|
|
|
add$2);
|
|
const grads = xs.map(x => accumulatedGradientMap[x.id]);
|
|
if (this.state.gradientDepth === 0) {
|
|
|
|
|
|
this.state.activeTape.forEach(node => {
|
|
for (const tensor of node.saved) {
|
|
tensor.dispose();
|
|
}
|
|
});
|
|
this.state.activeTape = null;
|
|
}
|
|
return { value: y, grads };
|
|
});
|
|
}
|
|
customGrad(f) {
|
|
assert$1(isFunction(f), () => 'The f passed in customGrad(f) must be a function.');
|
|
return (...inputs) => {
|
|
assert$1(inputs.every(t => t instanceof Tensor), () => 'The args passed in customGrad(f)(x1, x2,...) must all be ' +
|
|
'tensors');
|
|
let res;
|
|
const inputMap = {};
|
|
inputs.forEach((input, i) => {
|
|
inputMap[i] = input;
|
|
});
|
|
const forwardFunc = (_, save) => {
|
|
res = f(...[...inputs, save]);
|
|
assert$1(res.value instanceof Tensor, () => 'The function f passed in customGrad(f) must return an ' +
|
|
'object where `obj.value` is a tensor');
|
|
assert$1(isFunction(res.gradFunc), () => 'The function f passed in customGrad(f) must return an ' +
|
|
'object where `obj.gradFunc` is a function.');
|
|
return res.value;
|
|
};
|
|
const backwardsFunc = (dy, saved) => {
|
|
const gradRes = res.gradFunc(dy, saved);
|
|
const grads = Array.isArray(gradRes) ? gradRes : [gradRes];
|
|
assert$1(grads.length === inputs.length, () => 'The function f passed in customGrad(f) must return an ' +
|
|
'object where `obj.gradFunc` is a function that returns ' +
|
|
'the same number of tensors as inputs passed to f(...).');
|
|
assert$1(grads.every(t => t instanceof Tensor), () => 'The function f passed in customGrad(f) must return an ' +
|
|
'object where `obj.gradFunc` is a function that returns ' +
|
|
'a list of only tensors.');
|
|
const gradMap = {};
|
|
grads.forEach((grad, i) => {
|
|
gradMap[i] = () => grad;
|
|
});
|
|
return gradMap;
|
|
};
|
|
return this.runKernelFunc({
|
|
forwardFunc,
|
|
backwardsFunc,
|
|
inputs: inputMap,
|
|
});
|
|
};
|
|
}
|
|
readSync(dataId) {
|
|
|
|
const info = this.state.tensorInfo.get(dataId);
|
|
return info.backend.readSync(dataId);
|
|
}
|
|
read(dataId) {
|
|
|
|
const info = this.state.tensorInfo.get(dataId);
|
|
return info.backend.read(dataId);
|
|
}
|
|
readToGPU(dataId, options) {
|
|
|
|
const info = this.state.tensorInfo.get(dataId);
|
|
return info.backend.readToGPU(dataId, options);
|
|
}
|
|
async time(query) {
|
|
const start = now();
|
|
const timingInfo = await this.backend.time(query);
|
|
timingInfo.wallMs = now() - start;
|
|
return timingInfo;
|
|
}
|
|
|
|
track(result) {
|
|
if (this.state.activeScope != null) {
|
|
result.scopeId = this.state.activeScope.id;
|
|
this.state.activeScope.track.push(result);
|
|
}
|
|
return result;
|
|
}
|
|
get registeredVariables() {
|
|
return this.state.registeredVariables;
|
|
}
|
|
|
|
reset() {
|
|
|
|
this.pendingBackendInitId++;
|
|
this.state.dispose();
|
|
this.ENV.reset();
|
|
this.state = new EngineState();
|
|
for (const backendName in this.registry) {
|
|
this.disposeRegisteredKernels(backendName);
|
|
this.registry[backendName].dispose();
|
|
delete this.registry[backendName];
|
|
}
|
|
this.backendName = null;
|
|
this.backendInstance = null;
|
|
this.pendingBackendInit = null;
|
|
}
|
|
}
|
|
Engine.nextTensorId = 0;
|
|
Engine.nextVariableId = 0;
|
|
function ones$1(shape) {
|
|
const values = makeOnesTypedArray(sizeFromShape(shape), 'float32');
|
|
return ENGINE.makeTensor(values, shape, 'float32');
|
|
}
|
|
function getOrMakeEngine() {
|
|
const ns = getGlobalNamespace();
|
|
if (ns._tfengine == null) {
|
|
const environment = new Environment(ns);
|
|
ns._tfengine = new Engine(environment);
|
|
}
|
|
setEnvironmentGlobal(ns._tfengine.ENV);
|
|
|
|
|
|
setTensorTracker(() => ns._tfengine);
|
|
return ns._tfengine;
|
|
}
|
|
const ENGINE = getOrMakeEngine();
|
|
|
|
function add$2(a, b) {
|
|
|
|
const inputs = { a, b };
|
|
return ENGINE.runKernel(Add, inputs);
|
|
}
|
|
|
|
|
|
|
|
function _isNavigatorDefined() {
|
|
return typeof navigator !== 'undefined' && navigator != null;
|
|
}
|
|
function isMobile(nav) {
|
|
if (nav || _isNavigatorDefined()) {
|
|
if (!nav) {
|
|
nav = navigator;
|
|
}
|
|
if (nav.product === 'ReactNative') {
|
|
return true;
|
|
}
|
|
const a = nav.userAgent || nav.vendor ||
|
|
|
|
(typeof window !== 'undefined' ? window.opera : '');
|
|
|
|
if (!a) {
|
|
|
|
const navAny = nav;
|
|
return navAny.userAgentData && navAny.userAgentData.mobile;
|
|
}
|
|
|
|
return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i
|
|
.test(a) ||
|
|
|
|
/1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i
|
|
.test(a.substr(0, 4));
|
|
}
|
|
return false;
|
|
}
|
|
function isBrowser() {
|
|
return (typeof window !== 'undefined' && window.document != null) ||
|
|
|
|
(typeof WorkerGlobalScope !== 'undefined');
|
|
}
|
|
|
|
|
|
const ENV$1 = env();
|
|
|
|
|
|
ENV$1.registerFlag('DEBUG', () => false, debugValue => {
|
|
if (debugValue) {
|
|
console.warn('Debugging mode is ON. The output of every math call will ' +
|
|
'be downloaded to CPU and checked for NaNs. ' +
|
|
'This significantly impacts performance.');
|
|
}
|
|
});
|
|
|
|
ENV$1.registerFlag('IS_BROWSER', () => isBrowser());
|
|
|
|
ENV$1.registerFlag('IS_NODE', () => (typeof process !== 'undefined') &&
|
|
(typeof process.versions !== 'undefined') &&
|
|
(typeof process.versions.node !== 'undefined'));
|
|
|
|
ENV$1.registerFlag('IS_CHROME', () => typeof navigator !== 'undefined' && navigator != null &&
|
|
navigator.userAgent != null && /Chrome/.test(navigator.userAgent) &&
|
|
/Google Inc/.test(navigator.vendor));
|
|
|
|
ENV$1.registerFlag('IS_SAFARI', () => typeof navigator !== 'undefined' && navigator != null &&
|
|
navigator.userAgent != null && /Safari/.test(navigator.userAgent) &&
|
|
/Apple/.test(navigator.vendor));
|
|
|
|
ENV$1.registerFlag('PROD', () => false);
|
|
|
|
ENV$1.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', () => ENV$1.getBool('DEBUG'));
|
|
|
|
ENV$1.registerFlag('DEPRECATION_WARNINGS_ENABLED', () => true);
|
|
|
|
ENV$1.registerFlag('IS_TEST', () => false);
|
|
|
|
ENV$1.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', () => ENV$1.getBool('DEBUG'));
|
|
|
|
ENV$1.registerFlag('WRAP_TO_IMAGEBITMAP', () => false);
|
|
|
|
ENV$1.registerFlag('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', () => false);
|
|
|
|
ENV$1.registerFlag('USE_SETTIMEOUTCUSTOM', () => false);
|
|
|
|
|
|
|
|
function buffer(shape, dtype = 'float32', values) {
|
|
dtype = dtype || 'float32';
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
return new TensorBuffer(shape, dtype, values);
|
|
}
|
|
|
|
|
|
function inferShape(val, dtype) {
|
|
let firstElem = val;
|
|
if (isTypedArray(val)) {
|
|
return dtype === 'string' ? [] : [val.length];
|
|
}
|
|
if (isWebGLData(val)) {
|
|
const usedChannels = val.channels || 'RGBA';
|
|
return [val.height, val.width * usedChannels.length];
|
|
}
|
|
else if (isWebGPUData(val)) {
|
|
return [val.buffer.size / (dtype == null ? 4 : bytesPerElement(dtype))];
|
|
}
|
|
if (!Array.isArray(val)) {
|
|
return [];
|
|
}
|
|
const shape = [];
|
|
while (Array.isArray(firstElem) ||
|
|
isTypedArray(firstElem) && dtype !== 'string') {
|
|
shape.push(firstElem.length);
|
|
firstElem = firstElem[0];
|
|
}
|
|
if (Array.isArray(val) &&
|
|
env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) {
|
|
deepAssertShapeConsistency(val, shape, []);
|
|
}
|
|
return shape;
|
|
}
|
|
function deepAssertShapeConsistency(val, shape, indices) {
|
|
indices = indices || [];
|
|
if (!(Array.isArray(val)) && !isTypedArray(val)) {
|
|
assert$1(shape.length === 0, () => `Element arr[${indices.join('][')}] is a primitive, ` +
|
|
`but should be an array/TypedArray of ${shape[0]} elements`);
|
|
return;
|
|
}
|
|
assert$1(shape.length > 0, () => `Element arr[${indices.join('][')}] should be a primitive, ` +
|
|
`but is an array of ${val.length} elements`);
|
|
assert$1(val.length === shape[0], () => `Element arr[${indices.join('][')}] should have ${shape[0]} ` +
|
|
`elements, but has ${val.length} elements`);
|
|
const subShape = shape.slice(1);
|
|
for (let i = 0; i < val.length; ++i) {
|
|
deepAssertShapeConsistency(val[i], subShape, indices.concat(i));
|
|
}
|
|
}
|
|
function assertDtype(expectedDtype, actualDType, argName, functionName) {
|
|
if (expectedDtype === 'string_or_numeric') {
|
|
return;
|
|
}
|
|
if (expectedDtype == null) {
|
|
throw new Error(`Expected dtype cannot be null.`);
|
|
}
|
|
if (expectedDtype !== 'numeric' && expectedDtype !== actualDType ||
|
|
expectedDtype === 'numeric' && actualDType === 'string') {
|
|
throw new Error(`Argument '${argName}' passed to '${functionName}' must ` +
|
|
`be ${expectedDtype} tensor, but got ${actualDType} tensor`);
|
|
}
|
|
}
|
|
function convertToTensor(x, argName, functionName, parseAsDtype = 'numeric') {
|
|
if (x instanceof getGlobalTensorClass()) {
|
|
assertDtype(parseAsDtype, x.dtype, argName, functionName);
|
|
return x;
|
|
}
|
|
let inferredDtype = inferDtype(x);
|
|
|
|
|
|
if (inferredDtype !== 'string' &&
|
|
['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) {
|
|
inferredDtype = parseAsDtype;
|
|
}
|
|
assertDtype(parseAsDtype, inferredDtype, argName, functionName);
|
|
if ((x == null) ||
|
|
(!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' &&
|
|
typeof x !== 'boolean' && typeof x !== 'string')) {
|
|
const type = x == null ? 'null' : x.constructor.name;
|
|
throw new Error(`Argument '${argName}' passed to '${functionName}' must be a ` +
|
|
`Tensor or TensorLike, but got '${type}'`);
|
|
}
|
|
const inferredShape = inferShape(x, inferredDtype);
|
|
if (!isTypedArray(x) && !Array.isArray(x)) {
|
|
x = [x];
|
|
}
|
|
const skipTypedArray = true;
|
|
const values = inferredDtype !== 'string' ?
|
|
toTypedArray(x, inferredDtype) :
|
|
flatten$1(x, [], skipTypedArray);
|
|
return ENGINE.makeTensor(values, inferredShape, inferredDtype);
|
|
}
|
|
function convertToTensorArray(arg, argName, functionName, parseAsDtype = 'numeric') {
|
|
if (!Array.isArray(arg)) {
|
|
throw new Error(`Argument ${argName} passed to ${functionName} must be a ` +
|
|
'`Tensor[]` or `TensorLike[]`');
|
|
}
|
|
const tensors = arg;
|
|
return tensors.map((t, i) => convertToTensor(t, `${argName}[${i}]`, functionName, parseAsDtype));
|
|
}
|
|
|
|
|
|
const OP_SCOPE_SUFFIX = '__op';
|
|
|
|
function op(f) {
|
|
const keys = Object.keys(f);
|
|
if (keys.length !== 1) {
|
|
throw new Error(`Please provide an object with a single key ` +
|
|
`(operation name) mapping to a function. Got an object with ` +
|
|
`${keys.length} keys.`);
|
|
}
|
|
let opName = keys[0];
|
|
const fn = f[opName];
|
|
|
|
if (opName.endsWith('_')) {
|
|
opName = opName.substring(0, opName.length - 1);
|
|
}
|
|
|
|
opName = opName + OP_SCOPE_SUFFIX;
|
|
|
|
const f2 = (...args) => {
|
|
ENGINE.startScope(opName);
|
|
try {
|
|
const result = fn(...args);
|
|
if (isPromise(result)) {
|
|
console.error('Cannot return a Promise inside of tidy.');
|
|
}
|
|
ENGINE.endScope(result);
|
|
return result;
|
|
}
|
|
catch (ex) {
|
|
ENGINE.endScope(null);
|
|
throw ex;
|
|
}
|
|
};
|
|
Object.defineProperty(f2, 'name', { value: opName, configurable: true });
|
|
|
|
return f2;
|
|
}
|
|
|
|
|
|
|
|
function cast_(x, dtype) {
|
|
const $x = convertToTensor(x, 'x', 'cast');
|
|
|
|
if (!isValidDtype(dtype)) {
|
|
throw new Error(`Failed to cast to unknown dtype ${dtype}`);
|
|
}
|
|
if (dtype === 'string' && $x.dtype !== 'string' ||
|
|
dtype !== 'string' && $x.dtype === 'string') {
|
|
throw new Error('Only strings can be casted to strings');
|
|
}
|
|
const inputs = { x: $x };
|
|
const attrs = { dtype };
|
|
return ENGINE.runKernel(Cast, inputs, attrs);
|
|
}
|
|
const cast$3 = op({ cast_ });
|
|
|
|
|
|
|
|
function clone_(x) {
|
|
const $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric');
|
|
const inputs = { x: $x };
|
|
|
|
|
|
return ENGINE.runKernel(Identity$1, inputs);
|
|
}
|
|
const clone = op({ clone_ });
|
|
|
|
|
|
|
|
function print(x, verbose = false) {
|
|
console.log(x.toString(verbose));
|
|
}
|
|
|
|
|
|
|
|
|
|
getOrMakeEngine();
|
|
const opHandler = {
|
|
buffer,
|
|
cast: cast$3,
|
|
clone,
|
|
print
|
|
};
|
|
setOpHandler(opHandler);
|
|
|
|
|
|
|
|
function enableProdMode() {
|
|
env().set('PROD', true);
|
|
}
|
|
|
|
function engine() {
|
|
return ENGINE;
|
|
}
|
|
|
|
function memory() {
|
|
return ENGINE.memory();
|
|
}
|
|
|
|
function tidy(nameOrFn, fn) {
|
|
return ENGINE.tidy(nameOrFn, fn);
|
|
}
|
|
|
|
function dispose(container) {
|
|
const tensors = getTensorsInContainer(container);
|
|
tensors.forEach(tensor => tensor.dispose());
|
|
}
|
|
|
|
function keep(result) {
|
|
return ENGINE.keep(result);
|
|
}
|
|
|
|
function registerBackend(name, factory, priority = 1) {
|
|
return ENGINE.registerBackend(name, factory, priority);
|
|
}
|
|
|
|
function backend() {
|
|
return ENGINE.backend;
|
|
}
|
|
|
|
|
|
|
|
function add_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'add');
|
|
let $b = convertToTensor(b, 'b', 'add');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Add, inputs);
|
|
}
|
|
const add$1 = op({ add_ });
|
|
|
|
|
|
|
|
function floorDiv_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'floorDiv');
|
|
let $b = convertToTensor(b, 'b', 'floorDiv');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(FloorDiv, inputs);
|
|
}
|
|
const floorDiv$2 = op({ floorDiv_ });
|
|
|
|
|
|
|
|
function div_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'div');
|
|
let $b = convertToTensor(b, 'b', 'div');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
|
|
return floorDiv$2($a, $b);
|
|
}
|
|
const inputs = { a: $a, b: $b };
|
|
const attrs = {};
|
|
|
|
return ENGINE.runKernel(RealDiv, inputs, attrs);
|
|
}
|
|
const div$1 = op({ div_ });
|
|
|
|
|
|
|
|
function mul_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'mul');
|
|
let $b = convertToTensor(b, 'b', 'mul');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Multiply, inputs);
|
|
}
|
|
const mul = op({ mul_ });
|
|
|
|
|
|
|
|
function abs_(x) {
|
|
const $x = convertToTensor(x, 'x', 'abs');
|
|
if ($x.dtype === 'complex64') {
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(ComplexAbs, inputs);
|
|
}
|
|
else {
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Abs, inputs);
|
|
}
|
|
}
|
|
const abs$2 = op({ abs_ });
|
|
|
|
|
|
|
|
function any_(x, axis = null, keepDims = false) {
|
|
const $x = convertToTensor(x, 'x', 'any', 'bool');
|
|
const inputs = { x: $x };
|
|
const attrs = { axis, keepDims };
|
|
return ENGINE.runKernel(Any, inputs, attrs);
|
|
}
|
|
|
|
const any$2 = op({ any_ });
|
|
|
|
|
|
|
|
function argMax_(x, axis = 0) {
|
|
const $x = convertToTensor(x, 'x', 'argMax');
|
|
const inputs = { x: $x };
|
|
const attrs = { axis };
|
|
return ENGINE.runKernel(ArgMax, inputs, attrs);
|
|
}
|
|
const argMax$2 = op({ argMax_ });
|
|
|
|
|
|
|
|
function computeDilation2DInfo(inputShape, filterShape, strides, pad, dataFormat = 'NHWC', dilations) {
|
|
|
|
|
|
|
|
|
|
const inputChannels = inputShape[3];
|
|
const $filterShape = [...filterShape, inputChannels];
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
return computeConv2DInfo(inputShape, $filterShape, strides, dilations, pad, null , null , $dataFormat);
|
|
}
|
|
function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'channelsLast') {
|
|
const [filterHeight, filterWidth] = parseTupleParam(filterSize);
|
|
let filterShape;
|
|
if (dataFormat === 'channelsLast') {
|
|
filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]];
|
|
}
|
|
else if (dataFormat === 'channelsFirst') {
|
|
filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]];
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dataFormat ${dataFormat}`);
|
|
}
|
|
return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat);
|
|
}
|
|
|
|
function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat = 'NDHWC') {
|
|
const [filterDepth, filterHeight, filterWidth] = parse3TupleParam(filterSize);
|
|
let filterShape;
|
|
let $dataFormat;
|
|
if (dataFormat === 'NDHWC') {
|
|
$dataFormat = 'channelsLast';
|
|
filterShape =
|
|
[filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]];
|
|
}
|
|
else if (dataFormat === 'NCDHW') {
|
|
$dataFormat = 'channelsFirst';
|
|
filterShape =
|
|
[filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]];
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dataFormat ${dataFormat}`);
|
|
}
|
|
return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode);
|
|
}
|
|
|
|
function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise = false, dataFormat = 'channelsLast') {
|
|
let [batchSize, inHeight, inWidth, inChannels] = [-1, -1, -1, -1];
|
|
if (dataFormat === 'channelsLast') {
|
|
[batchSize, inHeight, inWidth, inChannels] = inShape;
|
|
}
|
|
else if (dataFormat === 'channelsFirst') {
|
|
[batchSize, inChannels, inHeight, inWidth] = inShape;
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dataFormat ${dataFormat}`);
|
|
}
|
|
const [filterHeight, filterWidth, , filterChannels] = filterShape;
|
|
const [strideHeight, strideWidth] = parseTupleParam(strides);
|
|
const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
|
|
const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
|
|
const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
|
|
const { padInfo, outHeight, outWidth } = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode, dataFormat);
|
|
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
|
|
let outShape;
|
|
if (dataFormat === 'channelsFirst') {
|
|
outShape = [batchSize, outChannels, outHeight, outWidth];
|
|
}
|
|
else if (dataFormat === 'channelsLast') {
|
|
outShape = [batchSize, outHeight, outWidth, outChannels];
|
|
}
|
|
return {
|
|
batchSize,
|
|
dataFormat,
|
|
inHeight,
|
|
inWidth,
|
|
inChannels,
|
|
outHeight,
|
|
outWidth,
|
|
outChannels,
|
|
padInfo,
|
|
strideHeight,
|
|
strideWidth,
|
|
filterHeight,
|
|
filterWidth,
|
|
effectiveFilterHeight,
|
|
effectiveFilterWidth,
|
|
dilationHeight,
|
|
dilationWidth,
|
|
inShape,
|
|
outShape,
|
|
filterShape
|
|
};
|
|
}
|
|
|
|
function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise = false, dataFormat = 'channelsLast', roundingMode) {
|
|
let [batchSize, inDepth, inHeight, inWidth, inChannels] = [-1, -1, -1, -1, -1];
|
|
if (dataFormat === 'channelsLast') {
|
|
[batchSize, inDepth, inHeight, inWidth, inChannels] = inShape;
|
|
}
|
|
else if (dataFormat === 'channelsFirst') {
|
|
[batchSize, inChannels, inDepth, inHeight, inWidth] = inShape;
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dataFormat ${dataFormat}`);
|
|
}
|
|
const [filterDepth, filterHeight, filterWidth, , filterChannels] = filterShape;
|
|
const [strideDepth, strideHeight, strideWidth] = parse3TupleParam(strides);
|
|
const [dilationDepth, dilationHeight, dilationWidth] = parse3TupleParam(dilations);
|
|
const effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth);
|
|
const effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight);
|
|
const effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth);
|
|
const { padInfo, outDepth, outHeight, outWidth } = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode);
|
|
const outChannels = depthwise ? filterChannels * inChannels : filterChannels;
|
|
let outShape;
|
|
if (dataFormat === 'channelsFirst') {
|
|
outShape = [batchSize, outChannels, outDepth, outHeight, outWidth];
|
|
}
|
|
else if (dataFormat === 'channelsLast') {
|
|
outShape = [batchSize, outDepth, outHeight, outWidth, outChannels];
|
|
}
|
|
return {
|
|
batchSize,
|
|
dataFormat,
|
|
inDepth,
|
|
inHeight,
|
|
inWidth,
|
|
inChannels,
|
|
outDepth,
|
|
outHeight,
|
|
outWidth,
|
|
outChannels,
|
|
padInfo,
|
|
strideDepth,
|
|
strideHeight,
|
|
strideWidth,
|
|
filterDepth,
|
|
filterHeight,
|
|
filterWidth,
|
|
effectiveFilterDepth,
|
|
effectiveFilterHeight,
|
|
effectiveFilterWidth,
|
|
dilationDepth,
|
|
dilationHeight,
|
|
dilationWidth,
|
|
inShape,
|
|
outShape,
|
|
filterShape
|
|
};
|
|
}
|
|
function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) {
|
|
if (zeroPad == null) {
|
|
zeroPad = computeDefaultPad(inShape, fieldSize, stride);
|
|
}
|
|
const inputRows = inShape[0];
|
|
const inputCols = inShape[1];
|
|
const outputRows = round$2((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
|
|
const outputCols = round$2((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode);
|
|
return [outputRows, outputCols];
|
|
}
|
|
function computeOutputShape4D(inShape, filterShape, outChannels, strides, zeroPad, roundingMode) {
|
|
if (zeroPad == null) {
|
|
zeroPad = computeDefaultPad(inShape, filterShape[0], strides[0]);
|
|
}
|
|
const outShape = [0, 0, 0, outChannels];
|
|
for (let index = 0; index < 3; index++) {
|
|
if (inShape[index] + 2 * zeroPad >= filterShape[index]) {
|
|
outShape[index] = round$2((inShape[index] - filterShape[index] + 2 * zeroPad) / strides[index] +
|
|
1, roundingMode);
|
|
}
|
|
}
|
|
return outShape;
|
|
}
|
|
function computeDefaultPad(inputShape, fieldSize, stride, dilation = 1) {
|
|
const effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation);
|
|
return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2);
|
|
}
|
|
function parseTupleParam(param) {
|
|
if (typeof param === 'number') {
|
|
return [param, param, param];
|
|
}
|
|
if (param.length === 2) {
|
|
return [param[0], param[1], 1];
|
|
}
|
|
return param;
|
|
}
|
|
function parse3TupleParam(param) {
|
|
return typeof param === 'number' ? [param, param, param] : param;
|
|
}
|
|
|
|
function getEffectiveFilterSize(filterSize, dilation) {
|
|
if (dilation <= 1) {
|
|
return filterSize;
|
|
}
|
|
return filterSize + (filterSize - 1) * (dilation - 1);
|
|
}
|
|
function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode, dataFormat) {
|
|
let padInfo;
|
|
let outHeight;
|
|
let outWidth;
|
|
if (typeof pad === 'number') {
|
|
const padType = (pad === 0) ? 'VALID' : 'NUMBER';
|
|
padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType };
|
|
const outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode);
|
|
outHeight = outShape[0];
|
|
outWidth = outShape[1];
|
|
}
|
|
else if (pad === 'same') {
|
|
outHeight = Math.ceil(inHeight / strideHeight);
|
|
outWidth = Math.ceil(inWidth / strideWidth);
|
|
const padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight);
|
|
const padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth);
|
|
const top = Math.floor(padAlongHeight / 2);
|
|
const bottom = padAlongHeight - top;
|
|
const left = Math.floor(padAlongWidth / 2);
|
|
const right = padAlongWidth - left;
|
|
padInfo = { top, bottom, left, right, type: 'SAME' };
|
|
}
|
|
else if (pad === 'valid') {
|
|
padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' };
|
|
outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight);
|
|
outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth);
|
|
}
|
|
else if (typeof pad === 'object') {
|
|
const top = dataFormat === 'channelsLast' ? pad[1][0] : pad[2][0];
|
|
const bottom = dataFormat === 'channelsLast' ? pad[1][1] : pad[2][1];
|
|
const left = dataFormat === 'channelsLast' ? pad[2][0] : pad[3][0];
|
|
const right = dataFormat === 'channelsLast' ? pad[2][1] : pad[3][1];
|
|
const padType = (top === 0 && bottom === 0 && left === 0 && right === 0) ?
|
|
'VALID' :
|
|
'EXPLICIT';
|
|
padInfo = { top, bottom, left, right, type: padType };
|
|
outHeight = round$2((inHeight - filterHeight + top + bottom) / strideHeight + 1, roundingMode);
|
|
outWidth = round$2((inWidth - filterWidth + left + right) / strideWidth + 1, roundingMode);
|
|
}
|
|
else {
|
|
throw Error(`Unknown padding parameter: ${pad}`);
|
|
}
|
|
return { padInfo, outHeight, outWidth };
|
|
}
|
|
function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) {
|
|
let padInfo;
|
|
let outDepth;
|
|
let outHeight;
|
|
let outWidth;
|
|
if (pad === 'valid') {
|
|
pad = 0;
|
|
}
|
|
if (typeof pad === 'number') {
|
|
const padType = (pad === 0) ? 'VALID' : 'NUMBER';
|
|
padInfo = {
|
|
top: pad,
|
|
bottom: pad,
|
|
left: pad,
|
|
right: pad,
|
|
front: pad,
|
|
back: pad,
|
|
type: padType
|
|
};
|
|
const outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], [filterDepth, filterHeight, filterWidth], 1, [strideDepth, strideHeight, strideWidth], pad, roundingMode);
|
|
outDepth = outShape[0];
|
|
outHeight = outShape[1];
|
|
outWidth = outShape[2];
|
|
}
|
|
else if (pad === 'same') {
|
|
outDepth = Math.ceil(inDepth / strideDepth);
|
|
outHeight = Math.ceil(inHeight / strideHeight);
|
|
outWidth = Math.ceil(inWidth / strideWidth);
|
|
const padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth;
|
|
const padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight;
|
|
const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth;
|
|
const front = Math.floor(padAlongDepth / 2);
|
|
const back = padAlongDepth - front;
|
|
const top = Math.floor(padAlongHeight / 2);
|
|
const bottom = padAlongHeight - top;
|
|
const left = Math.floor(padAlongWidth / 2);
|
|
const right = padAlongWidth - left;
|
|
padInfo = { top, bottom, left, right, front, back, type: 'SAME' };
|
|
}
|
|
else {
|
|
throw Error(`Unknown padding parameter: ${pad}`);
|
|
}
|
|
return { padInfo, outDepth, outHeight, outWidth };
|
|
}
|
|
|
|
function round$2(value, roundingMode) {
|
|
if (!roundingMode) {
|
|
return Math.trunc(value);
|
|
}
|
|
switch (roundingMode) {
|
|
case 'round':
|
|
|
|
return Math.round(value);
|
|
case 'ceil':
|
|
|
|
return Math.ceil(value);
|
|
case 'floor':
|
|
return Math.floor(value);
|
|
default:
|
|
throw new Error(`Unknown roundingMode ${roundingMode}`);
|
|
}
|
|
}
|
|
function tupleValuesAreOne(param) {
|
|
const [dimA, dimB, dimC] = parseTupleParam(param);
|
|
return dimA === 1 && dimB === 1 && dimC === 1;
|
|
}
|
|
function eitherStridesOrDilationsAreOne(strides, dilations) {
|
|
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
|
|
}
|
|
function stridesOrDilationsArePositive(values) {
|
|
return parseTupleParam(values).every(value => value > 0);
|
|
}
|
|
|
|
function convertConv2DDataFormat(dataFormat) {
|
|
if (dataFormat === 'NHWC') {
|
|
return 'channelsLast';
|
|
}
|
|
else if (dataFormat === 'NCHW') {
|
|
return 'channelsFirst';
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dataFormat ${dataFormat}`);
|
|
}
|
|
}
|
|
|
|
function checkPadOnDimRoundingMode(opDesc, pad, dimRoundingMode) {
|
|
if (dimRoundingMode != null) {
|
|
if (typeof pad === 'string') {
|
|
throw Error(`Error in ${opDesc}: pad must be an integer when using ` +
|
|
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
|
|
}
|
|
else if (typeof pad === 'number') {
|
|
assert$1(isInt(pad), () => `Error in ${opDesc}: pad must be an integer when using ` +
|
|
`dimRoundingMode ${dimRoundingMode} but got pad ${pad}.`);
|
|
}
|
|
else if (typeof pad === 'object') {
|
|
pad.forEach(p => {
|
|
p.forEach(v => {
|
|
assert$1(isInt(v), () => `Error in ${opDesc}: pad must be an integer when using ` +
|
|
`dimRoundingMode ${dimRoundingMode} but got pad ${v}.`);
|
|
});
|
|
});
|
|
}
|
|
else {
|
|
throw Error(`Error in ${opDesc}: Unknown padding parameter: ${pad}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function reshape_(x, shape) {
|
|
const $x = convertToTensor(x, 'x', 'reshape', 'string_or_numeric');
|
|
const inputs = { x: $x };
|
|
const attrs = { shape };
|
|
return ENGINE.runKernel(Reshape$1, inputs, attrs);
|
|
}
|
|
const reshape$2 = op({ reshape_ });
|
|
|
|
|
|
|
|
function concat_(tensors, axis = 0) {
|
|
assert$1(tensors.length >= 1, () => 'Pass at least one tensor to concat');
|
|
const $tensors = convertToTensorArray(tensors, 'tensors', 'concat', 'string_or_numeric');
|
|
if ($tensors[0].dtype === 'complex64') {
|
|
$tensors.forEach(tensor => {
|
|
if (tensor.dtype !== 'complex64') {
|
|
throw new Error(`Cannot concatenate complex64 tensors with a tensor
|
|
with dtype ${tensor.dtype}. `);
|
|
}
|
|
});
|
|
}
|
|
if ($tensors.length === 1) {
|
|
return clone($tensors[0]);
|
|
}
|
|
const inputs = $tensors;
|
|
const attr = { axis };
|
|
return ENGINE.runKernel(Concat, inputs, attr);
|
|
}
|
|
const concat$2 = op({ concat_ });
|
|
|
|
|
|
|
|
function matMul_(a, b, transposeA = false, transposeB = false) {
|
|
let $a = convertToTensor(a, 'a', 'matMul');
|
|
let $b = convertToTensor(b, 'b', 'matMul');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
const inputs = { a: $a, b: $b };
|
|
const attrs = { transposeA, transposeB };
|
|
return ENGINE.runKernel(BatchMatMul, inputs, attrs);
|
|
}
|
|
const matMul$1 = op({ matMul_ });
|
|
|
|
|
|
|
|
function sigmoid_(x) {
|
|
const $x = convertToTensor(x, 'x', 'sigmoid', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Sigmoid$1, inputs);
|
|
}
|
|
const sigmoid$2 = op({ sigmoid_ });
|
|
|
|
|
|
|
|
function slice_(x, begin, size) {
|
|
const $x = convertToTensor(x, 'x', 'slice', 'string_or_numeric');
|
|
if ($x.rank === 0) {
|
|
throw new Error('Slicing scalar is not possible');
|
|
}
|
|
const inputs = { x: $x };
|
|
const attrs = { begin, size };
|
|
return ENGINE.runKernel(Slice, inputs, attrs);
|
|
}
|
|
const slice$2 = op({ slice_ });
|
|
|
|
|
|
|
|
function tanh_(x) {
|
|
const $x = convertToTensor(x, 'x', 'tanh', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Tanh$1, inputs);
|
|
}
|
|
const tanh$2 = op({ tanh_ });
|
|
|
|
|
|
|
|
function batchToSpaceND_(x, blockShape, crops) {
|
|
const $x = convertToTensor(x, 'x', 'batchToSpaceND');
|
|
const prod = blockShape.reduce((a, b) => a * b);
|
|
assert$1($x.rank >= 1 + blockShape.length, () => `input rank is ${$x.rank} but should be > than blockShape.length ${blockShape.length}`);
|
|
assert$1(crops.length === blockShape.length, () => `crops.length is ${crops.length} but should be equal to blockShape.length ${blockShape.length}`);
|
|
assert$1($x.shape[0] % prod === 0, () => `input tensor batch is ${$x.shape[0]} but is not divisible by the product of ` +
|
|
`the elements of blockShape ${blockShape.join(' * ')} === ${prod}`);
|
|
const inputs = { x: $x };
|
|
const attrs = { blockShape, crops };
|
|
return ENGINE.runKernel(BatchToSpaceND, inputs, attrs);
|
|
}
|
|
const batchToSpaceND$2 = op({ batchToSpaceND_ });
|
|
|
|
|
|
|
|
function broadcastTo_(x, shape) {
|
|
let input = convertToTensor(x, 'broadcastTo', 'x');
|
|
const xShape = input.shape;
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
if (shape.length < input.rank) {
|
|
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`);
|
|
}
|
|
if (shape.length > input.rank) {
|
|
const newShape = input.shape.slice();
|
|
while (newShape.length < shape.length) {
|
|
newShape.unshift(1);
|
|
}
|
|
input = reshape$2(input, newShape);
|
|
}
|
|
const inputShape = input.shape;
|
|
const reps = Array.from(shape);
|
|
for (let i = shape.length - 1; i >= 0; i--) {
|
|
if (inputShape[i] === shape[i]) {
|
|
reps[i] = 1;
|
|
}
|
|
else if (input.shape[i] !== 1) {
|
|
throw new Error(`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
|
|
}
|
|
}
|
|
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);
|
|
if (axes.length === 0) {
|
|
return clone(input);
|
|
}
|
|
|
|
const inputs = { x: input };
|
|
const attrs = { reps };
|
|
return ENGINE.runKernel(Tile, inputs, attrs);
|
|
}
|
|
const broadcastTo = op({ broadcastTo_ });
|
|
|
|
|
|
|
|
function fill$2(shape, value, dtype) {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
dtype = dtype || inferDtype(value);
|
|
const attrs = { shape, value, dtype };
|
|
return ENGINE.runKernel(Fill, {}, attrs);
|
|
}
|
|
|
|
|
|
|
|
function clipByValue_(x, clipValueMin, clipValueMax) {
|
|
const $x = convertToTensor(x, 'x', 'clipByValue');
|
|
assert$1((clipValueMin <= clipValueMax), () => `Error in clip: min (${clipValueMin}) must be ` +
|
|
`less than or equal to max (${clipValueMax}).`);
|
|
if (clipValueMin === clipValueMax) {
|
|
return fill$2($x.shape, clipValueMin, $x.dtype);
|
|
}
|
|
const inputs = { x: $x };
|
|
const attrs = { clipValueMin, clipValueMax };
|
|
return ENGINE.runKernel(ClipByValue, inputs, attrs);
|
|
}
|
|
const clipByValue$2 = op({ clipByValue_ });
|
|
|
|
|
|
|
|
function complex_(real, imag) {
|
|
const $real = convertToTensor(real, 'real', 'complex');
|
|
const $imag = convertToTensor(imag, 'imag', 'complex');
|
|
assertShapesMatch($real.shape, $imag.shape, `real and imag shapes, ${$real.shape} and ${$imag.shape}, ` +
|
|
`must match in call to tf.complex().`);
|
|
const inputs = { real: $real, imag: $imag };
|
|
return ENGINE.runKernel(Complex, inputs);
|
|
}
|
|
const complex$2 = op({ complex_ });
|
|
|
|
|
|
|
|
function conv2d_(x, filter, strides, pad, dataFormat = 'NHWC', dilations = [1, 1], dimRoundingMode) {
|
|
const $x = convertToTensor(x, 'x', 'conv2d', 'float32');
|
|
const $filter = convertToTensor(filter, 'filter', 'conv2d', 'float32');
|
|
let x4D = $x;
|
|
let reshapedTo4D = false;
|
|
if ($x.rank === 3) {
|
|
reshapedTo4D = true;
|
|
x4D = reshape$2($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]);
|
|
}
|
|
assert$1(x4D.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${x4D.rank}.`);
|
|
assert$1($filter.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ` +
|
|
`${$filter.rank}.`);
|
|
checkPadOnDimRoundingMode('conv2d', pad, dimRoundingMode);
|
|
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
|
|
assert$1(inDepth === $filter.shape[2], () => `Error in conv2d: depth of input (${inDepth}) must match ` +
|
|
`input depth for filter ${$filter.shape[2]}.`);
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in conv2D: Either strides or dilations must be 1. ' +
|
|
`Got strides ${strides} and dilations '${dilations}'`);
|
|
assert$1(stridesOrDilationsArePositive(dilations), () => 'Error in conv2D: Dilated rates should be larger than 0.');
|
|
assert$1(stridesOrDilationsArePositive(strides), () => 'Error in conv2D: Strides should be larger than 0.');
|
|
const inputs = { x: x4D, filter: $filter };
|
|
const attrs = { strides, pad, dataFormat, dilations, dimRoundingMode };
|
|
|
|
const res = ENGINE.runKernel(Conv2D, inputs, attrs);
|
|
if (reshapedTo4D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
|
|
}
|
|
return res;
|
|
}
|
|
const conv2d$1 = op({ conv2d_ });
|
|
|
|
|
|
|
|
function conv2DBackpropInput_(xShape, dy, filter, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
|
|
assert$1(xShape.length === dy.rank, () => `Length of inShape ` +
|
|
`(${xShape.length}) and rank of dy (${dy.rank}) must match`);
|
|
let xShape4D = xShape;
|
|
let dy4D = dy;
|
|
let reshapedTo4D = false;
|
|
if (dy.rank === 3) {
|
|
reshapedTo4D = true;
|
|
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
|
|
xShape4D = [1, xShape[0], xShape[1], xShape[2]];
|
|
}
|
|
assert$1(xShape4D.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ` +
|
|
`${xShape4D.length}.`);
|
|
assert$1(dy4D.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got ` +
|
|
`rank ${dy4D.rank}`);
|
|
assert$1(filter.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got ` +
|
|
`rank ${filter.rank}`);
|
|
const inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1];
|
|
const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
|
|
assert$1(inDepth === filter.shape[2], () => `Error in conv2dDerInput: depth of input (${inDepth}) must ` +
|
|
`match input depth for filter ${filter.shape[2]}.`);
|
|
assert$1(outDepth === filter.shape[3], () => `Error in conv2dDerInput: depth of output (${outDepth}) must ` +
|
|
`match output depth for filter ${filter.shape[3]}.`);
|
|
checkPadOnDimRoundingMode('conv2dDerInput', pad, dimRoundingMode);
|
|
const inputs = { dy: dy4D, filter };
|
|
const attrs = { strides, pad, dataFormat, dimRoundingMode, inputShape: xShape4D };
|
|
|
|
const res = ENGINE.runKernel(Conv2DBackpropInput, inputs, attrs);
|
|
if (reshapedTo4D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
|
|
}
|
|
return res;
|
|
}
|
|
const conv2DBackpropInput$2 = op({ conv2DBackpropInput_ });
|
|
|
|
|
|
|
|
function conv3DBackpropInput_(xShape, dy, filter, strides, pad) {
|
|
assert$1(xShape.length === dy.rank, () => `Length of inShape ` +
|
|
`(${xShape.length}) and rank of dy (${dy.rank}) must match`);
|
|
let xShape5D = xShape;
|
|
let dy5D = dy;
|
|
let reshapedTo5D = false;
|
|
if (dy.rank === 4) {
|
|
reshapedTo5D = true;
|
|
dy5D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
|
|
xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]];
|
|
}
|
|
const inDepth = xShape5D[4];
|
|
const outDepth = dy5D.shape[4];
|
|
assert$1(xShape5D.length === 5, () => `Error in conv3dDerInput: inShape must be length 5, but got length ` +
|
|
`${xShape5D.length}.`);
|
|
assert$1(dy5D.rank === 5, () => `Error in conv3dDerInput: dy must be rank 5, but got ` +
|
|
`rank ${dy5D.rank}`);
|
|
assert$1(filter.rank === 5, () => `Error in conv3dDerInput: filter must be rank 5, but got ` +
|
|
`rank ${filter.rank}`);
|
|
assert$1(inDepth === filter.shape[3], () => `Error in conv3dDerInput: depth of input (${inDepth}) must ` +
|
|
`match input depth for filter ${filter.shape[3]}.`);
|
|
assert$1(outDepth === filter.shape[4], () => `Error in conv3dDerInput: depth of output (${outDepth}) must ` +
|
|
`match output depth for filter ${filter.shape[4]}.`);
|
|
const inputs = { dy: dy5D, filter };
|
|
const attrs = { pad, strides, inputShape: xShape5D };
|
|
|
|
const res = ENGINE.runKernel(Conv3DBackpropInputV2, inputs, attrs);
|
|
if (reshapedTo5D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
|
}
|
|
return res;
|
|
}
|
|
const conv3DBackpropInput$1 = op({ conv3DBackpropInput_ });
|
|
|
|
|
|
|
|
function cos_(x) {
|
|
const $x = convertToTensor(x, 'x', 'cos', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Cos, inputs);
|
|
}
|
|
const cos$2 = op({ cos_ });
|
|
|
|
|
|
|
|
function cosh_(x) {
|
|
const $x = convertToTensor(x, 'x', 'cosh', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Cosh, inputs);
|
|
}
|
|
const cosh$2 = op({ cosh_ });
|
|
|
|
|
|
|
|
function cumprod_(x, axis = 0, exclusive = false, reverse = false) {
|
|
const $x = convertToTensor(x, 'x', 'cumprod');
|
|
const inputs = { x: $x };
|
|
const attrs = { axis, exclusive, reverse };
|
|
return ENGINE.runKernel(Cumprod, inputs, attrs);
|
|
}
|
|
const cumprod$2 = op({ cumprod_ });
|
|
|
|
|
|
|
|
function cumsum_(x, axis = 0, exclusive = false, reverse = false) {
|
|
const $x = convertToTensor(x, 'x', 'cumsum');
|
|
const inputs = { x: $x };
|
|
const attrs = { axis, exclusive, reverse };
|
|
return ENGINE.runKernel(Cumsum, inputs, attrs);
|
|
}
|
|
const cumsum$2 = op({ cumsum_ });
|
|
|
|
|
|
|
|
function getBroadcastDims$1(inShape, outShape) {
|
|
const inRank = inShape.length;
|
|
const dims = [];
|
|
for (let i = 0; i < inRank; i++) {
|
|
const dim = inRank - 1 - i;
|
|
const a = inShape[dim] || 1;
|
|
const b = outShape[outShape.length - 1 - i] || 1;
|
|
if (b > 1 && a === 1) {
|
|
dims.unshift(dim);
|
|
}
|
|
}
|
|
return dims;
|
|
}
|
|
|
|
function getReductionAxes(inShape, outShape) {
|
|
const result = [];
|
|
for (let i = 0; i < outShape.length; i++) {
|
|
const inDim = inShape[inShape.length - i - 1];
|
|
const outAxis = outShape.length - i - 1;
|
|
const outDim = outShape[outAxis];
|
|
if (inDim == null || (inDim === 1 && outDim > 1)) {
|
|
result.unshift(outAxis);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
function assertAndGetBroadcastShape(shapeA, shapeB) {
|
|
const l = Math.max(shapeA.length, shapeB.length);
|
|
const result = new Array(l);
|
|
for (let i = 0; i < l; i++) {
|
|
let a = shapeA[shapeA.length - i - 1];
|
|
if (a == null) {
|
|
a = 1;
|
|
}
|
|
let b = shapeB[shapeB.length - i - 1];
|
|
if (b == null) {
|
|
b = 1;
|
|
}
|
|
if (a === 1) {
|
|
result[l - i - 1] = b;
|
|
}
|
|
else if (b === 1) {
|
|
result[l - i - 1] = a;
|
|
}
|
|
else if (a !== b) {
|
|
const errMsg = `Operands could not be broadcast together with shapes ` +
|
|
`${shapeA} and ${shapeB}.`;
|
|
throw Error(errMsg);
|
|
}
|
|
else {
|
|
result[l - i - 1] = a;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
function equal_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'equal', 'string_or_numeric');
|
|
let $b = convertToTensor(b, 'b', 'equal', 'string_or_numeric');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Equal, inputs);
|
|
}
|
|
const equal$2 = op({ equal_ });
|
|
|
|
|
|
|
|
function where_(condition, a, b) {
|
|
const $a = convertToTensor(a, 'a', 'where');
|
|
const $b = convertToTensor(b, 'b', 'where');
|
|
const $condition = convertToTensor(condition, 'condition', 'where', 'bool');
|
|
|
|
|
|
|
|
const broadcastShape = assertAndGetBroadcastShape(assertAndGetBroadcastShape($condition.shape, $a.shape), $b.shape);
|
|
const $broadcastedCondition = broadcastTo($condition, broadcastShape);
|
|
const $broadcastedA = broadcastTo($a, broadcastShape);
|
|
const $broadcastedB = broadcastTo($b, broadcastShape);
|
|
const inputs = {
|
|
condition: $broadcastedCondition,
|
|
t: $broadcastedA,
|
|
e: $broadcastedB
|
|
};
|
|
return ENGINE.runKernel(Select, inputs);
|
|
}
|
|
const where = op({ where_ });
|
|
|
|
|
|
|
|
function zerosLike_(x) {
|
|
const $x = convertToTensor(x, 'x', 'zerosLike');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(ZerosLike, inputs);
|
|
}
|
|
const zerosLike$2 = op({ zerosLike_ });
|
|
|
|
|
|
|
|
function elu_(x) {
|
|
const $x = convertToTensor(x, 'x', 'elu', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Elu$1, inputs);
|
|
}
|
|
const elu$3 = op({ elu_ });
|
|
|
|
|
|
|
|
function erf_(x) {
|
|
let $x = convertToTensor(x, 'x', 'erf');
|
|
assert$1($x.dtype === 'int32' || $x.dtype === 'float32', () => 'Input dtype must be `int32` or `float32`.');
|
|
if ($x.dtype === 'int32') {
|
|
$x = cast$3($x, 'float32');
|
|
}
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Erf, inputs);
|
|
}
|
|
const erf$2 = op({ erf_ });
|
|
|
|
|
|
|
|
function axesAreInnerMostDims(axes, rank) {
|
|
for (let i = 0; i < axes.length; ++i) {
|
|
if (axes[axes.length - i - 1] !== rank - 1 - i) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
function combineLocations(outputLoc, reduceLoc, axes) {
|
|
const rank = outputLoc.length + reduceLoc.length;
|
|
const loc = [];
|
|
let outIdx = 0;
|
|
let reduceIdx = 0;
|
|
for (let dim = 0; dim < rank; dim++) {
|
|
if (axes.indexOf(dim) === -1) {
|
|
loc.push(outputLoc[outIdx++]);
|
|
}
|
|
else {
|
|
loc.push(reduceLoc[reduceIdx++]);
|
|
}
|
|
}
|
|
return loc;
|
|
}
|
|
function computeOutAndReduceShapes(aShape, axes) {
|
|
const outShape = [];
|
|
const rank = aShape.length;
|
|
for (let dim = 0; dim < rank; dim++) {
|
|
if (axes.indexOf(dim) === -1) {
|
|
outShape.push(aShape[dim]);
|
|
}
|
|
}
|
|
const reduceShape = axes.map(dim => aShape[dim]);
|
|
return [outShape, reduceShape];
|
|
}
|
|
function expandShapeToKeepDim(shape, axes) {
|
|
const reduceSubShape = axes.map(x => 1);
|
|
return combineLocations(shape, reduceSubShape, axes);
|
|
}
|
|
function assertAxesAreInnerMostDims(msg, axes, rank) {
|
|
assert$1(axesAreInnerMostDims(axes, rank), () => `${msg} supports only inner-most axes for now. ` +
|
|
`Got axes ${axes} and rank-${rank} input.`);
|
|
}
|
|
|
|
function getAxesPermutation(axes, rank) {
|
|
if (axesAreInnerMostDims(axes, rank)) {
|
|
return null;
|
|
}
|
|
const result = [];
|
|
for (let i = 0; i < rank; ++i) {
|
|
if (axes.indexOf(i) === -1) {
|
|
result.push(i);
|
|
}
|
|
}
|
|
axes.forEach(axis => result.push(axis));
|
|
return result;
|
|
}
|
|
|
|
function getUndoAxesPermutation(axes) {
|
|
return axes.map((axis, i) => [i, axis])
|
|
.sort((a, b) => a[1] - b[1])
|
|
.map(x => x[0]);
|
|
}
|
|
function getInnerMostAxes(numAxes, rank) {
|
|
const res = [];
|
|
for (let i = rank - numAxes; i < rank; ++i) {
|
|
res.push(i);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
|
|
|
|
function max_(x, axis = null, keepDims = false) {
|
|
const $x = convertToTensor(x, 'x', 'max');
|
|
const inputs = { x: $x };
|
|
const attrs = { reductionIndices: axis, keepDims };
|
|
return ENGINE.runKernel(Max, inputs, attrs);
|
|
}
|
|
const max$2 = op({ max_ });
|
|
|
|
|
|
|
|
function min_(x, axis = null, keepDims = false) {
|
|
const $x = convertToTensor(x, 'x', 'min');
|
|
const inputs = { x: $x };
|
|
const attrs = { axis, keepDims };
|
|
|
|
return ENGINE.runKernel(Min, inputs, attrs);
|
|
}
|
|
const min$2 = op({ min_ });
|
|
|
|
|
|
|
|
function pow_(base, exp) {
|
|
let $base = convertToTensor(base, 'base', 'pow');
|
|
let $exp = convertToTensor(exp, 'exp', 'pow');
|
|
[$base, $exp] = makeTypesMatch($base, $exp);
|
|
const inputs = { a: $base, b: $exp };
|
|
return ENGINE.runKernel(Pow, inputs);
|
|
}
|
|
const pow$2 = op({ pow_ });
|
|
|
|
|
|
|
|
function makeTensor(values, shape, inferredShape, dtype) {
|
|
if (dtype == null) {
|
|
dtype = inferDtype(values);
|
|
}
|
|
else if (dtype === 'complex64') {
|
|
throw new Error(`Cannot construct a complex64 tensor directly. ` +
|
|
`Please use tf.complex(real, imag).`);
|
|
}
|
|
if (isWebGPUData(values) || isWebGLData(values)) {
|
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
|
throw new Error(`Creating tensor from GPU data only supports ` +
|
|
`'float32'|'int32' dtype, while the dtype is ${dtype}.`);
|
|
}
|
|
return ENGINE.backend.createTensorFromGPUData(values, shape || inferredShape, dtype);
|
|
}
|
|
if (!isTypedArray(values) && !Array.isArray(values) &&
|
|
typeof values !== 'number' && typeof values !== 'boolean' &&
|
|
typeof values !== 'string') {
|
|
throw new Error('values passed to tensor(values) must be a number/boolean/string or ' +
|
|
'an array of numbers/booleans/strings, or a TypedArray');
|
|
}
|
|
|
|
if (shape != null) {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
const providedSize = sizeFromShape(shape);
|
|
const inferredSize = sizeFromShape(inferredShape);
|
|
assert$1(providedSize === inferredSize, () => `Based on the provided shape, [${shape}], the tensor should have ` +
|
|
`${providedSize} values but has ${inferredSize}`);
|
|
for (let i = 0; i < inferredShape.length; ++i) {
|
|
const inferred = inferredShape[i];
|
|
const flatDimsDontMatch = i === inferredShape.length - 1 ?
|
|
inferred !== sizeFromShape(shape.slice(i)) :
|
|
true;
|
|
assert$1(inferredShape[i] === shape[i] || !flatDimsDontMatch, () => `Error creating a new Tensor. Inferred shape ` +
|
|
`(${inferredShape}) does not match the provided ` +
|
|
`shape (${shape}). `);
|
|
}
|
|
}
|
|
if (!isTypedArray(values) && !Array.isArray(values)) {
|
|
values = [values];
|
|
}
|
|
shape = shape || inferredShape;
|
|
values = dtype !== 'string' ?
|
|
toTypedArray(values, dtype) :
|
|
flatten$1(values, [], true);
|
|
return ENGINE.makeTensor(values, shape, dtype);
|
|
}
|
|
|
|
|
|
|
|
function scalar(value, dtype) {
|
|
if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) &&
|
|
dtype !== 'complex64') {
|
|
throw new Error('Error creating a new Scalar: value must be a primitive ' +
|
|
'(number|boolean|string)');
|
|
}
|
|
if (dtype === 'string' && isTypedArray(value) &&
|
|
!(value instanceof Uint8Array)) {
|
|
throw new Error('When making a scalar from encoded string, ' +
|
|
'the value must be `Uint8Array`.');
|
|
}
|
|
const shape = [];
|
|
const inferredShape = [];
|
|
return makeTensor(value, shape, inferredShape, dtype);
|
|
}
|
|
|
|
|
|
|
|
function sqrt_(x) {
|
|
const $x = convertToTensor(x, 'x', 'sqrt', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Sqrt, inputs);
|
|
}
|
|
const sqrt$2 = op({ sqrt_ });
|
|
|
|
|
|
|
|
function square_(x) {
|
|
const $x = convertToTensor(x, 'x', 'square');
|
|
const attrs = {};
|
|
return ENGINE.runKernel('Square', { x: $x }, attrs);
|
|
}
|
|
const square$2 = op({ square_ });
|
|
|
|
|
|
|
|
function sum_(x, axis = null, keepDims = false) {
|
|
let $x = convertToTensor(x, 'x', 'sum');
|
|
if ($x.dtype === 'bool') {
|
|
$x = cast$3($x, 'int32');
|
|
}
|
|
const inputs = { x: $x };
|
|
const attrs = { axis, keepDims };
|
|
return ENGINE.runKernel(Sum, inputs, attrs);
|
|
}
|
|
const sum$2 = op({ sum_ });
|
|
|
|
|
|
|
|
function norm_(x, ord = 'euclidean', axis = null, keepDims = false) {
|
|
x = convertToTensor(x, 'x', 'norm');
|
|
const norm = normImpl(x, ord, axis);
|
|
let keepDimsShape = norm.shape;
|
|
if (keepDims) {
|
|
const axes = parseAxisParam(axis, x.shape);
|
|
keepDimsShape = expandShapeToKeepDim(norm.shape, axes);
|
|
}
|
|
return reshape$2(norm, keepDimsShape);
|
|
}
|
|
function normImpl(x, p, axis = null) {
|
|
if (x.rank === 0) {
|
|
return abs$2(x);
|
|
}
|
|
|
|
if (x.rank !== 1 && axis === null) {
|
|
return normImpl(reshape$2(x, [-1]), p, axis);
|
|
}
|
|
|
|
if (x.rank === 1 || typeof axis === 'number' ||
|
|
Array.isArray(axis) && axis.length === 1) {
|
|
if (p === 1) {
|
|
return sum$2(abs$2(x), axis);
|
|
}
|
|
if (p === Infinity) {
|
|
return max$2(abs$2(x), axis);
|
|
}
|
|
if (p === -Infinity) {
|
|
return min$2(abs$2(x), axis);
|
|
}
|
|
if (p === 'euclidean' || p === 2) {
|
|
|
|
return sqrt$2(sum$2(pow$2(abs$2(x), scalar(2, 'int32')), axis));
|
|
}
|
|
throw new Error(`Error in norm: invalid ord value: ${p}`);
|
|
}
|
|
|
|
if (Array.isArray(axis) && axis.length === 2) {
|
|
if (p === 1) {
|
|
return max$2(sum$2(abs$2(x), axis[0]), axis[1] - 1);
|
|
}
|
|
if (p === Infinity) {
|
|
return max$2(sum$2(abs$2(x), axis[1]), axis[0]);
|
|
}
|
|
if (p === -Infinity) {
|
|
return min$2(sum$2(abs$2(x), axis[1]), axis[0]);
|
|
}
|
|
if (p === 'fro' || p === 'euclidean') {
|
|
|
|
return sqrt$2(sum$2(square$2(x), axis));
|
|
}
|
|
throw new Error(`Error in norm: invalid ord value: ${p}`);
|
|
}
|
|
throw new Error(`Error in norm: invalid axis: ${axis}`);
|
|
}
|
|
const norm = op({ norm_ });
|
|
|
|
|
|
|
|
function exp_(x) {
|
|
const $x = convertToTensor(x, 'x', 'exp');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Exp, inputs);
|
|
}
|
|
const exp$2 = op({ exp_ });
|
|
|
|
|
|
|
|
function expandDims_(x, axis = 0) {
|
|
const $x = convertToTensor(x, 'x', 'expandDims', 'string_or_numeric');
|
|
assert$1(axis <= $x.rank, () => 'Axis must be <= rank of the tensor');
|
|
const inputs = { input: $x };
|
|
const attrs = { dim: axis };
|
|
return ENGINE.runKernel(ExpandDims, inputs, attrs);
|
|
}
|
|
const expandDims$3 = op({ expandDims_ });
|
|
|
|
|
|
|
|
function tile_(x, reps) {
|
|
const $x = convertToTensor(x, 'x', 'tile', 'string_or_numeric');
|
|
assert$1($x.rank === reps.length, () => `Error in transpose: rank of input ${$x.rank} ` +
|
|
`must match length of reps ${reps}.`);
|
|
const inputs = { x: $x };
|
|
const attrs = { reps };
|
|
return ENGINE.runKernel(Tile, inputs, attrs);
|
|
}
|
|
const tile$3 = op({ tile_ });
|
|
|
|
|
|
|
|
function eye_(numRows, numColumns, batchShape, dtype = 'float32') {
|
|
if (numColumns == null) {
|
|
numColumns = numRows;
|
|
}
|
|
const buff = buffer([numRows, numColumns], dtype);
|
|
const n = numRows <= numColumns ? numRows : numColumns;
|
|
for (let i = 0; i < n; ++i) {
|
|
buff.set(1, i, i);
|
|
}
|
|
const out = reshape$2(buff.toTensor(), [numRows, numColumns]);
|
|
if (batchShape == null) {
|
|
return out;
|
|
}
|
|
else {
|
|
if (batchShape.length === 1) {
|
|
return tile$3(expandDims$3(out, 0), [batchShape[0], 1, 1]);
|
|
}
|
|
else if (batchShape.length === 2) {
|
|
|
|
return tile$3(expandDims$3(expandDims$3(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]);
|
|
}
|
|
else if (batchShape.length === 3) {
|
|
|
|
return tile$3(expandDims$3(expandDims$3(expandDims$3(out, 0), 0), 0), [
|
|
batchShape[0], batchShape[1], batchShape[2], 1, 1
|
|
]);
|
|
}
|
|
else {
|
|
throw new Error(`eye() currently supports only 1D and 2D ` +
|
|
|
|
`batchShapes, but received ${batchShape.length}D.`);
|
|
}
|
|
}
|
|
}
|
|
const eye = op({ eye_ });
|
|
|
|
|
|
|
|
function floor_(x) {
|
|
const $x = convertToTensor(x, 'x', 'floor', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Floor, inputs);
|
|
}
|
|
const floor$2 = op({ floor_ });
|
|
|
|
|
|
|
|
function gather_(x, indices, axis = 0, batchDims = 0) {
|
|
const $x = convertToTensor(x, 'x', 'gather');
|
|
const $indices = convertToTensor(indices, 'indices', 'gather', 'int32');
|
|
const inputs = { x: $x, indices: $indices };
|
|
const attrs = { axis, batchDims };
|
|
return ENGINE.runKernel(GatherV2, inputs, attrs);
|
|
}
|
|
const gather$1 = op({ gather_ });
|
|
|
|
|
|
|
|
function greater_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'greater', 'string_or_numeric');
|
|
let $b = convertToTensor(b, 'b', 'greater', 'string_or_numeric');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Greater, inputs);
|
|
}
|
|
const greater$2 = op({ greater_ });
|
|
|
|
|
|
|
|
function greaterEqual_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'greaterEqual', 'string_or_numeric');
|
|
let $b = convertToTensor(b, 'b', 'greaterEqual', 'string_or_numeric');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(GreaterEqual, inputs);
|
|
}
|
|
const greaterEqual$2 = op({ greaterEqual_ });
|
|
|
|
|
|
|
|
function imag_(input) {
|
|
const $input = convertToTensor(input, 'input', 'imag');
|
|
const inputs = { input: $input };
|
|
return ENGINE.runKernel(Imag, inputs);
|
|
}
|
|
const imag$2 = op({ imag_ });
|
|
|
|
|
|
|
|
function leakyRelu_(x, alpha = 0.2) {
|
|
const $x = convertToTensor(x, 'x', 'leakyRelu');
|
|
const inputs = { x: $x };
|
|
const attrs = { alpha };
|
|
return ENGINE.runKernel(LeakyRelu, inputs, attrs);
|
|
}
|
|
const leakyRelu$2 = op({ leakyRelu_ });
|
|
|
|
|
|
|
|
function less_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'less', 'string_or_numeric');
|
|
let $b = convertToTensor(b, 'b', 'less', 'string_or_numeric');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Less, inputs);
|
|
}
|
|
const less$2 = op({ less_ });
|
|
|
|
|
|
|
|
function lessEqual_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'lessEqual', 'string_or_numeric');
|
|
let $b = convertToTensor(b, 'b', 'lessEqual', 'string_or_numeric');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(LessEqual, inputs);
|
|
}
|
|
const lessEqual$2 = op({ lessEqual_ });
|
|
|
|
|
|
|
|
function log_(x) {
|
|
const $x = convertToTensor(x, 'x', 'log', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Log, inputs);
|
|
}
|
|
const log$2 = op({ log_ });
|
|
|
|
|
|
|
|
function log1p_(x) {
|
|
const $x = convertToTensor(x, 'x', 'log1p');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Log1p, inputs);
|
|
}
|
|
const log1p$2 = op({ log1p_ });
|
|
|
|
|
|
|
|
function variableGrads(f, varList) {
|
|
assert$1(isFunction(f), () => 'The f passed in variableGrads(f) must be a function');
|
|
assert$1(varList == null ||
|
|
Array.isArray(varList) && varList.every(v => v instanceof Variable), () => 'The varList passed in variableGrads(f, varList) must be an array ' +
|
|
'of variables');
|
|
const specifiedVarList = varList != null;
|
|
if (!specifiedVarList) {
|
|
|
|
varList = [];
|
|
for (const varName in ENGINE.registeredVariables) {
|
|
varList.push(ENGINE.registeredVariables[varName]);
|
|
}
|
|
}
|
|
const specifiedNonTrainable = specifiedVarList ? varList.filter(variable => !variable.trainable) : null;
|
|
|
|
const originalVarCount = varList.length;
|
|
varList = varList.filter(variable => variable.trainable);
|
|
assert$1(varList.length > 0, () => `variableGrads() expects at least one of the input variables to ` +
|
|
`be trainable, but none of the ${originalVarCount} variables is ` +
|
|
`trainable.`);
|
|
const allowNoGradients = true;
|
|
const { value, grads } = ENGINE.gradients(f, varList, null, allowNoGradients);
|
|
assert$1(grads.some(g => g != null), () => 'Cannot find a connection between any variable and the result of ' +
|
|
'the loss function y=f(x). Please make sure the operations that ' +
|
|
'use variables are inside the function f passed to minimize().');
|
|
assert$1(value.rank === 0, () => `The f passed in variableGrads(f) must return a scalar, but it ` +
|
|
`returned a rank-${value.rank} tensor`);
|
|
const namedGrads = {};
|
|
varList.forEach((v, i) => {
|
|
if (grads[i] != null) {
|
|
namedGrads[v.name] = grads[i];
|
|
}
|
|
});
|
|
if (specifiedNonTrainable != null) {
|
|
|
|
|
|
specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
|
|
}
|
|
return { value, grads: namedGrads };
|
|
}
|
|
|
|
function customGrad(f) {
|
|
return ENGINE.customGrad(f);
|
|
}
|
|
|
|
|
|
|
|
function neg_(x) {
|
|
const $x = convertToTensor(x, 'x', 'neg');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Neg, inputs);
|
|
}
|
|
const neg$2 = op({ neg_ });
|
|
|
|
|
|
|
|
function softplus_(x) {
|
|
const $x = convertToTensor(x, 'x', 'softplus');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Softplus$1, inputs);
|
|
}
|
|
const softplus$2 = op({ softplus_ });
|
|
|
|
|
|
|
|
function sub_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'sub');
|
|
let $b = convertToTensor(b, 'b', 'sub');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Sub, inputs);
|
|
}
|
|
const sub$2 = op({ sub_ });
|
|
|
|
|
|
|
|
function logSoftmax_(logits, axis = -1) {
|
|
const $logits = convertToTensor(logits, 'logits', 'logSoftmax');
|
|
if (axis === -1) {
|
|
axis = $logits.rank - 1;
|
|
}
|
|
if (axis !== $logits.rank - 1) {
|
|
throw Error('Log Softmax along a non-last dimension is not yet supported. ' +
|
|
`Logits was rank ${$logits.rank} and axis was ${axis}`);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const customOp = customGrad((logits, save) => {
|
|
const keepDims = true;
|
|
const xMax = max$2(logits, axis, true);
|
|
const shifted = sub$2(logits, xMax);
|
|
const value = sub$2(cast$3(shifted, 'float32'), log$2(sum$2(exp$2(shifted), axis, keepDims)));
|
|
save([value]);
|
|
const gradFunc = (dy, saved) => {
|
|
const [value] = saved;
|
|
const keepDims = true;
|
|
const softmax = exp$2(value);
|
|
return sub$2(dy, mul(sum$2(dy, axis, keepDims), softmax));
|
|
};
|
|
return { value, gradFunc };
|
|
});
|
|
return customOp($logits);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
const logSoftmax = op({ logSoftmax_ });
|
|
|
|
|
|
|
|
function logicalAnd_(a, b) {
|
|
const $a = convertToTensor(a, 'a', 'logicalAnd', 'bool');
|
|
const $b = convertToTensor(b, 'b', 'logicalAnd', 'bool');
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(LogicalAnd, inputs);
|
|
}
|
|
const logicalAnd$2 = op({ logicalAnd_ });
|
|
|
|
|
|
|
|
function logicalNot_(x) {
|
|
const $x = convertToTensor(x, 'x', 'logicalNot', 'bool');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(LogicalNot, inputs);
|
|
}
|
|
const logicalNot$2 = op({ logicalNot_ });
|
|
|
|
|
|
|
|
function maximum_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'maximum');
|
|
let $b = convertToTensor(b, 'b', 'maximum');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
if ($a.dtype === 'bool') {
|
|
$a = cast$3($a, 'int32');
|
|
$b = cast$3($b, 'int32');
|
|
}
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Maximum, inputs);
|
|
}
|
|
const maximum$2 = op({ maximum_ });
|
|
|
|
|
|
|
|
function mean_(x, axis = null, keepDims = false) {
|
|
const $x = convertToTensor(x, 'x', 'mean');
|
|
const inputs = { x: $x };
|
|
const attrs = { axis, keepDims };
|
|
return ENGINE.runKernel(Mean, inputs, attrs);
|
|
}
|
|
const mean$1 = op({ mean_ });
|
|
|
|
|
|
|
|
function zeros$1(shape, dtype = 'float32') {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
if (dtype === 'complex64') {
|
|
const real = zeros$1(shape, 'float32');
|
|
const imag = zeros$1(shape, 'float32');
|
|
return complex$2(real, imag);
|
|
}
|
|
const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
|
|
return ENGINE.makeTensor(values, shape, dtype);
|
|
}
|
|
|
|
|
|
|
|
function ones(shape, dtype = 'float32') {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
if (dtype === 'complex64') {
|
|
const real = ones(shape, 'float32');
|
|
const imag = zeros$1(shape, 'float32');
|
|
return complex$2(real, imag);
|
|
}
|
|
const values = makeOnesTypedArray(sizeFromShape(shape), dtype);
|
|
return ENGINE.makeTensor(values, shape, dtype);
|
|
}
|
|
|
|
|
|
|
|
function minimum_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'minimum');
|
|
let $b = convertToTensor(b, 'b', 'minimum');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
if ($a.dtype === 'bool') {
|
|
$a = cast$3($a, 'int32');
|
|
$b = cast$3($b, 'int32');
|
|
}
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(Minimum, inputs);
|
|
}
|
|
const minimum$2 = op({ minimum_ });
|
|
|
|
|
|
|
|
function notEqual_(a, b) {
|
|
let $a = convertToTensor(a, 'a', 'notEqual', 'string_or_numeric');
|
|
let $b = convertToTensor(b, 'b', 'notEqual', 'string_or_numeric');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
assertAndGetBroadcastShape($a.shape, $b.shape);
|
|
const inputs = { a: $a, b: $b };
|
|
return ENGINE.runKernel(NotEqual, inputs);
|
|
}
|
|
const notEqual$2 = op({ notEqual_ });
|
|
|
|
|
|
|
|
function oneHot_(indices, depth, onValue = 1, offValue = 0, dtype = 'int32') {
|
|
if (depth < 2) {
|
|
throw new Error(`Error in oneHot: depth must be >=2, but it is ${depth}`);
|
|
}
|
|
const $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32');
|
|
const inputs = { indices: $indices };
|
|
const attrs = { dtype, depth, onValue, offValue };
|
|
return ENGINE.runKernel(OneHot, inputs, attrs);
|
|
}
|
|
const oneHot$2 = op({ oneHot_ });
|
|
|
|
|
|
|
|
function onesLike_(x) {
|
|
const $x = convertToTensor(x, 'x', 'onesLike');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(OnesLike, inputs);
|
|
}
|
|
const onesLike$2 = op({ onesLike_ });
|
|
|
|
|
|
|
|
function pad_(x, paddings, constantValue = 0) {
|
|
const $x = convertToTensor(x, 'x', 'pad');
|
|
if ($x.rank === 0) {
|
|
throw new Error('pad(scalar) is not defined. Pass non-scalar to pad');
|
|
}
|
|
const attrs = { paddings, constantValue };
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(PadV2, inputs, attrs);
|
|
}
|
|
const pad = op({ pad_ });
|
|
|
|
|
|
|
|
function spaceToBatchND_(x, blockShape, paddings) {
|
|
const $x = convertToTensor(x, 'x', 'spaceToBatchND');
|
|
assert$1($x.rank >= 1 + blockShape.length, () => `input rank ${$x.rank} should be > than [blockShape] ${blockShape.length}`);
|
|
assert$1(paddings.length === blockShape.length, () => `paddings.shape[0] ${paddings.length} must be equal to [blockShape] ${blockShape.length}`);
|
|
assert$1($x.shape.reduce((a, b, i) => {
|
|
if (i > 0 && i <= blockShape.length) {
|
|
return a &&
|
|
((b + paddings[i - 1][0] + paddings[i - 1][1]) %
|
|
blockShape[i - 1] ===
|
|
0);
|
|
}
|
|
return a;
|
|
}, true), () => `input spatial dimensions ${$x.shape.slice(1)} with paddings ${paddings.toString()} must be divisible by blockShapes ${blockShape.toString()}`);
|
|
const inputs = { x: $x };
|
|
const attrs = { blockShape, paddings };
|
|
return ENGINE.runKernel(SpaceToBatchND, inputs, attrs);
|
|
}
|
|
const spaceToBatchND$2 = op({ spaceToBatchND_ });
|
|
|
|
|
|
|
|
function prelu_(x, alpha) {
|
|
const $x = convertToTensor(x, 'x', 'prelu');
|
|
const $alpha = convertToTensor(alpha, 'alpha', 'prelu');
|
|
const inputs = { x: $x, alpha: $alpha };
|
|
return ENGINE.runKernel(Prelu, inputs);
|
|
}
|
|
const prelu$2 = op({ prelu_ });
|
|
|
|
var alea$1 = {exports: {}};
|
|
|
|
(function (module) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(function(global, module, define) {
|
|
|
|
function Alea(seed) {
|
|
var me = this, mash = Mash();
|
|
|
|
me.next = function() {
|
|
var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10;
|
|
me.s0 = me.s1;
|
|
me.s1 = me.s2;
|
|
return me.s2 = t - (me.c = t | 0);
|
|
};
|
|
|
|
|
|
me.c = 1;
|
|
me.s0 = mash(' ');
|
|
me.s1 = mash(' ');
|
|
me.s2 = mash(' ');
|
|
me.s0 -= mash(seed);
|
|
if (me.s0 < 0) { me.s0 += 1; }
|
|
me.s1 -= mash(seed);
|
|
if (me.s1 < 0) { me.s1 += 1; }
|
|
me.s2 -= mash(seed);
|
|
if (me.s2 < 0) { me.s2 += 1; }
|
|
mash = null;
|
|
}
|
|
|
|
function copy(f, t) {
|
|
t.c = f.c;
|
|
t.s0 = f.s0;
|
|
t.s1 = f.s1;
|
|
t.s2 = f.s2;
|
|
return t;
|
|
}
|
|
|
|
function impl(seed, opts) {
|
|
var xg = new Alea(seed),
|
|
state = opts && opts.state,
|
|
prng = xg.next;
|
|
prng.int32 = function() { return (xg.next() * 0x100000000) | 0; };
|
|
prng.double = function() {
|
|
return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16;
|
|
};
|
|
prng.quick = prng;
|
|
if (state) {
|
|
if (typeof(state) == 'object') copy(state, xg);
|
|
prng.state = function() { return copy(xg, {}); };
|
|
}
|
|
return prng;
|
|
}
|
|
|
|
function Mash() {
|
|
var n = 0xefc8249d;
|
|
|
|
var mash = function(data) {
|
|
data = String(data);
|
|
for (var i = 0; i < data.length; i++) {
|
|
n += data.charCodeAt(i);
|
|
var h = 0.02519603282416938 * n;
|
|
n = h >>> 0;
|
|
h -= n;
|
|
h *= n;
|
|
n = h >>> 0;
|
|
h -= n;
|
|
n += h * 0x100000000;
|
|
}
|
|
return (n >>> 0) * 2.3283064365386963e-10;
|
|
};
|
|
|
|
return mash;
|
|
}
|
|
|
|
|
|
if (module && module.exports) {
|
|
module.exports = impl;
|
|
} else {
|
|
this.alea = impl;
|
|
}
|
|
|
|
})(
|
|
commonjsGlobal,
|
|
module);
|
|
} (alea$1));
|
|
|
|
var aleaExports = alea$1.exports;
|
|
|
|
var xor128$1 = {exports: {}};
|
|
|
|
(function (module) {
|
|
|
|
|
|
|
|
(function(global, module, define) {
|
|
|
|
function XorGen(seed) {
|
|
var me = this, strseed = '';
|
|
|
|
me.x = 0;
|
|
me.y = 0;
|
|
me.z = 0;
|
|
me.w = 0;
|
|
|
|
|
|
me.next = function() {
|
|
var t = me.x ^ (me.x << 11);
|
|
me.x = me.y;
|
|
me.y = me.z;
|
|
me.z = me.w;
|
|
return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8);
|
|
};
|
|
|
|
if (seed === (seed | 0)) {
|
|
|
|
me.x = seed;
|
|
} else {
|
|
|
|
strseed += seed;
|
|
}
|
|
|
|
|
|
for (var k = 0; k < strseed.length + 64; k++) {
|
|
me.x ^= strseed.charCodeAt(k) | 0;
|
|
me.next();
|
|
}
|
|
}
|
|
|
|
function copy(f, t) {
|
|
t.x = f.x;
|
|
t.y = f.y;
|
|
t.z = f.z;
|
|
t.w = f.w;
|
|
return t;
|
|
}
|
|
|
|
function impl(seed, opts) {
|
|
var xg = new XorGen(seed),
|
|
state = opts && opts.state,
|
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
|
prng.double = function() {
|
|
do {
|
|
var top = xg.next() >>> 11,
|
|
bot = (xg.next() >>> 0) / 0x100000000,
|
|
result = (top + bot) / (1 << 21);
|
|
} while (result === 0);
|
|
return result;
|
|
};
|
|
prng.int32 = xg.next;
|
|
prng.quick = prng;
|
|
if (state) {
|
|
if (typeof(state) == 'object') copy(state, xg);
|
|
prng.state = function() { return copy(xg, {}); };
|
|
}
|
|
return prng;
|
|
}
|
|
|
|
if (module && module.exports) {
|
|
module.exports = impl;
|
|
} else {
|
|
this.xor128 = impl;
|
|
}
|
|
|
|
})(
|
|
commonjsGlobal,
|
|
module);
|
|
} (xor128$1));
|
|
|
|
var xor128Exports = xor128$1.exports;
|
|
|
|
var xorwow$1 = {exports: {}};
|
|
|
|
(function (module) {
|
|
|
|
|
|
|
|
(function(global, module, define) {
|
|
|
|
function XorGen(seed) {
|
|
var me = this, strseed = '';
|
|
|
|
|
|
me.next = function() {
|
|
var t = (me.x ^ (me.x >>> 2));
|
|
me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v;
|
|
return (me.d = (me.d + 362437 | 0)) +
|
|
(me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0;
|
|
};
|
|
|
|
me.x = 0;
|
|
me.y = 0;
|
|
me.z = 0;
|
|
me.w = 0;
|
|
me.v = 0;
|
|
|
|
if (seed === (seed | 0)) {
|
|
|
|
me.x = seed;
|
|
} else {
|
|
|
|
strseed += seed;
|
|
}
|
|
|
|
|
|
for (var k = 0; k < strseed.length + 64; k++) {
|
|
me.x ^= strseed.charCodeAt(k) | 0;
|
|
if (k == strseed.length) {
|
|
me.d = me.x << 10 ^ me.x >>> 4;
|
|
}
|
|
me.next();
|
|
}
|
|
}
|
|
|
|
function copy(f, t) {
|
|
t.x = f.x;
|
|
t.y = f.y;
|
|
t.z = f.z;
|
|
t.w = f.w;
|
|
t.v = f.v;
|
|
t.d = f.d;
|
|
return t;
|
|
}
|
|
|
|
function impl(seed, opts) {
|
|
var xg = new XorGen(seed),
|
|
state = opts && opts.state,
|
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
|
prng.double = function() {
|
|
do {
|
|
var top = xg.next() >>> 11,
|
|
bot = (xg.next() >>> 0) / 0x100000000,
|
|
result = (top + bot) / (1 << 21);
|
|
} while (result === 0);
|
|
return result;
|
|
};
|
|
prng.int32 = xg.next;
|
|
prng.quick = prng;
|
|
if (state) {
|
|
if (typeof(state) == 'object') copy(state, xg);
|
|
prng.state = function() { return copy(xg, {}); };
|
|
}
|
|
return prng;
|
|
}
|
|
|
|
if (module && module.exports) {
|
|
module.exports = impl;
|
|
} else {
|
|
this.xorwow = impl;
|
|
}
|
|
|
|
})(
|
|
commonjsGlobal,
|
|
module);
|
|
} (xorwow$1));
|
|
|
|
var xorwowExports = xorwow$1.exports;
|
|
|
|
var xorshift7$1 = {exports: {}};
|
|
|
|
(function (module) {
|
|
|
|
|
|
|
|
|
|
|
|
(function(global, module, define) {
|
|
|
|
function XorGen(seed) {
|
|
var me = this;
|
|
|
|
|
|
me.next = function() {
|
|
|
|
var X = me.x, i = me.i, t, v;
|
|
t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24);
|
|
t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10);
|
|
t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3);
|
|
t = X[(i + 4) & 7]; v ^= t ^ (t << 7);
|
|
t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9);
|
|
X[i] = v;
|
|
me.i = (i + 1) & 7;
|
|
return v;
|
|
};
|
|
|
|
function init(me, seed) {
|
|
var j, X = [];
|
|
|
|
if (seed === (seed | 0)) {
|
|
|
|
X[0] = seed;
|
|
} else {
|
|
|
|
seed = '' + seed;
|
|
for (j = 0; j < seed.length; ++j) {
|
|
X[j & 7] = (X[j & 7] << 15) ^
|
|
(seed.charCodeAt(j) + X[(j + 1) & 7] << 13);
|
|
}
|
|
}
|
|
|
|
while (X.length < 8) X.push(0);
|
|
for (j = 0; j < 8 && X[j] === 0; ++j);
|
|
if (j == 8) X[7] = -1;
|
|
|
|
me.x = X;
|
|
me.i = 0;
|
|
|
|
|
|
for (j = 256; j > 0; --j) {
|
|
me.next();
|
|
}
|
|
}
|
|
|
|
init(me, seed);
|
|
}
|
|
|
|
function copy(f, t) {
|
|
t.x = f.x.slice();
|
|
t.i = f.i;
|
|
return t;
|
|
}
|
|
|
|
function impl(seed, opts) {
|
|
if (seed == null) seed = +(new Date);
|
|
var xg = new XorGen(seed),
|
|
state = opts && opts.state,
|
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
|
prng.double = function() {
|
|
do {
|
|
var top = xg.next() >>> 11,
|
|
bot = (xg.next() >>> 0) / 0x100000000,
|
|
result = (top + bot) / (1 << 21);
|
|
} while (result === 0);
|
|
return result;
|
|
};
|
|
prng.int32 = xg.next;
|
|
prng.quick = prng;
|
|
if (state) {
|
|
if (state.x) copy(state, xg);
|
|
prng.state = function() { return copy(xg, {}); };
|
|
}
|
|
return prng;
|
|
}
|
|
|
|
if (module && module.exports) {
|
|
module.exports = impl;
|
|
} else {
|
|
this.xorshift7 = impl;
|
|
}
|
|
|
|
})(
|
|
commonjsGlobal,
|
|
module);
|
|
} (xorshift7$1));
|
|
|
|
var xorshift7Exports = xorshift7$1.exports;
|
|
|
|
var xor4096$1 = {exports: {}};
|
|
|
|
(function (module) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(function(global, module, define) {
|
|
|
|
function XorGen(seed) {
|
|
var me = this;
|
|
|
|
|
|
me.next = function() {
|
|
var w = me.w,
|
|
X = me.X, i = me.i, t, v;
|
|
|
|
me.w = w = (w + 0x61c88647) | 0;
|
|
|
|
v = X[(i + 34) & 127];
|
|
t = X[i = ((i + 1) & 127)];
|
|
v ^= v << 13;
|
|
t ^= t << 17;
|
|
v ^= v >>> 15;
|
|
t ^= t >>> 12;
|
|
|
|
v = X[i] = v ^ t;
|
|
me.i = i;
|
|
|
|
return (v + (w ^ (w >>> 16))) | 0;
|
|
};
|
|
|
|
function init(me, seed) {
|
|
var t, v, i, j, w, X = [], limit = 128;
|
|
if (seed === (seed | 0)) {
|
|
|
|
v = seed;
|
|
seed = null;
|
|
} else {
|
|
|
|
seed = seed + '\0';
|
|
v = 0;
|
|
limit = Math.max(limit, seed.length);
|
|
}
|
|
|
|
for (i = 0, j = -32; j < limit; ++j) {
|
|
|
|
if (seed) v ^= seed.charCodeAt((j + 32) % seed.length);
|
|
|
|
if (j === 0) w = v;
|
|
v ^= v << 10;
|
|
v ^= v >>> 15;
|
|
v ^= v << 4;
|
|
v ^= v >>> 13;
|
|
if (j >= 0) {
|
|
w = (w + 0x61c88647) | 0;
|
|
t = (X[j & 127] ^= (v + w));
|
|
i = (0 == t) ? i + 1 : 0;
|
|
}
|
|
}
|
|
|
|
if (i >= 128) {
|
|
X[(seed && seed.length || 0) & 127] = -1;
|
|
}
|
|
|
|
|
|
|
|
i = 127;
|
|
for (j = 4 * 128; j > 0; --j) {
|
|
v = X[(i + 34) & 127];
|
|
t = X[i = ((i + 1) & 127)];
|
|
v ^= v << 13;
|
|
t ^= t << 17;
|
|
v ^= v >>> 15;
|
|
t ^= t >>> 12;
|
|
X[i] = v ^ t;
|
|
}
|
|
|
|
me.w = w;
|
|
me.X = X;
|
|
me.i = i;
|
|
}
|
|
|
|
init(me, seed);
|
|
}
|
|
|
|
function copy(f, t) {
|
|
t.i = f.i;
|
|
t.w = f.w;
|
|
t.X = f.X.slice();
|
|
return t;
|
|
}
|
|
function impl(seed, opts) {
|
|
if (seed == null) seed = +(new Date);
|
|
var xg = new XorGen(seed),
|
|
state = opts && opts.state,
|
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
|
prng.double = function() {
|
|
do {
|
|
var top = xg.next() >>> 11,
|
|
bot = (xg.next() >>> 0) / 0x100000000,
|
|
result = (top + bot) / (1 << 21);
|
|
} while (result === 0);
|
|
return result;
|
|
};
|
|
prng.int32 = xg.next;
|
|
prng.quick = prng;
|
|
if (state) {
|
|
if (state.X) copy(state, xg);
|
|
prng.state = function() { return copy(xg, {}); };
|
|
}
|
|
return prng;
|
|
}
|
|
|
|
if (module && module.exports) {
|
|
module.exports = impl;
|
|
} else {
|
|
this.xor4096 = impl;
|
|
}
|
|
|
|
})(
|
|
commonjsGlobal,
|
|
module);
|
|
} (xor4096$1));
|
|
|
|
var xor4096Exports = xor4096$1.exports;
|
|
|
|
var tychei$1 = {exports: {}};
|
|
|
|
(function (module) {
|
|
|
|
|
|
|
|
|
|
(function(global, module, define) {
|
|
|
|
function XorGen(seed) {
|
|
var me = this, strseed = '';
|
|
|
|
|
|
me.next = function() {
|
|
var b = me.b, c = me.c, d = me.d, a = me.a;
|
|
b = (b << 25) ^ (b >>> 7) ^ c;
|
|
c = (c - d) | 0;
|
|
d = (d << 24) ^ (d >>> 8) ^ a;
|
|
a = (a - b) | 0;
|
|
me.b = b = (b << 20) ^ (b >>> 12) ^ c;
|
|
me.c = c = (c - d) | 0;
|
|
me.d = (d << 16) ^ (c >>> 16) ^ a;
|
|
return me.a = (a - b) | 0;
|
|
};
|
|
|
|
|
|
|
|
me.a = 0;
|
|
me.b = 0;
|
|
me.c = 2654435769 | 0;
|
|
me.d = 1367130551;
|
|
|
|
if (seed === Math.floor(seed)) {
|
|
|
|
me.a = (seed / 0x100000000) | 0;
|
|
me.b = seed | 0;
|
|
} else {
|
|
|
|
strseed += seed;
|
|
}
|
|
|
|
|
|
for (var k = 0; k < strseed.length + 20; k++) {
|
|
me.b ^= strseed.charCodeAt(k) | 0;
|
|
me.next();
|
|
}
|
|
}
|
|
|
|
function copy(f, t) {
|
|
t.a = f.a;
|
|
t.b = f.b;
|
|
t.c = f.c;
|
|
t.d = f.d;
|
|
return t;
|
|
}
|
|
function impl(seed, opts) {
|
|
var xg = new XorGen(seed),
|
|
state = opts && opts.state,
|
|
prng = function() { return (xg.next() >>> 0) / 0x100000000; };
|
|
prng.double = function() {
|
|
do {
|
|
var top = xg.next() >>> 11,
|
|
bot = (xg.next() >>> 0) / 0x100000000,
|
|
result = (top + bot) / (1 << 21);
|
|
} while (result === 0);
|
|
return result;
|
|
};
|
|
prng.int32 = xg.next;
|
|
prng.quick = prng;
|
|
if (state) {
|
|
if (typeof(state) == 'object') copy(state, xg);
|
|
prng.state = function() { return copy(xg, {}); };
|
|
}
|
|
return prng;
|
|
}
|
|
|
|
if (module && module.exports) {
|
|
module.exports = impl;
|
|
} else {
|
|
this.tychei = impl;
|
|
}
|
|
|
|
})(
|
|
commonjsGlobal,
|
|
module);
|
|
} (tychei$1));
|
|
|
|
var tycheiExports = tychei$1.exports;
|
|
|
|
var seedrandom$1 = {exports: {}};
|
|
|
|
|
|
|
|
(function (module) {
|
|
(function (global, pool, math) {
|
|
|
|
|
|
|
|
|
|
var width = 256,
|
|
chunks = 6,
|
|
digits = 52,
|
|
rngname = 'random',
|
|
startdenom = math.pow(width, chunks),
|
|
significance = math.pow(2, digits),
|
|
overflow = significance * 2,
|
|
mask = width - 1,
|
|
nodecrypto;
|
|
|
|
|
|
|
|
|
|
|
|
function seedrandom(seed, options, callback) {
|
|
var key = [];
|
|
options = (options == true) ? { entropy: true } : (options || {});
|
|
|
|
|
|
var shortseed = mixkey(flatten(
|
|
options.entropy ? [seed, tostring(pool)] :
|
|
(seed == null) ? autoseed() : seed, 3), key);
|
|
|
|
|
|
var arc4 = new ARC4(key);
|
|
|
|
|
|
|
|
var prng = function() {
|
|
var n = arc4.g(chunks),
|
|
d = startdenom,
|
|
x = 0;
|
|
while (n < significance) {
|
|
n = (n + x) * width;
|
|
d *= width;
|
|
x = arc4.g(1);
|
|
}
|
|
while (n >= overflow) {
|
|
n /= 2;
|
|
d /= 2;
|
|
x >>>= 1;
|
|
}
|
|
return (n + x) / d;
|
|
};
|
|
|
|
prng.int32 = function() { return arc4.g(4) | 0; };
|
|
prng.quick = function() { return arc4.g(4) / 0x100000000; };
|
|
prng.double = prng;
|
|
|
|
|
|
mixkey(tostring(arc4.S), pool);
|
|
|
|
|
|
return (options.pass || callback ||
|
|
function(prng, seed, is_math_call, state) {
|
|
if (state) {
|
|
|
|
if (state.S) { copy(state, arc4); }
|
|
|
|
prng.state = function() { return copy(arc4, {}); };
|
|
}
|
|
|
|
|
|
|
|
if (is_math_call) { math[rngname] = prng; return seed; }
|
|
|
|
|
|
|
|
else return prng;
|
|
})(
|
|
prng,
|
|
shortseed,
|
|
'global' in options ? options.global : (this == math),
|
|
options.state);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function ARC4(key) {
|
|
var t, keylen = key.length,
|
|
me = this, i = 0, j = me.i = me.j = 0, s = me.S = [];
|
|
|
|
|
|
if (!keylen) { key = [keylen++]; }
|
|
|
|
|
|
while (i < width) {
|
|
s[i] = i++;
|
|
}
|
|
for (i = 0; i < width; i++) {
|
|
s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))];
|
|
s[j] = t;
|
|
}
|
|
|
|
|
|
(me.g = function(count) {
|
|
|
|
var t, r = 0,
|
|
i = me.i, j = me.j, s = me.S;
|
|
while (count--) {
|
|
t = s[i = mask & (i + 1)];
|
|
r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))];
|
|
}
|
|
me.i = i; me.j = j;
|
|
return r;
|
|
|
|
|
|
|
|
})(width);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
function copy(f, t) {
|
|
t.i = f.i;
|
|
t.j = f.j;
|
|
t.S = f.S.slice();
|
|
return t;
|
|
}
|
|
|
|
|
|
|
|
|
|
function flatten(obj, depth) {
|
|
var result = [], typ = (typeof obj), prop;
|
|
if (depth && typ == 'object') {
|
|
for (prop in obj) {
|
|
try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {}
|
|
}
|
|
}
|
|
return (result.length ? result : typ == 'string' ? obj : obj + '\0');
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function mixkey(seed, key) {
|
|
var stringseed = seed + '', smear, j = 0;
|
|
while (j < stringseed.length) {
|
|
key[mask & j] =
|
|
mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++));
|
|
}
|
|
return tostring(key);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function autoseed() {
|
|
try {
|
|
var out;
|
|
if (nodecrypto && (out = nodecrypto.randomBytes)) {
|
|
|
|
out = out(width);
|
|
} else {
|
|
out = new Uint8Array(width);
|
|
(global.crypto || global.msCrypto).getRandomValues(out);
|
|
}
|
|
return tostring(out);
|
|
} catch (e) {
|
|
var browser = global.navigator,
|
|
plugins = browser && browser.plugins;
|
|
return [+new Date, global, plugins, global.screen, tostring(pool)];
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
function tostring(a) {
|
|
return String.fromCharCode.apply(0, a);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixkey(math.random(), pool);
|
|
|
|
|
|
|
|
|
|
|
|
if (module.exports) {
|
|
module.exports = seedrandom;
|
|
|
|
try {
|
|
nodecrypto = require('crypto');
|
|
} catch (ex) {}
|
|
} else {
|
|
|
|
math['seed' + rngname] = seedrandom;
|
|
}
|
|
|
|
|
|
|
|
})(
|
|
|
|
|
|
(typeof self !== 'undefined') ? self : commonjsGlobal,
|
|
[],
|
|
Math
|
|
);
|
|
} (seedrandom$1));
|
|
|
|
var seedrandomExports = seedrandom$1.exports;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var alea = aleaExports;
|
|
|
|
|
|
|
|
|
|
var xor128 = xor128Exports;
|
|
|
|
|
|
|
|
|
|
var xorwow = xorwowExports;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var xorshift7 = xorshift7Exports;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var xor4096 = xor4096Exports;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var tychei = tycheiExports;
|
|
|
|
|
|
|
|
var sr = seedrandomExports;
|
|
|
|
sr.alea = alea;
|
|
sr.xor128 = xor128;
|
|
sr.xorwow = xorwow;
|
|
sr.xorshift7 = xorshift7;
|
|
sr.xor4096 = xor4096;
|
|
sr.tychei = tychei;
|
|
|
|
var seedrandom = sr;
|
|
|
|
|
|
|
|
class MPRandGauss {
|
|
constructor(mean, stdDeviation, dtype, truncated, seed) {
|
|
this.mean = mean;
|
|
this.stdDev = stdDeviation;
|
|
this.dtype = dtype;
|
|
this.nextVal = NaN;
|
|
this.truncated = truncated;
|
|
if (this.truncated) {
|
|
this.upper = this.mean + this.stdDev * 2;
|
|
this.lower = this.mean - this.stdDev * 2;
|
|
}
|
|
const seedValue = seed ? seed : Math.random();
|
|
this.random = seedrandom.alea(seedValue.toString());
|
|
}
|
|
|
|
nextValue() {
|
|
if (!isNaN(this.nextVal)) {
|
|
const value = this.nextVal;
|
|
this.nextVal = NaN;
|
|
return value;
|
|
}
|
|
let resultX, resultY;
|
|
let isValid = false;
|
|
while (!isValid) {
|
|
let v1, v2, s;
|
|
do {
|
|
v1 = 2 * this.random() - 1;
|
|
v2 = 2 * this.random() - 1;
|
|
s = v1 * v1 + v2 * v2;
|
|
} while (s >= 1 || s === 0);
|
|
const mul = Math.sqrt(-2 * Math.log(s) / s);
|
|
resultX = this.mean + this.stdDev * v1 * mul;
|
|
resultY = this.mean + this.stdDev * v2 * mul;
|
|
if (!this.truncated || this.isValidTruncated(resultX)) {
|
|
isValid = true;
|
|
}
|
|
}
|
|
if (!this.truncated || this.isValidTruncated(resultY)) {
|
|
this.nextVal = this.convertValue(resultY);
|
|
}
|
|
return this.convertValue(resultX);
|
|
}
|
|
|
|
convertValue(value) {
|
|
if (this.dtype == null || this.dtype === 'float32') {
|
|
return value;
|
|
}
|
|
return Math.round(value);
|
|
}
|
|
|
|
isValidTruncated(value) {
|
|
return value <= this.upper && value >= this.lower;
|
|
}
|
|
}
|
|
class UniformRandom {
|
|
constructor(min = 0, max = 1, dtype, seed) {
|
|
|
|
this.canReturnFloat = () => (this.dtype == null || this.dtype === 'float32');
|
|
this.min = min;
|
|
this.range = max - min;
|
|
this.dtype = dtype;
|
|
if (seed == null) {
|
|
seed = Math.random();
|
|
}
|
|
if (typeof seed === 'number') {
|
|
seed = seed.toString();
|
|
}
|
|
if (!this.canReturnFloat() && this.range <= 1) {
|
|
throw new Error(`The difference between ${min} - ${max} <= 1 and dtype is not float`);
|
|
}
|
|
this.random = seedrandom.alea(seed);
|
|
}
|
|
convertValue(value) {
|
|
if (this.canReturnFloat()) {
|
|
return value;
|
|
}
|
|
return Math.round(value);
|
|
}
|
|
nextValue() {
|
|
return this.convertValue(this.min + this.range * this.random());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function randomNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
if (dtype != null && dtype === 'bool') {
|
|
throw new Error(`Unsupported data type ${dtype}`);
|
|
}
|
|
const randGauss = new MPRandGauss(mean, stdDev, dtype, false , seed);
|
|
const res = buffer(shape, dtype);
|
|
for (let i = 0; i < res.values.length; i++) {
|
|
res.values[i] = randGauss.nextValue();
|
|
}
|
|
return res.toTensor();
|
|
}
|
|
const randomNormal$1 = op({ randomNormal_ });
|
|
|
|
|
|
|
|
function randomUniform_(shape, minval = 0, maxval = 1, dtype = 'float32', seed) {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
const res = buffer(shape, dtype);
|
|
const random = new UniformRandom(minval, maxval, null, seed);
|
|
for (let i = 0; i < res.values.length; i++) {
|
|
res.values[i] = random.nextValue();
|
|
}
|
|
return res.toTensor();
|
|
}
|
|
const randomUniform = op({ randomUniform_ });
|
|
|
|
|
|
|
|
function range$3(start, stop, step = 1, dtype = 'float32') {
|
|
if (step === 0) {
|
|
throw new Error('Cannot have a step of zero');
|
|
}
|
|
const attrs = { start, stop, step, dtype };
|
|
return ENGINE.runKernel(Range, {} , attrs);
|
|
}
|
|
|
|
|
|
|
|
function real_(input) {
|
|
const $input = convertToTensor(input, 'input', 'real');
|
|
const inputs = { input: $input };
|
|
return ENGINE.runKernel(Real, inputs);
|
|
}
|
|
const real$2 = op({ real_ });
|
|
|
|
|
|
|
|
function relu_(x) {
|
|
const $x = convertToTensor(x, 'x', 'relu');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Relu$1, inputs);
|
|
}
|
|
const relu$2 = op({ relu_ });
|
|
|
|
|
|
|
|
function relu6_(x) {
|
|
const $x = convertToTensor(x, 'x', 'relu6');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Relu6$1, inputs);
|
|
}
|
|
const relu6$2 = op({ relu6_ });
|
|
|
|
|
|
|
|
function reverse_(x, axis) {
|
|
const $x = convertToTensor(x, 'x', 'reverse');
|
|
const inputs = { x: $x };
|
|
const attrs = { dims: axis };
|
|
return ENGINE.runKernel(Reverse, inputs, attrs);
|
|
}
|
|
const reverse$2 = op({ reverse_ });
|
|
|
|
|
|
|
|
function rsqrt_(x) {
|
|
const $x = convertToTensor(x, 'x', 'rsqrt', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Rsqrt, inputs);
|
|
}
|
|
const rsqrt$2 = op({ rsqrt_ });
|
|
|
|
|
|
|
|
function selu_(x) {
|
|
const $x = convertToTensor(x, 'x', 'selu');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Selu$1, inputs);
|
|
}
|
|
const selu$2 = op({ selu_ });
|
|
|
|
|
|
|
|
function sin_(x) {
|
|
const $x = convertToTensor(x, 'x', 'sin', 'float32');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Sin, inputs);
|
|
}
|
|
const sin$2 = op({ sin_ });
|
|
|
|
|
|
|
|
function sinh_(x) {
|
|
const $x = convertToTensor(x, 'x', 'sinh');
|
|
const inputs = { x: $x };
|
|
return ENGINE.runKernel(Sinh, inputs);
|
|
}
|
|
const sinh$2 = op({ sinh_ });
|
|
|
|
|
|
|
|
function slice1d_(x, begin, size) {
|
|
const $x = convertToTensor(x, 'x', 'slice1d');
|
|
assert$1($x.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${$x.rank} tensor`);
|
|
return slice$2($x, [begin], [size]);
|
|
}
|
|
const slice1d = op({ slice1d_ });
|
|
|
|
|
|
|
|
function slice2d_(x, begin, size) {
|
|
const $x = convertToTensor(x, 'x', 'slice2d');
|
|
assert$1($x.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${$x.rank} tensor`);
|
|
return slice$2($x, begin, size);
|
|
}
|
|
const slice2d = op({ slice2d_ });
|
|
|
|
|
|
|
|
function slice3d_(x, begin, size) {
|
|
const $x = convertToTensor(x, 'x', 'slice3d');
|
|
assert$1($x.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${$x.rank} tensor`);
|
|
return slice$2($x, begin, size);
|
|
}
|
|
const slice3d = op({ slice3d_ });
|
|
|
|
|
|
|
|
function slice4d_(x, begin, size) {
|
|
const $x = convertToTensor(x, 'x', 'slice4d');
|
|
assert$1($x.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${$x.rank} tensor`);
|
|
return slice$2($x, begin, size);
|
|
}
|
|
const slice4d = op({ slice4d_ });
|
|
|
|
|
|
|
|
function softmax_(logits, dim = -1) {
|
|
const $logits = convertToTensor(logits, 'logits', 'softmax', 'float32');
|
|
if (dim === -1) {
|
|
dim = $logits.rank - 1;
|
|
}
|
|
if (dim !== $logits.rank - 1) {
|
|
throw Error('Softmax along a non-last dimension is not yet supported. ' +
|
|
`Logits was rank ${$logits.rank} and dim was ${dim}`);
|
|
}
|
|
const inputs = { logits: $logits };
|
|
const attrs = { dim };
|
|
return ENGINE.runKernel(Softmax$1, inputs, attrs);
|
|
}
|
|
const softmax$2 = op({ softmax_ });
|
|
|
|
|
|
|
|
function split_(x, numOrSizeSplits, axis = 0) {
|
|
const $x = convertToTensor(x, 'x', 'split');
|
|
const inputs = { x: $x };
|
|
const attr = { numOrSizeSplits, axis };
|
|
return ENGINE.runKernel(SplitV, inputs, attr);
|
|
}
|
|
const split$1 = op({ split_ });
|
|
|
|
|
|
|
|
function squeeze_(x, axis) {
|
|
const $x = convertToTensor(x, 'x', 'squeeze', 'string_or_numeric');
|
|
return reshape$2($x, squeezeShape($x.shape, axis).newShape);
|
|
}
|
|
const squeeze = op({ squeeze_ });
|
|
|
|
|
|
|
|
function stack_(tensors, axis = 0) {
|
|
const $tensors = convertToTensorArray(tensors, 'tensors', 'stack', 'string_or_numeric');
|
|
assert$1($tensors.length >= 1, () => 'Pass at least one tensor to tf.stack');
|
|
if ($tensors.length > 0) {
|
|
assert$1(axis <= $tensors[0].rank, () => 'Axis must be <= rank of the tensor');
|
|
}
|
|
const inputs = $tensors;
|
|
const attrs = { axis };
|
|
return ENGINE.runKernel(Pack, inputs, attrs);
|
|
}
|
|
const stack = op({ stack_ });
|
|
|
|
|
|
|
|
function step_(x, alpha = 0.0) {
|
|
const $x = convertToTensor(x, 'x', 'step');
|
|
const inputs = { x: $x };
|
|
const attrs = { alpha };
|
|
return ENGINE.runKernel(Step, inputs, attrs);
|
|
}
|
|
const step$2 = op({ step_ });
|
|
|
|
|
|
|
|
function tensor(values, shape, dtype) {
|
|
const inferredShape = inferShape(values, dtype);
|
|
return makeTensor(values, shape, inferredShape, dtype);
|
|
}
|
|
|
|
|
|
|
|
function tensor1d(values, dtype) {
|
|
assertNonNull(values);
|
|
const inferredShape = inferShape(values, dtype);
|
|
if (inferredShape.length !== 1) {
|
|
throw new Error('tensor1d() requires values to be a flat/TypedArray');
|
|
}
|
|
const shape = null;
|
|
return makeTensor(values, shape, inferredShape, dtype);
|
|
}
|
|
|
|
|
|
|
|
function tensor2d(values, shape, dtype) {
|
|
assertNonNull(values);
|
|
if (shape != null && shape.length !== 2) {
|
|
throw new Error('tensor2d() requires shape to have two numbers');
|
|
}
|
|
const inferredShape = inferShape(values, dtype);
|
|
if (inferredShape.length !== 2 && inferredShape.length !== 1) {
|
|
throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray');
|
|
}
|
|
if (inferredShape.length === 1 && shape == null) {
|
|
throw new Error('tensor2d() requires shape to be provided when `values` ' +
|
|
'are a flat/TypedArray');
|
|
}
|
|
return makeTensor(values, shape, inferredShape, dtype);
|
|
}
|
|
|
|
|
|
function validateUpdateShape(shape, indices, updates) {
|
|
const sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1;
|
|
const batchDim = (indices.rank > 1) ? indices.rank - 1 : 1;
|
|
const shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' +
|
|
`shape[sliceDim:], got updates.shape: ${updates.shape}` +
|
|
`, indices.shape: ${indices.shape}, shape: ${shape}` +
|
|
`, sliceDim: ${sliceDim}, and batchDim: ${batchDim}.`;
|
|
if (updates.rank < batchDim) {
|
|
throw new Error(shapeError + ` update.rank < ${batchDim}. `);
|
|
}
|
|
if (shape.length < sliceDim + (updates.rank - batchDim)) {
|
|
throw new Error(shapeError +
|
|
` Output shape length < ${sliceDim + (updates.rank - batchDim)}`);
|
|
}
|
|
if (updates.rank !== batchDim + shape.length - sliceDim) {
|
|
throw new Error(shapeError + ` update.rank != ${batchDim + shape.length - sliceDim}`);
|
|
}
|
|
for (let d = 0; d < batchDim; ++d) {
|
|
if (updates.shape[d] !== indices.shape[d]) {
|
|
throw new Error(shapeError +
|
|
` updates.shape[${d}] (${updates.shape[d]}) != indices.shape[${d}] (${indices.shape[d]}).`);
|
|
}
|
|
}
|
|
for (let d = 0; d < updates.rank - batchDim; ++d) {
|
|
if (updates.shape[d + batchDim] !== shape[d + sliceDim]) {
|
|
throw new Error(shapeError +
|
|
` updates.shape[${d + batchDim}] (${updates.shape[d + batchDim]}) != shape[${d + batchDim}] (${shape[d + batchDim]})`);
|
|
}
|
|
}
|
|
}
|
|
|
|
function validateInput(updates, indices, shape) {
|
|
if (indices.rank < 1) {
|
|
throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' +
|
|
` but the rank was ${indices.rank}.`);
|
|
}
|
|
if (updates.rank < 1) {
|
|
throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' +
|
|
` but the rank was ${updates.rank}.`);
|
|
}
|
|
if (indices.dtype !== 'int32') {
|
|
throw new Error(`The dtype of 'indices' should be int32, but got dtype: ${indices.dtype}`);
|
|
}
|
|
if (shape.length < 1) {
|
|
throw new Error(`Output rank must be greater or equal to 1, but got shape: ${shape}`);
|
|
}
|
|
if (shape.length === 0) {
|
|
if (indices.size === 0) {
|
|
throw new Error(`Indices specified for empty output. indices shape: ${indices.shape}`);
|
|
}
|
|
if (updates.size === 0) {
|
|
throw new Error(`Updates specified for empty output. updates shape: ${updates.shape}`);
|
|
}
|
|
}
|
|
validateUpdateShape(shape, indices, updates);
|
|
}
|
|
|
|
function calculateShapes(updates, indices, shape) {
|
|
|
|
const indicesRank = indices.shape.length;
|
|
const sliceRank = (indicesRank > 1) ? indices.shape[indicesRank - 1] : 1;
|
|
|
|
|
|
|
|
const totalNd = shape.length;
|
|
let sliceSize = 1;
|
|
for (let i = sliceRank; i < totalNd; ++i) {
|
|
sliceSize *= shape[i];
|
|
}
|
|
const safeSliceDim = (sliceRank < 1) ? 1 : sliceRank;
|
|
const numUpdates = sizeFromShape(indices.shape) / safeSliceDim;
|
|
const strides = [...computeStrides(shape.slice(0, sliceRank)), 1];
|
|
const outputSize = sizeFromShape(shape);
|
|
return { sliceRank, numUpdates, sliceSize, strides, outputSize };
|
|
}
|
|
|
|
|
|
|
|
function truncatedNormal_(shape, mean = 0, stdDev = 1, dtype, seed) {
|
|
assertNonNegativeIntegerDimensions(shape);
|
|
if (dtype != null && dtype === 'bool') {
|
|
throw new Error(`Unsupported data type $ { dtype }`);
|
|
}
|
|
const randGauss = new MPRandGauss(mean, stdDev, dtype, true , seed);
|
|
const res = buffer(shape, dtype);
|
|
for (let i = 0; i < res.values.length; i++) {
|
|
res.values[i] = randGauss.nextValue();
|
|
}
|
|
return res.toTensor();
|
|
}
|
|
const truncatedNormal = op({ truncatedNormal_ });
|
|
|
|
|
|
|
|
function unsortedSegmentSum_(x, segmentIds, numSegments) {
|
|
const $x = convertToTensor(x, 'x', 'unsortedSegmentSum');
|
|
const $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32');
|
|
assert$1(isInt(numSegments), () => 'numSegments must be of dtype int');
|
|
const inputs = { x: $x, segmentIds: $segmentIds };
|
|
const attrs = { numSegments };
|
|
return ENGINE.runKernel(UnsortedSegmentSum, inputs, attrs);
|
|
}
|
|
const unsortedSegmentSum$2 = op({ unsortedSegmentSum_ });
|
|
|
|
|
|
|
|
function unstack_(x, axis = 0) {
|
|
const $x = convertToTensor(x, 'x', 'unstack', 'string_or_numeric');
|
|
assert$1(axis >= -$x.shape.length && axis < $x.shape.length, () => `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`);
|
|
const inputs = { value: $x };
|
|
const attrs = { axis };
|
|
return ENGINE.runKernel(Unpack, inputs, attrs);
|
|
}
|
|
const unstack = op({ unstack_ });
|
|
|
|
|
|
|
|
function variable(initialValue, trainable = true, name, dtype) {
|
|
return ENGINE.makeVariable(initialValue, trainable, name, dtype);
|
|
}
|
|
|
|
|
|
|
|
function whereImpl$2(condShape, condVals) {
|
|
const indices = [];
|
|
for (let i = 0; i < condVals.length; i++) {
|
|
if (condVals[i]) {
|
|
indices.push(i);
|
|
}
|
|
}
|
|
const inBuffer = buffer(condShape, 'int32');
|
|
const out = buffer([indices.length, condShape.length], 'int32');
|
|
for (let i = 0; i < indices.length; i++) {
|
|
const loc = inBuffer.indexToLoc(indices[i]);
|
|
const offset = i * condShape.length;
|
|
out.values.set(loc, offset);
|
|
}
|
|
return out.toTensor();
|
|
}
|
|
|
|
|
|
|
|
function transpose_(x, perm, conjugate) {
|
|
const $x = convertToTensor(x, 'x', 'transpose');
|
|
if (perm == null) {
|
|
perm = $x.shape.map((s, i) => i).reverse();
|
|
}
|
|
assert$1($x.rank === perm.length, () => `Error in transpose: rank of input ${$x.rank} ` +
|
|
`must match length of perm ${perm}.`);
|
|
perm.forEach(axis => {
|
|
assert$1(axis >= 0 && axis < $x.rank, () => `All entries in 'perm' must be between 0 and ${$x.rank - 1}` +
|
|
` but got ${perm}`);
|
|
});
|
|
if ($x.rank <= 1) {
|
|
return $x.clone();
|
|
}
|
|
const inputs = { x: $x };
|
|
const attrs = { perm };
|
|
if ($x.dtype === 'complex64') {
|
|
return tidy(() => {
|
|
let $real = real$2($x);
|
|
let $imag = imag$2($x);
|
|
$real = ENGINE.runKernel(Transpose, { x: $real }, attrs);
|
|
$imag = ENGINE.runKernel(Transpose, { x: $imag }, attrs);
|
|
if (conjugate) {
|
|
$imag = neg$2($imag);
|
|
}
|
|
return complex$2($real, $imag);
|
|
});
|
|
}
|
|
return ENGINE.runKernel(Transpose, inputs, attrs);
|
|
}
|
|
const transpose$2 = op({ transpose_ });
|
|
|
|
|
|
|
|
function getNoiseShape(x, noiseShape) {
|
|
if (noiseShape == null) {
|
|
return x.shape.slice();
|
|
}
|
|
if (arraysEqual(x.shape, noiseShape)) {
|
|
return noiseShape;
|
|
}
|
|
if (x.shape.length === noiseShape.length) {
|
|
const newDimension = [];
|
|
for (let i = 0; i < x.shape.length; i++) {
|
|
if (noiseShape[i] == null && x.shape[i] != null) {
|
|
newDimension.push(x.shape[i]);
|
|
}
|
|
else {
|
|
newDimension.push(noiseShape[i]);
|
|
}
|
|
}
|
|
return newDimension;
|
|
}
|
|
return noiseShape;
|
|
}
|
|
|
|
|
|
|
|
function dropout_(x, rate, noiseShape, seed) {
|
|
const $x = convertToTensor(x, 'x', 'dropout');
|
|
assert$1($x.dtype === 'float32', () => `x has to be a floating point tensor since it's going to be ` +
|
|
`scaled, but got a ${$x.dtype} tensor instead.`);
|
|
assert$1(rate >= 0 && rate < 1, () => `rate must be a float in the range [0, 1), but got ${rate}.`);
|
|
if (rate === 0) {
|
|
return x instanceof Tensor ? $x.clone() : $x;
|
|
}
|
|
const $noiseShape = getNoiseShape($x, noiseShape);
|
|
const keepProb = 1 - rate;
|
|
const multiplier = div$1(floor$2(add$1(randomUniform($noiseShape, 0, 1, 'float32', seed), keepProb)), keepProb);
|
|
return mul($x, multiplier);
|
|
}
|
|
const dropout$2 = op({ dropout_ });
|
|
|
|
|
|
|
|
function conv2DBackpropFilter_(x, dy, filterShape, strides, pad, dataFormat = 'NHWC', dimRoundingMode) {
|
|
let x4D = x;
|
|
if (x.rank === 3) {
|
|
x4D = reshape$2(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
|
|
}
|
|
let dy4D = dy;
|
|
if (dy4D.rank === 3) {
|
|
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
|
|
}
|
|
assert$1(x4D.rank === 4, () => `Error in conv2dDerFilter: input must be rank 4, but got shape ` +
|
|
`${x4D.shape}.`);
|
|
assert$1(dy4D.rank === 4, () => `Error in conv2dDerFilter: dy must be rank 4, but got shape ` +
|
|
`${dy4D.shape}.`);
|
|
assert$1(filterShape.length === 4, () => `Error in conv2dDerFilter: filterShape must be length 4, but got ` +
|
|
`${filterShape}.`);
|
|
const inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1];
|
|
const outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1];
|
|
assert$1(inDepth === filterShape[2], () => `Error in conv2dDerFilter: depth of input ${inDepth}) must ` +
|
|
`match input depth in filter (${filterShape[2]}.`);
|
|
assert$1(outDepth === filterShape[3], () => `Error in conv2dDerFilter: depth of dy (${outDepth}) must ` +
|
|
`match output depth for filter (${filterShape[3]}).`);
|
|
checkPadOnDimRoundingMode('conv2dDerFilter', pad, dimRoundingMode);
|
|
const inputs = { x: x4D, dy: dy4D };
|
|
const attrs = { strides, pad, dataFormat, dimRoundingMode, filterShape };
|
|
|
|
return ENGINE.runKernel(Conv2DBackpropFilter, inputs, attrs);
|
|
}
|
|
const conv2DBackpropFilter$2 = op({ conv2DBackpropFilter_ });
|
|
|
|
|
|
|
|
function getFusedDyActivation(dy, y, activation) {
|
|
if (activation == null || activation === 'linear') {
|
|
return dy;
|
|
}
|
|
if (activation === 'relu') {
|
|
return mul(dy, step$2(y));
|
|
}
|
|
throw new Error(`Cannot compute gradient for fused activation ${activation}.`);
|
|
}
|
|
|
|
function getFusedBiasGradient(bias, dyActivation) {
|
|
let res = dyActivation;
|
|
const reduceAxes = getReductionAxes(bias.shape, dyActivation.shape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, bias.shape);
|
|
}
|
|
function applyActivation$1(x, activation, preluActivationWeights, leakyreluAlpha) {
|
|
if (activation === 'linear') {
|
|
return x;
|
|
}
|
|
else if (activation === 'relu') {
|
|
return relu$2(x);
|
|
}
|
|
else if (activation === 'elu') {
|
|
return elu$3(x);
|
|
}
|
|
else if (activation === 'relu6') {
|
|
return relu6$2(x);
|
|
}
|
|
else if (activation === 'prelu') {
|
|
return prelu$2(x, preluActivationWeights);
|
|
}
|
|
else if (activation === 'leakyrelu') {
|
|
return leakyRelu$2(x, leakyreluAlpha);
|
|
}
|
|
else if (activation === 'sigmoid') {
|
|
return sigmoid$2(x);
|
|
}
|
|
throw new Error(`Unknown fused activation ${activation}.`);
|
|
}
|
|
|
|
const shouldFuse = (gradientDepth, activation) => {
|
|
const gradientMode = gradientDepth > 0;
|
|
return !gradientMode || activation === 'linear';
|
|
};
|
|
|
|
|
|
function depthwiseConv2dNativeBackpropFilter_(x, dy, filterShape, strides, pad, dilations = [1, 1], dimRoundingMode) {
|
|
let x4D = x;
|
|
if (x.rank === 3) {
|
|
x4D = reshape$2(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
|
|
}
|
|
let dy4D = dy;
|
|
if (dy4D.rank === 3) {
|
|
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
|
|
}
|
|
const inputs = { x: x4D, dy: dy4D };
|
|
const attrs = { strides, pad, dimRoundingMode, dilations, filterShape };
|
|
|
|
return ENGINE.runKernel(DepthwiseConv2dNativeBackpropFilter, inputs, attrs);
|
|
}
|
|
const depthwiseConv2dNativeBackpropFilter$2 = op({ depthwiseConv2dNativeBackpropFilter_ });
|
|
|
|
|
|
function depthwiseConv2dNativeBackpropInput_(xShape, dy, filter, strides, pad, dilations = [1, 1], dimRoundingMode) {
|
|
let dy4D = dy;
|
|
let reshapedTo4D = false;
|
|
if (dy.rank === 3) {
|
|
reshapedTo4D = true;
|
|
dy4D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2]]);
|
|
}
|
|
const inputs = { dy: dy4D, filter };
|
|
const attrs = { strides, pad, dimRoundingMode, dilations, inputShape: xShape };
|
|
const res =
|
|
|
|
ENGINE.runKernel(DepthwiseConv2dNativeBackpropInput, inputs, attrs);
|
|
if (reshapedTo4D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
|
|
}
|
|
return res;
|
|
}
|
|
const depthwiseConv2dNativeBackpropInput$2 = op({ depthwiseConv2dNativeBackpropInput_ });
|
|
|
|
|
|
|
|
function fusedMatMul_({ a, b, transposeA = false, transposeB = false, bias, activation = 'linear', preluActivationWeights, leakyreluAlpha = 0.2, }) {
|
|
if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) {
|
|
let result = matMul$1(a, b, transposeA, transposeB);
|
|
if (bias != null) {
|
|
result = add$1(result, bias);
|
|
}
|
|
return applyActivation$1(result, activation, preluActivationWeights, leakyreluAlpha);
|
|
}
|
|
let $a = convertToTensor(a, 'a', 'fused matMul');
|
|
let $b = convertToTensor(b, 'b', 'fused matMul');
|
|
[$a, $b] = makeTypesMatch($a, $b);
|
|
const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1];
|
|
const innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2];
|
|
const outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2];
|
|
const outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1];
|
|
const outerDimsA = $a.shape.slice(0, -2);
|
|
const outerDimsB = $b.shape.slice(0, -2);
|
|
const batchDimA = sizeFromShape(outerDimsA);
|
|
const batchDimB = sizeFromShape(outerDimsB);
|
|
assert$1(innerShapeA === innerShapeB, () => `Error in fused matMul: inner shapes (${innerShapeA}) and (` +
|
|
`${innerShapeB}) of Tensors with shapes ${$a.shape} and ` +
|
|
`${$b.shape} and transposeA=${transposeA}` +
|
|
` and transposeB=${transposeB} must match.`);
|
|
const outShapeOuterDims = assertAndGetBroadcastShape($a.shape.slice(0, -2), $b.shape.slice(0, -2));
|
|
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
|
|
const a3D = transposeA ?
|
|
reshape$2($a, [batchDimA, innerShapeA, outerShapeA]) :
|
|
reshape$2($a, [batchDimA, outerShapeA, innerShapeA]);
|
|
const b3D = transposeB ?
|
|
reshape$2($b, [batchDimB, outerShapeB, innerShapeB]) :
|
|
reshape$2($b, [batchDimB, innerShapeB, outerShapeB]);
|
|
let $bias;
|
|
if (bias != null) {
|
|
$bias = convertToTensor(bias, 'bias', 'fused matMul');
|
|
[$bias] = makeTypesMatch($bias, $a);
|
|
assertAndGetBroadcastShape(outShape, $bias.shape);
|
|
}
|
|
let $preluActivationWeights;
|
|
if (preluActivationWeights != null) {
|
|
$preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul');
|
|
}
|
|
const grad = (dy, saved) => {
|
|
const [a3D, b3D, y, $bias] = saved;
|
|
|
|
|
|
|
|
const dyActivation = getFusedDyActivation(reshape$2(dy, y.shape), y, activation);
|
|
let aDer;
|
|
let bDer;
|
|
if (!transposeA && !transposeB) {
|
|
aDer = matMul$1(dyActivation, b3D, false, true);
|
|
bDer = matMul$1(a3D, dyActivation, true, false);
|
|
}
|
|
else if (!transposeA && transposeB) {
|
|
aDer = matMul$1(dyActivation, b3D, false, false);
|
|
bDer = matMul$1(dyActivation, a3D, true, false);
|
|
}
|
|
else if (transposeA && !transposeB) {
|
|
aDer = matMul$1(b3D, dyActivation, false, true);
|
|
bDer = matMul$1(a3D, dyActivation, false, false);
|
|
}
|
|
else {
|
|
aDer = matMul$1(b3D, dyActivation, true, true);
|
|
bDer = matMul$1(dyActivation, a3D, true, true);
|
|
}
|
|
if (bias != null) {
|
|
const biasDer = getFusedBiasGradient($bias, dyActivation);
|
|
return [aDer, bDer, biasDer];
|
|
}
|
|
else {
|
|
return [aDer, bDer];
|
|
}
|
|
};
|
|
const inputs = {
|
|
a: a3D,
|
|
b: b3D,
|
|
bias: $bias,
|
|
preluActivationWeights: $preluActivationWeights
|
|
};
|
|
const attrs = { transposeA, transposeB, activation, leakyreluAlpha };
|
|
|
|
|
|
if (bias == null) {
|
|
const customOp = customGrad((a3D, b3D, save) => {
|
|
const res =
|
|
|
|
ENGINE.runKernel(_FusedMatMul, inputs, attrs);
|
|
save([a3D, b3D, res]);
|
|
return { value: reshape$2(res, outShape), gradFunc: grad };
|
|
});
|
|
return customOp(a3D, b3D);
|
|
}
|
|
else {
|
|
const customOpWithBias = customGrad((a3D, b3D, $bias, save) => {
|
|
const res =
|
|
|
|
ENGINE.runKernel(_FusedMatMul, inputs, attrs);
|
|
save([a3D, b3D, res, $bias]);
|
|
return { value: reshape$2(res, outShape), gradFunc: grad };
|
|
});
|
|
return customOpWithBias(a3D, b3D, $bias);
|
|
}
|
|
}
|
|
const matMul = op({ fusedMatMul_ });
|
|
|
|
|
|
|
|
function binaryInsert(arr, element, comparator) {
|
|
const index = binarySearch(arr, element, comparator);
|
|
const insertionPoint = index < 0 ? -(index + 1) : index;
|
|
arr.splice(insertionPoint, 0, element);
|
|
}
|
|
|
|
function binarySearch(arr, target, comparator) {
|
|
return binarySearch_(arr, target, comparator || defaultComparator);
|
|
}
|
|
|
|
function defaultComparator(a, b) {
|
|
return a > b ? 1 : a < b ? -1 : 0;
|
|
}
|
|
function binarySearch_(arr, target, comparator) {
|
|
let left = 0;
|
|
let right = arr.length;
|
|
let middle = 0;
|
|
let found = false;
|
|
while (left < right) {
|
|
middle = left + ((right - left) >>> 1);
|
|
const compareResult = comparator(target, arr[middle]);
|
|
if (compareResult > 0) {
|
|
left = middle + 1;
|
|
}
|
|
else {
|
|
right = middle;
|
|
|
|
|
|
found = !compareResult;
|
|
}
|
|
}
|
|
return found ? left : -left - 1;
|
|
}
|
|
|
|
|
|
function nonMaxSuppressionV3Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) {
|
|
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 );
|
|
}
|
|
function nonMaxSuppressionV4Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize) {
|
|
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, 0 , false , padToMaxOutputSize , true
|
|
);
|
|
}
|
|
function nonMaxSuppressionV5Impl$2(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma) {
|
|
return nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, true );
|
|
}
|
|
function nonMaxSuppressionImpl_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma, returnScoresTensor = false, padToMaxOutputSize = false, returnValidOutputs = false) {
|
|
|
|
|
|
const candidates = [];
|
|
for (let i = 0; i < scores.length; i++) {
|
|
if (scores[i] > scoreThreshold) {
|
|
candidates.push({ score: scores[i], boxIndex: i, suppressBeginIndex: 0 });
|
|
}
|
|
}
|
|
candidates.sort(ascendingComparator);
|
|
|
|
|
|
const scale = softNmsSigma > 0 ? (-0.5 / softNmsSigma) : 0.0;
|
|
const selectedIndices = [];
|
|
const selectedScores = [];
|
|
while (selectedIndices.length < maxOutputSize && candidates.length > 0) {
|
|
const candidate = candidates.pop();
|
|
const { score: originalScore, boxIndex, suppressBeginIndex } = candidate;
|
|
if (originalScore < scoreThreshold) {
|
|
break;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let ignoreCandidate = false;
|
|
for (let j = selectedIndices.length - 1; j >= suppressBeginIndex; --j) {
|
|
const iou = intersectionOverUnion(boxes, boxIndex, selectedIndices[j]);
|
|
if (iou >= iouThreshold) {
|
|
ignoreCandidate = true;
|
|
break;
|
|
}
|
|
candidate.score =
|
|
candidate.score * suppressWeight(iouThreshold, scale, iou);
|
|
if (candidate.score <= scoreThreshold) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
candidate.suppressBeginIndex = selectedIndices.length;
|
|
if (!ignoreCandidate) {
|
|
|
|
|
|
if (candidate.score === originalScore) {
|
|
selectedIndices.push(boxIndex);
|
|
selectedScores.push(candidate.score);
|
|
}
|
|
else if (candidate.score > scoreThreshold) {
|
|
|
|
|
|
binaryInsert(candidates, candidate, ascendingComparator);
|
|
}
|
|
}
|
|
}
|
|
|
|
const validOutputs = selectedIndices.length;
|
|
const elemsToPad = maxOutputSize - validOutputs;
|
|
if (padToMaxOutputSize && elemsToPad > 0) {
|
|
selectedIndices.push(...new Array(elemsToPad).fill(0));
|
|
selectedScores.push(...new Array(elemsToPad).fill(0.0));
|
|
}
|
|
const result = { selectedIndices };
|
|
if (returnScoresTensor) {
|
|
result['selectedScores'] = selectedScores;
|
|
}
|
|
if (returnValidOutputs) {
|
|
result['validOutputs'] = validOutputs;
|
|
}
|
|
return result;
|
|
}
|
|
function intersectionOverUnion(boxes, i, j) {
|
|
const iCoord = boxes.subarray(i * 4, i * 4 + 4);
|
|
const jCoord = boxes.subarray(j * 4, j * 4 + 4);
|
|
const yminI = Math.min(iCoord[0], iCoord[2]);
|
|
const xminI = Math.min(iCoord[1], iCoord[3]);
|
|
const ymaxI = Math.max(iCoord[0], iCoord[2]);
|
|
const xmaxI = Math.max(iCoord[1], iCoord[3]);
|
|
const yminJ = Math.min(jCoord[0], jCoord[2]);
|
|
const xminJ = Math.min(jCoord[1], jCoord[3]);
|
|
const ymaxJ = Math.max(jCoord[0], jCoord[2]);
|
|
const xmaxJ = Math.max(jCoord[1], jCoord[3]);
|
|
const areaI = (ymaxI - yminI) * (xmaxI - xminI);
|
|
const areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
|
|
if (areaI <= 0 || areaJ <= 0) {
|
|
return 0.0;
|
|
}
|
|
const intersectionYmin = Math.max(yminI, yminJ);
|
|
const intersectionXmin = Math.max(xminI, xminJ);
|
|
const intersectionYmax = Math.min(ymaxI, ymaxJ);
|
|
const intersectionXmax = Math.min(xmaxI, xmaxJ);
|
|
const intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) *
|
|
Math.max(intersectionXmax - intersectionXmin, 0.0);
|
|
return intersectionArea / (areaI + areaJ - intersectionArea);
|
|
}
|
|
|
|
|
|
|
|
|
|
function suppressWeight(iouThreshold, scale, iou) {
|
|
const weight = Math.exp(scale * iou * iou);
|
|
return iou <= iouThreshold ? weight : 0.0;
|
|
}
|
|
function ascendingComparator(c1, c2) {
|
|
|
|
|
|
|
|
|
|
return (c1.score - c2.score) ||
|
|
((c1.score === c2.score) && (c2.boxIndex - c1.boxIndex));
|
|
}
|
|
|
|
|
|
|
|
function bandPart_(a, numLower, numUpper) {
|
|
const $a = convertToTensor(a, 'a', 'bandPart');
|
|
assert$1($a.rank >= 2, () => `bandPart(): Rank must be at least 2, got ${$a.rank}.`);
|
|
const shape = $a.shape;
|
|
const [M, N] = $a.shape.slice(-2);
|
|
let $numLower;
|
|
let $numUpper;
|
|
if (typeof numLower === 'number') {
|
|
assert$1(numLower % 1 === 0, () => `bandPart(): numLower must be an integer, got ${numLower}.`);
|
|
assert$1(numLower <= M, () => `bandPart(): numLower (${numLower})` +
|
|
` must not be greater than the number of rows (${M}).`);
|
|
$numLower =
|
|
convertToTensor(numLower < 0 ? M : numLower, 'numLower', 'bandPart');
|
|
}
|
|
else {
|
|
assert$1(numLower.dtype === 'int32', () => `bandPart(): numLower's dtype must be an int32.`);
|
|
|
|
|
|
$numLower = where(less$2(numLower, 0), M, minimum$2(numLower, M));
|
|
}
|
|
if (typeof numUpper === 'number') {
|
|
assert$1(numUpper % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${numUpper}.`);
|
|
assert$1(numUpper <= N, () => `bandPart(): numUpper (${numUpper})` +
|
|
` must not be greater than the number of columns (${N}).`);
|
|
$numUpper =
|
|
convertToTensor(numUpper < 0 ? N : numUpper, 'numUpper', 'bandPart');
|
|
}
|
|
else {
|
|
assert$1(numUpper.dtype === 'int32', () => `bandPart(): numUpper's dtype must be an int32.`);
|
|
$numUpper = where(less$2(numUpper, 0), N, minimum$2(numUpper, N));
|
|
}
|
|
const i = reshape$2(range$3(0, M, 1, 'int32'), [-1, 1]);
|
|
const j = range$3(0, N, 1, 'int32');
|
|
const ij = sub$2(i, j);
|
|
const inBand = logicalAnd$2(lessEqual$2(ij, $numLower), greaterEqual$2(ij, neg$2($numUpper)));
|
|
const zero = zeros$1([M, N], $a.dtype);
|
|
return reshape$2(stack(unstack(reshape$2($a, [-1, M, N]))
|
|
.map(mat => where(inBand, mat, zero))), shape);
|
|
}
|
|
const bandPart = op({ bandPart_ });
|
|
|
|
|
|
|
|
function gramSchmidt_(xs) {
|
|
let inputIsTensor2D;
|
|
if (Array.isArray(xs)) {
|
|
inputIsTensor2D = false;
|
|
assert$1(xs != null && xs.length > 0, () => 'Gram-Schmidt process: input must not be null, undefined, or ' +
|
|
'empty');
|
|
const dim = xs[0].shape[0];
|
|
for (let i = 1; i < xs.length; ++i) {
|
|
assert$1(xs[i].shape[0] === dim, () => 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' +
|
|
`(${xs[i].shape[0]} vs. ${dim})`);
|
|
}
|
|
}
|
|
else {
|
|
inputIsTensor2D = true;
|
|
xs = split$1(xs, xs.shape[0], 0).map(x => squeeze(x, [0]));
|
|
}
|
|
assert$1(xs.length <= xs[0].shape[0], () => `Gram-Schmidt: Number of vectors (${xs.length}) exceeds ` +
|
|
`number of dimensions (${xs[0].shape[0]}).`);
|
|
const ys = [];
|
|
const xs1d = xs;
|
|
for (let i = 0; i < xs.length; ++i) {
|
|
ys.push(ENGINE.tidy(() => {
|
|
let x = xs1d[i];
|
|
if (i > 0) {
|
|
for (let j = 0; j < i; ++j) {
|
|
const proj = mul(sum$2(mul(ys[j], x)), ys[j]);
|
|
x = sub$2(x, proj);
|
|
}
|
|
}
|
|
return div$1(x, norm(x, 'euclidean'));
|
|
}));
|
|
}
|
|
if (inputIsTensor2D) {
|
|
return stack(ys, 0);
|
|
}
|
|
else {
|
|
return ys;
|
|
}
|
|
}
|
|
const gramSchmidt = op({ gramSchmidt_ });
|
|
|
|
|
|
|
|
function qr_(x, fullMatrices = false) {
|
|
assert$1(x.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${x.rank}`);
|
|
if (x.rank === 2) {
|
|
return qr2d(x, fullMatrices);
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
|
|
const outerDimsProd = x.shape.slice(0, x.shape.length - 2)
|
|
.reduce((value, prev) => value * prev);
|
|
const x2ds = unstack(reshape$2(x, [
|
|
outerDimsProd, x.shape[x.shape.length - 2],
|
|
x.shape[x.shape.length - 1]
|
|
]), 0);
|
|
const q2ds = [];
|
|
const r2ds = [];
|
|
x2ds.forEach(x2d => {
|
|
const [q2d, r2d] = qr2d(x2d, fullMatrices);
|
|
q2ds.push(q2d);
|
|
r2ds.push(r2d);
|
|
});
|
|
const q = reshape$2(stack(q2ds, 0), x.shape);
|
|
const r = reshape$2(stack(r2ds, 0), x.shape);
|
|
return [q, r];
|
|
}
|
|
}
|
|
function qr2d(x, fullMatrices = false) {
|
|
return ENGINE.tidy(() => {
|
|
assert$1(x.shape.length === 2, () => `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`);
|
|
const m = x.shape[0];
|
|
const n = x.shape[1];
|
|
let q = eye(m);
|
|
let r = clone(x);
|
|
const one2D = tensor2d([[1]], [1, 1]);
|
|
let w = clone(one2D);
|
|
const iters = m >= n ? n : m;
|
|
for (let j = 0; j < iters; ++j) {
|
|
|
|
|
|
const rTemp = r;
|
|
const wTemp = w;
|
|
const qTemp = q;
|
|
[w, r, q] = ENGINE.tidy(() => {
|
|
|
|
const rjEnd1 = slice$2(r, [j, j], [m - j, 1]);
|
|
const normX = norm(rjEnd1);
|
|
const rjj = slice$2(r, [j, j], [1, 1]);
|
|
|
|
const s = where(greater$2(rjj, 0), tensor2d([[-1]]), tensor2d([[1]]));
|
|
const u1 = sub$2(rjj, mul(s, normX));
|
|
const wPre = div$1(rjEnd1, u1);
|
|
if (wPre.shape[0] === 1) {
|
|
w = clone(one2D);
|
|
}
|
|
else {
|
|
w = concat$2([
|
|
one2D,
|
|
slice$2(wPre, [1, 0], [wPre.shape[0] - 1, wPre.shape[1]])
|
|
], 0);
|
|
}
|
|
const tau = neg$2(div$1(matMul$1(s, u1), normX));
|
|
|
|
const rjEndAll = slice$2(r, [j, 0], [m - j, n]);
|
|
const tauTimesW = mul(tau, w);
|
|
const wT = transpose$2(w);
|
|
if (j === 0) {
|
|
r = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
|
|
}
|
|
else {
|
|
const rTimesTau = sub$2(rjEndAll, matMul$1(tauTimesW, matMul$1(wT, rjEndAll)));
|
|
r = concat$2([slice$2(r, [0, 0], [j, n]), rTimesTau], 0);
|
|
}
|
|
const tawTimesWT = transpose$2(tauTimesW);
|
|
const qAllJEnd = slice$2(q, [0, j], [m, q.shape[1] - j]);
|
|
if (j === 0) {
|
|
q = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
|
|
}
|
|
else {
|
|
const qTimesTau = sub$2(qAllJEnd, matMul$1(matMul$1(qAllJEnd, w), tawTimesWT));
|
|
q = concat$2([slice$2(q, [0, 0], [m, j]), qTimesTau], 1);
|
|
}
|
|
return [w, r, q];
|
|
});
|
|
dispose([rTemp, wTemp, qTemp]);
|
|
}
|
|
if (!fullMatrices && m > n) {
|
|
q = slice$2(q, [0, 0], [m, n]);
|
|
r = slice$2(r, [0, 0], [n, n]);
|
|
}
|
|
return [q, r];
|
|
});
|
|
}
|
|
const qr = op({ qr_ });
|
|
|
|
|
|
|
|
function stringToHashBucketFast_(input, numBuckets) {
|
|
const $input = convertToTensor(input, 'input', 'stringToHashBucketFast', 'string');
|
|
const attrs = { numBuckets };
|
|
if (numBuckets <= 0) {
|
|
throw new Error(`Number of buckets must be at least 1`);
|
|
}
|
|
const inputs = { input: $input };
|
|
return ENGINE.runKernel(StringToHashBucketFast, inputs, attrs);
|
|
}
|
|
const stringToHashBucketFast$2 = op({ stringToHashBucketFast_ });
|
|
|
|
|
|
|
|
const linalg = {
|
|
bandPart,
|
|
gramSchmidt,
|
|
qr
|
|
};
|
|
|
|
|
|
|
|
const GLOBAL_CUSTOM_OBJECT = new Map();
|
|
const GLOBAL_CUSTOM_NAMES = new Map();
|
|
|
|
class Serializable {
|
|
|
|
getClassName() {
|
|
return this.constructor
|
|
.className;
|
|
}
|
|
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config);
|
|
}
|
|
}
|
|
|
|
class SerializationMap {
|
|
constructor() {
|
|
this.classNameMap = {};
|
|
}
|
|
|
|
static getMap() {
|
|
if (SerializationMap.instance == null) {
|
|
SerializationMap.instance = new SerializationMap();
|
|
}
|
|
return SerializationMap.instance;
|
|
}
|
|
|
|
static register(cls) {
|
|
SerializationMap.getMap().classNameMap[cls.className] =
|
|
[cls, cls.fromConfig];
|
|
}
|
|
}
|
|
|
|
function registerClass(cls, pkg, name) {
|
|
assert$1(cls.className != null, () => `Class being registered does not have the static className ` +
|
|
`property defined.`);
|
|
assert$1(typeof cls.className === 'string', () => `className is required to be a string, but got type ` +
|
|
typeof cls.className);
|
|
assert$1(cls.className.length > 0, () => `Class being registered has an empty-string as its className, ` +
|
|
`which is disallowed.`);
|
|
if (typeof pkg === 'undefined') {
|
|
pkg = 'Custom';
|
|
}
|
|
if (typeof name === 'undefined') {
|
|
name = cls.className;
|
|
}
|
|
const className = name;
|
|
const registerName = pkg + '>' + className;
|
|
SerializationMap.register(cls);
|
|
GLOBAL_CUSTOM_OBJECT.set(registerName, cls);
|
|
GLOBAL_CUSTOM_NAMES.set(cls, registerName);
|
|
return cls;
|
|
}
|
|
|
|
|
|
|
|
class Optimizer extends Serializable {
|
|
|
|
minimize(f, returnCost = false, varList) {
|
|
const { value, grads } = this.computeGradients(f, varList);
|
|
if (varList != null) {
|
|
const gradArray = varList.map(v => ({ name: v.name, tensor: grads[v.name] }));
|
|
this.applyGradients(gradArray);
|
|
}
|
|
else {
|
|
this.applyGradients(grads);
|
|
}
|
|
|
|
dispose(grads);
|
|
if (returnCost) {
|
|
return value;
|
|
}
|
|
else {
|
|
value.dispose();
|
|
return null;
|
|
}
|
|
}
|
|
|
|
get iterations() {
|
|
if (this.iterations_ == null) {
|
|
this.iterations_ = 0;
|
|
}
|
|
return this.iterations_;
|
|
}
|
|
incrementIterations() {
|
|
this.iterations_ = this.iterations + 1;
|
|
}
|
|
|
|
computeGradients(f, varList) {
|
|
return variableGrads(f, varList);
|
|
}
|
|
|
|
dispose() {
|
|
if (this.iterations_ != null) {
|
|
dispose(this.iterations_);
|
|
}
|
|
}
|
|
async saveIterations() {
|
|
if (this.iterations_ == null) {
|
|
this.iterations_ = 0;
|
|
}
|
|
return {
|
|
name: 'iter',
|
|
|
|
tensor: scalar(this.iterations_, 'int32')
|
|
};
|
|
}
|
|
async getWeights() {
|
|
throw new Error('getWeights() is not implemented for this optimizer yet.');
|
|
}
|
|
async setWeights(weightValues) {
|
|
throw new Error(`setWeights() is not implemented for this optimizer class ` +
|
|
`${this.getClassName()}`);
|
|
}
|
|
|
|
async extractIterations(weightValues) {
|
|
this.iterations_ = (await weightValues[0].tensor.data())[0];
|
|
return weightValues.slice(1);
|
|
}
|
|
}
|
|
Object.defineProperty(Optimizer, Symbol.hasInstance, {
|
|
value: (instance) => {
|
|
return instance.minimize != null && instance.computeGradients != null &&
|
|
instance.applyGradients != null;
|
|
}
|
|
});
|
|
|
|
|
|
|
|
class AdadeltaOptimizer extends Optimizer {
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'Adadelta';
|
|
}
|
|
constructor(learningRate, rho, epsilon = null) {
|
|
super();
|
|
this.learningRate = learningRate;
|
|
this.rho = rho;
|
|
this.epsilon = epsilon;
|
|
this.accumulatedGrads = [];
|
|
this.accumulatedUpdates = [];
|
|
if (epsilon == null) {
|
|
this.epsilon = ENGINE.backend.epsilon();
|
|
}
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const variableNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(item => item.name) :
|
|
Object.keys(variableGradients);
|
|
variableNames.forEach((name, i) => {
|
|
const value = ENGINE.registeredVariables[name];
|
|
const trainable = false;
|
|
if (this.accumulatedGrads[i] == null) {
|
|
this.accumulatedGrads[i] = {
|
|
originalName: `${name}/accum_grad`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
if (this.accumulatedUpdates[i] == null) {
|
|
this.accumulatedUpdates[i] = {
|
|
originalName: `${name}/accum_var`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
const accumulatedGrad = this.accumulatedGrads[i].variable;
|
|
const accumulatedUpdate = this.accumulatedUpdates[i].variable;
|
|
tidy(() => {
|
|
const newAccumulatedGrad = add$1(mul(accumulatedGrad, this.rho), mul(square$2(gradient), 1 - this.rho));
|
|
const updates = mul(div$1(sqrt$2(add$1(accumulatedUpdate, this.epsilon)), sqrt$2(add$1(accumulatedGrad, this.epsilon))), gradient);
|
|
const newAccumulatedUpdate = add$1(mul(accumulatedUpdate, this.rho), mul(square$2(updates), 1 - this.rho));
|
|
accumulatedGrad.assign(newAccumulatedGrad);
|
|
accumulatedUpdate.assign(newAccumulatedUpdate);
|
|
const newValue = add$1(mul(updates, -this.learningRate), value);
|
|
value.assign(newValue);
|
|
});
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
dispose() {
|
|
if (this.accumulatedUpdates != null) {
|
|
dispose(this.accumulatedGrads.map(v => v.variable));
|
|
dispose(this.accumulatedUpdates.map(v => v.variable));
|
|
}
|
|
}
|
|
async getWeights() {
|
|
|
|
const variables = [...this.accumulatedGrads, ...this.accumulatedUpdates];
|
|
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
|
|
}
|
|
async setWeights(weightValues) {
|
|
weightValues = await this.extractIterations(weightValues);
|
|
const variableCount = weightValues.length / 2;
|
|
const trainable = false;
|
|
this.accumulatedGrads =
|
|
weightValues.slice(0, variableCount).map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
this.accumulatedUpdates =
|
|
weightValues.slice(variableCount, variableCount * 2)
|
|
.map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
}
|
|
getConfig() {
|
|
return {
|
|
'learningRate': this.learningRate,
|
|
'rho': this.rho,
|
|
'epsilon': this.epsilon
|
|
};
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate'], config['rho'], config['epsilon']);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class AdagradOptimizer extends Optimizer {
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'Adagrad';
|
|
}
|
|
constructor(learningRate, initialAccumulatorValue = 0.1) {
|
|
super();
|
|
this.learningRate = learningRate;
|
|
this.initialAccumulatorValue = initialAccumulatorValue;
|
|
this.accumulatedGrads = [];
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const variableNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(item => item.name) :
|
|
Object.keys(variableGradients);
|
|
variableNames.forEach((name, i) => {
|
|
const value = ENGINE.registeredVariables[name];
|
|
if (this.accumulatedGrads[i] == null) {
|
|
const trainable = false;
|
|
this.accumulatedGrads[i] = {
|
|
originalName: `${name}/accumulator`,
|
|
variable: tidy(() => fill$2(value.shape, this.initialAccumulatorValue)
|
|
.variable(trainable))
|
|
};
|
|
}
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
const accumulatedGrad = this.accumulatedGrads[i].variable;
|
|
tidy(() => {
|
|
const newAccumulatedGrad = add$1(accumulatedGrad, square$2(gradient));
|
|
accumulatedGrad.assign(newAccumulatedGrad);
|
|
const newValue = add$1(mul(div$1(gradient, sqrt$2(add$1(newAccumulatedGrad, ENGINE.backend.epsilon()))), -this.learningRate), value);
|
|
value.assign(newValue);
|
|
});
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
dispose() {
|
|
if (this.accumulatedGrads != null) {
|
|
dispose(this.accumulatedGrads.map(v => v.variable));
|
|
}
|
|
}
|
|
async getWeights() {
|
|
|
|
return [await this.saveIterations()].concat(this.accumulatedGrads.map(v => ({ name: v.originalName, tensor: v.variable })));
|
|
}
|
|
async setWeights(weightValues) {
|
|
weightValues = await this.extractIterations(weightValues);
|
|
const trainable = false;
|
|
this.accumulatedGrads = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
|
|
}
|
|
getConfig() {
|
|
return {
|
|
'learningRate': this.learningRate,
|
|
'initialAccumulatorValue': this.initialAccumulatorValue,
|
|
};
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate'], config['initialAccumulatorValue']);
|
|
}
|
|
}
|
|
|
|
|
|
class AdamOptimizer extends Optimizer {
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'Adam';
|
|
}
|
|
constructor(learningRate, beta1, beta2, epsilon = null) {
|
|
super();
|
|
this.learningRate = learningRate;
|
|
this.beta1 = beta1;
|
|
this.beta2 = beta2;
|
|
this.epsilon = epsilon;
|
|
this.accumulatedFirstMoment = [];
|
|
this.accumulatedSecondMoment = [];
|
|
tidy(() => {
|
|
|
|
this.accBeta1 = scalar(beta1).variable();
|
|
this.accBeta2 = scalar(beta2).variable();
|
|
});
|
|
if (epsilon == null) {
|
|
this.epsilon = ENGINE.backend.epsilon();
|
|
}
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const varNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(v => v.name) :
|
|
Object.keys(variableGradients);
|
|
tidy(() => {
|
|
const oneMinusAccBeta1 = sub$2(1, this.accBeta1);
|
|
const oneMinusAccBeta2 = sub$2(1, this.accBeta2);
|
|
varNames.forEach((name, i) => {
|
|
const value = ENGINE.registeredVariables[name];
|
|
const trainable = false;
|
|
if (this.accumulatedFirstMoment[i] == null) {
|
|
this.accumulatedFirstMoment[i] = {
|
|
originalName: `${name}/m`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
if (this.accumulatedSecondMoment[i] == null) {
|
|
this.accumulatedSecondMoment[i] = {
|
|
originalName: `${name}/v`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
const firstMoment = this.accumulatedFirstMoment[i].variable;
|
|
const secondMoment = this.accumulatedSecondMoment[i].variable;
|
|
const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
|
|
const newSecondMoment = add$1(mul(secondMoment, this.beta2), mul(square$2(gradient), 1 - this.beta2));
|
|
const biasCorrectedFirstMoment = div$1(newFirstMoment, oneMinusAccBeta1);
|
|
const biasCorrectedSecondMoment = div$1(newSecondMoment, oneMinusAccBeta2);
|
|
firstMoment.assign(newFirstMoment);
|
|
secondMoment.assign(newSecondMoment);
|
|
const newValue = add$1(mul(div$1(biasCorrectedFirstMoment, add$1(sqrt$2(biasCorrectedSecondMoment), this.epsilon)), -this.learningRate), value);
|
|
value.assign(newValue);
|
|
});
|
|
this.accBeta1.assign(mul(this.accBeta1, this.beta1));
|
|
this.accBeta2.assign(mul(this.accBeta2, this.beta2));
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
dispose() {
|
|
this.accBeta1.dispose();
|
|
this.accBeta2.dispose();
|
|
if (this.accumulatedFirstMoment != null) {
|
|
dispose(this.accumulatedFirstMoment.map(v => v.variable));
|
|
}
|
|
if (this.accumulatedSecondMoment != null) {
|
|
dispose(this.accumulatedSecondMoment.map(v => v.variable));
|
|
}
|
|
}
|
|
async getWeights() {
|
|
|
|
const variables = [...this.accumulatedFirstMoment, ...this.accumulatedSecondMoment];
|
|
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
|
|
}
|
|
async setWeights(weightValues) {
|
|
weightValues = await this.extractIterations(weightValues);
|
|
tidy(() => {
|
|
this.accBeta1.assign(pow$2(this.beta1, this.iterations_ + 1));
|
|
this.accBeta2.assign(pow$2(this.beta2, this.iterations_ + 1));
|
|
});
|
|
const variableCount = weightValues.length / 2;
|
|
const trainable = false;
|
|
this.accumulatedFirstMoment =
|
|
weightValues.slice(0, variableCount).map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
this.accumulatedSecondMoment =
|
|
weightValues.slice(variableCount, variableCount * 2)
|
|
.map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
}
|
|
getConfig() {
|
|
return {
|
|
'learningRate': this.learningRate,
|
|
'beta1': this.beta1,
|
|
'beta2': this.beta2,
|
|
'epsilon': this.epsilon,
|
|
};
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon']);
|
|
}
|
|
}
|
|
|
|
|
|
class AdamaxOptimizer extends Optimizer {
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'Adamax';
|
|
}
|
|
constructor(learningRate, beta1, beta2, epsilon = null, decay = 0.0) {
|
|
super();
|
|
this.learningRate = learningRate;
|
|
this.beta1 = beta1;
|
|
this.beta2 = beta2;
|
|
this.epsilon = epsilon;
|
|
this.decay = decay;
|
|
this.accumulatedFirstMoment = [];
|
|
this.accumulatedWeightedInfNorm = [];
|
|
tidy(() => {
|
|
this.iteration = scalar(0).variable();
|
|
this.accBeta1 = scalar(beta1).variable();
|
|
});
|
|
if (epsilon == null) {
|
|
this.epsilon = ENGINE.backend.epsilon();
|
|
}
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const variableNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(item => item.name) :
|
|
Object.keys(variableGradients);
|
|
tidy(() => {
|
|
const oneMinusAccBeta1 = sub$2(1, this.accBeta1);
|
|
const lr = div$1(-this.learningRate, add$1(mul(this.iteration, this.decay), 1));
|
|
variableNames.forEach((name, i) => {
|
|
const value = ENGINE.registeredVariables[name];
|
|
const trainable = false;
|
|
if (this.accumulatedFirstMoment[i] == null) {
|
|
this.accumulatedFirstMoment[i] = {
|
|
originalName: `${name}/m`,
|
|
variable: zerosLike$2(value).variable(trainable)
|
|
};
|
|
}
|
|
if (this.accumulatedWeightedInfNorm[i] == null) {
|
|
this.accumulatedWeightedInfNorm[i] = {
|
|
originalName: `${name}/v`,
|
|
variable: zerosLike$2(value).variable(trainable)
|
|
};
|
|
}
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
const firstMoment = this.accumulatedFirstMoment[i].variable;
|
|
const weightedInfNorm = this.accumulatedWeightedInfNorm[i].variable;
|
|
const newFirstMoment = add$1(mul(firstMoment, this.beta1), mul(gradient, 1 - this.beta1));
|
|
const ut0 = mul(weightedInfNorm, this.beta2);
|
|
const ut1 = abs$2(gradient);
|
|
const newWeightedInfNorm = maximum$2(ut0, ut1);
|
|
firstMoment.assign(newFirstMoment);
|
|
weightedInfNorm.assign(newWeightedInfNorm);
|
|
const newValue = add$1(mul(div$1(lr, oneMinusAccBeta1), div$1(newFirstMoment, add$1(newWeightedInfNorm, this.epsilon))), value);
|
|
value.assign(newValue);
|
|
});
|
|
this.iteration.assign(add$1(this.iteration, 1));
|
|
this.accBeta1.assign(mul(this.accBeta1, this.beta1));
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
dispose() {
|
|
this.accBeta1.dispose();
|
|
this.iteration.dispose();
|
|
if (this.accumulatedFirstMoment != null) {
|
|
dispose(this.accumulatedFirstMoment.map(v => v.variable));
|
|
}
|
|
if (this.accumulatedWeightedInfNorm != null) {
|
|
dispose(this.accumulatedWeightedInfNorm.map(v => v.variable));
|
|
}
|
|
}
|
|
async getWeights() {
|
|
throw new Error('getWeights() is not implemented for Adamax yet.');
|
|
}
|
|
async setWeights(weightValues) {
|
|
throw new Error('setWeights() is not implemented for Adamax yet.');
|
|
}
|
|
getConfig() {
|
|
return {
|
|
'learningRate': this.learningRate,
|
|
'beta1': this.beta1,
|
|
'beta2': this.beta2,
|
|
'epsilon': this.epsilon,
|
|
'decay': this.decay
|
|
};
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate'], config['beta1'], config['beta2'], config['epsilon'], config['decay']);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class SGDOptimizer extends Optimizer {
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'SGD';
|
|
}
|
|
constructor(learningRate) {
|
|
super();
|
|
this.learningRate = learningRate;
|
|
this.setLearningRate(learningRate);
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const varNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(v => v.name) :
|
|
Object.keys(variableGradients);
|
|
varNames.forEach((name, i) => {
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
const value = ENGINE.registeredVariables[name];
|
|
tidy(() => {
|
|
const newValue = add$1(mul(this.c, gradient), value);
|
|
value.assign(newValue);
|
|
});
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
|
|
setLearningRate(learningRate) {
|
|
this.learningRate = learningRate;
|
|
if (this.c != null) {
|
|
this.c.dispose();
|
|
}
|
|
this.c = keep(scalar(-learningRate));
|
|
}
|
|
dispose() {
|
|
this.c.dispose();
|
|
}
|
|
async getWeights() {
|
|
return [await this.saveIterations()];
|
|
}
|
|
async setWeights(weightValues) {
|
|
weightValues = await this.extractIterations(weightValues);
|
|
if (weightValues.length !== 0) {
|
|
throw new Error('SGD optimizer does not have settable weights.');
|
|
}
|
|
}
|
|
getConfig() {
|
|
return { 'learningRate': this.learningRate };
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate']);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class MomentumOptimizer extends SGDOptimizer {
|
|
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'Momentum';
|
|
}
|
|
constructor(learningRate, momentum, useNesterov = false) {
|
|
super(learningRate);
|
|
this.learningRate = learningRate;
|
|
this.momentum = momentum;
|
|
this.useNesterov = useNesterov;
|
|
this.accumulations = [];
|
|
this.m = scalar(this.momentum);
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const variableNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(item => item.name) :
|
|
Object.keys(variableGradients);
|
|
variableNames.forEach((name, i) => {
|
|
const value = ENGINE.registeredVariables[name];
|
|
if (this.accumulations[i] == null) {
|
|
const trainable = false;
|
|
this.accumulations[i] = {
|
|
originalName: `${name}/momentum`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
const accumulation = this.accumulations[i].variable;
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
tidy(() => {
|
|
let newValue;
|
|
const newAccumulation = add$1(mul(this.m, accumulation), gradient);
|
|
if (this.useNesterov) {
|
|
newValue = add$1(mul(this.c, add$1(gradient, mul(newAccumulation, this.m))), value);
|
|
}
|
|
else {
|
|
newValue = add$1(mul(this.c, newAccumulation), value);
|
|
}
|
|
accumulation.assign(newAccumulation);
|
|
value.assign(newValue);
|
|
});
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
dispose() {
|
|
this.m.dispose();
|
|
if (this.accumulations != null) {
|
|
dispose(this.accumulations.map(v => v.variable));
|
|
}
|
|
}
|
|
|
|
setMomentum(momentum) {
|
|
this.momentum = momentum;
|
|
}
|
|
async getWeights() {
|
|
|
|
return [await this.saveIterations()].concat(this.accumulations.map(v => ({ name: v.originalName, tensor: v.variable })));
|
|
}
|
|
async setWeights(weightValues) {
|
|
weightValues = await this.extractIterations(weightValues);
|
|
const trainable = false;
|
|
this.accumulations = weightValues.map(v => ({ originalName: v.name, variable: v.tensor.variable(trainable) }));
|
|
}
|
|
getConfig() {
|
|
return {
|
|
'learningRate': this.learningRate,
|
|
'momentum': this.momentum,
|
|
'useNesterov': this.useNesterov
|
|
};
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate'], config['momentum'], config['useNesterov']);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class RMSPropOptimizer extends Optimizer {
|
|
|
|
static get className() {
|
|
|
|
|
|
|
|
return 'RMSProp';
|
|
}
|
|
constructor(learningRate, decay = 0.9, momentum = 0.0, epsilon = null, centered = false) {
|
|
super();
|
|
this.learningRate = learningRate;
|
|
this.decay = decay;
|
|
this.momentum = momentum;
|
|
this.epsilon = epsilon;
|
|
this.accumulatedMeanSquares = [];
|
|
this.accumulatedMoments = [];
|
|
this.accumulatedMeanGrads = [];
|
|
this.centered = centered;
|
|
if (epsilon == null) {
|
|
this.epsilon = ENGINE.backend.epsilon();
|
|
}
|
|
if (learningRate == null) {
|
|
throw new Error(`learningRate for RMSPropOptimizer must be defined.`);
|
|
}
|
|
}
|
|
applyGradients(variableGradients) {
|
|
const variableNames = Array.isArray(variableGradients) ?
|
|
variableGradients.map(item => item.name) :
|
|
Object.keys(variableGradients);
|
|
variableNames.forEach((name, i) => {
|
|
const value = ENGINE.registeredVariables[name];
|
|
const trainable = false;
|
|
if (this.accumulatedMeanSquares[i] == null) {
|
|
this.accumulatedMeanSquares[i] = {
|
|
originalName: `${name}/rms`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
if (this.accumulatedMoments[i] == null) {
|
|
this.accumulatedMoments[i] = {
|
|
originalName: `${name}/momentum`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
if (this.accumulatedMeanGrads[i] == null && this.centered) {
|
|
this.accumulatedMeanGrads[i] = {
|
|
originalName: `${name}/mg`,
|
|
variable: tidy(() => zerosLike$2(value).variable(trainable))
|
|
};
|
|
}
|
|
const gradient = Array.isArray(variableGradients) ?
|
|
variableGradients[i].tensor :
|
|
variableGradients[name];
|
|
if (gradient == null) {
|
|
return;
|
|
}
|
|
const accumulatedMeanSquare = this.accumulatedMeanSquares[i].variable;
|
|
const accumulatedMoments = this.accumulatedMoments[i].variable;
|
|
tidy(() => {
|
|
const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay));
|
|
if (this.centered) {
|
|
const accumulatedMeanGrad = this.accumulatedMeanGrads[i].variable;
|
|
|
|
const newAccumulatedMeanGrad = add$1(mul(accumulatedMeanGrad, this.decay), mul(gradient, 1 - this.decay));
|
|
const gradContribution = div$1(mul(gradient, this.learningRate), sqrt$2(sub$2(newAccumulatedMeanSquare, add$1(square$2(newAccumulatedMeanGrad), this.epsilon))));
|
|
const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), gradContribution);
|
|
accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
|
|
accumulatedMeanGrad.assign(newAccumulatedMeanGrad);
|
|
accumulatedMoments.assign(newAccumulatedMoments);
|
|
const newValue = sub$2(value, newAccumulatedMoments);
|
|
value.assign(newValue);
|
|
}
|
|
else {
|
|
|
|
const newAccumulatedMeanSquare = add$1(mul(accumulatedMeanSquare, this.decay), mul(square$2(gradient), 1 - this.decay));
|
|
const newAccumulatedMoments = add$1(mul(accumulatedMoments, this.momentum), div$1(mul(gradient, this.learningRate), sqrt$2(add$1(newAccumulatedMeanSquare, this.epsilon))));
|
|
accumulatedMeanSquare.assign(newAccumulatedMeanSquare);
|
|
accumulatedMoments.assign(newAccumulatedMoments);
|
|
const newValue = sub$2(value, newAccumulatedMoments);
|
|
value.assign(newValue);
|
|
}
|
|
});
|
|
});
|
|
this.incrementIterations();
|
|
}
|
|
dispose() {
|
|
if (this.accumulatedMeanSquares != null) {
|
|
dispose(this.accumulatedMeanSquares.map(v => v.variable));
|
|
}
|
|
if (this.accumulatedMeanGrads != null && this.centered) {
|
|
dispose(this.accumulatedMeanGrads.map(v => v.variable));
|
|
}
|
|
if (this.accumulatedMoments != null) {
|
|
dispose(this.accumulatedMoments.map(v => v.variable));
|
|
}
|
|
}
|
|
async getWeights() {
|
|
|
|
const variables = [...this.accumulatedMeanSquares, ...this.accumulatedMoments];
|
|
if (this.centered) {
|
|
variables.push(...this.accumulatedMeanGrads);
|
|
}
|
|
return [await this.saveIterations()].concat(variables.map(v => ({ name: v.originalName, tensor: v.variable })));
|
|
}
|
|
async setWeights(weightValues) {
|
|
weightValues = await this.extractIterations(weightValues);
|
|
const variableCount = this.centered ? weightValues.length / 3 : weightValues.length / 2;
|
|
const trainable = false;
|
|
this.accumulatedMeanSquares =
|
|
weightValues.slice(0, variableCount).map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
this.accumulatedMoments =
|
|
weightValues.slice(variableCount, variableCount * 2)
|
|
.map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
if (this.centered) {
|
|
this.accumulatedMeanGrads =
|
|
weightValues.slice(variableCount * 2, variableCount * 3)
|
|
.map(v => ({
|
|
originalName: v.name,
|
|
variable: v.tensor.variable(trainable)
|
|
}));
|
|
}
|
|
}
|
|
getConfig() {
|
|
return {
|
|
'learningRate': this.learningRate,
|
|
'decay': this.decay,
|
|
'momentum': this.momentum,
|
|
'epsilon': this.epsilon,
|
|
'centered': this.centered
|
|
};
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls(config['learningRate'], config['decay'], config['momentum'], config['epsilon'], config['centered']);
|
|
}
|
|
}
|
|
|
|
|
|
const OPTIMIZERS = [
|
|
AdadeltaOptimizer,
|
|
AdagradOptimizer,
|
|
AdamOptimizer,
|
|
AdamaxOptimizer,
|
|
MomentumOptimizer,
|
|
RMSPropOptimizer,
|
|
SGDOptimizer,
|
|
];
|
|
function registerOptimizers() {
|
|
for (const optimizer of OPTIMIZERS) {
|
|
registerClass(optimizer);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
const DTYPE_VALUE_SIZE_MAP = {
|
|
'float32': 4,
|
|
'float16': 2,
|
|
'int32': 4,
|
|
'uint16': 2,
|
|
'uint8': 1,
|
|
'bool': 1,
|
|
'complex64': 8
|
|
};
|
|
|
|
|
|
class CompositeArrayBuffer {
|
|
|
|
static join(buffers) {
|
|
return new CompositeArrayBuffer(buffers).slice();
|
|
}
|
|
constructor(buffers) {
|
|
this.shards = [];
|
|
this.previousShardIndex = 0;
|
|
if (buffers == null) {
|
|
return;
|
|
}
|
|
|
|
if (!(buffers instanceof Array)) {
|
|
buffers = [buffers];
|
|
}
|
|
buffers = buffers.map((bufferOrTypedArray) => {
|
|
if (isTypedArray(bufferOrTypedArray)) {
|
|
return bufferOrTypedArray.buffer;
|
|
}
|
|
return bufferOrTypedArray;
|
|
});
|
|
|
|
if (buffers.length === 0) {
|
|
return;
|
|
}
|
|
this.bufferUniformSize = buffers[0].byteLength;
|
|
let start = 0;
|
|
for (let i = 0; i < buffers.length; i++) {
|
|
const buffer = buffers[i];
|
|
|
|
if (i !== buffers.length - 1 &&
|
|
buffer.byteLength !== this.bufferUniformSize) {
|
|
|
|
|
|
this.bufferUniformSize = undefined;
|
|
}
|
|
|
|
const end = start + buffer.byteLength;
|
|
this.shards.push({ buffer, start, end });
|
|
start = end;
|
|
}
|
|
|
|
if (this.shards.length === 0) {
|
|
this.byteLength = 0;
|
|
}
|
|
this.byteLength = this.shards[this.shards.length - 1].end;
|
|
}
|
|
slice(start = 0, end = this.byteLength) {
|
|
|
|
|
|
if (this.shards.length === 0) {
|
|
return new ArrayBuffer(0);
|
|
}
|
|
|
|
start = isNaN(Number(start)) ? 0 : start;
|
|
end = isNaN(Number(end)) ? 0 : end;
|
|
|
|
start = Math.max(0, start);
|
|
end = Math.min(this.byteLength, end);
|
|
if (end <= start) {
|
|
return new ArrayBuffer(0);
|
|
}
|
|
const startShardIndex = this.findShardForByte(start);
|
|
if (startShardIndex === -1) {
|
|
|
|
|
|
throw new Error(`Could not find start shard for byte ${start}`);
|
|
}
|
|
const size = end - start;
|
|
const outputBuffer = new ArrayBuffer(size);
|
|
const outputArray = new Uint8Array(outputBuffer);
|
|
let sliced = 0;
|
|
for (let i = startShardIndex; i < this.shards.length; i++) {
|
|
const shard = this.shards[i];
|
|
const globalStart = start + sliced;
|
|
const localStart = globalStart - shard.start;
|
|
const outputStart = sliced;
|
|
const globalEnd = Math.min(end, shard.end);
|
|
const localEnd = globalEnd - shard.start;
|
|
const outputSlice = new Uint8Array(shard.buffer, localStart, localEnd - localStart);
|
|
outputArray.set(outputSlice, outputStart);
|
|
sliced += outputSlice.length;
|
|
if (end < shard.end) {
|
|
break;
|
|
}
|
|
}
|
|
return outputBuffer;
|
|
}
|
|
|
|
findShardForByte(byteIndex) {
|
|
if (this.shards.length === 0 || byteIndex < 0 ||
|
|
byteIndex >= this.byteLength) {
|
|
return -1;
|
|
}
|
|
|
|
if (this.bufferUniformSize != null) {
|
|
this.previousShardIndex = Math.floor(byteIndex / this.bufferUniformSize);
|
|
return this.previousShardIndex;
|
|
}
|
|
|
|
|
|
|
|
function check(shard) {
|
|
if (byteIndex < shard.start) {
|
|
return -1;
|
|
}
|
|
if (byteIndex >= shard.end) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
if (check(this.shards[this.previousShardIndex]) === 0) {
|
|
return this.previousShardIndex;
|
|
}
|
|
|
|
|
|
|
|
const index = search(this.shards, check);
|
|
if (index === -1) {
|
|
return -1;
|
|
}
|
|
this.previousShardIndex = index;
|
|
return this.previousShardIndex;
|
|
}
|
|
}
|
|
|
|
function search(sortedArray, compare) {
|
|
|
|
let min = 0;
|
|
let max = sortedArray.length;
|
|
while (min <= max) {
|
|
const middle = Math.floor((max - min) / 2) + min;
|
|
const side = compare(sortedArray[middle]);
|
|
if (side === 0) {
|
|
return middle;
|
|
}
|
|
else if (side < 0) {
|
|
max = middle;
|
|
}
|
|
else {
|
|
min = middle + 1;
|
|
}
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
|
|
|
|
const NUM_BYTES_STRING_LENGTH = 4;
|
|
|
|
async function encodeWeights(tensors, group) {
|
|
|
|
const specs = [];
|
|
const dataPromises = [];
|
|
const names = Array.isArray(tensors) ?
|
|
tensors.map(tensor => tensor.name) :
|
|
Object.keys(tensors);
|
|
for (let i = 0; i < names.length; ++i) {
|
|
const name = names[i];
|
|
const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
|
|
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' &&
|
|
t.dtype !== 'string' && t.dtype !== 'complex64') {
|
|
throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
|
|
}
|
|
const spec = { name, shape: t.shape, dtype: t.dtype };
|
|
if (t.dtype === 'string') {
|
|
const utf8bytes = new Promise(async (resolve) => {
|
|
const vals = await t.bytes();
|
|
const totalNumBytes = vals.reduce((p, c) => p + c.length, 0) +
|
|
NUM_BYTES_STRING_LENGTH * vals.length;
|
|
const bytes = new Uint8Array(totalNumBytes);
|
|
let offset = 0;
|
|
for (let i = 0; i < vals.length; i++) {
|
|
const val = vals[i];
|
|
const bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer);
|
|
bytes.set(bytesOfLength, offset);
|
|
offset += NUM_BYTES_STRING_LENGTH;
|
|
bytes.set(val, offset);
|
|
offset += val.length;
|
|
}
|
|
resolve(bytes);
|
|
});
|
|
dataPromises.push(utf8bytes);
|
|
}
|
|
else {
|
|
dataPromises.push(t.data());
|
|
}
|
|
if (group != null) {
|
|
spec.group = group;
|
|
}
|
|
specs.push(spec);
|
|
}
|
|
const tensorValues = await Promise.all(dataPromises);
|
|
return { data: concatenateTypedArrays(tensorValues), specs };
|
|
}
|
|
|
|
function decodeWeights(weightData, specs) {
|
|
|
|
const compositeBuffer = new CompositeArrayBuffer(weightData);
|
|
const out = {};
|
|
let offset = 0;
|
|
for (const spec of specs) {
|
|
const byteLength = getWeightBytelength(spec, (start, end) => {
|
|
return compositeBuffer.slice(offset + start, offset + end);
|
|
});
|
|
out[spec.name] = decodeWeight(spec, compositeBuffer
|
|
.slice(offset, offset + byteLength));
|
|
offset += byteLength;
|
|
}
|
|
return out;
|
|
}
|
|
function getWeightBytelength(spec, slice) {
|
|
const size = sizeFromShape(spec.shape);
|
|
let bytesPerValue;
|
|
if ('quantization' in spec) {
|
|
const quantization = spec.quantization;
|
|
bytesPerValue = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
|
|
}
|
|
else if (spec.dtype === 'string') {
|
|
|
|
let byteLength = 0;
|
|
for (let i = 0; i < size; i++) {
|
|
byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
|
|
}
|
|
return byteLength;
|
|
}
|
|
else {
|
|
bytesPerValue = DTYPE_VALUE_SIZE_MAP[spec.dtype];
|
|
}
|
|
return size * bytesPerValue;
|
|
}
|
|
function decodeWeight(spec, byteBuffer) {
|
|
const name = spec.name;
|
|
const dtype = spec.dtype;
|
|
const shape = spec.shape;
|
|
const size = sizeFromShape(shape);
|
|
let values;
|
|
let offset = 0;
|
|
if ('quantization' in spec) {
|
|
const quantization = spec.quantization;
|
|
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
|
|
if (!('min' in quantization && 'scale' in quantization)) {
|
|
throw new Error(`Weight ${spec.name} with quantization ${quantization.dtype} ` +
|
|
`doesn't have corresponding metadata min and scale.`);
|
|
}
|
|
}
|
|
else if (quantization.dtype === 'float16') {
|
|
if (dtype !== 'float32') {
|
|
throw new Error(`Weight ${spec.name} is quantized with ${quantization.dtype} ` +
|
|
`which only supports weights of type float32 not ${dtype}.`);
|
|
}
|
|
}
|
|
else {
|
|
throw new Error(`Weight ${spec.name} has unknown ` +
|
|
`quantization dtype ${quantization.dtype}. ` +
|
|
`Supported quantization dtypes are: ` +
|
|
`'uint8', 'uint16', and 'float16'.`);
|
|
}
|
|
const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
|
|
const quantizedArray = (quantization.dtype === 'uint8') ?
|
|
new Uint8Array(byteBuffer) :
|
|
new Uint16Array(byteBuffer);
|
|
if (dtype === 'float32') {
|
|
if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
|
|
values = new Float32Array(quantizedArray.length);
|
|
for (let i = 0; i < quantizedArray.length; i++) {
|
|
const v = quantizedArray[i];
|
|
values[i] = v * quantization.scale + quantization.min;
|
|
}
|
|
}
|
|
else if (quantization.dtype === 'float16') {
|
|
|
|
const float16Decode = getFloat16Decoder();
|
|
values = float16Decode(quantizedArray);
|
|
}
|
|
else {
|
|
throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
|
|
`for weight type float32.`);
|
|
}
|
|
}
|
|
else if (dtype === 'int32') {
|
|
if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
|
|
throw new Error(`Unsupported quantization type ${quantization.dtype} ` +
|
|
`for weight type int32.`);
|
|
}
|
|
values = new Int32Array(quantizedArray.length);
|
|
for (let i = 0; i < quantizedArray.length; i++) {
|
|
const v = quantizedArray[i];
|
|
values[i] = Math.round(v * quantization.scale + quantization.min);
|
|
}
|
|
}
|
|
else {
|
|
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
|
|
}
|
|
offset += size * quantizationSizeFactor;
|
|
}
|
|
else if (dtype === 'string') {
|
|
const size = sizeFromShape(spec.shape);
|
|
values = [];
|
|
for (let i = 0; i < size; i++) {
|
|
const byteLength = new Uint32Array(byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
|
|
offset += NUM_BYTES_STRING_LENGTH;
|
|
const bytes = new Uint8Array(byteBuffer.slice(offset, offset + byteLength));
|
|
values.push(bytes);
|
|
offset += byteLength;
|
|
}
|
|
}
|
|
else {
|
|
const dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype];
|
|
if (dtype === 'float32') {
|
|
values = new Float32Array(byteBuffer);
|
|
}
|
|
else if (dtype === 'int32') {
|
|
values = new Int32Array(byteBuffer);
|
|
}
|
|
else if (dtype === 'bool') {
|
|
values = new Uint8Array(byteBuffer);
|
|
}
|
|
else if (dtype === 'complex64') {
|
|
values = new Float32Array(byteBuffer);
|
|
const real = new Float32Array(values.length / 2);
|
|
const image = new Float32Array(values.length / 2);
|
|
for (let i = 0; i < real.length; i++) {
|
|
real[i] = values[i * 2];
|
|
image[i] = values[i * 2 + 1];
|
|
}
|
|
const realTensor = tensor(real, shape, 'float32');
|
|
const imageTensor = tensor(image, shape, 'float32');
|
|
const complexTensor = complex$2(realTensor, imageTensor);
|
|
realTensor.dispose();
|
|
imageTensor.dispose();
|
|
return complexTensor;
|
|
}
|
|
else {
|
|
throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`);
|
|
}
|
|
offset += size * dtypeFactor;
|
|
}
|
|
return tensor(values, shape, dtype);
|
|
}
|
|
|
|
function concatenateTypedArrays(xs) {
|
|
|
|
if (xs === null) {
|
|
throw new Error(`Invalid input value: ${JSON.stringify(xs)}`);
|
|
}
|
|
let totalByteLength = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const normalizedXs = [];
|
|
xs.forEach((x) => {
|
|
totalByteLength += x.byteLength;
|
|
|
|
normalizedXs.push(x.byteLength === x.buffer.byteLength ? x :
|
|
new x.constructor(x));
|
|
if (!(x instanceof Float32Array || x instanceof Int32Array ||
|
|
x instanceof Uint8Array)) {
|
|
throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
|
|
}
|
|
|
|
});
|
|
const y = new Uint8Array(totalByteLength);
|
|
let offset = 0;
|
|
normalizedXs.forEach((x) => {
|
|
y.set(new Uint8Array(x.buffer), offset);
|
|
offset += x.byteLength;
|
|
});
|
|
return y.buffer;
|
|
}
|
|
|
|
const useNodeBuffer = typeof Buffer !== 'undefined' &&
|
|
(typeof Blob === 'undefined' || typeof atob === 'undefined' ||
|
|
typeof btoa === 'undefined');
|
|
|
|
function stringByteLength(str) {
|
|
if (useNodeBuffer) {
|
|
return Buffer.byteLength(str, 'utf8');
|
|
}
|
|
return new Blob([str]).size;
|
|
}
|
|
|
|
function arrayBufferToBase64String(buffer) {
|
|
if (useNodeBuffer) {
|
|
return Buffer.from(buffer).toString('base64');
|
|
}
|
|
const buf = new Uint8Array(buffer);
|
|
let s = '';
|
|
for (let i = 0, l = buf.length; i < l; i++) {
|
|
s += String.fromCharCode(buf[i]);
|
|
}
|
|
return btoa(s);
|
|
}
|
|
|
|
function base64StringToArrayBuffer(str) {
|
|
if (useNodeBuffer) {
|
|
const buf = Buffer.from(str, 'base64');
|
|
return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength);
|
|
}
|
|
const s = atob(str);
|
|
const buffer = new Uint8Array(s.length);
|
|
for (let i = 0; i < s.length; ++i) {
|
|
buffer.set([s.charCodeAt(i)], i);
|
|
}
|
|
return buffer.buffer;
|
|
}
|
|
|
|
function concatenateArrayBuffers(buffers) {
|
|
return CompositeArrayBuffer.join(buffers);
|
|
}
|
|
|
|
function getModelJSONForModelArtifacts(artifacts, manifest) {
|
|
const result = {
|
|
modelTopology: artifacts.modelTopology,
|
|
format: artifacts.format,
|
|
generatedBy: artifacts.generatedBy,
|
|
convertedBy: artifacts.convertedBy,
|
|
weightsManifest: manifest
|
|
};
|
|
if (artifacts.signature != null) {
|
|
result.signature = artifacts.signature;
|
|
}
|
|
if (artifacts.userDefinedMetadata != null) {
|
|
result.userDefinedMetadata = artifacts.userDefinedMetadata;
|
|
}
|
|
if (artifacts.modelInitializer != null) {
|
|
result.modelInitializer = artifacts.modelInitializer;
|
|
}
|
|
if (artifacts.initializerSignature != null) {
|
|
result.initializerSignature = artifacts.initializerSignature;
|
|
}
|
|
if (artifacts.trainingConfig != null) {
|
|
result.trainingConfig = artifacts.trainingConfig;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
function getModelArtifactsInfoForJSON(modelArtifacts) {
|
|
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
|
|
throw new Error('Expected JSON model topology, received ArrayBuffer.');
|
|
}
|
|
return {
|
|
dateSaved: new Date(),
|
|
modelTopologyType: 'JSON',
|
|
modelTopologyBytes: modelArtifacts.modelTopology == null ?
|
|
0 :
|
|
stringByteLength(JSON.stringify(modelArtifacts.modelTopology)),
|
|
weightSpecsBytes: modelArtifacts.weightSpecs == null ?
|
|
0 :
|
|
stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)),
|
|
weightDataBytes: modelArtifacts.weightData == null ?
|
|
0 :
|
|
new CompositeArrayBuffer(modelArtifacts.weightData).byteLength,
|
|
};
|
|
}
|
|
|
|
function computeFloat16MantisaTable() {
|
|
const convertMantissa = (i) => {
|
|
let m = i << 13;
|
|
let e = 0;
|
|
while ((m & 0x00800000) === 0) {
|
|
e -= 0x00800000;
|
|
m <<= 1;
|
|
}
|
|
m &= -8388609;
|
|
e += 0x38800000;
|
|
return m | e;
|
|
};
|
|
const mantisaTable = new Uint32Array(2048);
|
|
mantisaTable[0] = 0;
|
|
for (let i = 1; i < 1024; i++) {
|
|
mantisaTable[i] = convertMantissa(i);
|
|
}
|
|
for (let i = 1024; i < 2048; i++) {
|
|
mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
|
|
}
|
|
return mantisaTable;
|
|
}
|
|
|
|
function computeFloat16ExponentTable() {
|
|
const exponentTable = new Uint32Array(64);
|
|
exponentTable[0] = 0;
|
|
exponentTable[31] = 0x47800000;
|
|
exponentTable[32] = 0x80000000;
|
|
exponentTable[63] = 0xc7800000;
|
|
for (let i = 1; i < 31; i++) {
|
|
exponentTable[i] = i << 23;
|
|
}
|
|
for (let i = 33; i < 63; i++) {
|
|
exponentTable[i] = 0x80000000 + ((i - 32) << 23);
|
|
}
|
|
return exponentTable;
|
|
}
|
|
|
|
function computeFloat16OffsetTable() {
|
|
const offsetTable = new Uint32Array(64);
|
|
for (let i = 0; i < 64; i++) {
|
|
offsetTable[i] = 1024;
|
|
}
|
|
offsetTable[0] = offsetTable[32] = 0;
|
|
return offsetTable;
|
|
}
|
|
|
|
function getFloat16Decoder() {
|
|
|
|
|
|
|
|
const mantisaTable = computeFloat16MantisaTable();
|
|
const exponentTable = computeFloat16ExponentTable();
|
|
const offsetTable = computeFloat16OffsetTable();
|
|
return (quantizedArray) => {
|
|
const buffer = new ArrayBuffer(4 * quantizedArray.length);
|
|
const bufferUint32View = new Uint32Array(buffer);
|
|
for (let index = 0; index < quantizedArray.length; index++) {
|
|
const float16Bits = quantizedArray[index];
|
|
const float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
|
|
exponentTable[float16Bits >> 10];
|
|
bufferUint32View[index] = float32Bits;
|
|
}
|
|
return new Float32Array(buffer);
|
|
};
|
|
}
|
|
|
|
|
|
class IORouterRegistry {
|
|
constructor() {
|
|
this.saveRouters = [];
|
|
this.loadRouters = [];
|
|
}
|
|
static getInstance() {
|
|
if (IORouterRegistry.instance == null) {
|
|
IORouterRegistry.instance = new IORouterRegistry();
|
|
}
|
|
return IORouterRegistry.instance;
|
|
}
|
|
|
|
static registerSaveRouter(saveRouter) {
|
|
IORouterRegistry.getInstance().saveRouters.push(saveRouter);
|
|
}
|
|
|
|
static registerLoadRouter(loadRouter) {
|
|
IORouterRegistry.getInstance().loadRouters.push(loadRouter);
|
|
}
|
|
|
|
static getSaveHandlers(url) {
|
|
return IORouterRegistry.getHandlers(url, 'save');
|
|
}
|
|
|
|
static getLoadHandlers(url, loadOptions) {
|
|
return IORouterRegistry.getHandlers(url, 'load', loadOptions);
|
|
}
|
|
static getHandlers(url, handlerType, loadOptions) {
|
|
const validHandlers = [];
|
|
const routers = handlerType === 'load' ?
|
|
IORouterRegistry.getInstance().loadRouters :
|
|
IORouterRegistry.getInstance().saveRouters;
|
|
routers.forEach(router => {
|
|
const handler = router(url, loadOptions);
|
|
if (handler !== null) {
|
|
validHandlers.push(handler);
|
|
}
|
|
});
|
|
return validHandlers;
|
|
}
|
|
}
|
|
const getSaveHandlers = (url) => IORouterRegistry.getSaveHandlers(url);
|
|
|
|
|
|
const DATABASE_NAME = 'tensorflowjs';
|
|
const DATABASE_VERSION = 1;
|
|
|
|
|
|
|
|
const MODEL_STORE_NAME = 'models_store';
|
|
|
|
|
|
|
|
const INFO_STORE_NAME = 'model_info_store';
|
|
function getIndexedDBFactory() {
|
|
if (!env().getBool('IS_BROWSER')) {
|
|
|
|
|
|
|
|
throw new Error('Failed to obtain IndexedDB factory because the current environment' +
|
|
'is not a web browser.');
|
|
}
|
|
|
|
const theWindow = typeof window === 'undefined' ? self : window;
|
|
const factory = theWindow.indexedDB || theWindow.mozIndexedDB ||
|
|
theWindow.webkitIndexedDB || theWindow.msIndexedDB ||
|
|
theWindow.shimIndexedDB;
|
|
if (factory == null) {
|
|
throw new Error('The current browser does not appear to support IndexedDB.');
|
|
}
|
|
return factory;
|
|
}
|
|
function setUpDatabase(openRequest) {
|
|
const db = openRequest.result;
|
|
db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' });
|
|
db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' });
|
|
}
|
|
|
|
class BrowserIndexedDB {
|
|
constructor(modelPath) {
|
|
this.indexedDB = getIndexedDBFactory();
|
|
if (modelPath == null || !modelPath) {
|
|
throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.');
|
|
}
|
|
this.modelPath = modelPath;
|
|
}
|
|
async save(modelArtifacts) {
|
|
|
|
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
|
|
throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
|
|
'in binary formats yet.');
|
|
}
|
|
return this.databaseAction(this.modelPath, modelArtifacts);
|
|
}
|
|
async load() {
|
|
return this.databaseAction(this.modelPath);
|
|
}
|
|
|
|
databaseAction(modelPath, modelArtifacts) {
|
|
return new Promise((resolve, reject) => {
|
|
const openRequest = this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION);
|
|
openRequest.onupgradeneeded = () => setUpDatabase(openRequest);
|
|
openRequest.onsuccess = () => {
|
|
const db = openRequest.result;
|
|
if (modelArtifacts == null) {
|
|
|
|
const modelTx = db.transaction(MODEL_STORE_NAME, 'readonly');
|
|
const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
|
|
const getRequest = modelStore.get(this.modelPath);
|
|
getRequest.onsuccess = () => {
|
|
if (getRequest.result == null) {
|
|
db.close();
|
|
return reject(new Error(`Cannot find model with path '${this.modelPath}' ` +
|
|
`in IndexedDB.`));
|
|
}
|
|
else {
|
|
resolve(getRequest.result.modelArtifacts);
|
|
}
|
|
};
|
|
getRequest.onerror = error => {
|
|
db.close();
|
|
return reject(getRequest.error);
|
|
};
|
|
modelTx.oncomplete = () => db.close();
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
|
|
|
|
modelArtifacts.weightData = CompositeArrayBuffer.join(modelArtifacts.weightData);
|
|
const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
|
|
|
|
const infoTx = db.transaction(INFO_STORE_NAME, 'readwrite');
|
|
let infoStore = infoTx.objectStore(INFO_STORE_NAME);
|
|
let putInfoRequest;
|
|
try {
|
|
putInfoRequest =
|
|
infoStore.put({ modelPath: this.modelPath, modelArtifactsInfo });
|
|
}
|
|
catch (error) {
|
|
return reject(error);
|
|
}
|
|
let modelTx;
|
|
putInfoRequest.onsuccess = () => {
|
|
|
|
modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite');
|
|
const modelStore = modelTx.objectStore(MODEL_STORE_NAME);
|
|
let putModelRequest;
|
|
try {
|
|
putModelRequest = modelStore.put({
|
|
modelPath: this.modelPath,
|
|
modelArtifacts,
|
|
modelArtifactsInfo
|
|
});
|
|
}
|
|
catch (error) {
|
|
|
|
return reject(error);
|
|
}
|
|
putModelRequest.onsuccess = () => resolve({ modelArtifactsInfo });
|
|
putModelRequest.onerror = error => {
|
|
|
|
|
|
infoStore = infoTx.objectStore(INFO_STORE_NAME);
|
|
const deleteInfoRequest = infoStore.delete(this.modelPath);
|
|
deleteInfoRequest.onsuccess = () => {
|
|
db.close();
|
|
return reject(putModelRequest.error);
|
|
};
|
|
deleteInfoRequest.onerror = error => {
|
|
db.close();
|
|
return reject(putModelRequest.error);
|
|
};
|
|
};
|
|
};
|
|
putInfoRequest.onerror = error => {
|
|
db.close();
|
|
return reject(putInfoRequest.error);
|
|
};
|
|
infoTx.oncomplete = () => {
|
|
if (modelTx == null) {
|
|
db.close();
|
|
}
|
|
else {
|
|
modelTx.oncomplete = () => db.close();
|
|
}
|
|
};
|
|
}
|
|
};
|
|
openRequest.onerror = error => reject(openRequest.error);
|
|
});
|
|
}
|
|
}
|
|
BrowserIndexedDB.URL_SCHEME = 'indexeddb://';
|
|
const indexedDBRouter = (url) => {
|
|
if (!env().getBool('IS_BROWSER')) {
|
|
return null;
|
|
}
|
|
else {
|
|
if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) {
|
|
return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length));
|
|
}
|
|
else {
|
|
return null;
|
|
}
|
|
}
|
|
};
|
|
IORouterRegistry.registerSaveRouter(indexedDBRouter);
|
|
IORouterRegistry.registerLoadRouter(indexedDBRouter);
|
|
|
|
function browserIndexedDB(modelPath) {
|
|
return new BrowserIndexedDB(modelPath);
|
|
}
|
|
|
|
|
|
const PATH_SEPARATOR = '/';
|
|
const PATH_PREFIX = 'tensorflowjs_models';
|
|
const INFO_SUFFIX = 'info';
|
|
const MODEL_TOPOLOGY_SUFFIX = 'model_topology';
|
|
const WEIGHT_SPECS_SUFFIX = 'weight_specs';
|
|
const WEIGHT_DATA_SUFFIX = 'weight_data';
|
|
const MODEL_METADATA_SUFFIX = 'model_metadata';
|
|
function getModelKeys(path) {
|
|
return {
|
|
info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR),
|
|
topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR),
|
|
weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR),
|
|
weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
|
|
modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR)
|
|
};
|
|
}
|
|
function removeItems(keys) {
|
|
for (const key of Object.values(keys)) {
|
|
window.localStorage.removeItem(key);
|
|
}
|
|
}
|
|
|
|
class BrowserLocalStorage {
|
|
constructor(modelPath) {
|
|
if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' ||
|
|
typeof window.localStorage === 'undefined') {
|
|
|
|
|
|
|
|
|
|
throw new Error('The current environment does not support local storage.');
|
|
}
|
|
this.LS = window.localStorage;
|
|
if (modelPath == null || !modelPath) {
|
|
throw new Error('For local storage, modelPath must not be null, undefined or empty.');
|
|
}
|
|
this.modelPath = modelPath;
|
|
this.keys = getModelKeys(this.modelPath);
|
|
}
|
|
|
|
async save(modelArtifacts) {
|
|
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
|
|
throw new Error('BrowserLocalStorage.save() does not support saving model topology ' +
|
|
'in binary formats yet.');
|
|
}
|
|
else {
|
|
const topology = JSON.stringify(modelArtifacts.modelTopology);
|
|
const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs);
|
|
const modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts);
|
|
|
|
|
|
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
|
|
try {
|
|
this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo));
|
|
this.LS.setItem(this.keys.topology, topology);
|
|
this.LS.setItem(this.keys.weightSpecs, weightSpecs);
|
|
this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(weightBuffer));
|
|
|
|
|
|
|
|
const metadata = {
|
|
format: modelArtifacts.format,
|
|
generatedBy: modelArtifacts.generatedBy,
|
|
convertedBy: modelArtifacts.convertedBy,
|
|
signature: modelArtifacts.signature != null ?
|
|
modelArtifacts.signature :
|
|
undefined,
|
|
userDefinedMetadata: modelArtifacts.userDefinedMetadata != null ?
|
|
modelArtifacts.userDefinedMetadata :
|
|
undefined,
|
|
modelInitializer: modelArtifacts.modelInitializer != null ?
|
|
modelArtifacts.modelInitializer :
|
|
undefined,
|
|
initializerSignature: modelArtifacts.initializerSignature != null ?
|
|
modelArtifacts.initializerSignature :
|
|
undefined,
|
|
trainingConfig: modelArtifacts.trainingConfig != null ?
|
|
modelArtifacts.trainingConfig :
|
|
undefined
|
|
};
|
|
this.LS.setItem(this.keys.modelMetadata, JSON.stringify(metadata));
|
|
return { modelArtifactsInfo };
|
|
}
|
|
catch (err) {
|
|
|
|
removeItems(this.keys);
|
|
throw new Error(`Failed to save model '${this.modelPath}' to local storage: ` +
|
|
`size quota being exceeded is a possible cause of this failure: ` +
|
|
`modelTopologyBytes=${modelArtifactsInfo.modelTopologyBytes}, ` +
|
|
`weightSpecsBytes=${modelArtifactsInfo.weightSpecsBytes}, ` +
|
|
`weightDataBytes=${modelArtifactsInfo.weightDataBytes}.`);
|
|
}
|
|
}
|
|
}
|
|
|
|
async load() {
|
|
const info = JSON.parse(this.LS.getItem(this.keys.info));
|
|
if (info == null) {
|
|
throw new Error(`In local storage, there is no model with name '${this.modelPath}'`);
|
|
}
|
|
if (info.modelTopologyType !== 'JSON') {
|
|
throw new Error('BrowserLocalStorage does not support loading non-JSON model ' +
|
|
'topology yet.');
|
|
}
|
|
const out = {};
|
|
|
|
const topology = JSON.parse(this.LS.getItem(this.keys.topology));
|
|
if (topology == null) {
|
|
throw new Error(`In local storage, the topology of model '${this.modelPath}' ` +
|
|
`is missing.`);
|
|
}
|
|
out.modelTopology = topology;
|
|
|
|
const weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs));
|
|
if (weightSpecs == null) {
|
|
throw new Error(`In local storage, the weight specs of model '${this.modelPath}' ` +
|
|
`are missing.`);
|
|
}
|
|
out.weightSpecs = weightSpecs;
|
|
|
|
const metadataString = this.LS.getItem(this.keys.modelMetadata);
|
|
if (metadataString != null) {
|
|
const metadata = JSON.parse(metadataString);
|
|
out.format = metadata.format;
|
|
out.generatedBy = metadata.generatedBy;
|
|
out.convertedBy = metadata.convertedBy;
|
|
if (metadata.signature != null) {
|
|
out.signature = metadata.signature;
|
|
}
|
|
if (metadata.userDefinedMetadata != null) {
|
|
out.userDefinedMetadata = metadata.userDefinedMetadata;
|
|
}
|
|
if (metadata.modelInitializer != null) {
|
|
out.modelInitializer = metadata.modelInitializer;
|
|
}
|
|
if (metadata.initializerSignature != null) {
|
|
out.initializerSignature = metadata.initializerSignature;
|
|
}
|
|
if (metadata.trainingConfig != null) {
|
|
out.trainingConfig = metadata.trainingConfig;
|
|
}
|
|
}
|
|
|
|
const weightDataBase64 = this.LS.getItem(this.keys.weightData);
|
|
if (weightDataBase64 == null) {
|
|
throw new Error(`In local storage, the binary weight values of model ` +
|
|
`'${this.modelPath}' are missing.`);
|
|
}
|
|
out.weightData = base64StringToArrayBuffer(weightDataBase64);
|
|
return out;
|
|
}
|
|
}
|
|
BrowserLocalStorage.URL_SCHEME = 'localstorage://';
|
|
const localStorageRouter = (url) => {
|
|
if (!env().getBool('IS_BROWSER')) {
|
|
return null;
|
|
}
|
|
else {
|
|
if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) {
|
|
return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length));
|
|
}
|
|
else {
|
|
return null;
|
|
}
|
|
}
|
|
};
|
|
IORouterRegistry.registerSaveRouter(localStorageRouter);
|
|
IORouterRegistry.registerLoadRouter(localStorageRouter);
|
|
|
|
function browserLocalStorage(modelPath) {
|
|
return new BrowserLocalStorage(modelPath);
|
|
}
|
|
|
|
|
|
|
|
const DEFAULT_FILE_NAME_PREFIX = 'model';
|
|
const DEFAULT_JSON_EXTENSION_NAME = '.json';
|
|
const DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin';
|
|
function defer(f) {
|
|
return new Promise(resolve => setTimeout(resolve)).then(f);
|
|
}
|
|
class BrowserDownloads {
|
|
constructor(fileNamePrefix) {
|
|
if (!env().getBool('IS_BROWSER')) {
|
|
|
|
|
|
throw new Error('browserDownloads() cannot proceed because the current environment ' +
|
|
'is not a browser.');
|
|
}
|
|
if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) {
|
|
fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length);
|
|
}
|
|
if (fileNamePrefix == null || fileNamePrefix.length === 0) {
|
|
fileNamePrefix = DEFAULT_FILE_NAME_PREFIX;
|
|
}
|
|
this.modelJsonFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME;
|
|
this.weightDataFileName =
|
|
fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME;
|
|
}
|
|
async save(modelArtifacts) {
|
|
if (typeof (document) === 'undefined') {
|
|
throw new Error('Browser downloads are not supported in ' +
|
|
'this environment since `document` is not present');
|
|
}
|
|
|
|
|
|
const weightBuffer = CompositeArrayBuffer.join(modelArtifacts.weightData);
|
|
const weightsURL = window.URL.createObjectURL(new Blob([weightBuffer], { type: 'application/octet-stream' }));
|
|
if (modelArtifacts.modelTopology instanceof ArrayBuffer) {
|
|
throw new Error('BrowserDownloads.save() does not support saving model topology ' +
|
|
'in binary formats yet.');
|
|
}
|
|
else {
|
|
const weightsManifest = [{
|
|
paths: ['./' + this.weightDataFileName],
|
|
weights: modelArtifacts.weightSpecs
|
|
}];
|
|
const modelJSON = getModelJSONForModelArtifacts(modelArtifacts, weightsManifest);
|
|
const modelJsonURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelJSON)], { type: 'application/json' }));
|
|
|
|
|
|
const jsonAnchor = this.modelJsonAnchor == null ?
|
|
document.createElement('a') :
|
|
this.modelJsonAnchor;
|
|
jsonAnchor.download = this.modelJsonFileName;
|
|
jsonAnchor.href = modelJsonURL;
|
|
|
|
|
|
|
|
await defer(() => jsonAnchor.dispatchEvent(new MouseEvent('click')));
|
|
if (modelArtifacts.weightData != null) {
|
|
const weightDataAnchor = this.weightDataAnchor == null ?
|
|
document.createElement('a') :
|
|
this.weightDataAnchor;
|
|
weightDataAnchor.download = this.weightDataFileName;
|
|
weightDataAnchor.href = weightsURL;
|
|
await defer(() => weightDataAnchor.dispatchEvent(new MouseEvent('click')));
|
|
}
|
|
return { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) };
|
|
}
|
|
}
|
|
}
|
|
BrowserDownloads.URL_SCHEME = 'downloads://';
|
|
const browserDownloadsRouter = (url) => {
|
|
if (!env().getBool('IS_BROWSER')) {
|
|
return null;
|
|
}
|
|
else {
|
|
if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) {
|
|
return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length));
|
|
}
|
|
else {
|
|
return null;
|
|
}
|
|
}
|
|
};
|
|
IORouterRegistry.registerSaveRouter(browserDownloadsRouter);
|
|
|
|
function browserDownloads(fileNamePrefix = 'model') {
|
|
return new BrowserDownloads(fileNamePrefix);
|
|
}
|
|
|
|
|
|
class PassthroughLoader {
|
|
constructor(modelArtifacts) {
|
|
this.modelArtifacts = modelArtifacts;
|
|
}
|
|
load() {
|
|
return this.modelArtifacts;
|
|
}
|
|
}
|
|
class PassthroughSaver {
|
|
constructor(saveHandler) {
|
|
this.saveHandler = saveHandler;
|
|
}
|
|
save(modelArtifacts) {
|
|
return this.saveHandler(modelArtifacts);
|
|
}
|
|
}
|
|
class PassthroughAsync {
|
|
constructor(handler) {
|
|
if (handler.load) {
|
|
this.load = () => Promise.resolve(handler.load());
|
|
}
|
|
if (handler.save) {
|
|
this.save = (modelArtifacts) => Promise.resolve(handler.save(modelArtifacts));
|
|
}
|
|
}
|
|
}
|
|
|
|
function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) {
|
|
const args = arguments;
|
|
return new PassthroughAsync(fromMemorySync(...args));
|
|
}
|
|
|
|
function fromMemorySync(modelArtifacts, weightSpecs, weightData, trainingConfig) {
|
|
if (arguments.length === 1) {
|
|
const isModelArtifacts = modelArtifacts.modelTopology != null ||
|
|
modelArtifacts.weightSpecs != null;
|
|
if (isModelArtifacts) {
|
|
return new PassthroughLoader(modelArtifacts);
|
|
}
|
|
else {
|
|
|
|
|
|
console.warn('Please call tf.io.fromMemory() with only one argument. ' +
|
|
'The argument should be of type ModelArtifacts. ' +
|
|
'The multi-argument signature of tf.io.fromMemory() has been ' +
|
|
'deprecated and will be removed in a future release.');
|
|
return new PassthroughLoader({ modelTopology: modelArtifacts });
|
|
}
|
|
}
|
|
else {
|
|
|
|
|
|
console.warn('Please call tf.io.fromMemory() with only one argument. ' +
|
|
'The argument should be of type ModelArtifacts. ' +
|
|
'The multi-argument signature of tf.io.fromMemory() has been ' +
|
|
'deprecated and will be removed in a future release.');
|
|
return new PassthroughLoader({
|
|
modelTopology: modelArtifacts,
|
|
weightSpecs,
|
|
weightData,
|
|
trainingConfig
|
|
});
|
|
}
|
|
}
|
|
|
|
function withSaveHandler(saveHandler) {
|
|
return new PassthroughSaver(saveHandler);
|
|
}
|
|
|
|
|
|
function prepareAndValidate(tensor, indices) {
|
|
const tensorRank = tensor.shape.length;
|
|
const indicesRank = indices.shape.length;
|
|
if (tensorRank < 1) {
|
|
throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' +
|
|
` but the rank was ${tensorRank}.`);
|
|
}
|
|
if (indicesRank < 1) {
|
|
throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' +
|
|
` but the rank was ${indicesRank}.`);
|
|
}
|
|
if (indices.dtype !== 'int32') {
|
|
throw new Error('tf.gatherND() expects the indices to be int32 type,' +
|
|
` but the dtype was ${indices.dtype}.`);
|
|
}
|
|
if (indices.shape[indicesRank - 1] > tensorRank) {
|
|
throw new Error('index innermost dimension length must be <= tensor rank; saw: ' +
|
|
`${indices.shape[indicesRank - 1]} vs. ${tensorRank}`);
|
|
}
|
|
if (sizeFromShape(tensor.shape) === 0) {
|
|
throw new Error('Requested more than 0 entries, but input is empty.' +
|
|
` Input shape: ${tensor.shape}.`);
|
|
}
|
|
const indicesShape = indices.shape;
|
|
const sliceRank = indicesShape[indicesShape.length - 1];
|
|
|
|
|
|
let nResult = 1;
|
|
for (let i = 0; i < indicesShape.length - 1; ++i) {
|
|
nResult *= indicesShape[i];
|
|
}
|
|
const inputShape = tensor.shape;
|
|
const resultShape = indicesShape.slice();
|
|
resultShape.pop();
|
|
let sliceSize = 1;
|
|
for (let i = sliceRank; i < tensorRank; ++i) {
|
|
sliceSize *= inputShape[i];
|
|
resultShape.push(inputShape[i]);
|
|
}
|
|
const strides = [...computeStrides(tensor.shape).map(stride => stride / sliceSize),
|
|
1].slice(0, sliceRank);
|
|
return [resultShape, nResult, sliceSize, strides];
|
|
}
|
|
|
|
|
|
const NEW_AXIS = -2;
|
|
const SHRINK_AXIS = -1;
|
|
function assertParamsValid(input, begin, size) {
|
|
const inputRank = input.shape.length;
|
|
assert$1(inputRank === begin.length, () => `Error in slice${inputRank}D: Length of begin ${begin} must ` +
|
|
`match the rank of the array (${inputRank}).`);
|
|
assert$1(inputRank === size.length, () => `Error in slice${inputRank}D: Length of size ${size} must ` +
|
|
`match the rank of the array (${inputRank}).`);
|
|
for (let i = 0; i < inputRank; ++i) {
|
|
assert$1(begin[i] + size[i] <= input.shape[i], () => `Error in slice${inputRank}D: begin[${i}] + size[${i}] ` +
|
|
`(${begin[i] + size[i]}) would overflow input.shape[${i}] (${input.shape[i]})`);
|
|
}
|
|
}
|
|
|
|
function maskToAxes(mask) {
|
|
const axes = [];
|
|
let axis = 0;
|
|
while (mask > 0) {
|
|
if (mask & 1) {
|
|
axes.push(axis);
|
|
}
|
|
mask /= 2;
|
|
axis++;
|
|
}
|
|
return axes;
|
|
}
|
|
|
|
function computeOutShape$2(begin, end, strides) {
|
|
const size = [];
|
|
for (let axis = 0; axis < begin.length; axis++) {
|
|
size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
|
|
}
|
|
return size;
|
|
}
|
|
|
|
|
|
function stridesWithElidedDims(strides, ellipsisInsertionIndex, numElidedAxes, inputShape) {
|
|
const newStrides = [...strides];
|
|
for (let i = newStrides.length; i < inputShape.length; i++) {
|
|
newStrides.push(1);
|
|
}
|
|
for (let i = 0; i < numElidedAxes; i++) {
|
|
if (i === 0) {
|
|
newStrides[ellipsisInsertionIndex] = 1;
|
|
}
|
|
else {
|
|
newStrides.splice(ellipsisInsertionIndex, 0 , 1 );
|
|
newStrides.pop();
|
|
}
|
|
}
|
|
return newStrides;
|
|
}
|
|
function unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, normalizedAxis) {
|
|
if (normalizedAxis <= ellipsisInsertionIndex) {
|
|
return normalizedAxis;
|
|
}
|
|
return normalizedAxis - (numElidedAxes - 1);
|
|
}
|
|
function getElidedAxes(numElidedAxes, ellipsisInsertionIndex) {
|
|
const elidedAxes = [];
|
|
for (let i = 0; i < numElidedAxes; i++) {
|
|
elidedAxes.push(ellipsisInsertionIndex + i);
|
|
}
|
|
return elidedAxes;
|
|
}
|
|
|
|
function getNormalizedAxes(inputShape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, beginMask, endMask, ellipsisMask) {
|
|
const inputRank = inputShape.length;
|
|
let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), normalizedStrides = new Array(inputRank);
|
|
if (ellipsisAxes.length && numInterpolatedAxes > 0) {
|
|
const fullIndex = ellipsisAxes[0];
|
|
|
|
|
|
const numElidedAxes = numInterpolatedAxes + 1;
|
|
normalizedBegin = startIndicesWithElidedDims(beginMask, fullIndex, numElidedAxes, begin, inputShape);
|
|
normalizedEnd = stopIndicesWithElidedDims(endMask, fullIndex, numElidedAxes, end, inputShape);
|
|
normalizedStrides =
|
|
stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape);
|
|
}
|
|
else {
|
|
for (let axis = 0; axis < inputRank; axis++) {
|
|
normalizedBegin[axis] = startForAxis(beginMask, begin, strides, inputShape, axis, ellipsisMask);
|
|
normalizedEnd[axis] =
|
|
stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask);
|
|
normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask);
|
|
}
|
|
}
|
|
return {
|
|
begin: normalizedBegin,
|
|
end: normalizedEnd,
|
|
strides: normalizedStrides
|
|
};
|
|
}
|
|
|
|
|
|
function startIndicesWithElidedDims(beginMask, ellipsisInsertionIndex, numElidedAxes, originalBegin, inputShape) {
|
|
const newIndices = [...inputShape];
|
|
const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
|
|
for (let axis = 0; axis < newIndices.length; axis++) {
|
|
if (elidedAxes.indexOf(axis) > -1) {
|
|
newIndices[axis] = 0;
|
|
}
|
|
else {
|
|
const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
|
|
let originalValue = originalBegin[originalAxis];
|
|
if (beginMask & 1 << originalAxis) {
|
|
originalValue = 0;
|
|
}
|
|
newIndices[axis] = originalValue;
|
|
}
|
|
}
|
|
return newIndices;
|
|
}
|
|
|
|
|
|
function stopIndicesWithElidedDims(endMask, ellipsisInsertionIndex, numElidedAxes, originalEnd, inputShape) {
|
|
const newIndices = [...inputShape];
|
|
const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);
|
|
for (let axis = 0; axis < newIndices.length; axis++) {
|
|
if (elidedAxes.indexOf(axis) > -1) {
|
|
newIndices[axis] = Number.MAX_SAFE_INTEGER;
|
|
}
|
|
else {
|
|
const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
|
|
let originalValue = originalEnd[originalAxis];
|
|
if (endMask & 1 << originalAxis) {
|
|
originalValue = Number.MAX_SAFE_INTEGER;
|
|
}
|
|
newIndices[axis] = originalValue;
|
|
}
|
|
}
|
|
for (let i = 0; i < newIndices.length; i++) {
|
|
|
|
const axisSize = inputShape[i];
|
|
if (newIndices[i] < 0) {
|
|
newIndices[i] += axisSize;
|
|
}
|
|
newIndices[i] = clamp(0, newIndices[i], inputShape[i]);
|
|
}
|
|
return newIndices;
|
|
}
|
|
function stridesForAxis(strides, axis, ellipsisMask) {
|
|
let stride = strides[axis];
|
|
if (ellipsisMask & (1 << axis) || stride == null) {
|
|
stride = 1;
|
|
}
|
|
return stride;
|
|
}
|
|
function startForAxis(beginMask, startIndices, strides, inputShape, axis, ellipsisMask) {
|
|
|
|
let start = startIndices[axis];
|
|
const stride = strides[axis] || 1;
|
|
|
|
|
|
if (beginMask & 1 << axis || ellipsisMask & 1 << axis || start == null) {
|
|
if (stride > 0) {
|
|
|
|
|
|
|
|
start = Number.MIN_SAFE_INTEGER;
|
|
}
|
|
else {
|
|
|
|
start = Number.MAX_SAFE_INTEGER;
|
|
}
|
|
}
|
|
|
|
const axisSize = inputShape[axis];
|
|
if (start < 0) {
|
|
start += axisSize;
|
|
}
|
|
|
|
start = clamp(0, start, axisSize - 1);
|
|
return start;
|
|
}
|
|
function stopForAxis(endMask, stopIndices, strides, inputShape, axis, ellipsisMask) {
|
|
|
|
let stop = stopIndices[axis];
|
|
const stride = strides[axis] || 1;
|
|
|
|
|
|
if (endMask & (1 << axis) || ellipsisMask & (1 << axis) || stop == null) {
|
|
if (stride > 0) {
|
|
|
|
|
|
stop = Number.MAX_SAFE_INTEGER;
|
|
}
|
|
else {
|
|
|
|
stop = Number.MIN_SAFE_INTEGER;
|
|
}
|
|
}
|
|
|
|
const axisSize = inputShape[axis];
|
|
if (stop < 0) {
|
|
stop += axisSize;
|
|
}
|
|
|
|
|
|
|
|
if (stride > 0) {
|
|
|
|
stop = clamp(0, stop, axisSize);
|
|
}
|
|
else {
|
|
|
|
stop = clamp(-1, stop, axisSize - 1);
|
|
}
|
|
return stop;
|
|
}
|
|
|
|
function isSliceContinous(shape, begin, size) {
|
|
|
|
let firstNonOneAxis = size.length;
|
|
for (let i = 0; i < size.length; i++) {
|
|
if (size[i] > 1) {
|
|
firstNonOneAxis = i;
|
|
break;
|
|
}
|
|
}
|
|
for (let i = firstNonOneAxis + 1; i < size.length; i++) {
|
|
if (begin[i] > 0 || size[i] !== shape[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
function computeFlatOffset(begin, strides) {
|
|
let flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1;
|
|
for (let i = 0; i < begin.length - 1; i++) {
|
|
flatOffset += begin[i] * strides[i];
|
|
}
|
|
return flatOffset;
|
|
}
|
|
function parseSliceParams(x, begin, size) {
|
|
|
|
let begin_;
|
|
const xRank = x.shape.length;
|
|
if (typeof begin === 'number') {
|
|
begin_ = [begin, ...new Array(xRank - 1).fill(0)];
|
|
}
|
|
else if (begin.length < xRank) {
|
|
begin_ = begin.concat(new Array(xRank - begin.length).fill(0));
|
|
}
|
|
else {
|
|
begin_ = begin.slice();
|
|
}
|
|
begin_.forEach(d => {
|
|
assert$1(d !== -1, () => 'slice() does not support negative begin indexing.');
|
|
});
|
|
let size_;
|
|
if (size == null) {
|
|
size_ = new Array(xRank).fill(-1);
|
|
}
|
|
else if (typeof size === 'number') {
|
|
size_ = [size, ...new Array(xRank - 1).fill(-1)];
|
|
}
|
|
else if (size.length < xRank) {
|
|
size_ = size.concat(new Array(xRank - size.length).fill(-1));
|
|
}
|
|
else {
|
|
size_ = size;
|
|
}
|
|
size_ = size_.map((d, i) => {
|
|
if (d >= 0) {
|
|
return d;
|
|
}
|
|
else {
|
|
assert$1(d === -1, () => `Negative size values should be exactly -1 but got ` +
|
|
`${d} for the slice() size at index ${i}.`);
|
|
return x.shape[i] - begin_[i];
|
|
}
|
|
});
|
|
return [begin_, size_];
|
|
}
|
|
|
|
|
|
function sliceInfo(xShape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) {
|
|
let stridesNonNull;
|
|
if (strides == null) {
|
|
stridesNonNull = new Array(begin.length);
|
|
stridesNonNull.fill(1);
|
|
}
|
|
else {
|
|
stridesNonNull = strides;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (ellipsisMask != null && (ellipsisMask & (ellipsisMask - 1)) !== 0) {
|
|
throw new Error('Multiple ellipses in slice is not allowed.');
|
|
}
|
|
|
|
|
|
let ellipsisSeen = false;
|
|
const sparseSpec = {
|
|
dims: stridesNonNull.length,
|
|
numAddAxisAfterEllipsis: 0,
|
|
begin: begin.slice(),
|
|
end: end.slice(),
|
|
strides: stridesNonNull.slice(),
|
|
beginMask,
|
|
endMask,
|
|
ellipsisMask,
|
|
newAxisMask,
|
|
shrinkAxisMask
|
|
};
|
|
for (let i = 0; i < sparseSpec.dims; i++) {
|
|
if (ellipsisSeen && ((1 << i) & newAxisMask) !== 0) {
|
|
sparseSpec.numAddAxisAfterEllipsis++;
|
|
}
|
|
if ((1 << i) & ellipsisMask) {
|
|
ellipsisSeen = true;
|
|
}
|
|
}
|
|
|
|
if (!ellipsisSeen) {
|
|
sparseSpec.ellipsisMask |= (1 << sparseSpec.dims);
|
|
sparseSpec.dims++;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const denseSpec = {
|
|
dims: xShape.length,
|
|
beginMask: 0,
|
|
endMask: 0,
|
|
beginValid: false,
|
|
endValid: false
|
|
};
|
|
buildDenseSpec(sparseSpec, denseSpec);
|
|
|
|
|
|
let isIdentity = true;
|
|
let sliceDim0 = true;
|
|
let isSimpleSlice = true;
|
|
const processingShape = [];
|
|
const finalShape = [];
|
|
for (let i = 0; i < xShape.length; ++i) {
|
|
if (denseSpec.strides[i] === 0) {
|
|
throw Error(`strides[${i}] must be non-zero`);
|
|
}
|
|
const shrinkI = !!(denseSpec.shrinkAxisMask & (1 << i));
|
|
const dimI = xShape[i];
|
|
if (dimI === -1) {
|
|
processingShape.push(shrinkI ? 1 : -1);
|
|
continue;
|
|
}
|
|
const masks = [denseSpec.beginMask & (1 << i), denseSpec.endMask & (1 << i)];
|
|
const validRange = [
|
|
denseSpec.strides[i] > 0 ? 0 : -1,
|
|
denseSpec.strides[i] > 0 ? dimI : dimI - 1
|
|
];
|
|
if (shrinkI && denseSpec.strides[i] <= 0) {
|
|
throw Error('only stride 1 allowed on non-range indexing.');
|
|
}
|
|
isSimpleSlice = isSimpleSlice && (denseSpec.strides[i] === 1);
|
|
const beginAndEndMasked = !!((denseSpec.beginMask & (1 << i)) && (denseSpec.endMask & (1 << i)));
|
|
if (denseSpec.beginValid && denseSpec.endValid) {
|
|
if (shrinkI) {
|
|
|
|
|
|
|
|
|
|
const xFwd = denseSpec.begin[i] < 0 ? dimI + denseSpec.begin[i] :
|
|
denseSpec.begin[i];
|
|
denseSpec.begin[i] = xFwd;
|
|
denseSpec.end[i] = denseSpec.begin[i] + 1;
|
|
if (xFwd < 0 || xFwd >= dimI) {
|
|
throw Error(`slice index ${denseSpec.begin[i]} of dimension ${i} out of bounds.`);
|
|
}
|
|
}
|
|
else {
|
|
denseSpec.begin[i] = canonical(denseSpec.begin[i], 0, denseSpec.strides[i], dimI, masks, validRange);
|
|
denseSpec.end[i] = canonical(denseSpec.end[i], 1, denseSpec.strides[i], dimI, masks, validRange);
|
|
}
|
|
|
|
const takeAllInDimension = denseSpec.strides[i] === 1 &&
|
|
denseSpec.begin[i] === 0 && denseSpec.end[i] === dimI;
|
|
isIdentity = isIdentity && takeAllInDimension;
|
|
sliceDim0 = sliceDim0 &&
|
|
((i === 0 && denseSpec.strides[i] === 1) || takeAllInDimension);
|
|
}
|
|
else {
|
|
isIdentity =
|
|
isIdentity && ((denseSpec.strides[i] === 1) && beginAndEndMasked);
|
|
sliceDim0 = sliceDim0 &&
|
|
((i === 0 && denseSpec.strides[i] === 1) || beginAndEndMasked);
|
|
}
|
|
|
|
let intervalLength;
|
|
let knownInterval = false;
|
|
if (denseSpec.beginValid && denseSpec.endValid) {
|
|
intervalLength = denseSpec.end[i] - denseSpec.begin[i];
|
|
knownInterval = true;
|
|
}
|
|
else if (shrinkI) {
|
|
|
|
|
|
intervalLength = 1;
|
|
knownInterval = true;
|
|
}
|
|
else if (beginAndEndMasked) {
|
|
|
|
|
|
|
|
if (dimI >= 0) {
|
|
if (denseSpec.strides[i] < 0) {
|
|
intervalLength = -dimI;
|
|
}
|
|
else {
|
|
intervalLength = dimI;
|
|
}
|
|
knownInterval = true;
|
|
}
|
|
}
|
|
if (knownInterval) {
|
|
let sizeI;
|
|
|
|
|
|
if (intervalLength === 0 ||
|
|
((intervalLength < 0) !== (denseSpec.strides[i] < 0))) {
|
|
sizeI = 0;
|
|
}
|
|
else {
|
|
sizeI = Math.trunc(intervalLength / denseSpec.strides[i]) +
|
|
(intervalLength % denseSpec.strides[i] !== 0 ? 1 : 0);
|
|
}
|
|
processingShape.push(sizeI);
|
|
}
|
|
else {
|
|
processingShape.push(-1);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (let denseDim = 0; denseDim < denseSpec.finalShapeGatherIndices.length; ++denseDim) {
|
|
const gatherIndex = denseSpec.finalShapeGatherIndices[denseDim];
|
|
if (gatherIndex >= 0) {
|
|
finalShape.push(processingShape[gatherIndex]);
|
|
}
|
|
else if (gatherIndex === NEW_AXIS) {
|
|
finalShape.push(1);
|
|
}
|
|
}
|
|
const finalShapeSparse = finalShape.filter((dim, i) => denseSpec.finalShapeGatherIndices[i] !== NEW_AXIS);
|
|
return {
|
|
finalShapeSparse,
|
|
finalShape,
|
|
isIdentity,
|
|
sliceDim0,
|
|
isSimpleSlice,
|
|
begin: denseSpec.begin,
|
|
end: denseSpec.end,
|
|
strides: denseSpec.strides
|
|
};
|
|
}
|
|
function buildDenseSpec(sparse, dense) {
|
|
dense.beginMask = 0;
|
|
dense.endMask = 0;
|
|
dense.shrinkAxisMask = 0;
|
|
let fullIndex = 0;
|
|
dense.beginValid = sparse.begin != null;
|
|
dense.endValid = sparse.end != null;
|
|
dense.begin = new Array(dense.dims);
|
|
dense.end = new Array(dense.dims);
|
|
dense.strides = new Array(dense.dims);
|
|
dense.finalShapeGatherIndices = [];
|
|
dense.finalShapeGatherIndicesSparse = [];
|
|
dense.inputShapeGatherIndicesSparse = new Array(dense.dims);
|
|
for (let i = 0; i < sparse.dims; i++) {
|
|
if ((1 << i) & sparse.ellipsisMask) {
|
|
|
|
|
|
|
|
const nextIndex = Math.min(dense.dims - (sparse.dims - i) + 1 + sparse.numAddAxisAfterEllipsis, dense.dims);
|
|
for (; fullIndex < nextIndex; fullIndex++) {
|
|
|
|
dense.begin[fullIndex] = 0;
|
|
dense.end[fullIndex] = 0;
|
|
dense.strides[fullIndex] = 1;
|
|
dense.beginMask |= (1 << fullIndex);
|
|
dense.endMask |= (1 << fullIndex);
|
|
dense.finalShapeGatherIndices.push(fullIndex);
|
|
dense.finalShapeGatherIndicesSparse.push(-1);
|
|
dense.inputShapeGatherIndicesSparse[fullIndex] = i;
|
|
}
|
|
}
|
|
else if ((1 << i) & sparse.newAxisMask) {
|
|
|
|
dense.finalShapeGatherIndices.push(NEW_AXIS);
|
|
dense.finalShapeGatherIndicesSparse.push(-1);
|
|
}
|
|
else {
|
|
if (fullIndex === dense.begin.length) {
|
|
throw Error(`Index out of range using input dim ${fullIndex}; input ` +
|
|
`has only ${dense.dims} dims, ${dense.begin.length}.`);
|
|
}
|
|
|
|
if (sparse.begin != null) {
|
|
dense.begin[fullIndex] = sparse.begin[i];
|
|
}
|
|
if (sparse.end != null) {
|
|
dense.end[fullIndex] = sparse.end[i];
|
|
}
|
|
dense.strides[fullIndex] = sparse.strides[i];
|
|
if (sparse.beginMask & (1 << i)) {
|
|
dense.beginMask |= (1 << fullIndex);
|
|
}
|
|
if (sparse.endMask & (1 << i)) {
|
|
dense.endMask |= (1 << fullIndex);
|
|
}
|
|
|
|
|
|
|
|
if (sparse.shrinkAxisMask & (1 << i)) {
|
|
dense.finalShapeGatherIndices.push(SHRINK_AXIS);
|
|
dense.finalShapeGatherIndicesSparse.push(-1);
|
|
dense.shrinkAxisMask |= (1 << fullIndex);
|
|
}
|
|
else {
|
|
dense.finalShapeGatherIndices.push(fullIndex);
|
|
|
|
dense.finalShapeGatherIndicesSparse.push(i);
|
|
}
|
|
dense.inputShapeGatherIndicesSparse[fullIndex] = i;
|
|
fullIndex++;
|
|
}
|
|
}
|
|
}
|
|
function canonical(x, c, strideI, dimI, masks, validRange) {
|
|
if (masks[c]) {
|
|
return strideI > 0 ? validRange[c] : validRange[(c + 1) & 1];
|
|
}
|
|
else {
|
|
const xFwd = x < 0 ? dimI + x : x;
|
|
return xFwd < validRange[0] ? validRange[0] :
|
|
xFwd > validRange[1] ? validRange[1] : xFwd;
|
|
}
|
|
}
|
|
|
|
var slice_util = Object.freeze({
|
|
__proto__: null,
|
|
assertParamsValid: assertParamsValid,
|
|
computeFlatOffset: computeFlatOffset,
|
|
computeOutShape: computeOutShape$2,
|
|
getNormalizedAxes: getNormalizedAxes,
|
|
isSliceContinous: isSliceContinous,
|
|
maskToAxes: maskToAxes,
|
|
parseSliceParams: parseSliceParams,
|
|
sliceInfo: sliceInfo,
|
|
startForAxis: startForAxis,
|
|
startIndicesWithElidedDims: startIndicesWithElidedDims,
|
|
stopForAxis: stopForAxis,
|
|
stopIndicesWithElidedDims: stopIndicesWithElidedDims,
|
|
stridesForAxis: stridesForAxis,
|
|
stridesWithElidedDims: stridesWithElidedDims
|
|
});
|
|
|
|
|
|
class OptimizerConstructors {
|
|
|
|
static sgd(learningRate) {
|
|
return new SGDOptimizer(learningRate);
|
|
}
|
|
|
|
static momentum(learningRate, momentum, useNesterov = false) {
|
|
return new MomentumOptimizer(learningRate, momentum, useNesterov);
|
|
}
|
|
|
|
static rmsprop(learningRate, decay = .9, momentum = 0.0, epsilon = null, centered = false) {
|
|
return new RMSPropOptimizer(learningRate, decay, momentum, epsilon, centered);
|
|
}
|
|
|
|
static adam(learningRate = 0.001, beta1 = 0.9, beta2 = 0.999, epsilon = null) {
|
|
return new AdamOptimizer(learningRate, beta1, beta2, epsilon);
|
|
}
|
|
|
|
static adadelta(learningRate = .001, rho = .95, epsilon = null) {
|
|
return new AdadeltaOptimizer(learningRate, rho, epsilon);
|
|
}
|
|
|
|
static adamax(learningRate = 0.002, beta1 = 0.9, beta2 = 0.999, epsilon = null, decay = 0.0) {
|
|
return new AdamaxOptimizer(learningRate, beta1, beta2, epsilon, decay);
|
|
}
|
|
|
|
static adagrad(learningRate, initialAccumulatorValue = 0.1) {
|
|
return new AdagradOptimizer(learningRate, initialAccumulatorValue);
|
|
}
|
|
}
|
|
|
|
|
|
const train = OptimizerConstructors;
|
|
|
|
|
|
const delayCallback = (() => {
|
|
if (typeof requestAnimationFrame !== 'undefined') {
|
|
return requestAnimationFrame;
|
|
}
|
|
else if (typeof setImmediate !== 'undefined') {
|
|
return setImmediate;
|
|
}
|
|
return (f) => f();
|
|
})();
|
|
|
|
function nextFrame() {
|
|
return new Promise(resolve => delayCallback(() => resolve()));
|
|
}
|
|
|
|
|
|
function assertParamsConsistent(shapes, axis) {
|
|
const rank = shapes[0].length;
|
|
shapes.forEach((shape, i) => {
|
|
assert$1(shape.length === rank, () => `Error in concat${rank}D: rank of tensors[${i}] must be the same ` +
|
|
`as the rank of the rest (${rank})`);
|
|
});
|
|
assert$1(axis >= 0 && axis < rank, () => `Error in concat${rank}D: axis must be between 0 and ${rank - 1}.`);
|
|
const firstShape = shapes[0];
|
|
shapes.forEach((shape, i) => {
|
|
for (let r = 0; r < rank; r++) {
|
|
assert$1((r === axis) || (shape[r] === firstShape[r]), () => `Error in concat${rank}D: Shape of tensors[${i}] (${shape}) ` +
|
|
`does not match the shape of the rest (${firstShape}) ` +
|
|
`along the non-concatenated axis ${i}.`);
|
|
}
|
|
});
|
|
}
|
|
function computeOutShape$1(shapes, axis) {
|
|
const outputShape = shapes[0].slice();
|
|
for (let i = 1; i < shapes.length; i++) {
|
|
outputShape[axis] += shapes[i][axis];
|
|
}
|
|
return outputShape;
|
|
}
|
|
|
|
|
|
var RowPartitionType$1;
|
|
(function (RowPartitionType) {
|
|
RowPartitionType[RowPartitionType["FIRST_DIM_SIZE"] = 0] = "FIRST_DIM_SIZE";
|
|
RowPartitionType[RowPartitionType["VALUE_ROWIDS"] = 1] = "VALUE_ROWIDS";
|
|
RowPartitionType[RowPartitionType["ROW_LENGTHS"] = 2] = "ROW_LENGTHS";
|
|
RowPartitionType[RowPartitionType["ROW_SPLITS"] = 3] = "ROW_SPLITS";
|
|
RowPartitionType[RowPartitionType["ROW_LIMITS"] = 4] = "ROW_LIMITS";
|
|
RowPartitionType[RowPartitionType["ROW_STARTS"] = 5] = "ROW_STARTS";
|
|
})(RowPartitionType$1 || (RowPartitionType$1 = {}));
|
|
function combineRaggedTensorToTensorShapes(raggedRank, shape, valueShape) {
|
|
|
|
|
|
|
|
let outputShape = new Array();
|
|
if (valueShape == null && shape == null) {
|
|
return outputShape;
|
|
}
|
|
if (shape == null) {
|
|
|
|
while (outputShape.length < raggedRank + valueShape.length) {
|
|
outputShape.push(-1);
|
|
}
|
|
}
|
|
else {
|
|
outputShape = shape.slice();
|
|
}
|
|
if (valueShape == null) {
|
|
return outputShape;
|
|
}
|
|
|
|
if (raggedRank + valueShape.length !== outputShape.length) {
|
|
throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.rank = ${raggedRank +
|
|
valueShape.length}, but shape.rank = ${outputShape.length}`);
|
|
}
|
|
for (let i = 1; i < valueShape.length; ++i) {
|
|
const valueDim = valueShape[i];
|
|
const outputShapeDimIndex = outputShape[outputShape.length - valueShape.length + i];
|
|
const outputShapeDim = outputShape[outputShapeDimIndex];
|
|
if (valueDim >= 0) {
|
|
if (outputShapeDim >= 0) {
|
|
if (outputShapeDim !== valueDim) {
|
|
throw new Error(`rt input.shape and shape=${shape} are incompatible: rt input.shape[${i + raggedRank}] = ${valueDim} but shape[${i + raggedRank}] = ${outputShapeDim}`);
|
|
}
|
|
}
|
|
else {
|
|
outputShape[outputShapeDimIndex] = valueDim;
|
|
}
|
|
}
|
|
}
|
|
return outputShape;
|
|
}
|
|
function getRowPartitionTypesHelper(rowPartitionTypeStrings) {
|
|
const stringToType = {
|
|
'FIRST_DIM_SIZE': RowPartitionType$1.FIRST_DIM_SIZE,
|
|
'VALUE_ROWIDS': RowPartitionType$1.VALUE_ROWIDS,
|
|
'ROW_LENGTHS': RowPartitionType$1.ROW_LENGTHS,
|
|
'ROW_SPLITS': RowPartitionType$1.ROW_SPLITS,
|
|
'ROW_LIMITS': RowPartitionType$1.ROW_LIMITS,
|
|
'ROW_STARTS': RowPartitionType$1.ROW_STARTS
|
|
};
|
|
const result = [];
|
|
for (const typeStr of rowPartitionTypeStrings) {
|
|
if (typeStr in stringToType) {
|
|
result.push(stringToType[typeStr]);
|
|
}
|
|
else {
|
|
break;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
function getRaggedRank(rowPartitionTypes) {
|
|
if (rowPartitionTypes.length === 0) {
|
|
return 0;
|
|
}
|
|
if (rowPartitionTypes[0] === RowPartitionType$1.FIRST_DIM_SIZE) {
|
|
return rowPartitionTypes.length - 1;
|
|
}
|
|
return rowPartitionTypes.length;
|
|
}
|
|
function validateDefaultValueShape(defaultValueShape, valueShape) {
|
|
if (defaultValueShape == null || valueShape == null) {
|
|
return;
|
|
}
|
|
const defaultNDims = defaultValueShape.length;
|
|
const valuesNDims = valueShape.length;
|
|
if (defaultNDims >= valuesNDims) {
|
|
throw new Error(`defaultValue.shape=${defaultValueShape} and ragged tensor flatValues.shape=${valueShape}, are incompatible: defaultValue.rank = ${defaultNDims} must be less than ragged tensor input flatValues.rank = ${valuesNDims})`);
|
|
}
|
|
for (let i = 0; i < Math.min(defaultNDims, valuesNDims - 1); ++i) {
|
|
const defaultDim = defaultValueShape[i];
|
|
const valueDim = valueShape[i + 1];
|
|
if (defaultDim >= 0 && valueDim >= 0 && defaultDim !== 1 &&
|
|
defaultDim !== valueDim) {
|
|
throw new Error(`defaultValue.shape=${defaultValueShape}, and ragged tensor input flatValues.shape=${valueShape} are incompatible: defaultValue.shape[${i - defaultValueShape.length}] = ${defaultDim} but ragged tensor input.flatValues.shape[${i - defaultValueShape.length}] = ${valueDim}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const PARALLELIZE_THRESHOLD = 30;
|
|
function computeOptimalWindowSize(inSize) {
|
|
if (inSize <= PARALLELIZE_THRESHOLD) {
|
|
return inSize;
|
|
}
|
|
return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
|
|
}
|
|
|
|
|
|
|
|
function getImageCenter(center, imageHeight, imageWidth) {
|
|
const centerX = imageWidth * (typeof center === 'number' ? center : center[0]);
|
|
const centerY = imageHeight * (typeof center === 'number' ? center : center[1]);
|
|
return [centerX, centerY];
|
|
}
|
|
|
|
|
|
|
|
function getReshaped(inputShape, blockShape, prod, batchToSpace = true) {
|
|
let reshaped = [];
|
|
if (batchToSpace) {
|
|
reshaped = reshaped.concat(blockShape.slice(0));
|
|
reshaped.push(inputShape[0] / prod);
|
|
reshaped = reshaped.concat(inputShape.slice(1));
|
|
}
|
|
else {
|
|
reshaped = reshaped.concat(inputShape[0]);
|
|
const spatialLength = blockShape.length;
|
|
for (let i = 0; i < spatialLength; ++i) {
|
|
reshaped =
|
|
reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]);
|
|
}
|
|
reshaped = reshaped.concat(inputShape.slice(spatialLength + 1));
|
|
}
|
|
return reshaped;
|
|
}
|
|
|
|
function getPermuted(reshapedRank, blockShapeRank, batchToSpace = true) {
|
|
const permuted = [];
|
|
if (batchToSpace) {
|
|
permuted.push(blockShapeRank);
|
|
for (let i = blockShapeRank + 1; i < reshapedRank; ++i) {
|
|
if (i <= 2 * blockShapeRank) {
|
|
permuted.push(i);
|
|
permuted.push(i - (blockShapeRank + 1));
|
|
}
|
|
else {
|
|
permuted.push(i);
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
const permutedBeforeBatch = [];
|
|
const permutedAfterBatch = [];
|
|
for (let i = 1; i < reshapedRank; ++i) {
|
|
if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) {
|
|
permutedAfterBatch.push(i);
|
|
}
|
|
else {
|
|
permutedBeforeBatch.push(i);
|
|
}
|
|
}
|
|
permuted.push(...permutedBeforeBatch);
|
|
permuted.push(0);
|
|
permuted.push(...permutedAfterBatch);
|
|
}
|
|
return permuted;
|
|
}
|
|
|
|
function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace = true) {
|
|
const reshapedPermuted = [];
|
|
if (batchToSpace) {
|
|
reshapedPermuted.push(inputShape[0] / prod);
|
|
}
|
|
else {
|
|
reshapedPermuted.push(inputShape[0] * prod);
|
|
}
|
|
for (let i = 1; i < inputShape.length; ++i) {
|
|
if (i <= blockShape.length) {
|
|
if (batchToSpace) {
|
|
reshapedPermuted.push(blockShape[i - 1] * inputShape[i]);
|
|
}
|
|
else {
|
|
reshapedPermuted.push(inputShape[i] / blockShape[i - 1]);
|
|
}
|
|
}
|
|
else {
|
|
reshapedPermuted.push(inputShape[i]);
|
|
}
|
|
}
|
|
return reshapedPermuted;
|
|
}
|
|
|
|
function getSliceBeginCoords(crops, blockShape) {
|
|
const sliceBeginCoords = [0];
|
|
for (let i = 0; i < blockShape; ++i) {
|
|
sliceBeginCoords.push(crops[i][0]);
|
|
}
|
|
return sliceBeginCoords;
|
|
}
|
|
|
|
function getSliceSize(uncroppedShape, crops, blockShape) {
|
|
const sliceSize = uncroppedShape.slice(0, 1);
|
|
for (let i = 0; i < blockShape; ++i) {
|
|
sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]);
|
|
}
|
|
return sliceSize;
|
|
}
|
|
|
|
|
|
const SELU_SCALEALPHA = 1.7580993408473768599402175208123;
|
|
const SELU_SCALE = 1.0507009873554804934193349852946;
|
|
|
|
|
|
const ERF_P = 0.3275911;
|
|
const ERF_A1 = 0.254829592;
|
|
const ERF_A2 = -0.284496736;
|
|
const ERF_A3 = 1.421413741;
|
|
const ERF_A4 = -1.453152027;
|
|
const ERF_A5 = 1.061405429;
|
|
|
|
|
|
|
|
function mergeRealAndImagArrays(real, imag) {
|
|
if (real.length !== imag.length) {
|
|
throw new Error(`Cannot merge real and imag arrays of different lengths. real:` +
|
|
`${real.length}, imag: ${imag.length}.`);
|
|
}
|
|
const result = new Float32Array(real.length * 2);
|
|
for (let i = 0; i < result.length; i += 2) {
|
|
result[i] = real[i / 2];
|
|
result[i + 1] = imag[i / 2];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
function splitRealAndImagArrays(complex) {
|
|
const real = new Float32Array(complex.length / 2);
|
|
const imag = new Float32Array(complex.length / 2);
|
|
for (let i = 0; i < complex.length; i += 2) {
|
|
real[i / 2] = complex[i];
|
|
imag[i / 2] = complex[i + 1];
|
|
}
|
|
return { real, imag };
|
|
}
|
|
|
|
function complexWithEvenIndex(complex) {
|
|
const len = Math.ceil(complex.length / 4);
|
|
const real = new Float32Array(len);
|
|
const imag = new Float32Array(len);
|
|
for (let i = 0; i < complex.length; i += 4) {
|
|
real[Math.floor(i / 4)] = complex[i];
|
|
imag[Math.floor(i / 4)] = complex[i + 1];
|
|
}
|
|
return { real, imag };
|
|
}
|
|
|
|
function complexWithOddIndex(complex) {
|
|
const len = Math.floor(complex.length / 4);
|
|
const real = new Float32Array(len);
|
|
const imag = new Float32Array(len);
|
|
for (let i = 2; i < complex.length; i += 4) {
|
|
real[Math.floor(i / 4)] = complex[i];
|
|
imag[Math.floor(i / 4)] = complex[i + 1];
|
|
}
|
|
return { real, imag };
|
|
}
|
|
|
|
function getComplexWithIndex(complex, index) {
|
|
const real = complex[index * 2];
|
|
const imag = complex[index * 2 + 1];
|
|
return { real, imag };
|
|
}
|
|
|
|
function assignToTypedArray(data, real, imag, index) {
|
|
data[index * 2] = real;
|
|
data[index * 2 + 1] = imag;
|
|
}
|
|
|
|
function exponents(n, inverse) {
|
|
const real = new Float32Array(n / 2);
|
|
const imag = new Float32Array(n / 2);
|
|
for (let i = 0; i < Math.ceil(n / 2); i++) {
|
|
const x = (inverse ? 2 : -2) * Math.PI * (i / n);
|
|
real[i] = Math.cos(x);
|
|
imag[i] = Math.sin(x);
|
|
}
|
|
return { real, imag };
|
|
}
|
|
|
|
function exponent(k, n, inverse) {
|
|
const x = (inverse ? 2 : -2) * Math.PI * (k / n);
|
|
const real = Math.cos(x);
|
|
const imag = Math.sin(x);
|
|
return { real, imag };
|
|
}
|
|
|
|
|
|
const ARROW = '->';
|
|
const ARROW_REGEX = /->/g;
|
|
const COMMA = ',';
|
|
const ELLIPSIS = '...';
|
|
|
|
function decodeEinsumEquation(equation, numTensors) {
|
|
equation = equation.replace(/\s/g, '');
|
|
const numArrows = (equation.length - equation.replace(ARROW_REGEX, '').length) /
|
|
ARROW.length;
|
|
if (numArrows < 1) {
|
|
throw new Error('Equations without an arrow are not supported.');
|
|
}
|
|
else if (numArrows > 1) {
|
|
throw new Error(`Equation must contain exactly one arrow ("${ARROW}").`);
|
|
}
|
|
const [inputString, outputString] = equation.split(ARROW);
|
|
assert$1(inputString.indexOf(ELLIPSIS) === -1, () => `The ellipsis notation ("${ELLIPSIS}") is not supported yet.`);
|
|
const inputTerms = inputString.split(COMMA);
|
|
const numInputs = inputTerms.length;
|
|
if (numTensors !== numInputs) {
|
|
throw new Error(`Expected ${numInputs} input tensors, received ${numTensors}`);
|
|
}
|
|
if (numInputs > 2) {
|
|
throw new Error('Support for more than 2 input tensors is not implemented yet.');
|
|
}
|
|
const allDims = [];
|
|
for (let i = 0; i < outputString.length; ++i) {
|
|
const dimName = outputString[i];
|
|
if (!inputTerms.some(inputTerm => inputTerm.indexOf(dimName) !== -1)) {
|
|
throw new Error(`Output subscripts contain the label ${dimName} ` +
|
|
`not present in the input subscripts.`);
|
|
}
|
|
if (allDims.indexOf(dimName) === -1) {
|
|
allDims.push(dimName);
|
|
}
|
|
}
|
|
for (let i = 0; i < inputString.length; ++i) {
|
|
const dimName = inputString[i];
|
|
if (allDims.indexOf(dimName) === -1 && dimName !== COMMA) {
|
|
allDims.push(dimName);
|
|
}
|
|
}
|
|
const idDims = new Array(inputTerms.length);
|
|
for (let i = 0; i < numInputs; ++i) {
|
|
if (new Set(inputTerms[i].split('')).size !== inputTerms[i].length) {
|
|
throw new Error(`Found duplicate axes in input component ${inputTerms[i]}. ` +
|
|
`Support for duplicate axes in input is not implemented yet.`);
|
|
}
|
|
idDims[i] = [];
|
|
for (let j = 0; j < inputTerms[i].length; ++j) {
|
|
idDims[i].push(allDims.indexOf(inputTerms[i][j]));
|
|
}
|
|
}
|
|
const numDims = allDims.length;
|
|
const numOutDims = outputString.length;
|
|
const summedDims = [];
|
|
for (let i = numOutDims; i < numDims; ++i) {
|
|
summedDims.push(i);
|
|
}
|
|
return { allDims, summedDims, idDims };
|
|
}
|
|
|
|
function getEinsumPermutation(nDims, idDims) {
|
|
let permutationIndices = new Array(nDims);
|
|
permutationIndices.fill(-1);
|
|
for (let i = 0; i < idDims.length; ++i) {
|
|
permutationIndices[idDims[i]] = i;
|
|
}
|
|
const expandDims = [];
|
|
for (let i = 0; i < nDims; ++i) {
|
|
if (permutationIndices[i] === -1) {
|
|
expandDims.push(i);
|
|
}
|
|
}
|
|
permutationIndices = permutationIndices.filter(d => d !== -1);
|
|
return { permutationIndices, expandDims };
|
|
}
|
|
|
|
function checkEinsumDimSizes(nDims, idDims, tensors) {
|
|
const dimSizes = new Array(nDims);
|
|
for (let i = 0; i < tensors.length; ++i) {
|
|
const shape = tensors[i].shape;
|
|
for (let j = 0; j < idDims[i].length; ++j) {
|
|
if (dimSizes[idDims[i][j]] === undefined) {
|
|
dimSizes[idDims[i][j]] = shape[j];
|
|
}
|
|
else {
|
|
assert$1(dimSizes[idDims[i][j]] === shape[j], () => `Expected dimension ${dimSizes[idDims[i][j]]} at axis ${j} ` +
|
|
`of input shaped ${JSON.stringify(shape)}, ` +
|
|
`but got dimension ${shape[j]}`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
function getEinsumComputePath(summedDims, idDims) {
|
|
const path = summedDims;
|
|
const steps = [];
|
|
let nSteps = 0;
|
|
if (summedDims.length === 0) {
|
|
|
|
path.push(-1);
|
|
}
|
|
nSteps = summedDims.length + 1;
|
|
for (let i = 0; i < nSteps; ++i) {
|
|
steps.push([]);
|
|
}
|
|
const computedTermIndices = [];
|
|
for (let i = 0; i < path.length; ++i) {
|
|
const summedDim = path[i];
|
|
const termIndices = findTermsWithDim(idDims, summedDim);
|
|
for (const termIndex of termIndices) {
|
|
if (computedTermIndices.indexOf(termIndex) === -1) {
|
|
steps[i].push(termIndex);
|
|
computedTermIndices.push(termIndex);
|
|
}
|
|
}
|
|
}
|
|
return { path, steps };
|
|
}
|
|
|
|
function isIdentityPermutation(perm) {
|
|
return perm.every((dim, index) => dim === index);
|
|
}
|
|
function findTermsWithDim(idDims, dim) {
|
|
const termIndices = [];
|
|
for (let i = 0; i < idDims.length; ++i) {
|
|
if (idDims[i].length === 0 || idDims[i].indexOf(dim) !== -1 || dim === -1) {
|
|
termIndices.push(i);
|
|
}
|
|
}
|
|
return termIndices;
|
|
}
|
|
|
|
|
|
function prepareSplitSize(x, numOrSizeSplits, axis = 0) {
|
|
let splitSizes = [];
|
|
if (typeof (numOrSizeSplits) === 'number') {
|
|
assert$1(x.shape[axis] % numOrSizeSplits === 0, () => 'Number of splits must evenly divide the axis.');
|
|
splitSizes =
|
|
new Array(numOrSizeSplits).fill(x.shape[axis] / numOrSizeSplits);
|
|
}
|
|
else {
|
|
const numOfNegs = numOrSizeSplits.reduce((count, value) => {
|
|
if (value === -1) {
|
|
count += 1;
|
|
}
|
|
return count;
|
|
}, 0);
|
|
assert$1(numOfNegs <= 1, () => 'There should be only one negative value in split array.');
|
|
const negIndex = numOrSizeSplits.indexOf(-1);
|
|
|
|
|
|
if (negIndex !== -1) {
|
|
const total = numOrSizeSplits.reduce((a, b) => b > 0 ? a + b : a);
|
|
numOrSizeSplits[negIndex] = x.shape[axis] - total;
|
|
}
|
|
assert$1(x.shape[axis] === numOrSizeSplits.reduce((a, b) => a + b), () => 'The sum of sizes must match the size of the axis dimension.');
|
|
splitSizes = numOrSizeSplits;
|
|
}
|
|
return splitSizes;
|
|
}
|
|
|
|
|
|
|
|
function getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesLength) {
|
|
return `Received SparseTensor with denseShape[0] = 0 but
|
|
indices.shape[0] = ${indicesLength}`;
|
|
}
|
|
|
|
function getSparseFillEmptyRowsNegativeIndexErrorMessage(index, value) {
|
|
return `indices(${index}, 0) is invalid: ${value} < 0`;
|
|
}
|
|
|
|
function getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(index, value, limit) {
|
|
return `indices(${index}, 0) is invalid: ${value} >= ${limit}`;
|
|
}
|
|
|
|
|
|
|
|
function getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(dim1, dim2) {
|
|
return `only one output dimension may be -1, not both ${dim1} and ${dim2}`;
|
|
}
|
|
|
|
function getSparseReshapeNegativeOutputDimErrorMessage(dim, value) {
|
|
return `size ${dim} must be non-negative, not ${value}`;
|
|
}
|
|
|
|
function getSparseReshapeEmptyTensorZeroOutputDimErrorMessage() {
|
|
return 'reshape cannot infer the missing input size for an empty tensor ' +
|
|
'unless all specified input sizes are non-zero';
|
|
}
|
|
|
|
function getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape) {
|
|
const inputSize = sizeFromShape(inputShape);
|
|
const outputSize = sizeFromShape(outputShape);
|
|
return `Input to reshape is a SparseTensor with ${inputSize}
|
|
dense values, but the requested shape requires a multiple of ${outputSize}. inputShape=${inputShape} outputShape= ${outputShape}`;
|
|
}
|
|
|
|
function getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape) {
|
|
const inputSize = sizeFromShape(inputShape);
|
|
const outputSize = sizeFromShape(outputShape);
|
|
return `Input to reshape is a tensor with ${inputSize} dense values, but the requested shape has ${outputSize}. inputShape=${inputShape} outputShape=${outputShape}`;
|
|
}
|
|
|
|
|
|
|
|
function getSparseSegmentReductionNegativeSegmentIdsErrorMessage() {
|
|
return `segment ids must be >= 0`;
|
|
}
|
|
|
|
function getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage() {
|
|
return `segment ids are not increasing`;
|
|
}
|
|
|
|
function getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(segmentId, outputRows) {
|
|
return `Segment id ${segmentId} out of range [0, ${outputRows}), possibly because segmentIds input is not sorted.`;
|
|
}
|
|
|
|
function getSparseSegmentReductionIndicesOutOfRangeErrorMessage(index, indexValue, inputRows) {
|
|
return `Bad: indices[${index}] == ${indexValue} out of range [0, ${inputRows})`;
|
|
}
|
|
|
|
|
|
function segOpComputeOptimalWindowSize(inSize, numSegments) {
|
|
let done = false;
|
|
let res;
|
|
if (inSize <= PARALLELIZE_THRESHOLD) {
|
|
res = inSize;
|
|
done = true;
|
|
}
|
|
else {
|
|
res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize)));
|
|
}
|
|
while (!done) {
|
|
if (res > numSegments || res === inSize) {
|
|
done = true;
|
|
}
|
|
else {
|
|
res = nearestDivisor(inSize, res + 1);
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
function computeOutShape(aShape, axis, numSegments) {
|
|
const outShape = [];
|
|
const rank = aShape.length;
|
|
for (let dim = 0; dim < rank; dim++) {
|
|
if (dim !== axis) {
|
|
outShape.push(aShape[dim]);
|
|
}
|
|
else {
|
|
outShape.push(numSegments);
|
|
}
|
|
}
|
|
return outShape;
|
|
}
|
|
function collectGatherOpShapeInfo(x, indices, axis, batchDims) {
|
|
const indicesRank = indices.shape.length;
|
|
const xRank = x.shape.length;
|
|
if (batchDims !== 0) {
|
|
if (batchDims < -indicesRank || batchDims > indicesRank) {
|
|
throw new Error(`Expect batchDims in the range of [-${indicesRank}, ${indicesRank}], but got ${batchDims}`);
|
|
}
|
|
}
|
|
if (batchDims < 0) {
|
|
batchDims += indicesRank;
|
|
}
|
|
if (batchDims > xRank) {
|
|
throw new Error(`batchDims (${batchDims}) must be less than rank(x) (
|
|
${xRank}).`);
|
|
}
|
|
if (axis < batchDims) {
|
|
throw new Error(`batchDims (${batchDims}) must be less than or equal to axis (${axis}).`);
|
|
}
|
|
for (let i = 0; i < batchDims; ++i) {
|
|
if (x.shape[i] !== indices.shape[i]) {
|
|
throw new Error(`x.shape[${i}]: ${x.shape[i]} should be equal to indices.shape[${i}]: ${indices.shape[i]}.`);
|
|
}
|
|
}
|
|
const dimSize = x.shape[axis];
|
|
const outputShape = [];
|
|
let batchSize = 1;
|
|
let outerSize = 1;
|
|
let sliceSize = 1;
|
|
for (let i = 0; i < batchDims; ++i) {
|
|
outputShape.push(x.shape[i]);
|
|
batchSize *= x.shape[i];
|
|
}
|
|
for (let i = batchDims; i < axis; i++) {
|
|
outputShape.push(x.shape[i]);
|
|
outerSize *= x.shape[i];
|
|
}
|
|
for (let i = batchDims; i < indicesRank; i++) {
|
|
outputShape.push(indices.shape[i]);
|
|
}
|
|
for (let i = axis + 1; i < xRank; i++) {
|
|
outputShape.push(x.shape[i]);
|
|
sliceSize *= x.shape[i];
|
|
}
|
|
return { batchSize, sliceSize, outerSize, dimSize, outputShape };
|
|
}
|
|
|
|
var segment_util = Object.freeze({
|
|
__proto__: null,
|
|
collectGatherOpShapeInfo: collectGatherOpShapeInfo,
|
|
computeOutShape: computeOutShape,
|
|
segOpComputeOptimalWindowSize: segOpComputeOptimalWindowSize
|
|
});
|
|
|
|
|
|
function fromUint8ToStringArray(vals) {
|
|
try {
|
|
|
|
return vals.map(val => decodeString(val));
|
|
}
|
|
catch (err) {
|
|
throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${err}`);
|
|
}
|
|
}
|
|
function fromStringArrayToUint8(strings) {
|
|
return strings.map(s => encodeString(s));
|
|
}
|
|
|
|
var backend_util = Object.freeze({
|
|
__proto__: null,
|
|
ERF_A1: ERF_A1,
|
|
ERF_A2: ERF_A2,
|
|
ERF_A3: ERF_A3,
|
|
ERF_A4: ERF_A4,
|
|
ERF_A5: ERF_A5,
|
|
ERF_P: ERF_P,
|
|
PARALLELIZE_THRESHOLD: PARALLELIZE_THRESHOLD,
|
|
get RowPartitionType () { return RowPartitionType$1; },
|
|
SELU_SCALE: SELU_SCALE,
|
|
SELU_SCALEALPHA: SELU_SCALEALPHA,
|
|
applyActivation: applyActivation$1,
|
|
assertAndGetBroadcastShape: assertAndGetBroadcastShape,
|
|
assertAxesAreInnerMostDims: assertAxesAreInnerMostDims,
|
|
assertParamsConsistent: assertParamsConsistent,
|
|
assignToTypedArray: assignToTypedArray,
|
|
axesAreInnerMostDims: axesAreInnerMostDims,
|
|
calculateShapes: calculateShapes,
|
|
checkEinsumDimSizes: checkEinsumDimSizes,
|
|
checkPadOnDimRoundingMode: checkPadOnDimRoundingMode,
|
|
combineLocations: combineLocations,
|
|
combineRaggedTensorToTensorShapes: combineRaggedTensorToTensorShapes,
|
|
complexWithEvenIndex: complexWithEvenIndex,
|
|
complexWithOddIndex: complexWithOddIndex,
|
|
computeConv2DInfo: computeConv2DInfo,
|
|
computeConv3DInfo: computeConv3DInfo,
|
|
computeDefaultPad: computeDefaultPad,
|
|
computeDilation2DInfo: computeDilation2DInfo,
|
|
computeOptimalWindowSize: computeOptimalWindowSize,
|
|
computeOutAndReduceShapes: computeOutAndReduceShapes,
|
|
computeOutShape: computeOutShape$1,
|
|
computePool2DInfo: computePool2DInfo,
|
|
computePool3DInfo: computePool3DInfo,
|
|
convertConv2DDataFormat: convertConv2DDataFormat,
|
|
decodeEinsumEquation: decodeEinsumEquation,
|
|
eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne,
|
|
expandShapeToKeepDim: expandShapeToKeepDim,
|
|
exponent: exponent,
|
|
exponents: exponents,
|
|
fromStringArrayToUint8: fromStringArrayToUint8,
|
|
fromUint8ToStringArray: fromUint8ToStringArray,
|
|
getAxesPermutation: getAxesPermutation,
|
|
getBroadcastDims: getBroadcastDims$1,
|
|
getComplexWithIndex: getComplexWithIndex,
|
|
getEinsumComputePath: getEinsumComputePath,
|
|
getEinsumPermutation: getEinsumPermutation,
|
|
getFusedBiasGradient: getFusedBiasGradient,
|
|
getFusedDyActivation: getFusedDyActivation,
|
|
getImageCenter: getImageCenter,
|
|
getInnerMostAxes: getInnerMostAxes,
|
|
getPermuted: getPermuted,
|
|
getRaggedRank: getRaggedRank,
|
|
getReductionAxes: getReductionAxes,
|
|
getReshaped: getReshaped,
|
|
getReshapedPermuted: getReshapedPermuted,
|
|
getRowPartitionTypesHelper: getRowPartitionTypesHelper,
|
|
getSliceBeginCoords: getSliceBeginCoords,
|
|
getSliceSize: getSliceSize,
|
|
getSparseFillEmptyRowsIndicesDenseShapeMismatch: getSparseFillEmptyRowsIndicesDenseShapeMismatch,
|
|
getSparseFillEmptyRowsNegativeIndexErrorMessage: getSparseFillEmptyRowsNegativeIndexErrorMessage,
|
|
getSparseFillEmptyRowsOutOfRangeIndexErrorMessage: getSparseFillEmptyRowsOutOfRangeIndexErrorMessage,
|
|
getSparseReshapeEmptyTensorZeroOutputDimErrorMessage: getSparseReshapeEmptyTensorZeroOutputDimErrorMessage,
|
|
getSparseReshapeInputOutputMismatchErrorMessage: getSparseReshapeInputOutputMismatchErrorMessage,
|
|
getSparseReshapeInputOutputMultipleErrorMessage: getSparseReshapeInputOutputMultipleErrorMessage,
|
|
getSparseReshapeMultipleNegativeOneOutputDimErrorMessage: getSparseReshapeMultipleNegativeOneOutputDimErrorMessage,
|
|
getSparseReshapeNegativeOutputDimErrorMessage: getSparseReshapeNegativeOutputDimErrorMessage,
|
|
getSparseSegmentReductionIndicesOutOfRangeErrorMessage: getSparseSegmentReductionIndicesOutOfRangeErrorMessage,
|
|
getSparseSegmentReductionNegativeSegmentIdsErrorMessage: getSparseSegmentReductionNegativeSegmentIdsErrorMessage,
|
|
getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage: getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage,
|
|
getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage: getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage,
|
|
getUndoAxesPermutation: getUndoAxesPermutation,
|
|
isIdentityPermutation: isIdentityPermutation,
|
|
log: log$3,
|
|
mergeRealAndImagArrays: mergeRealAndImagArrays,
|
|
prepareAndValidate: prepareAndValidate,
|
|
prepareSplitSize: prepareSplitSize,
|
|
segment_util: segment_util,
|
|
shouldFuse: shouldFuse,
|
|
slice_util: slice_util,
|
|
splitRealAndImagArrays: splitRealAndImagArrays,
|
|
stridesOrDilationsArePositive: stridesOrDilationsArePositive,
|
|
tupleValuesAreOne: tupleValuesAreOne,
|
|
upcastType: upcastType,
|
|
validateDefaultValueShape: validateDefaultValueShape,
|
|
validateInput: validateInput,
|
|
validateUpdateShape: validateUpdateShape,
|
|
warn: warn
|
|
});
|
|
|
|
|
|
|
|
registerOptimizers();
|
|
|
|
|
|
const contexts = {};
|
|
const WEBGL_ATTRIBUTES = {
|
|
alpha: false,
|
|
antialias: false,
|
|
premultipliedAlpha: false,
|
|
preserveDrawingBuffer: false,
|
|
depth: false,
|
|
stencil: false,
|
|
failIfMajorPerformanceCaveat: true
|
|
};
|
|
function setWebGLContext(webGLVersion, gl) {
|
|
contexts[webGLVersion] = gl;
|
|
}
|
|
function getWebGLContext(webGLVersion, customCanvas) {
|
|
if (!(webGLVersion in contexts) || customCanvas != null) {
|
|
const newCtx = getWebGLRenderingContext(webGLVersion, customCanvas);
|
|
if (newCtx !== null) {
|
|
contexts[webGLVersion] = newCtx;
|
|
}
|
|
else {
|
|
console.log('Could not get context for WebGL version', webGLVersion);
|
|
return null;
|
|
}
|
|
}
|
|
const gl = contexts[webGLVersion];
|
|
if (gl == null || gl.isContextLost()) {
|
|
delete contexts[webGLVersion];
|
|
return getWebGLContext(webGLVersion);
|
|
}
|
|
gl.disable(gl.DEPTH_TEST);
|
|
gl.disable(gl.STENCIL_TEST);
|
|
gl.disable(gl.BLEND);
|
|
gl.disable(gl.DITHER);
|
|
gl.disable(gl.POLYGON_OFFSET_FILL);
|
|
gl.disable(gl.SAMPLE_COVERAGE);
|
|
gl.enable(gl.SCISSOR_TEST);
|
|
gl.enable(gl.CULL_FACE);
|
|
gl.cullFace(gl.BACK);
|
|
return contexts[webGLVersion];
|
|
}
|
|
function createCanvas(webGLVersion) {
|
|
|
|
|
|
if (!env().getBool('IS_SAFARI') && typeof OffscreenCanvas !== 'undefined' &&
|
|
webGLVersion === 2) {
|
|
return new OffscreenCanvas(300, 150);
|
|
}
|
|
else if (typeof document !== 'undefined') {
|
|
return document.createElement('canvas');
|
|
}
|
|
else {
|
|
throw new Error('Cannot create a canvas in this context');
|
|
}
|
|
}
|
|
function getWebGLRenderingContext(webGLVersion, customCanvas) {
|
|
if (webGLVersion !== 1 && webGLVersion !== 2) {
|
|
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
|
|
}
|
|
const canvas = customCanvas == null ? createCanvas(webGLVersion) : customCanvas;
|
|
canvas.addEventListener('webglcontextlost', (ev) => {
|
|
ev.preventDefault();
|
|
delete contexts[webGLVersion];
|
|
}, false);
|
|
if (env().getBool('SOFTWARE_WEBGL_ENABLED')) {
|
|
WEBGL_ATTRIBUTES.failIfMajorPerformanceCaveat = false;
|
|
}
|
|
if (webGLVersion === 1) {
|
|
return (
|
|
|
|
canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
|
|
canvas
|
|
.getContext('experimental-webgl', WEBGL_ATTRIBUTES));
|
|
}
|
|
return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
|
|
}
|
|
|
|
|
|
var PackingScheme;
|
|
(function (PackingScheme) {
|
|
|
|
PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
|
|
|
|
PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
|
|
})(PackingScheme || (PackingScheme = {}));
|
|
var TextureUsage;
|
|
(function (TextureUsage) {
|
|
TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
|
|
TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
|
|
TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
|
|
TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
|
|
})(TextureUsage || (TextureUsage = {}));
|
|
var PhysicalTextureType;
|
|
(function (PhysicalTextureType) {
|
|
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
|
|
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
|
|
PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
|
|
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
|
|
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
|
|
})(PhysicalTextureType || (PhysicalTextureType = {}));
|
|
function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
|
|
return [columns, rows];
|
|
}
|
|
function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
|
|
return matrixSize * channelsPerTexture;
|
|
}
|
|
|
|
function getDenseTexShape(shape) {
|
|
const size = sizeFromShape(shape);
|
|
const texelsNeeded = Math.ceil(size / 4);
|
|
return sizeToSquarishShape(texelsNeeded);
|
|
}
|
|
function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
|
|
return [
|
|
Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
|
|
];
|
|
}
|
|
function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
|
|
const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
return w * h * 4;
|
|
}
|
|
function getTextureConfig(
|
|
|
|
gl, textureHalfFloatExtension) {
|
|
|
|
const glany = gl;
|
|
let internalFormatFloat;
|
|
let internalFormatHalfFloat;
|
|
let internalFormatPackedHalfFloat;
|
|
let internalFormatPackedFloat;
|
|
let textureFormatFloat;
|
|
let downloadTextureFormat;
|
|
let downloadUnpackNumChannels;
|
|
let defaultNumChannels;
|
|
let textureTypeHalfFloat;
|
|
let textureTypeFloat;
|
|
if (env().getNumber('WEBGL_VERSION') === 2) {
|
|
internalFormatFloat = glany.R32F;
|
|
internalFormatHalfFloat = glany.R16F;
|
|
internalFormatPackedHalfFloat = glany.RGBA16F;
|
|
internalFormatPackedFloat = glany.RGBA32F;
|
|
textureFormatFloat = glany.RED;
|
|
downloadUnpackNumChannels = 4;
|
|
defaultNumChannels = 1;
|
|
textureTypeHalfFloat = glany.HALF_FLOAT;
|
|
textureTypeFloat = glany.FLOAT;
|
|
downloadTextureFormat = glany.RGBA8;
|
|
}
|
|
else {
|
|
internalFormatFloat = gl.RGBA;
|
|
internalFormatHalfFloat = gl.RGBA;
|
|
internalFormatPackedHalfFloat = gl.RGBA;
|
|
internalFormatPackedFloat = glany.RGBA;
|
|
textureFormatFloat = gl.RGBA;
|
|
downloadUnpackNumChannels = 4;
|
|
defaultNumChannels = 4;
|
|
textureTypeHalfFloat = textureHalfFloatExtension != null ?
|
|
textureHalfFloatExtension.HALF_FLOAT_OES :
|
|
null;
|
|
textureTypeFloat = gl.FLOAT;
|
|
downloadTextureFormat = gl.RGBA;
|
|
}
|
|
return {
|
|
internalFormatFloat,
|
|
internalFormatHalfFloat,
|
|
internalFormatPackedHalfFloat,
|
|
internalFormatPackedFloat,
|
|
textureFormatFloat,
|
|
downloadTextureFormat,
|
|
downloadUnpackNumChannels,
|
|
defaultNumChannels,
|
|
textureTypeHalfFloat,
|
|
textureTypeFloat
|
|
};
|
|
}
|
|
|
|
|
|
function callAndCheck(gl, func) {
|
|
const returnValue = func();
|
|
if (env().getBool('DEBUG')) {
|
|
checkWebGLError(gl);
|
|
}
|
|
return returnValue;
|
|
}
|
|
function checkWebGLError(gl) {
|
|
const error = gl.getError();
|
|
if (error !== gl.NO_ERROR) {
|
|
throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
|
|
}
|
|
}
|
|
|
|
const MIN_FLOAT16 = 5.96e-8;
|
|
const MAX_FLOAT16 = 65504;
|
|
function canBeRepresented(num) {
|
|
if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
|
|
(MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
function getWebGLErrorMessage(gl, status) {
|
|
switch (status) {
|
|
case gl.NO_ERROR:
|
|
return 'NO_ERROR';
|
|
case gl.INVALID_ENUM:
|
|
return 'INVALID_ENUM';
|
|
case gl.INVALID_VALUE:
|
|
return 'INVALID_VALUE';
|
|
case gl.INVALID_OPERATION:
|
|
return 'INVALID_OPERATION';
|
|
case gl.INVALID_FRAMEBUFFER_OPERATION:
|
|
return 'INVALID_FRAMEBUFFER_OPERATION';
|
|
case gl.OUT_OF_MEMORY:
|
|
return 'OUT_OF_MEMORY';
|
|
case gl.CONTEXT_LOST_WEBGL:
|
|
return 'CONTEXT_LOST_WEBGL';
|
|
default:
|
|
return `Unknown error code ${status}`;
|
|
}
|
|
}
|
|
function getExtensionOrThrow(gl, extensionName) {
|
|
return throwIfNull(gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.');
|
|
}
|
|
function createVertexShader$1(gl, vertexShaderSource) {
|
|
const vertexShader = throwIfNull(gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.');
|
|
callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource));
|
|
callAndCheck(gl, () => gl.compileShader(vertexShader));
|
|
if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
|
|
console.log(gl.getShaderInfoLog(vertexShader));
|
|
throw new Error('Failed to compile vertex shader.');
|
|
}
|
|
return vertexShader;
|
|
}
|
|
function createFragmentShader(gl, fragmentShaderSource) {
|
|
const fragmentShader = throwIfNull(gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.');
|
|
callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource));
|
|
callAndCheck(gl, () => gl.compileShader(fragmentShader));
|
|
if (env().get('ENGINE_COMPILE_ONLY')) {
|
|
return fragmentShader;
|
|
}
|
|
if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
|
|
logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
|
|
throw new Error('Failed to compile fragment shader.');
|
|
}
|
|
return fragmentShader;
|
|
}
|
|
const lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
|
|
function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
|
|
const lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
|
|
if (lineNumberRegexResult == null) {
|
|
console.log(`Couldn't parse line number in error: ${shaderInfoLog}`);
|
|
console.log(shaderSource);
|
|
return;
|
|
}
|
|
const lineNumber = +lineNumberRegexResult[1];
|
|
const shaderLines = shaderSource.split('\n');
|
|
const pad = shaderLines.length.toString().length + 2;
|
|
const linesWithLineNumbers = shaderLines.map((line, lineNumber) => rightPad((lineNumber + 1).toString(), pad) + line);
|
|
let maxLineLength = 0;
|
|
for (let i = 0; i < linesWithLineNumbers.length; i++) {
|
|
maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
|
|
}
|
|
const beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
|
|
const errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
|
|
const afterErrorLines = linesWithLineNumbers.slice(lineNumber);
|
|
console.log(beforeErrorLines.join('\n'));
|
|
console.log(shaderInfoLog.split('\n')[0]);
|
|
console.log(`%c ${rightPad(errorLine[0], maxLineLength)}`, 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
|
|
console.log(afterErrorLines.join('\n'));
|
|
}
|
|
function createProgram(gl) {
|
|
return throwIfNull(gl, () => gl.createProgram(), 'Unable to create WebGLProgram.');
|
|
}
|
|
function linkProgram(gl, program) {
|
|
callAndCheck(gl, () => gl.linkProgram(program));
|
|
if (env().get('ENGINE_COMPILE_ONLY')) {
|
|
return;
|
|
}
|
|
if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
|
|
console.log(gl.getProgramInfoLog(program));
|
|
throw new Error('Failed to link vertex and fragment shaders.');
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
function validateProgram(gl, program) {
|
|
callAndCheck(gl, () => gl.validateProgram(program));
|
|
if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
|
|
console.log(gl.getProgramInfoLog(program));
|
|
throw new Error('Shader program validation failed.');
|
|
}
|
|
}
|
|
function createStaticVertexBuffer(gl, data) {
|
|
const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
|
|
callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW));
|
|
return buffer;
|
|
}
|
|
function createStaticIndexBuffer(gl, data) {
|
|
const buffer = throwIfNull(gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer');
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer));
|
|
callAndCheck(gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW));
|
|
return buffer;
|
|
}
|
|
function createTexture(gl) {
|
|
return throwIfNull(gl, () => gl.createTexture(), 'Unable to create WebGLTexture.');
|
|
}
|
|
function validateTextureSize(width, height) {
|
|
const maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
|
|
if ((width <= 0) || (height <= 0)) {
|
|
const requested = `[${width}x${height}]`;
|
|
throw new Error('Requested texture size ' + requested + ' is invalid.');
|
|
}
|
|
if ((width > maxTextureSize) || (height > maxTextureSize)) {
|
|
const requested = `[${width}x${height}]`;
|
|
const max = `[${maxTextureSize}x${maxTextureSize}]`;
|
|
throw new Error('Requested texture size ' + requested +
|
|
' greater than WebGL maximum on this browser / GPU ' + max + '.');
|
|
}
|
|
}
|
|
function createFramebuffer(gl) {
|
|
return throwIfNull(gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.');
|
|
}
|
|
function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
|
|
const loc = gl.getAttribLocation(program, attribute);
|
|
if (loc === -1) {
|
|
|
|
|
|
return false;
|
|
}
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer));
|
|
callAndCheck(gl, () => gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes));
|
|
callAndCheck(gl, () => gl.enableVertexAttribArray(loc));
|
|
return true;
|
|
}
|
|
function bindTextureUnit(gl, texture, textureUnit) {
|
|
validateTextureUnit(gl, textureUnit);
|
|
callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit));
|
|
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
|
|
}
|
|
function getProgramUniformLocationOrThrow(gl, program, uniformName) {
|
|
return throwIfNull(gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.');
|
|
}
|
|
function getProgramUniformLocation(gl, program, uniformName) {
|
|
return gl.getUniformLocation(program, uniformName);
|
|
}
|
|
function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
|
|
callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit));
|
|
callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit));
|
|
}
|
|
function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
|
|
callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
|
|
callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0));
|
|
}
|
|
function unbindColorTextureFromFramebuffer(gl, framebuffer) {
|
|
callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer));
|
|
callAndCheck(gl, () => gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0));
|
|
}
|
|
function validateFramebuffer(gl) {
|
|
const status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
|
|
if (status !== gl.FRAMEBUFFER_COMPLETE) {
|
|
throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
|
|
}
|
|
}
|
|
function getFramebufferErrorMessage(gl, status) {
|
|
switch (status) {
|
|
case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
|
|
return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
|
|
case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
|
|
return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
|
|
case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
|
|
return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
|
|
case gl.FRAMEBUFFER_UNSUPPORTED:
|
|
return 'FRAMEBUFFER_UNSUPPORTED';
|
|
default:
|
|
return `unknown error ${status}`;
|
|
}
|
|
}
|
|
function throwIfNull(gl, returnTOrNull, failureMessage) {
|
|
const tOrNull = callAndCheck(gl, () => returnTOrNull());
|
|
if (tOrNull == null) {
|
|
throw new Error(failureMessage);
|
|
}
|
|
return tOrNull;
|
|
}
|
|
function validateTextureUnit(gl, textureUnit) {
|
|
const maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
|
|
const glTextureUnit = textureUnit + gl.TEXTURE0;
|
|
if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
|
|
const textureUnitRange = `[gl.TEXTURE0, gl.TEXTURE${maxTextureUnit}]`;
|
|
throw new Error(`textureUnit must be in ${textureUnitRange}.`);
|
|
}
|
|
}
|
|
function getBatchDim(shape, dimsToSkip = 2) {
|
|
return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
|
|
}
|
|
function getRowsCols(shape) {
|
|
if (shape.length === 0) {
|
|
throw Error('Cannot get rows and columns of an empty shape array.');
|
|
}
|
|
return [
|
|
shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
|
|
];
|
|
}
|
|
function getShapeAs3D(shape) {
|
|
let shapeAs3D = [1, 1, 1];
|
|
const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
|
|
if (!isScalar) {
|
|
shapeAs3D =
|
|
[getBatchDim(shape), ...getRowsCols(shape)];
|
|
}
|
|
return shapeAs3D;
|
|
}
|
|
function getTextureShapeFromLogicalShape(logShape, isPacked = false) {
|
|
let maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
|
|
let maxSizeForNarrowTex = env().getNumber('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE');
|
|
if (maxSizeForNarrowTex === Infinity &&
|
|
env().getBool('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE')) {
|
|
maxSizeForNarrowTex = maxTexSize / 2;
|
|
}
|
|
if (isPacked) {
|
|
maxTexSize = maxTexSize * 2;
|
|
maxSizeForNarrowTex = maxSizeForNarrowTex * 2;
|
|
|
|
|
|
|
|
|
|
|
|
logShape = logShape.map((d, i) => i >= logShape.length - 2 ?
|
|
nearestLargerEven(logShape[i]) :
|
|
logShape[i]);
|
|
|
|
|
|
if (logShape.length === 1) {
|
|
logShape = [2, logShape[0]];
|
|
}
|
|
}
|
|
|
|
if (logShape.length !== 2) {
|
|
const squeezeResult = squeezeShape(logShape);
|
|
logShape = squeezeResult.newShape;
|
|
}
|
|
let size = sizeFromShape(logShape);
|
|
let textureShape = null;
|
|
if (logShape.length <= 1 && size <= maxTexSize) {
|
|
textureShape = [1, size];
|
|
}
|
|
else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
|
|
logShape[1] <= maxTexSize) {
|
|
textureShape = logShape;
|
|
}
|
|
else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
|
|
logShape[2] <= maxTexSize) {
|
|
textureShape = [logShape[0] * logShape[1], logShape[2]];
|
|
}
|
|
else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
|
|
logShape[1] * logShape[2] <= maxTexSize) {
|
|
textureShape = [logShape[0], logShape[1] * logShape[2]];
|
|
}
|
|
else if (logShape.length === 4 &&
|
|
logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
|
|
logShape[3] <= maxTexSize) {
|
|
textureShape = [logShape[0] * logShape[1] * logShape[2], logShape[3]];
|
|
}
|
|
else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
|
|
logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
|
|
textureShape = [logShape[0], logShape[1] * logShape[2] * logShape[3]];
|
|
}
|
|
|
|
|
|
const isLongNarrowTex = textureShape != null &&
|
|
Math.max(...textureShape) > maxSizeForNarrowTex &&
|
|
Math.min(...textureShape) <= (isPacked ? 2 : 1) &&
|
|
Math.min(...textureShape) > 0;
|
|
if (textureShape == null || isLongNarrowTex) {
|
|
if (isPacked) {
|
|
|
|
|
|
|
|
|
|
|
|
const batchDim = getBatchDim(logShape);
|
|
let rows = 2, cols = 2;
|
|
if (logShape.length) {
|
|
[rows, cols] = getRowsCols(logShape);
|
|
}
|
|
size = batchDim * (rows / 2) * (cols / 2);
|
|
textureShape =
|
|
sizeToSquarishShape(size).map(d => d * 2);
|
|
}
|
|
else {
|
|
textureShape = sizeToSquarishShape(size);
|
|
}
|
|
}
|
|
return textureShape;
|
|
}
|
|
function isEven(n) {
|
|
return n % 2 === 0;
|
|
}
|
|
|
|
function isReshapeFree(shape1, shape2) {
|
|
shape1 = shape1.slice(-2);
|
|
shape2 = shape2.slice(-2);
|
|
if (arraysEqual(shape1, shape2)) {
|
|
return true;
|
|
}
|
|
if (!shape1.length || !shape2.length) {
|
|
return true;
|
|
}
|
|
if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
|
|
shape2[1] === 0) {
|
|
return true;
|
|
}
|
|
if (shape1.length !== shape2.length) {
|
|
const shape1Cols = shape1[shape1.length - 1];
|
|
const shape2Cols = shape2[shape2.length - 1];
|
|
if (shape1Cols === shape2Cols) {
|
|
return true;
|
|
}
|
|
if (isEven(shape1Cols) && isEven(shape2Cols) &&
|
|
(shape1[0] === 1 || shape2[0] === 1)) {
|
|
return true;
|
|
}
|
|
}
|
|
return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
|
|
}
|
|
|
|
|
|
|
|
let MAX_TEXTURE_SIZE;
|
|
let MAX_TEXTURES_IN_SHADER;
|
|
function getWebGLMaxTextureSize(webGLVersion) {
|
|
if (MAX_TEXTURE_SIZE == null) {
|
|
const gl = getWebGLContext(webGLVersion);
|
|
MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
|
|
}
|
|
return MAX_TEXTURE_SIZE;
|
|
}
|
|
function getMaxTexturesInShader(webGLVersion) {
|
|
if (MAX_TEXTURES_IN_SHADER == null) {
|
|
const gl = getWebGLContext(webGLVersion);
|
|
MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
|
|
}
|
|
|
|
return Math.min(16, MAX_TEXTURES_IN_SHADER);
|
|
}
|
|
function getWebGLDisjointQueryTimerVersion(webGLVersion) {
|
|
if (webGLVersion === 0) {
|
|
return 0;
|
|
}
|
|
let queryTimerVersion;
|
|
const gl = getWebGLContext(webGLVersion);
|
|
if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
|
|
webGLVersion === 2) {
|
|
queryTimerVersion = 2;
|
|
}
|
|
else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
|
|
queryTimerVersion = 1;
|
|
}
|
|
else {
|
|
queryTimerVersion = 0;
|
|
}
|
|
return queryTimerVersion;
|
|
}
|
|
function hasExtension(gl, extensionName) {
|
|
const ext = gl.getExtension(extensionName);
|
|
return ext != null;
|
|
}
|
|
function isWebGLVersionEnabled(webGLVersion) {
|
|
try {
|
|
const gl = getWebGLContext(webGLVersion);
|
|
if (gl != null) {
|
|
return true;
|
|
}
|
|
}
|
|
catch (e) {
|
|
console.log('Error when getting WebGL context: ', e);
|
|
return false;
|
|
}
|
|
return false;
|
|
}
|
|
function isCapableOfRenderingToFloatTexture(webGLVersion) {
|
|
if (webGLVersion === 0) {
|
|
return false;
|
|
}
|
|
const gl = getWebGLContext(webGLVersion);
|
|
if (webGLVersion === 1) {
|
|
if (!hasExtension(gl, 'OES_texture_float')) {
|
|
return false;
|
|
}
|
|
}
|
|
else {
|
|
if (!hasExtension(gl, 'EXT_color_buffer_float')) {
|
|
return false;
|
|
}
|
|
}
|
|
const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
|
|
return isFrameBufferComplete;
|
|
}
|
|
|
|
function isDownloadFloatTextureEnabled(webGLVersion) {
|
|
if (webGLVersion === 0) {
|
|
return false;
|
|
}
|
|
const gl = getWebGLContext(webGLVersion);
|
|
if (webGLVersion === 1) {
|
|
if (!hasExtension(gl, 'OES_texture_float')) {
|
|
return false;
|
|
}
|
|
if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
|
|
return false;
|
|
}
|
|
}
|
|
else {
|
|
if (hasExtension(gl, 'EXT_color_buffer_float')) {
|
|
return createFloatTextureAndBindToFramebuffer(gl);
|
|
}
|
|
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
|
|
if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
|
|
const textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
|
|
return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
|
|
}
|
|
return false;
|
|
}
|
|
const isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
|
|
return isFrameBufferComplete;
|
|
}
|
|
function createFloatTextureAndBindToFramebuffer(gl) {
|
|
const texConfig = getTextureConfig(gl);
|
|
const texture = gl.createTexture();
|
|
gl.bindTexture(gl.TEXTURE_2D, texture);
|
|
const width = 1;
|
|
const height = 1;
|
|
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
|
|
const frameBuffer = gl.createFramebuffer();
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
|
|
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
|
|
const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
|
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
gl.deleteTexture(texture);
|
|
gl.deleteFramebuffer(frameBuffer);
|
|
return isFrameBufferComplete;
|
|
}
|
|
function createHalfFloatTextureAndBindToFramebuffer(
|
|
|
|
gl, textureHalfFloatExtension) {
|
|
const texConfig = getTextureConfig(gl, textureHalfFloatExtension);
|
|
const texture = gl.createTexture();
|
|
gl.bindTexture(gl.TEXTURE_2D, texture);
|
|
const width = 1;
|
|
const height = 1;
|
|
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
|
|
const frameBuffer = gl.createFramebuffer();
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
|
|
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
|
|
const isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
|
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
gl.deleteTexture(texture);
|
|
gl.deleteFramebuffer(frameBuffer);
|
|
return isFrameBufferComplete;
|
|
}
|
|
function isWebGLFenceEnabled(webGLVersion) {
|
|
if (webGLVersion !== 2) {
|
|
return false;
|
|
}
|
|
const gl = getWebGLContext(webGLVersion);
|
|
|
|
const isEnabled = gl.fenceSync != null;
|
|
return isEnabled;
|
|
}
|
|
function assertNotComplex$1(tensor, opName) {
|
|
if (!Array.isArray(tensor)) {
|
|
tensor = [tensor];
|
|
}
|
|
tensor.forEach(t => {
|
|
if (t != null) {
|
|
assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors ` +
|
|
'in the WebGL backend.');
|
|
}
|
|
});
|
|
}
|
|
|
|
|
|
const ENV = env();
|
|
|
|
|
|
ENV.registerFlag('HAS_WEBGL', () => ENV.getNumber('WEBGL_VERSION') > 0);
|
|
|
|
ENV.registerFlag('WEBGL_VERSION', () => {
|
|
if (isWebGLVersionEnabled(2)) {
|
|
return 2;
|
|
}
|
|
else if (isWebGLVersionEnabled(1)) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
});
|
|
|
|
ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', () => false);
|
|
ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', () => ENV.get('WEBGL_VERSION') === 2);
|
|
|
|
ENV.registerFlag('WEBGL_CPU_FORWARD', () => true);
|
|
|
|
ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', () => false);
|
|
|
|
ENV.registerFlag('WEBGL_PACK', () => ENV.getBool('HAS_WEBGL'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_NORMALIZATION', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_CLIP', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_REDUCE', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_LAZILY_UNPACK', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_CONV_IM2COL', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_PACK_CONV2DTRANSPOSE', () => ENV.getBool('WEBGL_PACK'));
|
|
|
|
ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', () => getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION')));
|
|
|
|
ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', () => getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION')));
|
|
|
|
ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', () => {
|
|
const webGLVersion = ENV.getNumber('WEBGL_VERSION');
|
|
if (webGLVersion === 0) {
|
|
return 0;
|
|
}
|
|
return getWebGLDisjointQueryTimerVersion(webGLVersion);
|
|
});
|
|
|
|
ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
|
|
!isMobile());
|
|
|
|
ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', () => isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION')));
|
|
|
|
ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', () => {
|
|
return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ?
|
|
false :
|
|
ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
|
|
});
|
|
|
|
ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', () => isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION')));
|
|
|
|
ENV.registerFlag('WEBGL_FENCE_API_ENABLED', () => isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION')));
|
|
|
|
ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', () => {
|
|
|
|
|
|
|
|
|
|
const useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
|
|
return useUniforms ? 4 : 0;
|
|
});
|
|
|
|
ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', () => {
|
|
return -1;
|
|
}, threshold => {
|
|
if (!(typeof threshold === 'number')) {
|
|
throw new Error('WEBGL_DELETE_TEXTURE_THRESHOLD must be a number but ' +
|
|
`got ${threshold}.`);
|
|
}
|
|
if (threshold < 0 && threshold !== -1) {
|
|
throw new Error(`WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never ` +
|
|
`delete) or at least 0, but got ${threshold}.`);
|
|
}
|
|
});
|
|
|
|
ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', () => {
|
|
return isMobile() ? 1 : -1;
|
|
}, threshold => {
|
|
if (!(typeof threshold === 'number')) {
|
|
throw new Error('WEBGL_FLUSH_THRESHOLD must be a number but got ' +
|
|
`${threshold}.`);
|
|
}
|
|
if (threshold < 0 && threshold !== -1) {
|
|
throw new Error(`WEBGL_FLUSH_THRESHOLD must be -1 (indicating never ` +
|
|
`manual flush) or at least 0, but got ${threshold}.`);
|
|
}
|
|
});
|
|
|
|
ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', () => 128);
|
|
|
|
ENV.registerFlag('WEBGL_USE_SHAPES_UNIFORMS', () => false);
|
|
|
|
ENV.registerFlag('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD', () => 100000);
|
|
|
|
ENV.registerFlag('TOPK_K_CPU_HANDOFF_THRESHOLD', () => 128);
|
|
|
|
ENV.registerFlag('WEBGL_EXP_CONV', () => false);
|
|
|
|
ENV.registerFlag('SOFTWARE_WEBGL_ENABLED', () => ENV.getBool('IS_TEST'));
|
|
|
|
ENV.registerFlag('WEBGL_MAX_SIZE_FOR_NARROW_TEXTURE', () => Infinity);
|
|
|
|
ENV.registerFlag('WEBGL_AUTO_SQUARIFY_NARROW_TEXTURE_SHAPE', () => false);
|
|
|
|
ENV.registerFlag('WEBGL2_ISNAN_CUSTOM', () => false);
|
|
|
|
ENV.registerFlag('ENGINE_COMPILE_ONLY', () => false);
|
|
|
|
|
|
function getGlslDifferences() {
|
|
let version;
|
|
let attribute;
|
|
let varyingVs;
|
|
let varyingFs;
|
|
let texture2D;
|
|
let output;
|
|
let defineOutput;
|
|
let defineSpecialNaN;
|
|
let defineSpecialInf;
|
|
let defineRound;
|
|
if (env().getNumber('WEBGL_VERSION') === 2) {
|
|
version = '#version 300 es';
|
|
attribute = 'in';
|
|
varyingVs = 'out';
|
|
varyingFs = 'in';
|
|
texture2D = 'texture';
|
|
output = 'outputColor';
|
|
defineOutput = 'out vec4 outputColor;';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
defineSpecialNaN = env().getBool('WEBGL2_ISNAN_CUSTOM') ? `
|
|
bool isnan_custom(float val) {
|
|
uint floatToUint = floatBitsToUint(val);
|
|
return (floatToUint & 0x7fffffffu) > 0x7f800000u;
|
|
}
|
|
|
|
bvec4 isnan_custom(vec4 val) {
|
|
return bvec4(isnan_custom(val.x),
|
|
isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));
|
|
}
|
|
|
|
#define isnan(value) isnan_custom(value)
|
|
` :
|
|
'';
|
|
|
|
|
|
defineSpecialInf = ``;
|
|
defineRound = `
|
|
#define round(value) newRound(value)
|
|
int newRound(float value) {
|
|
return int(floor(value + 0.5));
|
|
}
|
|
|
|
ivec4 newRound(vec4 value) {
|
|
return ivec4(floor(value + vec4(0.5)));
|
|
}
|
|
`;
|
|
}
|
|
else {
|
|
version = '';
|
|
attribute = 'attribute';
|
|
varyingVs = 'varying';
|
|
varyingFs = 'varying';
|
|
texture2D = 'texture2D';
|
|
output = 'gl_FragColor';
|
|
defineOutput = '';
|
|
|
|
defineSpecialNaN = `
|
|
#define isnan(value) isnan_custom(value)
|
|
bool isnan_custom(float val) {
|
|
return (val > 0. || val < 1. || val == 0.) ? false : true;
|
|
}
|
|
bvec4 isnan_custom(vec4 val) {
|
|
return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));
|
|
}
|
|
`;
|
|
defineSpecialInf = `
|
|
uniform float INFINITY;
|
|
|
|
bool isinf(float val) {
|
|
return abs(val) == INFINITY;
|
|
}
|
|
bvec4 isinf(vec4 val) {
|
|
return equal(abs(val), vec4(INFINITY));
|
|
}
|
|
`;
|
|
defineRound = `
|
|
int round(float value) {
|
|
return int(floor(value + 0.5));
|
|
}
|
|
|
|
ivec4 round(vec4 value) {
|
|
return ivec4(floor(value + vec4(0.5)));
|
|
}
|
|
`;
|
|
}
|
|
return {
|
|
version,
|
|
attribute,
|
|
varyingVs,
|
|
varyingFs,
|
|
texture2D,
|
|
output,
|
|
defineOutput,
|
|
defineSpecialNaN,
|
|
defineSpecialInf,
|
|
defineRound
|
|
};
|
|
}
|
|
|
|
|
|
|
|
function getLogicalCoordinatesFromFlatIndex(coords, shape, index = 'index') {
|
|
const strides = computeStrides(shape);
|
|
return strides
|
|
.map((stride, i) => {
|
|
const line1 = `int ${coords[i]} = ${index} / ${stride}`;
|
|
const line2 = i === strides.length - 1 ?
|
|
`int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}` :
|
|
`index -= ${coords[i]} * ${stride}`;
|
|
return `${line1}; ${line2};`;
|
|
})
|
|
.join('');
|
|
}
|
|
function getOutputLogicalCoordinatesFromFlatIndexByUniform(coords, shape, index = 'index') {
|
|
const strides = computeStrides(shape);
|
|
return strides
|
|
.map((_, i) => {
|
|
const line1 = `int ${coords[i]} = ${index} / outShapeStrides[${i}]`;
|
|
const line2 = i === strides.length - 1 ?
|
|
`int ${coords[i + 1]} = ${index} - ${coords[i]} * outShapeStrides[${i}]` :
|
|
`index -= ${coords[i]} * outShapeStrides[${i}]`;
|
|
return `${line1}; ${line2};`;
|
|
})
|
|
.join('');
|
|
}
|
|
|
|
function symbolicallyComputeStrides(indicesArr, variableName) {
|
|
const numCoords = indicesArr.length;
|
|
const shape = indicesArr.map(d => `${variableName}[${d}]`);
|
|
const strides = new Array(numCoords - 1);
|
|
strides[numCoords - 2] = shape[numCoords - 1];
|
|
for (let i = numCoords - 3; i >= 0; --i) {
|
|
strides[i] = `(${strides[i + 1]} * ${shape[i + 1]})`;
|
|
}
|
|
return strides;
|
|
}
|
|
function getLogicalCoordinatesFromFlatIndexByUniform(coords, variableName, index = 'index') {
|
|
const indicesArray = coords.map((_, i) => i);
|
|
const strides = symbolicallyComputeStrides(indicesArray, variableName);
|
|
return strides
|
|
.map((_, i) => {
|
|
const line1 = `int ${coords[i]} = ${index} / ${strides[i]}`;
|
|
const line2 = i === strides.length - 1 ?
|
|
`int ${coords[i + 1]} = ${index} - ${coords[i]} * ${strides[i]}` :
|
|
`index -= ${coords[i]} * ${strides[i]}`;
|
|
return `${line1}; ${line2};`;
|
|
})
|
|
.join('');
|
|
}
|
|
|
|
function getFlatIndexFrom3D(shape) {
|
|
const strides = computeStrides(shape).map(d => d.toString());
|
|
return `
|
|
int getFlatIndex(ivec3 coords) {
|
|
return coords.x * ${strides[0]} + coords.y * ${strides[1]} + coords.z;
|
|
}
|
|
`;
|
|
}
|
|
function getFlatIndexFrom3DOutput() {
|
|
return `
|
|
int getFlatIndex(ivec3 coords) {
|
|
return coords.x * outShapeStrides[0] + coords.y * outShapeStrides[1] + coords.z;
|
|
}
|
|
`;
|
|
}
|
|
const ENCODE_FLOAT_SNIPPET = `
|
|
const float FLOAT_MAX = 1.70141184e38;
|
|
const float FLOAT_MIN = 1.17549435e-38;
|
|
|
|
lowp vec4 encode_float(highp float v) {
|
|
if (isnan(v)) {
|
|
return vec4(255, 255, 255, 255);
|
|
}
|
|
|
|
highp float av = abs(v);
|
|
|
|
if(av < FLOAT_MIN) {
|
|
return vec4(0.0, 0.0, 0.0, 0.0);
|
|
} else if(v > FLOAT_MAX) {
|
|
return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
|
|
} else if(v < -FLOAT_MAX) {
|
|
return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
|
|
}
|
|
|
|
highp vec4 c = vec4(0,0,0,0);
|
|
|
|
highp float e = floor(log2(av));
|
|
highp float m = exp2(fract(log2(av))) - 1.0;
|
|
|
|
c[2] = floor(128.0 * m);
|
|
m -= c[2] / 128.0;
|
|
c[1] = floor(32768.0 * m);
|
|
m -= c[1] / 32768.0;
|
|
c[0] = floor(8388608.0 * m);
|
|
|
|
highp float ebias = e + 127.0;
|
|
c[3] = floor(ebias / 2.0);
|
|
ebias -= c[3] * 2.0;
|
|
c[2] += floor(ebias) * 128.0;
|
|
|
|
c[3] += 128.0 * step(0.0, -v);
|
|
|
|
return c / 255.0;
|
|
}
|
|
`;
|
|
|
|
|
|
|
|
|
|
const { getBroadcastDims } = backend_util;
|
|
function makeShader(inputsInfo, outputShape, program) {
|
|
const prefixSnippets = [];
|
|
inputsInfo.forEach(x => {
|
|
const size = sizeFromShape(x.shapeInfo.logicalShape);
|
|
|
|
if (x.shapeInfo.isUniform) {
|
|
prefixSnippets.push(`uniform float ${x.name}${size > 1 ? `[${size}]` : ''};`);
|
|
}
|
|
else {
|
|
prefixSnippets.push(`uniform sampler2D ${x.name};`);
|
|
prefixSnippets.push(`uniform int offset${x.name};`);
|
|
}
|
|
if (program.enableShapeUniforms) {
|
|
const { uniformShape } = getUniformInfoFromShape(program.packedInputs, x.shapeInfo.logicalShape, x.shapeInfo.texShape);
|
|
switch (uniformShape.length) {
|
|
case 1:
|
|
prefixSnippets.push(`uniform int ${x.name}Shape;`);
|
|
break;
|
|
case 2:
|
|
prefixSnippets.push(`uniform ivec2 ${x.name}Shape;`);
|
|
break;
|
|
case 3:
|
|
prefixSnippets.push(`uniform ivec3 ${x.name}Shape;`);
|
|
break;
|
|
case 4:
|
|
prefixSnippets.push(`uniform ivec4 ${x.name}Shape;`);
|
|
break;
|
|
}
|
|
prefixSnippets.push(`uniform ivec2 ${x.name}TexShape;`);
|
|
}
|
|
});
|
|
if (program.enableShapeUniforms) {
|
|
switch (outputShape.logicalShape.length) {
|
|
case 1:
|
|
prefixSnippets.push(`uniform int outShape;`);
|
|
break;
|
|
case 2:
|
|
prefixSnippets.push(`uniform ivec2 outShape;`);
|
|
prefixSnippets.push(`uniform int outShapeStrides;`);
|
|
break;
|
|
case 3:
|
|
prefixSnippets.push(`uniform ivec3 outShape;`);
|
|
prefixSnippets.push(`uniform ivec2 outShapeStrides;`);
|
|
break;
|
|
case 4:
|
|
prefixSnippets.push(`uniform ivec4 outShape;`);
|
|
prefixSnippets.push(`uniform ivec3 outShapeStrides;`);
|
|
break;
|
|
}
|
|
prefixSnippets.push(`uniform ivec2 outTexShape;`);
|
|
}
|
|
if (program.customUniforms) {
|
|
program.customUniforms.forEach((d) => {
|
|
prefixSnippets.push(`uniform ${d.type} ${d.name}${d.arrayIndex ? `[${d.arrayIndex}]` : ''};`);
|
|
});
|
|
}
|
|
const inputPrefixSnippet = prefixSnippets.join('\n');
|
|
const inputSamplingSnippet = inputsInfo
|
|
.map(x => getInputSamplingSnippet(x, outputShape, program.packedInputs, program.enableShapeUniforms))
|
|
.join('\n');
|
|
const outTexShape = outputShape.texShape;
|
|
const glsl = getGlslDifferences();
|
|
const floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
|
|
let outputSamplingSnippet;
|
|
let floatTextureSetOutputSnippet;
|
|
let shaderPrefix = getShaderPrefix(glsl);
|
|
if (outputShape.isPacked) {
|
|
outputSamplingSnippet = getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
|
|
floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
|
|
}
|
|
else {
|
|
outputSamplingSnippet = getOutputSamplingSnippet(outputShape.logicalShape, outTexShape, program.enableShapeUniforms);
|
|
floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
|
|
}
|
|
if (program.packedInputs) {
|
|
shaderPrefix += SHADER_PACKED_PREFIX;
|
|
}
|
|
const source = [
|
|
shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
|
|
inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet,
|
|
program.userCode
|
|
].join('\n');
|
|
return source;
|
|
}
|
|
function getSamplerFromInInfo(inInfo, enableShapeUniforms = false) {
|
|
const shape = inInfo.shapeInfo.logicalShape;
|
|
switch (shape.length) {
|
|
case 0:
|
|
return getSamplerScalar(inInfo, enableShapeUniforms);
|
|
case 1:
|
|
return getSampler1D(inInfo, enableShapeUniforms);
|
|
case 2:
|
|
return getSampler2D(inInfo, enableShapeUniforms);
|
|
case 3:
|
|
return getSampler3D(inInfo, enableShapeUniforms);
|
|
case 4:
|
|
return getSampler4D(inInfo, enableShapeUniforms);
|
|
case 5:
|
|
return getSampler5D(inInfo);
|
|
case 6:
|
|
return getSampler6D(inInfo);
|
|
default:
|
|
throw new Error(`${shape.length}-D input sampling` +
|
|
` is not yet supported`);
|
|
}
|
|
}
|
|
function getPackedSamplerFromInInfo(inInfo, enableShapeUniforms) {
|
|
const shape = inInfo.shapeInfo.logicalShape;
|
|
switch (shape.length) {
|
|
case 0:
|
|
return getPackedSamplerScalar(inInfo);
|
|
case 1:
|
|
return getPackedSampler1D(inInfo, enableShapeUniforms);
|
|
case 2:
|
|
return getPackedSampler2D(inInfo, enableShapeUniforms);
|
|
case 3:
|
|
return getPackedSampler3D(inInfo, enableShapeUniforms);
|
|
default:
|
|
return getPackedSamplerND(inInfo, enableShapeUniforms);
|
|
}
|
|
}
|
|
function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures = false, enableShapeUniforms) {
|
|
let res = '';
|
|
if (usesPackedTextures) {
|
|
res += getPackedSamplerFromInInfo(inInfo, enableShapeUniforms);
|
|
}
|
|
else {
|
|
res += getSamplerFromInInfo(inInfo, enableShapeUniforms);
|
|
}
|
|
const inShape = inInfo.shapeInfo.logicalShape;
|
|
const outShape = outShapeInfo.logicalShape;
|
|
if (inShape.length <= outShape.length) {
|
|
if (usesPackedTextures) {
|
|
res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
|
|
}
|
|
else {
|
|
res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
function getPackedOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
|
|
switch (outShape.length) {
|
|
case 0:
|
|
return getOutputScalarCoords();
|
|
case 1:
|
|
return getOutputPacked1DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
case 2:
|
|
return getOutputPacked2DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
case 3:
|
|
return getOutputPacked3DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
default:
|
|
return getOutputPackedNDCoords(outShape, outTexShape, enableShapeUniforms);
|
|
}
|
|
}
|
|
function getOutputSamplingSnippet(outShape, outTexShape, enableShapeUniforms) {
|
|
switch (outShape.length) {
|
|
case 0:
|
|
return getOutputScalarCoords();
|
|
case 1:
|
|
return getOutput1DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
case 2:
|
|
return getOutput2DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
case 3:
|
|
return getOutput3DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
case 4:
|
|
return getOutput4DCoords(outShape, outTexShape, enableShapeUniforms);
|
|
case 5:
|
|
return getOutput5DCoords(outShape, outTexShape);
|
|
case 6:
|
|
return getOutput6DCoords(outShape, outTexShape);
|
|
default:
|
|
throw new Error(`${outShape.length}-D output sampling is not yet supported`);
|
|
}
|
|
}
|
|
function getFloatTextureSampleSnippet(glsl) {
|
|
return `
|
|
float sampleTexture(sampler2D textureSampler, vec2 uv) {
|
|
return ${glsl.texture2D}(textureSampler, uv).r;
|
|
}
|
|
`;
|
|
}
|
|
function getFloatTextureSetRSnippet(glsl) {
|
|
return `
|
|
void setOutput(float val) {
|
|
${glsl.output} = vec4(val, 0, 0, 0);
|
|
}
|
|
`;
|
|
}
|
|
function getFloatTextureSetRGBASnippet(glsl) {
|
|
return `
|
|
void setOutput(vec4 val) {
|
|
${glsl.output} = val;
|
|
}
|
|
`;
|
|
}
|
|
function getShaderPrefix(glsl) {
|
|
const SHADER_PREFIX = `${glsl.version}
|
|
precision highp float;
|
|
precision highp int;
|
|
precision highp sampler2D;
|
|
${glsl.varyingFs} vec2 resultUV;
|
|
${glsl.defineOutput}
|
|
const vec2 halfCR = vec2(0.5, 0.5);
|
|
|
|
struct ivec5
|
|
{
|
|
int x;
|
|
int y;
|
|
int z;
|
|
int w;
|
|
int u;
|
|
};
|
|
|
|
struct ivec6
|
|
{
|
|
int x;
|
|
int y;
|
|
int z;
|
|
int w;
|
|
int u;
|
|
int v;
|
|
};
|
|
|
|
uniform float NAN;
|
|
${glsl.defineSpecialNaN}
|
|
${glsl.defineSpecialInf}
|
|
${glsl.defineRound}
|
|
|
|
int imod(int x, int y) {
|
|
return x - y * (x / y);
|
|
}
|
|
|
|
int idiv(int a, int b, float sign) {
|
|
int res = a / b;
|
|
int mod = imod(a, b);
|
|
if (sign < 0. && mod != 0) {
|
|
res -= 1;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
|
|
|
|
#define HASHSCALE1 443.8975
|
|
float random(float seed){
|
|
vec2 p = resultUV * seed;
|
|
vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);
|
|
p3 += dot(p3, p3.yzx + 19.19);
|
|
return fract((p3.x + p3.y) * p3.z);
|
|
}
|
|
|
|
${SAMPLE_1D_SNIPPET}
|
|
${SAMPLE_2D_SNIPPET}
|
|
${SAMPLE_3D_SNIPPET}
|
|
`;
|
|
return SHADER_PREFIX;
|
|
}
|
|
const SAMPLE_1D_SNIPPET = `
|
|
vec2 uvFromFlat(int texNumR, int texNumC, int index) {
|
|
int texR = index / texNumC;
|
|
int texC = index - texR * texNumC;
|
|
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
|
|
}
|
|
vec2 packedUVfrom1D(int texNumR, int texNumC, int index) {
|
|
int texelIndex = index / 2;
|
|
int texR = texelIndex / texNumC;
|
|
int texC = texelIndex - texR * texNumC;
|
|
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
|
|
}
|
|
`;
|
|
const SAMPLE_2D_SNIPPET = `
|
|
vec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,
|
|
int texNumC, int row, int col) {
|
|
int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);
|
|
int texR = texelIndex / texNumC;
|
|
int texC = texelIndex - texR * texNumC;
|
|
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
|
|
}
|
|
`;
|
|
const SAMPLE_3D_SNIPPET = `
|
|
vec2 packedUVfrom3D(int texNumR, int texNumC,
|
|
int texelsInBatch, int texelsInLogicalRow, int b,
|
|
int row, int col) {
|
|
int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);
|
|
int texR = index / texNumC;
|
|
int texC = index - texR * texNumC;
|
|
return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);
|
|
}
|
|
`;
|
|
const SHADER_PACKED_PREFIX = `
|
|
float getChannel(vec4 frag, vec2 innerDims) {
|
|
vec2 modCoord = mod(innerDims, 2.);
|
|
return modCoord.x == 0. ?
|
|
(modCoord.y == 0. ? frag.r : frag.g) :
|
|
(modCoord.y == 0. ? frag.b : frag.a);
|
|
}
|
|
float getChannel(vec4 frag, int dim) {
|
|
float modCoord = mod(float(dim), 2.);
|
|
return modCoord == 0. ? frag.r : frag.g;
|
|
}
|
|
`;
|
|
function getOutputScalarCoords() {
|
|
return `
|
|
int getOutputCoords() {
|
|
return 0;
|
|
}
|
|
`;
|
|
}
|
|
function getOutputPacked1DCoords(shape, texShape, enableShapeUniforms) {
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
if (packedTexShape[0] === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
int getOutputCoords() {
|
|
return 2 * int(resultUV.x * ceil(float(outTexShape[1]) / 2.0));
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
int getOutputCoords() {
|
|
return 2 * int(resultUV.x * ${packedTexShape[1]}.0);
|
|
}
|
|
`;
|
|
}
|
|
if (packedTexShape[1] === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
int getOutputCoords() {
|
|
return 2 * int(resultUV.y * ceil(float(outTexShape[0]) / 2.0));
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
int getOutputCoords() {
|
|
return 2 * int(resultUV.y * ${packedTexShape[0]}.0);
|
|
}
|
|
`;
|
|
}
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
int getOutputCoords() {
|
|
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(packedTexShape[0], packedTexShape[1]));
|
|
return 2 * (resTexRC.x * packedTexShape[1] + resTexRC.y);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
int getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
|
|
return 2 * (resTexRC.x * ${packedTexShape[1]} + resTexRC.y);
|
|
}
|
|
`;
|
|
}
|
|
function getOutput1DCoords(shape, texShape, enableShapeUniforms) {
|
|
if (texShape[0] === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
int getOutputCoords() {
|
|
return int(resultUV.x * float(outTexShape[1]));
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
int getOutputCoords() {
|
|
return int(resultUV.x * ${texShape[1]}.0);
|
|
}
|
|
`;
|
|
}
|
|
if (texShape[1] === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
int getOutputCoords() {
|
|
return int(resultUV.y * float(outTexShape[0]));
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
int getOutputCoords() {
|
|
return int(resultUV.y * ${texShape[0]}.0);
|
|
}
|
|
`;
|
|
}
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
int getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(outTexShape[0], outTexShape[1]));
|
|
return resTexRC.x * outTexShape[1] + resTexRC.y;
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
int getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
return resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
}
|
|
`;
|
|
}
|
|
function getOutputPacked3DCoords(shape, texShape, enableShapeUniforms) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec3 getOutputCoords() {
|
|
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
|
|
int texelsInLogicalRow = int(ceil(float(outShape[2]) / 2.0));
|
|
int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[1]) / 2.0));
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(packedTexShape[0], packedTexShape[1]));
|
|
int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
|
|
|
|
int b = index / texelsInBatch;
|
|
index -= b * texelsInBatch;
|
|
|
|
int r = 2 * (index / texelsInLogicalRow);
|
|
int c = imod(index, texelsInLogicalRow) * 2;
|
|
|
|
return ivec3(b, r, c);
|
|
}
|
|
`;
|
|
}
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
const texelsInLogicalRow = Math.ceil(shape[2] / 2);
|
|
const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
|
|
return `
|
|
ivec3 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
|
|
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
|
|
|
|
int b = index / ${texelsInBatch};
|
|
index -= b * ${texelsInBatch};
|
|
|
|
int r = 2 * (index / ${texelsInLogicalRow});
|
|
int c = imod(index, ${texelsInLogicalRow}) * 2;
|
|
|
|
return ivec3(b, r, c);
|
|
}
|
|
`;
|
|
}
|
|
function getOutput3DCoords(shape, texShape, enableShapeUniforms) {
|
|
if (enableShapeUniforms) {
|
|
const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], shape);
|
|
return `
|
|
ivec3 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(outTexShape[0], outTexShape[1]));
|
|
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
|
|
${coordsFromIndexSnippet}
|
|
return ivec3(r, c, d);
|
|
}
|
|
`;
|
|
}
|
|
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
|
|
return `
|
|
ivec3 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
${coordsFromIndexSnippet}
|
|
return ivec3(r, c, d);
|
|
}
|
|
`;
|
|
}
|
|
function getOutputPackedNDCoords(shape, texShape, enableShapeUniforms) {
|
|
if (enableShapeUniforms) {
|
|
|
|
return `
|
|
ivec4 getOutputCoords() {
|
|
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(packedTexShape[0], packedTexShape[1]));
|
|
int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
|
|
|
|
int texelsInLogicalRow = int(ceil(float(outShape[3]) / 2.0));
|
|
int texelsInBatch = texelsInLogicalRow * int(ceil(float(outShape[2]) / 2.0));
|
|
int texelsInBatchN = texelsInBatch * outShape[1];
|
|
|
|
int b2 = index / texelsInBatchN;
|
|
index -= b2 * texelsInBatchN;
|
|
|
|
int b = index / texelsInBatch;
|
|
index -= b * texelsInBatch;
|
|
|
|
int r = 2 * (index / texelsInLogicalRow);
|
|
int c = imod(index, texelsInLogicalRow) * 2;
|
|
|
|
return ivec4(b2, b, r, c);
|
|
}
|
|
`;
|
|
}
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
const texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
|
|
const texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
|
|
let texelsInBatchN = texelsInBatch;
|
|
let batches = ``;
|
|
let coords = 'b, r, c';
|
|
for (let b = 2; b < shape.length - 1; b++) {
|
|
texelsInBatchN *= shape[shape.length - b - 1];
|
|
batches = `
|
|
int b${b} = index / ${texelsInBatchN};
|
|
index -= b${b} * ${texelsInBatchN};
|
|
` + batches;
|
|
coords = `b${b}, ` + coords;
|
|
}
|
|
return `
|
|
ivec${shape.length} getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
|
|
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
|
|
|
|
${batches}
|
|
|
|
int b = index / ${texelsInBatch};
|
|
index -= b * ${texelsInBatch};
|
|
|
|
int r = 2 * (index / ${texelsInLogicalRow});
|
|
int c = imod(index, ${texelsInLogicalRow}) * 2;
|
|
|
|
return ivec${shape.length}(${coords});
|
|
}
|
|
`;
|
|
}
|
|
function getOutput4DCoords(shape, texShape, enableShapeUniforms) {
|
|
if (enableShapeUniforms) {
|
|
const coordsFromIndexSnippet = getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd', 'd2'], shape);
|
|
return `
|
|
ivec4 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(outTexShape[0], outTexShape[1]));
|
|
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
|
|
${coordsFromIndexSnippet}
|
|
return ivec4(r, c, d, d2);
|
|
}
|
|
`;
|
|
}
|
|
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
|
|
return `
|
|
ivec4 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
${coordsFromIndexSnippet}
|
|
return ivec4(r, c, d, d2);
|
|
}
|
|
`;
|
|
}
|
|
function getOutput5DCoords(shape, texShape) {
|
|
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
|
|
return `
|
|
ivec5 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx * vec2(${texShape[0]},
|
|
${texShape[1]}));
|
|
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
|
|
${coordsFromIndexSnippet}
|
|
|
|
ivec5 outShape = ivec5(r, c, d, d2, d3);
|
|
return outShape;
|
|
}
|
|
`;
|
|
}
|
|
function getOutput6DCoords(shape, texShape) {
|
|
const coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
|
|
return `
|
|
ivec6 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
|
|
${coordsFromIndexSnippet}
|
|
|
|
ivec6 result = ivec6(r, c, d, d2, d3, d4);
|
|
return result;
|
|
}
|
|
`;
|
|
}
|
|
function getOutputPacked2DCoords(shape, texShape, enableShapeUniforms) {
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
if (arraysEqual(shape, texShape)) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
|
|
return 2 * ivec2(resultUV.yx * vec2(packedTexShape[0], packedTexShape[1]));
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
return 2 * ivec2(resultUV.yx * vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
|
|
}
|
|
`;
|
|
}
|
|
|
|
const texelsInLogicalRow = Math.ceil(shape[1] / 2);
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 packedTexShape = ivec2(ceil(float(outTexShape[0]) / 2.0), ceil(float(outTexShape[1]) / 2.0));
|
|
int texelsInLogicalRow = int(ceil(float(outShape[1]) / 2.0));
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(packedTexShape[0], packedTexShape[1]));
|
|
|
|
int index = resTexRC.x * packedTexShape[1] + resTexRC.y;
|
|
int r = 2 * (index / texelsInLogicalRow);
|
|
int c = imod(index, texelsInLogicalRow) * 2;
|
|
|
|
return ivec2(r, c);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${packedTexShape[0]}, ${packedTexShape[1]}));
|
|
|
|
int index = resTexRC.x * ${packedTexShape[1]} + resTexRC.y;
|
|
int r = 2 * (index / ${texelsInLogicalRow});
|
|
int c = imod(index, ${texelsInLogicalRow}) * 2;
|
|
|
|
return ivec2(r, c);
|
|
}
|
|
`;
|
|
}
|
|
function getOutput2DCoords(shape, texShape, enableShapeUniforms) {
|
|
if (arraysEqual(shape, texShape)) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
return ivec2(resultUV.yx * vec2(outTexShape[0], outTexShape[1]));
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
return ivec2(resultUV.yx * vec2(${texShape[0]}, ${texShape[1]}));
|
|
}
|
|
`;
|
|
}
|
|
if (shape[1] === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(outTexShape[0], outTexShape[1]));
|
|
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
|
|
return ivec2(index, 0);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
return ivec2(index, 0);
|
|
}
|
|
`;
|
|
}
|
|
if (shape[0] === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(outTexShape[0], outTexShape[1]));
|
|
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
|
|
return ivec2(0, index);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
return ivec2(0, index);
|
|
}
|
|
`;
|
|
}
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(outTexShape[0], outTexShape[1]));
|
|
int index = resTexRC.x * outTexShape[1] + resTexRC.y;
|
|
int r = index / outShape[1];
|
|
int c = index - r * outShape[1];
|
|
return ivec2(r, c);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
ivec2 getOutputCoords() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx *
|
|
vec2(${texShape[0]}, ${texShape[1]}));
|
|
int index = resTexRC.x * ${texShape[1]} + resTexRC.y;
|
|
int r = index / ${shape[1]};
|
|
int c = index - r * ${shape[1]};
|
|
return ivec2(r, c);
|
|
}
|
|
`;
|
|
}
|
|
function getFlatOffsetUniformName(texName) {
|
|
return `offset${texName}`;
|
|
}
|
|
function getPackedSamplerScalar(inputInfo) {
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const glsl = getGlslDifferences();
|
|
return `
|
|
vec4 ${funcName}() {
|
|
return ${glsl.texture2D}(${texName}, halfCR);
|
|
}
|
|
`;
|
|
}
|
|
function getSamplerScalar(inputInfo, enableShapeUniforms) {
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
return `float ${funcName}() {return ${texName};}`;
|
|
}
|
|
const [texNumR, texNumC] = inputInfo.shapeInfo.texShape;
|
|
if (texNumR === 1 && texNumC === 1) {
|
|
return `
|
|
float ${funcName}() {
|
|
return sampleTexture(${texName}, halfCR);
|
|
}
|
|
`;
|
|
}
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}() {
|
|
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], ${offset});
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const [tNumR, tNumC] = inputInfo.shapeInfo.texShape;
|
|
return `
|
|
float ${funcName}() {
|
|
vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, ${offset});
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getPackedSampler1D(inputInfo, enableShapeUniforms) {
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const glsl = getGlslDifferences();
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
vec4 ${funcName}(int index) {
|
|
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
|
|
vec2 uv = packedUVfrom1D(
|
|
packedTexShape[0], packedTexShape[1], index);
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
return `
|
|
vec4 ${funcName}(int index) {
|
|
vec2 uv = packedUVfrom1D(
|
|
${packedTexShape[0]}, ${packedTexShape[1]}, index);
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getSampler1D(inputInfo, enableShapeUniforms) {
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
|
|
return `
|
|
float ${funcName}(int index) {
|
|
${getUniformSampler(inputInfo)}
|
|
}
|
|
`;
|
|
}
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const tNumR = texShape[0];
|
|
const tNumC = texShape[1];
|
|
if (tNumC === 1 && tNumR === 1) {
|
|
return `
|
|
float ${funcName}(int index) {
|
|
return sampleTexture(${texName}, halfCR);
|
|
}
|
|
`;
|
|
}
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
if (tNumC === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int index) {
|
|
vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / float(${texName}TexShape[0]));
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int index) {
|
|
vec2 uv = vec2(0.5, (float(index + ${offset}) + 0.5) / ${tNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (tNumR === 1) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int index) {
|
|
vec2 uv = vec2((float(index + ${offset}) + 0.5) / float(${texName}TexShape[1]), 0.5);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int index) {
|
|
vec2 uv = vec2((float(index + ${offset}) + 0.5) / ${tNumC}.0, 0.5);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int index) {
|
|
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int index) {
|
|
vec2 uv = uvFromFlat(${tNumR}, ${tNumC}, index + ${offset});
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getPackedSampler2D(inputInfo, enableShapeUniforms) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
const glsl = getGlslDifferences();
|
|
if (texShape != null && arraysEqual(shape, texShape)) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
vec4 ${funcName}(int row, int col) {
|
|
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
|
|
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
vec4 ${funcName}(int row, int col) {
|
|
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
|
|
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
vec4 ${funcName}(int row, int col) {
|
|
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
|
|
int valuesPerRow = int(ceil(float(${texName}Shape[1]) / 2.0));
|
|
vec2 uv = packedUVfrom2D(valuesPerRow, packedTexShape[0], packedTexShape[1], row, col);
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
const valuesPerRow = Math.ceil(shape[1] / 2);
|
|
return `
|
|
vec4 ${funcName}(int row, int col) {
|
|
vec2 uv = packedUVfrom2D(${valuesPerRow}, ${packedTexShape[0]}, ${packedTexShape[1]}, row, col);
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getSampler2D(inputInfo, enableShapeUniforms) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
if (texShape != null && arraysEqual(shape, texShape)) {
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
vec2 uv = (vec2(col, row) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const { newShape, keptDims } = squeezeShape(shape);
|
|
const squeezedShape = newShape;
|
|
if (squeezedShape.length < shape.length) {
|
|
const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
|
|
const params = ['row', 'col'];
|
|
return `
|
|
${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
|
|
float ${funcName}(int row, int col) {
|
|
return ${funcName}(${getSqueezedParams(params, keptDims)});
|
|
}
|
|
`;
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
int index = round(dot(vec2(row, col), vec2(${shape[1]}, 1)));
|
|
${getUniformSampler(inputInfo)}
|
|
}
|
|
`;
|
|
}
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
if (texNumC === 1) {
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
|
|
vec2 uv = vec2(0.5, (index + 0.5) / float(${texName}TexShape[0]));
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
|
|
vec2 uv = vec2(0.5, (index + 0.5) / ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (texNumR === 1) {
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
float index = dot(vec3(row, col, ${offset}), vec3(${texName}Shape[1], 1, 1));
|
|
vec2 uv = vec2((index + 0.5) / float(${texName}TexShape[1]), 0.5);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
float index = dot(vec3(row, col, ${offset}), vec3(${shape[1]}, 1, 1));
|
|
vec2 uv = vec2((index + 0.5) / ${texNumC}.0, 0.5);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
|
|
int index = row * ${texName}Shape[1] + col + ${offset};
|
|
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col) {
|
|
|
|
int index = row * ${shape[1]} + col + ${offset};
|
|
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getPackedSampler3D(inputInfo, enableShapeUniforms) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
if (shape[0] === 1) {
|
|
const squeezedShape = shape.slice(1);
|
|
const keptDims = [1, 2];
|
|
const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
|
|
const params = ['b', 'row', 'col'];
|
|
return `
|
|
${getPackedSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
|
|
vec4 ${funcName}(int b, int row, int col) {
|
|
return ${funcName}(${getSqueezedParams(params, keptDims)});
|
|
}
|
|
`;
|
|
}
|
|
const glsl = getGlslDifferences();
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
vec4 ${funcName}(int b, int row, int col) {
|
|
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
|
|
int valuesPerRow = int(ceil(float(${texName}Shape[2]) / 2.0));
|
|
int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[1]) / 2.0));
|
|
vec2 uv = packedUVfrom3D(
|
|
packedTexShape[0], packedTexShape[1], texelsInBatch, valuesPerRow, b, row, col);
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const texNumR = packedTexShape[0];
|
|
const texNumC = packedTexShape[1];
|
|
const valuesPerRow = Math.ceil(shape[2] / 2);
|
|
const texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
|
|
return `
|
|
vec4 ${funcName}(int b, int row, int col) {
|
|
vec2 uv = packedUVfrom3D(
|
|
${texNumR}, ${texNumC}, ${texelsInBatch}, ${valuesPerRow}, b, row, col);
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getSampler3D(inputInfo, enableShapeUniforms) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const stride0 = shape[1] * shape[2];
|
|
const stride1 = shape[2];
|
|
const { newShape, keptDims } = squeezeShape(shape);
|
|
const squeezedShape = newShape;
|
|
if (squeezedShape.length < shape.length) {
|
|
const newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
|
|
const params = ['row', 'col', 'depth'];
|
|
return `
|
|
${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
|
|
float ${funcName}(int row, int col, int depth) {
|
|
return ${funcName}(${getSqueezedParams(params, keptDims)});
|
|
}
|
|
`;
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
int index = round(dot(vec3(row, col, depth),
|
|
vec3(${stride0}, ${stride1}, 1)));
|
|
${getUniformSampler(inputInfo)}
|
|
}
|
|
`;
|
|
}
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
const flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
int stride1 = ${texName}Shape[2];
|
|
float texR = float(row);
|
|
float texC = dot(vec2(col, depth), vec2(stride1, 1));
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texName}TexShape[1], ${texName}TexShape[0]);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
float texR = float(row);
|
|
float texC = dot(vec2(col, depth), vec2(${stride1}, 1));
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (texNumC === stride1 && flatOffset == null) {
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
float texR = dot(vec2(row, col), vec2(${texName}Shape[1], 1));
|
|
float texC = float(depth);
|
|
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texName}TexShape[1], ${texName}TexShape[0]);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
float texR = dot(vec2(row, col), vec2(${shape[1]}, 1));
|
|
float texC = float(depth);
|
|
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
|
|
int stride0 = ${texName}Shape[1] * ${texName}Shape[2];
|
|
int stride1 = ${texName}Shape[2];
|
|
int index = row * stride0 + col * stride1 + depth + ${offset};
|
|
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col, int depth) {
|
|
|
|
int index = row * ${stride0} + col * ${stride1} + depth + ${offset};
|
|
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getPackedSamplerND(inputInfo, enableShapeUniforms) {
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const glsl = getGlslDifferences();
|
|
if (enableShapeUniforms) {
|
|
|
|
return `
|
|
vec4 ${funcName}(int b2, int b, int row, int col) {
|
|
int valuesPerRow = int(ceil(float(${texName}Shape[3]) / 2.0));
|
|
int texelsInBatch = valuesPerRow * int(ceil(float(${texName}Shape[2]) / 2.0));
|
|
int index = b * texelsInBatch + (row / 2) * valuesPerRow + (col / 2);
|
|
texelsInBatch *= ${texName}Shape[1];
|
|
index = b2 * texelsInBatch + index;
|
|
ivec2 packedTexShape = ivec2(ceil(float(${texName}TexShape[0]) / 2.0), ceil(float(${texName}TexShape[1]) / 2.0));
|
|
int texR = index / packedTexShape[1];
|
|
int texC = index - texR * packedTexShape[1];
|
|
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(packedTexShape[1], packedTexShape[0]); return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const rank = shape.length;
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
const texNumR = packedTexShape[0];
|
|
const texNumC = packedTexShape[1];
|
|
const valuesPerRow = Math.ceil(shape[rank - 1] / 2);
|
|
let texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
|
|
let params = `int b, int row, int col`;
|
|
let index = `b * ${texelsInBatch} + (row / 2) * ${valuesPerRow} + (col / 2)`;
|
|
for (let b = 2; b < rank - 1; b++) {
|
|
params = `int b${b}, ` + params;
|
|
texelsInBatch *= shape[rank - b - 1];
|
|
index = `b${b} * ${texelsInBatch} + ` + index;
|
|
}
|
|
return `
|
|
vec4 ${funcName}(${params}) {
|
|
int index = ${index};
|
|
int texR = index / ${texNumC};
|
|
int texC = index - texR * ${texNumC};
|
|
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${texNumC}, ${texNumR});
|
|
return ${glsl.texture2D}(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getSampler4D(inputInfo, enableShapeUniforms) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const stride2 = shape[3];
|
|
const stride1 = shape[2] * stride2;
|
|
const stride0 = shape[1] * stride1;
|
|
const { newShape, keptDims } = squeezeShape(shape);
|
|
if (newShape.length < shape.length) {
|
|
const newInputInfo = squeezeInputInfo(inputInfo, newShape);
|
|
const params = ['row', 'col', 'depth', 'depth2'];
|
|
return `
|
|
${getSamplerFromInInfo(newInputInfo, enableShapeUniforms)}
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
return ${funcName}(${getSqueezedParams(params, keptDims)});
|
|
}
|
|
`;
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
int index = round(dot(vec4(row, col, depth, depth2),
|
|
vec4(${stride0}, ${stride1}, ${stride2}, 1)));
|
|
${getUniformSampler(inputInfo)}
|
|
}
|
|
`;
|
|
}
|
|
const flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
const stride2Str = `int stride2 = ${texName}Shape[3];`;
|
|
const stride1Str = `int stride1 = ${texName}Shape[2] * stride2;`;
|
|
const stride0Str = `int stride0 = ${texName}Shape[1] * stride1;`;
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
${stride2Str}
|
|
${stride1Str}
|
|
float texR = float(row);
|
|
float texC =
|
|
dot(vec3(col, depth, depth2),
|
|
vec3(stride1, stride2, 1));
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texName}TexShape[1], ${texName}TexShape[0]);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
float texR = float(row);
|
|
float texC =
|
|
dot(vec3(col, depth, depth2),
|
|
vec3(${stride1}, ${stride2}, 1));
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (texNumC === stride2 && flatOffset == null) {
|
|
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
float texR = dot(vec3(row, col, depth),
|
|
vec3(${texName}Shape[1] * ${texName}Shape[2], ${texName}Shape[2], 1));
|
|
float texC = float(depth2);
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texName}TexShape[1], ${texName}TexShape[0]);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
float texR = dot(vec3(row, col, depth),
|
|
vec3(${shape[1] * shape[2]}, ${shape[2]}, 1));
|
|
float texC = float(depth2);
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
if (enableShapeUniforms) {
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
|
|
${stride2Str}
|
|
${stride1Str}
|
|
${stride0Str}
|
|
int index = row * stride0 + col * stride1 +
|
|
depth * stride2 + depth2;
|
|
vec2 uv = uvFromFlat(${texName}TexShape[0], ${texName}TexShape[1], index + ${offset});
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2) {
|
|
|
|
int index = row * ${stride0} + col * ${stride1} +
|
|
depth * ${stride2} + depth2;
|
|
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index + ${offset});
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getSampler5D(inputInfo) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const stride3 = shape[4];
|
|
const stride2 = shape[3] * stride3;
|
|
const stride1 = shape[2] * stride2;
|
|
const stride0 = shape[1] * stride1;
|
|
const { newShape, keptDims } = squeezeShape(shape);
|
|
if (newShape.length < shape.length) {
|
|
const newInputInfo = squeezeInputInfo(inputInfo, newShape);
|
|
const params = ['row', 'col', 'depth', 'depth2', 'depth3'];
|
|
return `
|
|
${getSamplerFromInInfo(newInputInfo)}
|
|
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
|
|
return ${funcName}(${getSqueezedParams(params, keptDims)});
|
|
}
|
|
`;
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
|
|
float index = dot(
|
|
vec4(row, col, depth, depth2),
|
|
vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
|
|
depth3;
|
|
${getUniformSampler(inputInfo)}
|
|
}
|
|
`;
|
|
}
|
|
const flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
|
|
int texR = row;
|
|
float texC = dot(vec4(col, depth, depth2, depth3),
|
|
vec4(${stride1}, ${stride2}, ${stride3}, 1));
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (texNumC === stride3 && flatOffset == null) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
|
|
float texR = dot(
|
|
vec4(row, col, depth, depth2),
|
|
vec4(${shape[1] * shape[2] * shape[3]},
|
|
${shape[2] * shape[3]}, ${shape[3]}, 1));
|
|
int texC = depth3;
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
return `
|
|
float ${funcName}(int row, int col, int depth, int depth2, int depth3) {
|
|
|
|
int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
|
|
depth2 * ${stride3} + depth3 + ${offset};
|
|
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getSampler6D(inputInfo) {
|
|
const shape = inputInfo.shapeInfo.logicalShape;
|
|
const texName = inputInfo.name;
|
|
const funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const { newShape, keptDims } = squeezeShape(shape);
|
|
if (newShape.length < shape.length) {
|
|
const newInputInfo = squeezeInputInfo(inputInfo, newShape);
|
|
const params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
|
|
return `
|
|
${getSamplerFromInInfo(newInputInfo)}
|
|
float ${funcName}(int row, int col, int depth,
|
|
int depth2, int depth3, int depth4) {
|
|
return ${funcName}(${getSqueezedParams(params, keptDims)});
|
|
}
|
|
`;
|
|
}
|
|
const stride4 = shape[5];
|
|
const stride3 = shape[4] * stride4;
|
|
const stride2 = shape[3] * stride3;
|
|
const stride1 = shape[2] * stride2;
|
|
const stride0 = shape[1] * stride1;
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth,
|
|
int depth2, int depth3, int depth4) {
|
|
int index = round(dot(
|
|
vec4(row, col, depth, depth2),
|
|
vec4(${stride0}, ${stride1}, ${stride2}, ${stride3})) +
|
|
dot(
|
|
vec2(depth3, depth4),
|
|
vec2(${stride4}, 1)));
|
|
${getUniformSampler(inputInfo)}
|
|
}
|
|
`;
|
|
}
|
|
const flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
const texShape = inputInfo.shapeInfo.texShape;
|
|
const texNumR = texShape[0];
|
|
const texNumC = texShape[1];
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth,
|
|
int depth2, int depth3, int depth4) {
|
|
int texR = row;
|
|
float texC = dot(vec4(col, depth, depth2, depth3),
|
|
vec4(${stride1}, ${stride2}, ${stride3}, ${stride4})) +
|
|
float(depth4);
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
if (texNumC === stride4 && flatOffset == null) {
|
|
|
|
return `
|
|
float ${funcName}(int row, int col, int depth,
|
|
int depth2, int depth3, int depth4) {
|
|
float texR = dot(vec4(row, col, depth, depth2),
|
|
vec4(${shape[1] * shape[2] * shape[3] * shape[4]},
|
|
${shape[2] * shape[3] * shape[4]},
|
|
${shape[3] * shape[4]},
|
|
${shape[4]})) + float(depth3);
|
|
int texC = depth4;
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${texNumC}.0, ${texNumR}.0);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
const offset = getFlatOffsetUniformName(texName);
|
|
return `
|
|
float ${funcName}(int row, int col, int depth,
|
|
int depth2, int depth3, int depth4) {
|
|
|
|
int index = row * ${stride0} + col * ${stride1} + depth * ${stride2} +
|
|
depth2 * ${stride3} + depth3 * ${stride4} + depth4 + ${offset};
|
|
vec2 uv = uvFromFlat(${texNumR}, ${texNumC}, index);
|
|
return sampleTexture(${texName}, uv);
|
|
}
|
|
`;
|
|
}
|
|
function getUniformSampler(inputInfo) {
|
|
const texName = inputInfo.name;
|
|
const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
|
|
if (inSize < 2) {
|
|
return `return ${texName};`;
|
|
}
|
|
return `
|
|
for (int i = 0; i < ${inSize}; i++) {
|
|
if (i == index) {
|
|
return ${texName}[i];
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
|
|
const texName = inputInfo.name;
|
|
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
|
|
const inRank = inputInfo.shapeInfo.logicalShape.length;
|
|
const outRank = outShapeInfo.logicalShape.length;
|
|
const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
|
|
const type = getCoordsDataType(outRank);
|
|
const rankDiff = outRank - inRank;
|
|
let coordsSnippet;
|
|
const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
|
|
if (inRank === 0) {
|
|
coordsSnippet = '';
|
|
}
|
|
else if (outRank < 2 && broadcastDims.length >= 1) {
|
|
coordsSnippet = 'coords = 0;';
|
|
}
|
|
else {
|
|
coordsSnippet =
|
|
broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
|
|
.join('\n');
|
|
}
|
|
let unpackedCoordsSnippet = '';
|
|
if (outRank < 2 && inRank > 0) {
|
|
unpackedCoordsSnippet = 'coords';
|
|
}
|
|
else {
|
|
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
|
|
.map((s, i) => `coords.${fields[i + rankDiff]}`)
|
|
.join(', ');
|
|
}
|
|
let output = `return outputValue;`;
|
|
const inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape);
|
|
const isInputScalar = inSize === 1;
|
|
const outSize = sizeFromShape(outShapeInfo.logicalShape);
|
|
const isOutputScalar = outSize === 1;
|
|
if (inRank === 1 && !isInputScalar && !isOutputScalar) {
|
|
output = `
|
|
return vec4(outputValue.xy, outputValue.xy);
|
|
`;
|
|
}
|
|
else if (isInputScalar && !isOutputScalar) {
|
|
if (outRank === 1) {
|
|
output = `
|
|
return vec4(outputValue.x, outputValue.x, 0., 0.);
|
|
`;
|
|
}
|
|
else {
|
|
output = `
|
|
return vec4(outputValue.x);
|
|
`;
|
|
}
|
|
}
|
|
else if (broadcastDims.length) {
|
|
const rows = inRank - 2;
|
|
const cols = inRank - 1;
|
|
if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
|
|
output = `return vec4(outputValue.x);`;
|
|
}
|
|
else if (broadcastDims.indexOf(rows) > -1) {
|
|
output = `return vec4(outputValue.x, outputValue.y, ` +
|
|
`outputValue.x, outputValue.y);`;
|
|
}
|
|
else if (broadcastDims.indexOf(cols) > -1) {
|
|
output = `return vec4(outputValue.xx, outputValue.zz);`;
|
|
}
|
|
}
|
|
return `
|
|
vec4 ${funcName}() {
|
|
${type} coords = getOutputCoords();
|
|
${coordsSnippet}
|
|
vec4 outputValue = get${texFuncSnippet}(${unpackedCoordsSnippet});
|
|
${output}
|
|
}
|
|
`;
|
|
}
|
|
function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
|
|
const texName = inputInfo.name;
|
|
const texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
const funcName = 'get' + texFuncSnippet + 'AtOutCoords';
|
|
const outTexShape = outShapeInfo.texShape;
|
|
const inTexShape = inputInfo.shapeInfo.texShape;
|
|
const inRank = inputInfo.shapeInfo.logicalShape.length;
|
|
const outRank = outShapeInfo.logicalShape.length;
|
|
if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
|
|
inputInfo.shapeInfo.flatOffset == null &&
|
|
arraysEqual(inTexShape, outTexShape)) {
|
|
return `
|
|
float ${funcName}() {
|
|
return sampleTexture(${texName}, resultUV);
|
|
}
|
|
`;
|
|
}
|
|
const type = getCoordsDataType(outRank);
|
|
const broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
|
|
const rankDiff = outRank - inRank;
|
|
let coordsSnippet;
|
|
const fields = ['x', 'y', 'z', 'w', 'u', 'v'];
|
|
if (inRank === 0) {
|
|
coordsSnippet = '';
|
|
}
|
|
else if (outRank < 2 && broadcastDims.length >= 1) {
|
|
coordsSnippet = 'coords = 0;';
|
|
}
|
|
else {
|
|
coordsSnippet =
|
|
broadcastDims.map(d => `coords.${fields[d + rankDiff]} = 0;`)
|
|
.join('\n');
|
|
}
|
|
let unpackedCoordsSnippet = '';
|
|
if (outRank < 2 && inRank > 0) {
|
|
unpackedCoordsSnippet = 'coords';
|
|
}
|
|
else {
|
|
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
|
|
.map((s, i) => `coords.${fields[i + rankDiff]}`)
|
|
.join(', ');
|
|
}
|
|
return `
|
|
float ${funcName}() {
|
|
${type} coords = getOutputCoords();
|
|
${coordsSnippet}
|
|
return get${texFuncSnippet}(${unpackedCoordsSnippet});
|
|
}
|
|
`;
|
|
}
|
|
function getCoordsDataType(rank) {
|
|
if (rank <= 1) {
|
|
return 'int';
|
|
}
|
|
else if (rank === 2) {
|
|
return 'ivec2';
|
|
}
|
|
else if (rank === 3) {
|
|
return 'ivec3';
|
|
}
|
|
else if (rank === 4) {
|
|
return 'ivec4';
|
|
}
|
|
else if (rank === 5) {
|
|
return 'ivec5';
|
|
}
|
|
else if (rank === 6) {
|
|
return 'ivec6';
|
|
}
|
|
else {
|
|
throw Error(`GPU for rank ${rank} is not yet supported`);
|
|
}
|
|
}
|
|
function getUniformInfoFromShape(isPacked, shape, texShape) {
|
|
const { newShape, keptDims } = squeezeShape(shape);
|
|
const rank = shape.length;
|
|
const useSqueezePackedShape = isPacked && rank === 3 && shape[0] === 1;
|
|
const squeezeShape$1 = useSqueezePackedShape ? shape.slice(1) : newShape;
|
|
const useSqueezeShape = (!isPacked && rank > 1 && !arraysEqual(shape, texShape) &&
|
|
newShape.length < rank) ||
|
|
useSqueezePackedShape;
|
|
const uniformShape = useSqueezeShape ? squeezeShape$1 : shape;
|
|
return { useSqueezeShape, uniformShape, keptDims };
|
|
}
|
|
|
|
function squeezeInputInfo(inInfo, squeezedShape) {
|
|
|
|
const newInputInfo = JSON.parse(JSON.stringify(inInfo));
|
|
newInputInfo.shapeInfo.logicalShape = squeezedShape;
|
|
return newInputInfo;
|
|
}
|
|
function getSqueezedParams(params, keptDims) {
|
|
return keptDims.map(d => params[d]).join(', ');
|
|
}
|
|
|
|
|
|
function compileProgram(gpgpu, program, inputs, output) {
|
|
const inputInfos = inputs.map((input, i) => {
|
|
const shapeInfo = {
|
|
logicalShape: input.shape,
|
|
texShape: input.isUniform ? null : input.texData.texShape,
|
|
isUniform: input.isUniform,
|
|
isPacked: input.isUniform ? false : input.texData.isPacked,
|
|
flatOffset: null
|
|
};
|
|
if (input.texData != null && input.texData.slice != null &&
|
|
input.texData.slice.flatOffset > 0) {
|
|
shapeInfo.flatOffset = input.texData.slice.flatOffset;
|
|
}
|
|
return { name: program.variableNames[i], shapeInfo };
|
|
});
|
|
const inShapeInfos = inputInfos.map(x => x.shapeInfo);
|
|
const outShapeInfo = {
|
|
logicalShape: output.shape,
|
|
texShape: output.texData.texShape,
|
|
isUniform: false,
|
|
isPacked: output.texData.isPacked,
|
|
flatOffset: null
|
|
};
|
|
const source = makeShader(inputInfos, outShapeInfo, program);
|
|
const fragmentShader = createFragmentShader(gpgpu.gl, source);
|
|
const webGLProgram = gpgpu.createProgram(fragmentShader);
|
|
if (!env().get('ENGINE_COMPILE_ONLY')) {
|
|
gpgpu.buildVao(webGLProgram);
|
|
return Object.assign({ program,
|
|
fragmentShader,
|
|
source,
|
|
webGLProgram,
|
|
inShapeInfos,
|
|
outShapeInfo }, getUniformLocations(gpgpu, program, webGLProgram));
|
|
}
|
|
else {
|
|
return {
|
|
program,
|
|
fragmentShader,
|
|
source,
|
|
webGLProgram,
|
|
inShapeInfos,
|
|
outShapeInfo,
|
|
variablesLocations: null,
|
|
customUniformLocations: null,
|
|
infLoc: null,
|
|
nanLoc: null,
|
|
outShapeLocation: null,
|
|
outShapeStridesLocation: null,
|
|
outTexShapeLocation: null
|
|
};
|
|
}
|
|
}
|
|
function getUniformLocations(gpgpu, program, webGLProgram) {
|
|
const variablesLocations = [];
|
|
const customUniformLocations = [];
|
|
let outShapeLocation;
|
|
let outTexShapeLocation;
|
|
let outShapeStridesLocation;
|
|
let infLoc = null;
|
|
let nanLoc = null;
|
|
|
|
nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
|
|
if (env().getNumber('WEBGL_VERSION') === 1) {
|
|
infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
|
|
}
|
|
|
|
const shouldThrow = false;
|
|
for (const varName of program.variableNames) {
|
|
const varLocs = {
|
|
name: varName,
|
|
uniform: gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow),
|
|
offset: gpgpu.getUniformLocation(webGLProgram, `offset${varName}`, shouldThrow),
|
|
};
|
|
if (program.enableShapeUniforms) {
|
|
varLocs.shape = gpgpu.getUniformLocation(webGLProgram, `${varName}Shape`, shouldThrow);
|
|
varLocs.texShape = gpgpu.getUniformLocation(webGLProgram, `${varName}TexShape`, shouldThrow);
|
|
}
|
|
variablesLocations.push(varLocs);
|
|
}
|
|
if (program.enableShapeUniforms) {
|
|
outShapeLocation =
|
|
gpgpu.getUniformLocation(webGLProgram, 'outShape', shouldThrow);
|
|
outShapeStridesLocation =
|
|
gpgpu.getUniformLocation(webGLProgram, 'outShapeStrides', shouldThrow);
|
|
outTexShapeLocation =
|
|
gpgpu.getUniformLocation(webGLProgram, 'outTexShape', shouldThrow);
|
|
}
|
|
if (program.customUniforms) {
|
|
for (const d of program.customUniforms) {
|
|
customUniformLocations.push(gpgpu.getUniformLocation(webGLProgram, d.name, shouldThrow));
|
|
}
|
|
}
|
|
return {
|
|
variablesLocations,
|
|
customUniformLocations,
|
|
infLoc,
|
|
nanLoc,
|
|
outShapeLocation,
|
|
outShapeStridesLocation,
|
|
outTexShapeLocation
|
|
};
|
|
}
|
|
function validateBinaryAndProgram(shapeInfos, inputs) {
|
|
if (shapeInfos.length !== inputs.length) {
|
|
throw Error(`Binary was compiled with ${shapeInfos.length} inputs, but ` +
|
|
`was executed with ${inputs.length} inputs`);
|
|
}
|
|
shapeInfos.forEach((s, i) => {
|
|
const shapeA = s.logicalShape;
|
|
const input = inputs[i];
|
|
const shapeB = input.shape;
|
|
if (!arraysEqual(shapeA, shapeB)) {
|
|
throw Error(`Binary was compiled with different shapes than ` +
|
|
`the current args. Shapes ${shapeA} and ${shapeB} must match`);
|
|
}
|
|
|
|
if (s.isUniform && input.isUniform) {
|
|
return;
|
|
}
|
|
const texShapeA = s.texShape;
|
|
const texShapeB = input.isUniform ? null : input.texData.texShape;
|
|
if (!arraysEqual(texShapeA, texShapeB)) {
|
|
throw Error(`Binary was compiled with different texture shapes than the` +
|
|
` current args. Shape ${texShapeA} and ${texShapeB} must match`);
|
|
}
|
|
});
|
|
}
|
|
function runProgram(gpgpu, binary, inputs, output, customUniformValues) {
|
|
if (!binary.program.enableShapeUniforms) {
|
|
validateBinaryAndProgram(binary.inShapeInfos, inputs);
|
|
validateBinaryAndProgram([binary.outShapeInfo], [output]);
|
|
}
|
|
const outTex = output.texData.texture;
|
|
const outTexShape = output.texData.texShape;
|
|
if (output.texData.isPacked) {
|
|
gpgpu.setOutputPackedMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
|
|
}
|
|
else {
|
|
gpgpu.setOutputMatrixTexture(outTex.texture, outTexShape[0], outTexShape[1]);
|
|
}
|
|
gpgpu.setProgram(binary.webGLProgram);
|
|
gpgpu.bindVertexArray(binary.webGLProgram.vao);
|
|
|
|
if (env().getNumber('WEBGL_VERSION') === 1) {
|
|
if (binary.infLoc !== null) {
|
|
gpgpu.gl.uniform1f(binary.infLoc, Infinity);
|
|
}
|
|
}
|
|
if (binary.nanLoc !== null) {
|
|
gpgpu.gl.uniform1f(binary.nanLoc, NaN);
|
|
}
|
|
|
|
for (let i = 0; i < inputs.length; ++i) {
|
|
const input = inputs[i];
|
|
const { uniform: varLoc, offset: varOffsetLoc, shape: varShapeLoc, texShape: varTexShapeLoc, } = binary.variablesLocations[i];
|
|
if (varShapeLoc) {
|
|
const { uniformShape } = getUniformInfoFromShape(binary.program.packedInputs, input.shape, input.texData.texShape);
|
|
switch (uniformShape.length) {
|
|
case 1:
|
|
gpgpu.gl.uniform1iv(varShapeLoc, new Int32Array(uniformShape));
|
|
break;
|
|
case 2:
|
|
gpgpu.gl.uniform2iv(varShapeLoc, new Int32Array(uniformShape));
|
|
break;
|
|
case 3:
|
|
gpgpu.gl.uniform3iv(varShapeLoc, new Int32Array(uniformShape));
|
|
break;
|
|
case 4:
|
|
gpgpu.gl.uniform4iv(varShapeLoc, new Int32Array(uniformShape));
|
|
break;
|
|
}
|
|
}
|
|
if (varTexShapeLoc) {
|
|
gpgpu.gl.uniform2i(varTexShapeLoc, input.texData.texShape[0], input.texData.texShape[1]);
|
|
}
|
|
if (varLoc == null) {
|
|
|
|
continue;
|
|
}
|
|
if (input.isUniform) {
|
|
|
|
if (sizeFromShape(input.shape) < 2) {
|
|
gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
|
|
}
|
|
else {
|
|
let vals = input.uniformValues;
|
|
if (!(vals instanceof Float32Array)) {
|
|
vals = new Float32Array(vals);
|
|
}
|
|
gpgpu.gl.uniform1fv(varLoc, vals);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
if (input.texData.slice != null && varOffsetLoc != null) {
|
|
gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
|
|
}
|
|
gpgpu.setInputMatrixTexture(input.texData.texture.texture, varLoc, i);
|
|
}
|
|
const outShapeLoc = binary.outShapeLocation;
|
|
if (outShapeLoc) {
|
|
switch (output.shape.length) {
|
|
case 1:
|
|
gpgpu.gl.uniform1iv(outShapeLoc, new Int32Array(output.shape));
|
|
break;
|
|
case 2:
|
|
gpgpu.gl.uniform2iv(outShapeLoc, new Int32Array(output.shape));
|
|
break;
|
|
case 3:
|
|
gpgpu.gl.uniform3iv(outShapeLoc, new Int32Array(output.shape));
|
|
break;
|
|
case 4:
|
|
gpgpu.gl.uniform4iv(outShapeLoc, new Int32Array(output.shape));
|
|
break;
|
|
}
|
|
}
|
|
if (binary.outShapeStridesLocation) {
|
|
const strides = computeStrides(output.shape);
|
|
switch (output.shape.length) {
|
|
case 2:
|
|
gpgpu.gl.uniform1iv(binary.outShapeStridesLocation, new Int32Array(strides));
|
|
break;
|
|
case 3:
|
|
gpgpu.gl.uniform2iv(binary.outShapeStridesLocation, new Int32Array(strides));
|
|
break;
|
|
case 4:
|
|
gpgpu.gl.uniform3iv(binary.outShapeStridesLocation, new Int32Array(strides));
|
|
break;
|
|
}
|
|
}
|
|
if (binary.outTexShapeLocation) {
|
|
gpgpu.gl.uniform2i(binary.outTexShapeLocation, output.texData.texShape[0], output.texData.texShape[1]);
|
|
}
|
|
if (binary.program.customUniforms && customUniformValues) {
|
|
for (let i = 0; i < binary.program.customUniforms.length; ++i) {
|
|
const d = binary.program.customUniforms[i];
|
|
const customLoc = binary.customUniformLocations[i];
|
|
const customValue = customUniformValues[i];
|
|
if (d.type === 'float') {
|
|
gpgpu.gl.uniform1fv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'vec2') {
|
|
gpgpu.gl.uniform2fv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'vec3') {
|
|
gpgpu.gl.uniform3fv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'vec4') {
|
|
gpgpu.gl.uniform4fv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'int') {
|
|
gpgpu.gl.uniform1iv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'ivec2') {
|
|
gpgpu.gl.uniform2iv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'ivec3') {
|
|
gpgpu.gl.uniform3iv(customLoc, customValue);
|
|
}
|
|
else if (d.type === 'ivec4') {
|
|
gpgpu.gl.uniform4iv(customLoc, customValue);
|
|
}
|
|
else {
|
|
throw Error(`uniform type ${d.type} is not supported yet.`);
|
|
}
|
|
}
|
|
}
|
|
gpgpu.executeProgram();
|
|
}
|
|
function makeShaderKey(program, inputs, output) {
|
|
let keyInputs = '';
|
|
inputs.concat(output).forEach(x => {
|
|
const hasOffset = x.texData != null && x.texData.slice != null &&
|
|
x.texData.slice.flatOffset > 0;
|
|
|
|
if (program.enableShapeUniforms && !x.isUniform) {
|
|
const xTexShape = x.texData.texShape;
|
|
const { useSqueezeShape, uniformShape, keptDims } = getUniformInfoFromShape(program.packedInputs, x.shape, xTexShape);
|
|
let rank1 = '', rank2 = '', rank34 = '';
|
|
if (uniformShape.length === 1 && program.packedInputs) {
|
|
const packedTexShape = [Math.ceil(xTexShape[0] / 2), Math.ceil(xTexShape[1] / 2)];
|
|
rank1 = `${packedTexShape[0] > 1}_${packedTexShape[1] > 1}`;
|
|
}
|
|
else if (uniformShape.length === 2 && !program.packedInputs) {
|
|
rank2 = `${uniformShape[0] > 1}_${uniformShape[1] > 1}`;
|
|
}
|
|
else if (uniformShape.length > 2 && !program.packedInputs) {
|
|
const strides = computeStrides(uniformShape);
|
|
rank34 = `${strides[0] === xTexShape[1]}_${strides[strides.length - 1] === xTexShape[1]}`;
|
|
}
|
|
const xRank = x.shape.length;
|
|
const isLogicalShapTexShapeEqual = uniformShape.length === 2 && arraysEqual(x.shape, xTexShape);
|
|
const isScalar = sizeFromShape(x.shape) === 1;
|
|
const broadcastDims = getBroadcastDims$1(x.shape, output.shape);
|
|
const isInOutTexShapeEqual = !program.packedInputs &&
|
|
xRank === output.shape.length &&
|
|
arraysEqual(xTexShape, output.texData.texShape);
|
|
const isTexShapeGreaterThanOne = program.packedInputs || uniformShape.length > 2 ?
|
|
'' :
|
|
`${xTexShape[0] > 1}_${xTexShape[1] > 1}`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
keyInputs += `${xRank}_${isInOutTexShapeEqual}_${useSqueezeShape ? keptDims : ''}_${uniformShape.length}_${isScalar}_${broadcastDims}_${isLogicalShapTexShapeEqual}_${rank1}_${rank2}_${rank34}_${isTexShapeGreaterThanOne}_${hasOffset}`;
|
|
}
|
|
else {
|
|
const texShape = x.isUniform ? 'uniform' : x.texData.texShape;
|
|
keyInputs += `${x.shape}_${texShape}_${hasOffset}`;
|
|
}
|
|
});
|
|
const keyUserCode = program.userCode;
|
|
let key = program.constructor.name;
|
|
|
|
key += '_' + keyInputs + '_' + keyUserCode +
|
|
`${env().getNumber('WEBGL_VERSION')}`;
|
|
return key;
|
|
}
|
|
function useShapeUniforms(rank) {
|
|
|
|
return env().getBool('WEBGL_USE_SHAPES_UNIFORMS') && rank <= 4;
|
|
}
|
|
|
|
|
|
class DecodeMatrixProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
this.outPackingScheme = PackingScheme.DENSE;
|
|
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
|
|
const glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
this.userCode = `
|
|
ivec3 outCoordsFromFlatIndex(int index) {
|
|
${this.enableShapeUniforms ?
|
|
getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
|
|
getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
|
|
return ivec3(r, c, d);
|
|
}
|
|
|
|
void main() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
|
|
int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
|
|
|
|
vec4 result = vec4(0.);
|
|
|
|
for (int i=0; i<4; i++) {
|
|
int flatIndex = index + i;
|
|
ivec3 rc = outCoordsFromFlatIndex(flatIndex);
|
|
result[i] = getA(rc.x, rc.y, rc.z);
|
|
}
|
|
|
|
${glsl.output} = result;
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class DecodeMatrixPackedProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outPackingScheme = PackingScheme.DENSE;
|
|
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
|
|
const glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
this.userCode = `
|
|
ivec3 outCoordsFromFlatIndex(int index) {
|
|
${this.enableShapeUniforms ?
|
|
getOutputLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], outputShape) :
|
|
getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape)}
|
|
return ivec3(r, c, d);
|
|
}
|
|
|
|
void main() {
|
|
ivec2 resTexRC = ivec2(resultUV.yx * vec2(texShape[0], texShape[1]));
|
|
int index = 4 * (resTexRC.x * texShape[1] + resTexRC.y);
|
|
|
|
vec4 result = vec4(0.);
|
|
|
|
for (int i=0; i<4; i++) {
|
|
int flatIndex = index + i;
|
|
ivec3 rc = outCoordsFromFlatIndex(flatIndex);
|
|
result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));
|
|
}
|
|
|
|
${glsl.output} = result;
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class EncodeFloatProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.outTexUsage = TextureUsage.DOWNLOAD;
|
|
const glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.userCode = `
|
|
${ENCODE_FLOAT_SNIPPET}
|
|
|
|
void main() {
|
|
float x = getAAtOutCoords();
|
|
${glsl.output} = encode_float(x);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class EncodeFloatPackedProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = false;
|
|
this.outTexUsage = TextureUsage.DOWNLOAD;
|
|
const glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.userCode = `
|
|
${ENCODE_FLOAT_SNIPPET}
|
|
|
|
void main() {
|
|
ivec3 coords = getOutputCoords();
|
|
float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));
|
|
${glsl.output} = encode_float(x);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const CHANNEL_CHAR_TO_INDEX_MAP = {
|
|
'R': 0,
|
|
'G': 1,
|
|
'B': 2,
|
|
'A': 3
|
|
};
|
|
class EncodeMatrixProgram {
|
|
constructor(outputShape, inputIsUnsignedByte = false, usedChannels = 'RGBA') {
|
|
this.variableNames = ['A'];
|
|
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
|
|
const glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
let output = `result`;
|
|
if (inputIsUnsignedByte) {
|
|
output = `floor(result * 255. + 0.5)`;
|
|
}
|
|
let mainLoop = '';
|
|
for (let usedChannelIndex = 0; usedChannelIndex < usedChannels.length; usedChannelIndex++) {
|
|
const curChannel = usedChannels[usedChannelIndex];
|
|
mainLoop += `
|
|
if(offset == ${usedChannelIndex}) {
|
|
result = values[${CHANNEL_CHAR_TO_INDEX_MAP[curChannel]}];
|
|
}`;
|
|
}
|
|
this.userCode = `
|
|
${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
|
|
getFlatIndexFrom3D(outputShape)}
|
|
|
|
void main() {
|
|
ivec3 coords = getOutputCoords();
|
|
int flatIndex = getFlatIndex(coords);
|
|
float result = 0.;
|
|
int offset = imod(flatIndex, ${usedChannels.length});
|
|
|
|
flatIndex = idiv(flatIndex, ${usedChannels.length}, 1.);
|
|
|
|
int r = flatIndex / texShape[1];
|
|
if (r < texShape[0]) {
|
|
int c = imod(flatIndex, texShape[1]);
|
|
vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
|
|
vec4 values = ${glsl.texture2D}(A, uv);
|
|
${mainLoop}
|
|
}
|
|
${glsl.output} = vec4(${output}, 0., 0., 0.);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class EncodeMatrixPackedProgram {
|
|
constructor(outputShape, inputIsUnsignedByte = false) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [{ name: 'texShape', type: 'ivec2' }];
|
|
const glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
let mainLoop = '';
|
|
let output = 'result';
|
|
if (inputIsUnsignedByte) {
|
|
output = 'floor(result * 255. + 0.5)';
|
|
}
|
|
for (let row = 0; row <= 1; row++) {
|
|
for (let col = 0; col <= 1; col++) {
|
|
const channel = row * 2 + col;
|
|
mainLoop += `
|
|
localCoords = coords;
|
|
if(localCoords[2] + ${col} < ${this.enableShapeUniforms ? 'outShape[2]' : `${outputShape[2]}`}) {
|
|
localCoords[2] += ${col};
|
|
if (localCoords[1] + ${row} < ${this.enableShapeUniforms ? 'outShape[1]' : `${outputShape[1]}`}) {
|
|
localCoords[1] += ${row};
|
|
|
|
flatIndex = getFlatIndex(localCoords);
|
|
offset = imod(flatIndex, 4);
|
|
|
|
flatIndex = idiv(flatIndex, 4, 1.);
|
|
|
|
int r = flatIndex / texShape[1];
|
|
int c = imod(flatIndex, texShape[1]);
|
|
vec2 uv = (vec2(c, r) + halfCR) / vec2(texShape[1], texShape[0]);
|
|
values = ${glsl.texture2D}(A, uv);
|
|
|
|
if (offset == 0) {
|
|
result[${channel}] = values[0];
|
|
} else if (offset == 1) {
|
|
result[${channel}] = values[1];
|
|
} else if (offset == 2) {
|
|
result[${channel}] = values[2];
|
|
} else {
|
|
result[${channel}] = values[3];
|
|
}
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
this.userCode = `
|
|
${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
|
|
getFlatIndexFrom3D(outputShape)}
|
|
|
|
void main() {
|
|
ivec3 coords = getOutputCoords();
|
|
|
|
vec4 result = vec4(0.);
|
|
int flatIndex, r, c, offset;
|
|
ivec3 localCoords;
|
|
vec2 uv;
|
|
vec4 values;
|
|
|
|
${mainLoop}
|
|
|
|
${glsl.output} = ${output};
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function createVertexShader(gl) {
|
|
const glsl = getGlslDifferences();
|
|
const vertexShaderSource = `${glsl.version}
|
|
precision highp float;
|
|
${glsl.attribute} vec3 clipSpacePos;
|
|
${glsl.attribute} vec2 uv;
|
|
${glsl.varyingVs} vec2 resultUV;
|
|
|
|
void main() {
|
|
gl_Position = vec4(clipSpacePos, 1);
|
|
resultUV = uv;
|
|
}`;
|
|
return createVertexShader$1(gl, vertexShaderSource);
|
|
}
|
|
function createVertexBuffer(gl) {
|
|
|
|
const vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
|
|
return createStaticVertexBuffer(gl, vertexArray);
|
|
}
|
|
function createIndexBuffer(gl) {
|
|
|
|
const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
|
|
return createStaticIndexBuffer(gl, triangleVertexIndices);
|
|
}
|
|
function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
|
|
validateTextureSize(width, height);
|
|
const texture = createTexture(gl);
|
|
const tex2d = gl.TEXTURE_2D;
|
|
callAndCheck(gl, () => gl.bindTexture(tex2d, texture));
|
|
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE));
|
|
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE));
|
|
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST));
|
|
callAndCheck(gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST));
|
|
if (env().getNumber('WEBGL_VERSION') === 1) {
|
|
callAndCheck(gl, () => gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null));
|
|
}
|
|
else {
|
|
callAndCheck(gl, () => gl
|
|
.texStorage2D(tex2d, 1, internalFormat, width, height));
|
|
}
|
|
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
|
|
return { texture, texShape: [height, width] };
|
|
}
|
|
function getInternalFormatForFloat32MatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatFloat;
|
|
}
|
|
function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
|
|
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
|
|
}
|
|
function getInternalFormatForFloat16MatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatHalfFloat;
|
|
}
|
|
function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
|
|
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
|
|
}
|
|
function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
|
|
return textureConfig.downloadTextureFormat;
|
|
}
|
|
function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
|
|
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
|
|
}
|
|
function getInternalFormatForPackedMatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatPackedFloat;
|
|
}
|
|
function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
|
|
const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
|
|
}
|
|
function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatPackedHalfFloat;
|
|
}
|
|
function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
|
|
const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
|
|
}
|
|
function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
|
|
const posOffset = 0;
|
|
const uvOffset = 3 * 4;
|
|
const stride = (3 * 4) + (2 * 4);
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer));
|
|
const success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
|
|
return success &&
|
|
bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
|
|
}
|
|
function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
|
|
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
|
|
let dataForUpload, texelDataType, internalFormat;
|
|
if (data instanceof Uint8Array) {
|
|
dataForUpload = new Uint8Array(width * height * 4);
|
|
texelDataType = gl.UNSIGNED_BYTE;
|
|
internalFormat = gl.RGBA;
|
|
}
|
|
else {
|
|
dataForUpload = new Float32Array(width * height * 4);
|
|
texelDataType = gl.FLOAT;
|
|
internalFormat = textureConfig.internalFormatPackedFloat;
|
|
}
|
|
dataForUpload.set(data);
|
|
if (env().getNumber('WEBGL_VERSION') === 2) {
|
|
callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, width, height, gl.RGBA, texelDataType, dataForUpload));
|
|
}
|
|
else {
|
|
callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload));
|
|
}
|
|
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
|
|
}
|
|
function uploadPixelDataToTexture(gl, texture, pixels) {
|
|
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture));
|
|
if (pixels.data instanceof Uint8Array) {
|
|
if (env().getNumber('WEBGL_VERSION') === 2) {
|
|
callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, pixels.width, pixels.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
|
|
}
|
|
else {
|
|
callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data));
|
|
}
|
|
}
|
|
else {
|
|
if (env().getNumber('WEBGL_VERSION') === 2) {
|
|
callAndCheck(gl, () => gl.texSubImage2D(gl.TEXTURE_2D, 0, 0, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
|
|
}
|
|
else {
|
|
callAndCheck(gl, () => gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels));
|
|
}
|
|
}
|
|
callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null));
|
|
}
|
|
function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
|
|
|
|
const buffer = gl2.createBuffer();
|
|
callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer));
|
|
|
|
const bytesPerFloat = 4;
|
|
const valuesPerTexel = 4;
|
|
const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
|
|
callAndCheck(gl2, () => gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ));
|
|
|
|
|
|
callAndCheck(gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0));
|
|
callAndCheck(gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null));
|
|
return buffer;
|
|
}
|
|
function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
|
|
const gl2 = gl;
|
|
const downloadTarget = new Float32Array(size);
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
|
|
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
|
|
return downloadTarget;
|
|
}
|
|
function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
|
|
const [w, h] = getUnpackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
const numChannels = 4;
|
|
const downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
|
|
callAndCheck(gl, () => gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget));
|
|
|
|
|
|
return new Float32Array(downloadTarget.buffer);
|
|
}
|
|
function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
|
|
const gl2 = gl;
|
|
const downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
|
|
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
|
|
return downloadTarget;
|
|
}
|
|
function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
|
|
const packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
|
|
callAndCheck(gl, () => gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA));
|
|
return packedRGBA;
|
|
}
|
|
|
|
|
|
class GPGPUContext {
|
|
constructor(gl) {
|
|
this.outputTexture = null;
|
|
this.program = null;
|
|
this.disposed = false;
|
|
this.itemsToPoll = [];
|
|
const glVersion = env().getNumber('WEBGL_VERSION');
|
|
if (gl != null) {
|
|
this.gl = gl;
|
|
setWebGLContext(glVersion, gl);
|
|
}
|
|
else {
|
|
this.gl = getWebGLContext(glVersion);
|
|
}
|
|
gl = this.gl;
|
|
if (env().getNumber('WEBGL_VERSION') === 2) {
|
|
const gl2 = gl;
|
|
this.createVertexArray = () => {
|
|
return callAndCheck(gl2, () => gl2.createVertexArray());
|
|
};
|
|
this.bindVertexArray = (vao) => {
|
|
return callAndCheck(gl2, () => gl2.bindVertexArray(vao));
|
|
};
|
|
this.deleteVertexArray = (vao) => {
|
|
return callAndCheck(gl2, () => gl2.deleteVertexArray(vao));
|
|
};
|
|
this.getVertexArray = () => {
|
|
return callAndCheck(gl2, () => gl2.getParameter(gl2.VERTEX_ARRAY_BINDING));
|
|
};
|
|
}
|
|
else if (gl != null) {
|
|
const ext = gl.getExtension('OES_vertex_array_object');
|
|
if (ext == null) {
|
|
throw new Error('All WebGL1 implementations are expected to offer' +
|
|
' OES_vertex_array_object.');
|
|
}
|
|
this.createVertexArray = () => {
|
|
return callAndCheck(gl, () => ext.createVertexArrayOES());
|
|
};
|
|
this.bindVertexArray = (vao) => {
|
|
return callAndCheck(gl, () => ext.bindVertexArrayOES(vao));
|
|
};
|
|
this.deleteVertexArray = (vao) => {
|
|
return callAndCheck(gl, () => ext.deleteVertexArrayOES(vao));
|
|
};
|
|
this.getVertexArray = () => {
|
|
return callAndCheck(gl, () => gl.getParameter(ext.VERTEX_ARRAY_BINDING_OES));
|
|
};
|
|
}
|
|
|
|
let COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
|
|
const COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
|
|
this.parallelCompilationExtension =
|
|
this.gl.getExtension('KHR_parallel_shader_compile');
|
|
if (env().getNumber('WEBGL_VERSION') === 1) {
|
|
const TEXTURE_FLOAT = 'OES_texture_float';
|
|
const TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
|
|
this.textureFloatExtension =
|
|
getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
|
|
if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
|
|
this.textureHalfFloatExtension =
|
|
getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
|
|
}
|
|
else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
|
|
throw new Error('GL context does not support half float textures, yet the ' +
|
|
'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
|
|
}
|
|
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
|
|
if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
|
|
this.colorBufferHalfFloatExtension =
|
|
getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
|
|
}
|
|
else if (env().get('WEBGL_FORCE_F16_TEXTURES')) {
|
|
throw new Error('GL context does not support color renderable half floats, yet ' +
|
|
'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
|
|
}
|
|
}
|
|
else {
|
|
COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
|
|
if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
|
|
this.colorBufferFloatExtension =
|
|
this.gl.getExtension(COLOR_BUFFER_FLOAT);
|
|
}
|
|
else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
|
|
this.colorBufferHalfFloatExtension =
|
|
this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
|
|
}
|
|
else {
|
|
throw new Error('GL context does not support color renderable floats');
|
|
}
|
|
}
|
|
this.vertexBuffer = createVertexBuffer(this.gl);
|
|
this.indexBuffer = createIndexBuffer(this.gl);
|
|
this.framebuffer = createFramebuffer(this.gl);
|
|
this.textureConfig =
|
|
getTextureConfig(this.gl, this.textureHalfFloatExtension);
|
|
}
|
|
get debug() {
|
|
return env().getBool('DEBUG');
|
|
}
|
|
dispose() {
|
|
if (this.disposed) {
|
|
return;
|
|
}
|
|
if (this.program != null) {
|
|
console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
|
|
' This is probably a resource leak, delete the program with ' +
|
|
'GPGPUContext.deleteProgram before disposing.');
|
|
}
|
|
if (this.outputTexture != null) {
|
|
console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
|
|
'texture. This is probably a resource leak, delete the output ' +
|
|
'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
|
|
'disposing.');
|
|
}
|
|
const gl = this.gl;
|
|
callAndCheck(gl, () => gl.finish());
|
|
callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null));
|
|
callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer));
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null));
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null));
|
|
callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer));
|
|
this.disposed = true;
|
|
}
|
|
createFloat32MatrixTexture(rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
}
|
|
createFloat16MatrixTexture(rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
}
|
|
createUnsignedBytesMatrixTexture(rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
}
|
|
uploadPixelDataToTexture(texture, pixels) {
|
|
this.throwIfDisposed();
|
|
uploadPixelDataToTexture(this.gl, texture, pixels);
|
|
}
|
|
uploadDenseMatrixToTexture(texture, width, height, data) {
|
|
this.throwIfDisposed();
|
|
uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
|
|
}
|
|
createFloat16PackedMatrixTexture(rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
}
|
|
createPackedMatrixTexture(rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
}
|
|
deleteMatrixTexture(texture) {
|
|
this.throwIfDisposed();
|
|
if (this.outputTexture === texture) {
|
|
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
|
|
this.outputTexture = null;
|
|
}
|
|
callAndCheck(this.gl, () => this.gl.deleteTexture(texture));
|
|
}
|
|
downloadByteEncodedFloatMatrixFromOutputTexture(texture, rows, columns) {
|
|
return this.downloadMatrixDriver(texture, () => downloadByteEncodedFloatMatrixFromOutputTexture(this.gl, rows, columns, this.textureConfig));
|
|
}
|
|
downloadPackedMatrixFromBuffer(buffer, batch, rows, columns, physicalRows, physicalCols) {
|
|
return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols);
|
|
}
|
|
downloadFloat32MatrixFromBuffer(buffer, size) {
|
|
return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
|
|
}
|
|
createBufferFromTexture(texture, rows, columns) {
|
|
this.bindTextureToFrameBuffer(texture);
|
|
const result = createBufferFromOutputTexture(this.gl, rows, columns);
|
|
this.unbindTextureToFrameBuffer();
|
|
return result;
|
|
}
|
|
createAndWaitForFence() {
|
|
const fenceContext = this.createFence(this.gl);
|
|
return this.pollFence(fenceContext);
|
|
}
|
|
createFence(gl) {
|
|
let query;
|
|
let isFencePassed;
|
|
if (env().getBool('WEBGL_FENCE_API_ENABLED')) {
|
|
const gl2 = gl;
|
|
const sync = gl2.fenceSync(gl2.SYNC_GPU_COMMANDS_COMPLETE, 0);
|
|
gl.flush();
|
|
isFencePassed = () => {
|
|
const status = gl2.clientWaitSync(sync, 0, 0);
|
|
return status === gl2.ALREADY_SIGNALED ||
|
|
status === gl2.CONDITION_SATISFIED;
|
|
};
|
|
query = sync;
|
|
}
|
|
else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
|
|
query = this.beginQuery();
|
|
this.endQuery();
|
|
isFencePassed = () => this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
|
|
isFencePassed = () => true;
|
|
}
|
|
return { query, isFencePassed };
|
|
}
|
|
downloadMatrixFromPackedTexture(texture, physicalRows, physicalCols) {
|
|
return this.downloadMatrixDriver(texture, () => downloadMatrixFromPackedOutputTexture(this.gl, physicalRows, physicalCols));
|
|
}
|
|
createProgram(fragmentShader) {
|
|
this.throwIfDisposed();
|
|
const gl = this.gl;
|
|
if (this.vertexShader == null) {
|
|
this.vertexShader = createVertexShader(gl);
|
|
}
|
|
const program = createProgram(gl);
|
|
callAndCheck(gl, () => gl.attachShader(program, this.vertexShader));
|
|
callAndCheck(gl, () => gl.attachShader(program, fragmentShader));
|
|
linkProgram(gl, program);
|
|
const program2 = Object.assign(program, { vao: this.createVertexArray() });
|
|
if (this.debug) {
|
|
validateProgram(gl, program2);
|
|
}
|
|
return program2;
|
|
}
|
|
buildVao(program) {
|
|
this.setProgram(program);
|
|
this.bindVertexArray(program.vao);
|
|
const gl = this.gl;
|
|
|
|
|
|
callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, this.indexBuffer));
|
|
bindVertexProgramAttributeStreams(gl, program, this.vertexBuffer);
|
|
}
|
|
deleteProgram(program) {
|
|
this.throwIfDisposed();
|
|
if (program === this.program) {
|
|
this.program = null;
|
|
}
|
|
if (program != null) {
|
|
callAndCheck(this.gl, () => this.gl.deleteProgram(program));
|
|
this.deleteVertexArray(program.vao);
|
|
}
|
|
}
|
|
setProgram(program) {
|
|
this.throwIfDisposed();
|
|
this.program = program;
|
|
if (this.program != null) {
|
|
if (this.debug) {
|
|
validateProgram(this.gl, this.program);
|
|
}
|
|
}
|
|
callAndCheck(this.gl, () => this.gl.useProgram(program));
|
|
}
|
|
getUniformLocation(program, uniformName, shouldThrow = true) {
|
|
this.throwIfDisposed();
|
|
if (shouldThrow) {
|
|
return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
|
|
}
|
|
else {
|
|
return getProgramUniformLocation(this.gl, program, uniformName);
|
|
}
|
|
}
|
|
getAttributeLocation(program, attribute) {
|
|
this.throwIfDisposed();
|
|
return callAndCheck(this.gl, () => this.gl.getAttribLocation(program, attribute));
|
|
}
|
|
getUniformLocationNoThrow(program, uniformName) {
|
|
this.throwIfDisposed();
|
|
return this.gl.getUniformLocation(program, uniformName);
|
|
}
|
|
setInputMatrixTexture(inputMatrixTexture, uniformLocation, textureUnit) {
|
|
this.throwIfDisposed();
|
|
this.throwIfNoProgram();
|
|
bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
|
|
}
|
|
setOutputMatrixTexture(outputMatrixTexture, rows, columns) {
|
|
this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
|
|
}
|
|
setOutputPackedMatrixTexture(outputPackedMatrixTexture, rows, columns) {
|
|
this.throwIfDisposed();
|
|
const [width, height] = getPackedMatrixTextureShapeWidthHeight(rows, columns);
|
|
this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
|
|
}
|
|
setOutputMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
|
|
this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
|
|
}
|
|
setOutputPackedMatrixWriteRegion(startRow, numRows, startColumn, numColumns) {
|
|
throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
|
|
}
|
|
debugValidate() {
|
|
if (this.program != null) {
|
|
validateProgram(this.gl, this.program);
|
|
}
|
|
validateFramebuffer(this.gl);
|
|
}
|
|
executeProgram() {
|
|
this.throwIfDisposed();
|
|
this.throwIfNoProgram();
|
|
const gl = this.gl;
|
|
if (this.debug) {
|
|
const boundVao = this.getVertexArray();
|
|
console.assert(boundVao === this.program.vao, 'VAO changed between setProgram and executeProgram!');
|
|
this.debugValidate();
|
|
}
|
|
callAndCheck(gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0));
|
|
}
|
|
blockUntilAllProgramsCompleted() {
|
|
this.throwIfDisposed();
|
|
callAndCheck(this.gl, () => this.gl.finish());
|
|
}
|
|
getQueryTimerExtension() {
|
|
if (this.disjointQueryTimerExtension == null) {
|
|
this.disjointQueryTimerExtension =
|
|
getExtensionOrThrow(this.gl, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
|
|
'EXT_disjoint_timer_query_webgl2' :
|
|
'EXT_disjoint_timer_query');
|
|
}
|
|
return this.disjointQueryTimerExtension;
|
|
}
|
|
getQueryTimerExtensionWebGL2() {
|
|
return this.getQueryTimerExtension();
|
|
}
|
|
getQueryTimerExtensionWebGL1() {
|
|
return this.getQueryTimerExtension();
|
|
}
|
|
beginQuery() {
|
|
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
|
|
const gl2 = this.gl;
|
|
const ext = this.getQueryTimerExtensionWebGL2();
|
|
const query = gl2.createQuery();
|
|
gl2.beginQuery(ext.TIME_ELAPSED_EXT, query);
|
|
return query;
|
|
}
|
|
const ext = this.getQueryTimerExtensionWebGL1();
|
|
const query = ext.createQueryEXT();
|
|
ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
|
|
return query;
|
|
}
|
|
endQuery() {
|
|
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
|
|
const gl2 = this.gl;
|
|
const ext = this.getQueryTimerExtensionWebGL2();
|
|
gl2.endQuery(ext.TIME_ELAPSED_EXT);
|
|
return;
|
|
}
|
|
const ext = this.getQueryTimerExtensionWebGL1();
|
|
ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
|
|
}
|
|
async waitForQueryAndGetTime(query) {
|
|
await repeatedTry(() => this.disposed ||
|
|
|
|
|
|
this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')));
|
|
return this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'));
|
|
}
|
|
getQueryTime(query, queryTimerVersion) {
|
|
if (queryTimerVersion === 0) {
|
|
return null;
|
|
}
|
|
if (queryTimerVersion === 2) {
|
|
const gl2 = this.gl;
|
|
const timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
|
|
|
|
return timeElapsedNanos / 1000000;
|
|
}
|
|
else {
|
|
const ext = this.getQueryTimerExtensionWebGL1();
|
|
const timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
|
|
|
|
return timeElapsedNanos / 1000000;
|
|
}
|
|
}
|
|
isQueryAvailable(query, queryTimerVersion) {
|
|
if (queryTimerVersion === 0) {
|
|
return true;
|
|
}
|
|
if (queryTimerVersion === 2) {
|
|
const gl2 = this.gl;
|
|
const ext = this.getQueryTimerExtensionWebGL2();
|
|
const available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
|
|
if (this.disjoint == null) {
|
|
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
|
|
}
|
|
return available && !this.disjoint;
|
|
}
|
|
else {
|
|
const ext = this.getQueryTimerExtensionWebGL1();
|
|
const available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
|
|
if (this.disjoint == null) {
|
|
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
|
|
}
|
|
return available && !this.disjoint;
|
|
}
|
|
}
|
|
pollFence(fenceContext) {
|
|
return new Promise(resolve => {
|
|
this.addItemToPoll(() => fenceContext.isFencePassed(), () => resolve());
|
|
});
|
|
}
|
|
pollItems() {
|
|
|
|
const index = linearSearchLastTrue(this.itemsToPoll.map(x => x.isDoneFn));
|
|
for (let i = 0; i <= index; ++i) {
|
|
const { resolveFn } = this.itemsToPoll[i];
|
|
resolveFn();
|
|
}
|
|
this.itemsToPoll = this.itemsToPoll.slice(index + 1);
|
|
}
|
|
addItemToPoll(isDoneFn, resolveFn) {
|
|
this.itemsToPoll.push({ isDoneFn, resolveFn });
|
|
if (this.itemsToPoll.length > 1) {
|
|
|
|
return;
|
|
}
|
|
|
|
let scheduleFn = undefined;
|
|
if ('setTimeoutCustom' in env().platform) {
|
|
scheduleFn = env().platform.setTimeoutCustom.bind(env().platform);
|
|
}
|
|
repeatedTry(() => {
|
|
this.pollItems();
|
|
|
|
return this.itemsToPoll.length === 0;
|
|
}, () => 0, null, scheduleFn);
|
|
}
|
|
bindTextureToFrameBuffer(texture) {
|
|
this.throwIfDisposed();
|
|
bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
|
|
if (this.debug) {
|
|
validateFramebuffer(this.gl);
|
|
}
|
|
}
|
|
unbindTextureToFrameBuffer() {
|
|
if (this.outputTexture != null) {
|
|
bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
|
|
if (this.debug) {
|
|
validateFramebuffer(this.gl);
|
|
}
|
|
}
|
|
else {
|
|
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
|
|
}
|
|
}
|
|
downloadMatrixDriver(texture, downloadAndDecode) {
|
|
this.bindTextureToFrameBuffer(texture);
|
|
const result = downloadAndDecode();
|
|
this.unbindTextureToFrameBuffer();
|
|
return result;
|
|
}
|
|
setOutputMatrixTextureDriver(outputMatrixTextureMaybePacked, width, height) {
|
|
this.throwIfDisposed();
|
|
const gl = this.gl;
|
|
bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
|
|
if (this.debug) {
|
|
validateFramebuffer(gl);
|
|
}
|
|
this.outputTexture = outputMatrixTextureMaybePacked;
|
|
callAndCheck(gl, () => gl.viewport(0, 0, width, height));
|
|
callAndCheck(gl, () => gl.scissor(0, 0, width, height));
|
|
}
|
|
setOutputMatrixWriteRegionDriver(x, y, width, height) {
|
|
this.throwIfDisposed();
|
|
callAndCheck(this.gl, () => this.gl.scissor(x, y, width, height));
|
|
}
|
|
throwIfDisposed() {
|
|
if (this.disposed) {
|
|
throw new Error('Attempted to use disposed GPGPUContext.');
|
|
}
|
|
}
|
|
throwIfNoProgram() {
|
|
if (this.program == null) {
|
|
throw new Error('No GPU program is currently set.');
|
|
}
|
|
}
|
|
}
|
|
|
|
function linearSearchLastTrue(arr) {
|
|
let i = 0;
|
|
for (; i < arr.length; ++i) {
|
|
const isDone = arr[i]();
|
|
if (!isDone) {
|
|
break;
|
|
}
|
|
}
|
|
return i - 1;
|
|
}
|
|
|
|
|
|
function assertNotComplex(tensor, opName) {
|
|
if (!Array.isArray(tensor)) {
|
|
tensor = [tensor];
|
|
}
|
|
tensor.forEach(t => {
|
|
if (t != null) {
|
|
assert$1(t.dtype !== 'complex64', () => `${opName} does not support complex64 tensors in the CPU backend.`);
|
|
}
|
|
});
|
|
}
|
|
|
|
|
|
function simpleAbsImpl(vals) {
|
|
const resultValues = new Float32Array(vals.length);
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
resultValues[i] = Math.abs(vals[i]);
|
|
}
|
|
return resultValues;
|
|
}
|
|
const abs$1 = (args) => {
|
|
const { x } = args.inputs;
|
|
const cpuBackend = args.backend;
|
|
assertNotComplex(x, 'abs');
|
|
let resultValues = new Float32Array(sizeFromShape(x.shape));
|
|
const values = cpuBackend.data.get(x.dataId).values;
|
|
resultValues = simpleAbsImpl(values);
|
|
return cpuBackend.makeOutput(resultValues, x.shape, x.dtype);
|
|
};
|
|
const absConfig$1 = {
|
|
kernelName: Abs,
|
|
backendName: 'cpu',
|
|
kernelFunc: abs$1,
|
|
};
|
|
|
|
|
|
|
|
function createSimpleBinaryKernelImpl(op) {
|
|
return (aShape, bShape, aVals, bVals, dtype) => {
|
|
const newShape = assertAndGetBroadcastShape(aShape, bShape);
|
|
const resultRank = newShape.length;
|
|
const resultStrides = computeStrides(newShape);
|
|
const resultSize = sizeFromShape(newShape);
|
|
const result = getTypedArrayFromDType(dtype, resultSize);
|
|
const aRank = aShape.length;
|
|
const bRank = bShape.length;
|
|
const aStrides = computeStrides(aShape);
|
|
const bStrides = computeStrides(bShape);
|
|
const aBroadcastDims = getBroadcastDims$1(aShape, newShape);
|
|
const bBroadcastDims = getBroadcastDims$1(bShape, newShape);
|
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
|
for (let i = 0; i < result.length; ++i) {
|
|
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
|
|
}
|
|
}
|
|
else {
|
|
for (let i = 0; i < result.length; ++i) {
|
|
const loc = indexToLoc(i, resultRank, resultStrides);
|
|
const aLoc = loc.slice(-aRank);
|
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
|
const aIndex = locToIndex(aLoc, aRank, aStrides);
|
|
const bLoc = loc.slice(-bRank);
|
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
|
const bIndex = locToIndex(bLoc, bRank, bStrides);
|
|
result[i] = op(aVals[aIndex], bVals[bIndex]);
|
|
}
|
|
}
|
|
return [result, newShape];
|
|
};
|
|
}
|
|
|
|
|
|
function complex$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { real, imag } = inputs;
|
|
const realVals = backend.data.get(real.dataId).values;
|
|
const imagVals = backend.data.get(imag.dataId).values;
|
|
const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
|
|
const complex = backend.data.get(complexInfo.dataId);
|
|
|
|
|
|
|
|
complex.complexTensorInfos = {
|
|
real: backend.makeTensorInfo(real.shape, 'float32', realVals),
|
|
imag: backend.makeTensorInfo(imag.shape, 'float32', imagVals)
|
|
};
|
|
return complexInfo;
|
|
}
|
|
const complexConfig$1 = {
|
|
kernelName: Complex,
|
|
backendName: 'cpu',
|
|
kernelFunc: complex$1
|
|
};
|
|
|
|
|
|
|
|
function zeros(backend, shape, dtype = 'float32') {
|
|
if (dtype === 'complex64') {
|
|
const real = zeros(backend, shape, 'float32');
|
|
const imag = zeros(backend, shape, 'float32');
|
|
return complex$1({ inputs: { real, imag }, backend });
|
|
}
|
|
const values = makeZerosTypedArray(sizeFromShape(shape), dtype);
|
|
return backend.makeTensorInfo(shape, dtype, values);
|
|
}
|
|
|
|
|
|
function identity$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
backend.incRef(x.dataId);
|
|
return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
|
|
}
|
|
const identityConfig$1 = {
|
|
kernelName: Identity$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: identity$1
|
|
};
|
|
|
|
|
|
function real$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
const real = backend.data.get(input.dataId).complexTensorInfos.real;
|
|
const realVal = backend.data.get(real.dataId).values;
|
|
|
|
|
|
|
|
return backend.makeTensorInfo(real.shape, real.dtype, realVal);
|
|
}
|
|
const realConfig$1 = {
|
|
kernelName: Real,
|
|
backendName: 'cpu',
|
|
kernelFunc: real$1
|
|
};
|
|
|
|
|
|
function castImpl(values, shape, inputType, dtype) {
|
|
if (dtype === 'int32') {
|
|
const resultValues = Int32Array.from(values);
|
|
return [shape, 'int32', resultValues];
|
|
}
|
|
if (dtype === 'bool') {
|
|
|
|
|
|
|
|
const zero = toTypedArray([0], inputType);
|
|
const [resultData, resultShape] = createSimpleBinaryKernelImpl((a, b) => (a !== b) ? 1 : 0)(shape, [], values, zero, 'bool');
|
|
return [resultShape, 'bool', resultData];
|
|
}
|
|
throw new Error(`Error in Cast: failed to cast ${inputType} to ${dtype}`);
|
|
}
|
|
function cast$2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { dtype } = attrs;
|
|
|
|
if (dtype === 'complex64') {
|
|
if (x.dtype === 'complex64') {
|
|
return identity$1({ inputs: { x }, backend });
|
|
}
|
|
const zerosTensorInfo = zeros(backend, x.shape, x.dtype);
|
|
const floatX = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
|
const result = complex$1({ inputs: { real: floatX, imag: zerosTensorInfo }, backend });
|
|
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
|
|
backend.disposeIntermediateTensorInfo(floatX);
|
|
return result;
|
|
}
|
|
|
|
if (x.dtype === 'complex64') {
|
|
const realPart = real$1({ inputs: { input: x }, backend });
|
|
const result = cast$2({ inputs: { x: realPart }, backend, attrs: { dtype } });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
return result;
|
|
}
|
|
if (!hasEncodingLoss(x.dtype, dtype)) {
|
|
|
|
|
|
const result = identity$1({ inputs: { x }, backend });
|
|
return { dataId: result.dataId, shape: result.shape, dtype };
|
|
}
|
|
const values = backend.data.get(x.dataId).values;
|
|
const [resultShape, resultType, resultData] = castImpl(values, x.shape, x.dtype, dtype);
|
|
return backend.makeTensorInfo(resultShape, resultType, resultData);
|
|
}
|
|
const castConfig$1 = {
|
|
kernelName: Cast,
|
|
backendName: 'cpu',
|
|
kernelFunc: cast$2
|
|
};
|
|
|
|
|
|
|
|
function binaryKernelFunc$1(name, simpleImpl, complexImpl, dtype) {
|
|
if (complexImpl == null) {
|
|
return ({ inputs, backend }) => {
|
|
const { a, b } = inputs;
|
|
const cpuBackend = backend;
|
|
assertNotComplex([a, b], name);
|
|
const aVals = cpuBackend.data.get(a.dataId).values;
|
|
const bVals = cpuBackend.data.get(b.dataId).values;
|
|
const decodedAVals = a.dtype === 'string' ?
|
|
|
|
fromUint8ToStringArray(aVals) :
|
|
aVals;
|
|
const decodedBVals = a.dtype === 'string' ?
|
|
|
|
fromUint8ToStringArray(bVals) :
|
|
bVals;
|
|
const $dtype = dtype || a.dtype;
|
|
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
|
|
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
|
|
};
|
|
}
|
|
return ({ inputs, backend }) => {
|
|
const { a, b } = inputs;
|
|
const cpuBackend = backend;
|
|
if (a.dtype === 'complex64' || b.dtype === 'complex64') {
|
|
const $aComplex = cast$2({ inputs: { x: a }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
|
|
const $aComplexVals = cpuBackend.data.get($aComplex.dataId);
|
|
const aReal = $aComplexVals.complexTensorInfos.real;
|
|
const aImag = $aComplexVals.complexTensorInfos.imag;
|
|
const aRealVals = cpuBackend.data.get(aReal.dataId).values;
|
|
const aImagVals = cpuBackend.data.get(aImag.dataId).values;
|
|
const $bComplex = cast$2({ inputs: { x: b }, backend: cpuBackend, attrs: { dtype: 'complex64' } });
|
|
const $bComplexVals = cpuBackend.data.get($bComplex.dataId);
|
|
const bReal = $bComplexVals.complexTensorInfos.real;
|
|
const bImag = $bComplexVals.complexTensorInfos.imag;
|
|
const bRealVals = cpuBackend.data.get(bReal.dataId).values;
|
|
const bImagVals = cpuBackend.data.get(bImag.dataId).values;
|
|
const [resultRealData, resultImagData, resultShape] = complexImpl(a.shape, b.shape, aRealVals, aImagVals, bRealVals, bImagVals);
|
|
const resultReal = cpuBackend.makeTensorInfo(resultShape, 'float32', resultRealData);
|
|
const resultImag = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImagData);
|
|
const result = complex$1({ inputs: { real: resultReal, imag: resultImag }, backend: cpuBackend });
|
|
cpuBackend.disposeIntermediateTensorInfo($aComplex);
|
|
cpuBackend.disposeIntermediateTensorInfo($bComplex);
|
|
cpuBackend.disposeIntermediateTensorInfo(resultReal);
|
|
cpuBackend.disposeIntermediateTensorInfo(resultImag);
|
|
return result;
|
|
}
|
|
else {
|
|
const aVals = cpuBackend.data.get(a.dataId).values;
|
|
const bVals = cpuBackend.data.get(b.dataId).values;
|
|
const $dtype = dtype || a.dtype;
|
|
const [resultData, resultShape] = simpleImpl(a.shape, b.shape, aVals, bVals, $dtype);
|
|
return cpuBackend.makeTensorInfo(resultShape, $dtype, resultData);
|
|
}
|
|
};
|
|
}
|
|
|
|
function createComplexBinaryKernelImpl(op) {
|
|
return (aShape, bShape, aRealVals, aImagVals, bRealVals, bImagVals) => {
|
|
const resultShape = assertAndGetBroadcastShape(aShape, bShape);
|
|
const resultSize = sizeFromShape(resultShape);
|
|
const resultRank = resultShape.length;
|
|
const resultStrides = computeStrides(resultShape);
|
|
const resultRealVals = getTypedArrayFromDType('float32', resultSize);
|
|
const resultImagVals = getTypedArrayFromDType('float32', resultSize);
|
|
const aBroadcastDims = getBroadcastDims$1(aShape, resultShape);
|
|
const bBroadcastDims = getBroadcastDims$1(bShape, resultShape);
|
|
const aVals = mergeRealAndImagArrays(aRealVals, aImagVals);
|
|
const bVals = mergeRealAndImagArrays(bRealVals, bImagVals);
|
|
const aRank = aShape.length;
|
|
const aStrides = computeStrides(aShape);
|
|
const bRank = bShape.length;
|
|
const bStrides = computeStrides(bShape);
|
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
|
for (let i = 0; i < resultRealVals.length; i++) {
|
|
const aIdx = i % aVals.length;
|
|
const bIdx = i % bVals.length;
|
|
const result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]);
|
|
resultRealVals[i] = result.real;
|
|
resultImagVals[i] = result.imag;
|
|
}
|
|
}
|
|
else {
|
|
for (let i = 0; i < resultRealVals.length; i++) {
|
|
const loc = indexToLoc(i, resultRank, resultStrides);
|
|
const aLoc = loc.slice(-aRank);
|
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
|
const aIndex = locToIndex(aLoc, aRank, aStrides);
|
|
const bLoc = loc.slice(-bRank);
|
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
|
const bIndex = locToIndex(bLoc, bRank, bStrides);
|
|
const opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]);
|
|
resultRealVals[i] = opResult.real;
|
|
resultImagVals[i] = opResult.imag;
|
|
}
|
|
}
|
|
return [resultRealVals, resultImagVals, resultShape];
|
|
};
|
|
}
|
|
|
|
|
|
const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
|
|
const addComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
|
|
return { real: aReal + bReal, imag: aImag + bImag };
|
|
}));
|
|
const add = binaryKernelFunc$1(Add, addImpl, addComplexImpl);
|
|
const addConfig$1 = {
|
|
kernelName: Add,
|
|
backendName: 'cpu',
|
|
kernelFunc: add
|
|
};
|
|
|
|
|
|
function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
|
|
const weightsSize = sizeFromShape(weightsShape);
|
|
const outVals = makeZerosTypedArray(size, weightsDtype);
|
|
for (let i = 0; i < xVals.length; i++) {
|
|
const value = xVals[i];
|
|
if (value < 0) {
|
|
throw new Error('Input x must be non-negative!');
|
|
}
|
|
if (value >= size) {
|
|
continue;
|
|
}
|
|
if (weightsSize > 0) {
|
|
outVals[value] += weightsVals[i];
|
|
}
|
|
else {
|
|
outVals[value] += 1;
|
|
}
|
|
}
|
|
return outVals;
|
|
}
|
|
function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) {
|
|
const numRows = xBuf.shape[0];
|
|
const numCols = xBuf.shape[1];
|
|
const outBuf = buffer([numRows, size], weightsBuf.dtype);
|
|
for (let i = 0; i < numRows; i++) {
|
|
for (let j = 0; j < numCols; j++) {
|
|
const value = xBuf.get(i, j);
|
|
if (value < 0) {
|
|
throw new Error('Input x must be non-negative!');
|
|
}
|
|
if (value >= size) {
|
|
continue;
|
|
}
|
|
if (binaryOutput) {
|
|
outBuf.set(1, i, value);
|
|
}
|
|
else {
|
|
if (weightsBuf.size > 0) {
|
|
outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
|
|
}
|
|
else {
|
|
outBuf.set(outBuf.get(i, value) + 1, i, value);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
|
|
const bitwiseAndImpl = createSimpleBinaryKernelImpl(((a, b) => a & b));
|
|
const bitwiseAnd$1 = binaryKernelFunc$1(BitwiseAnd, bitwiseAndImpl);
|
|
const bitwiseAndConfig$1 = {
|
|
kernelName: BitwiseAnd,
|
|
backendName: 'cpu',
|
|
kernelFunc: bitwiseAnd$1
|
|
};
|
|
|
|
|
|
|
|
function createSimpleUnaryImpl(op) {
|
|
return (values, dtype, attrs) => {
|
|
const newValues = getArrayFromDType(dtype, values.length);
|
|
for (let i = 0; i < values.length; ++i) {
|
|
newValues[i] = op(values[i], attrs);
|
|
}
|
|
return newValues;
|
|
};
|
|
}
|
|
|
|
|
|
|
|
function unaryKernelFunc$1(name, op, dtype) {
|
|
const impl = createSimpleUnaryImpl(op);
|
|
return unaryKernelFuncFromImpl(name, impl, dtype);
|
|
}
|
|
|
|
function unaryKernelFuncFromImpl(name, unaryImpl, dtype) {
|
|
return ({ inputs, attrs, backend }) => {
|
|
const { x } = inputs;
|
|
assertNotComplex(x, name);
|
|
const cpuBackend = backend;
|
|
const values = cpuBackend.data.get(x.dataId).values;
|
|
let decoded;
|
|
if (x.dtype === 'string') {
|
|
if (!Array.isArray(values)) {
|
|
throw new Error('String tensor\'s value was not an instance of Array');
|
|
}
|
|
decoded = fromUint8ToStringArray(values);
|
|
}
|
|
else {
|
|
decoded = values;
|
|
}
|
|
const $dtype = dtype || x.dtype;
|
|
const newValues = unaryImpl(decoded, $dtype, attrs);
|
|
return cpuBackend.makeTensorInfo(x.shape, $dtype, newValues);
|
|
};
|
|
}
|
|
|
|
|
|
const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
|
|
const ceil$1 = unaryKernelFuncFromImpl(Ceil, ceilImpl);
|
|
const ceilConfig$1 = {
|
|
kernelName: Ceil,
|
|
backendName: 'cpu',
|
|
kernelFunc: ceil$1,
|
|
};
|
|
|
|
|
|
function concatImpl$1(inputs, outShape, dtype, simplyConcat) {
|
|
const outVals = getArrayFromDType(dtype, sizeFromShape(outShape));
|
|
if (simplyConcat && dtype !== 'string') {
|
|
|
|
let offset = 0;
|
|
inputs.forEach(input => {
|
|
const size = sizeFromShape(input.shape);
|
|
outVals.set(input.vals, offset);
|
|
offset += size;
|
|
});
|
|
}
|
|
else {
|
|
let colOffset = 0;
|
|
inputs.forEach(input => {
|
|
const decodedData = dtype === 'string' ?
|
|
fromUint8ToStringArray(input.vals) :
|
|
input.vals;
|
|
let tIdx = 0;
|
|
for (let row = 0; row < input.shape[0]; ++row) {
|
|
const resIdx = row * outShape[1] + colOffset;
|
|
for (let col = 0; col < input.shape[1]; ++col) {
|
|
outVals[resIdx + col] = decodedData[tIdx++];
|
|
}
|
|
}
|
|
colOffset += input.shape[1];
|
|
});
|
|
}
|
|
return outVals;
|
|
}
|
|
|
|
|
|
const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0);
|
|
const equal$1 = binaryKernelFunc$1(Equal, equalImpl, null , 'bool');
|
|
const equalConfig$1 = {
|
|
kernelName: Equal,
|
|
backendName: 'cpu',
|
|
kernelFunc: equal$1
|
|
};
|
|
|
|
|
|
const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
|
|
const exp$1 = unaryKernelFuncFromImpl(Exp, expImpl, 'float32');
|
|
const expConfig$1 = {
|
|
kernelName: Exp,
|
|
backendName: 'cpu',
|
|
kernelFunc: exp$1,
|
|
};
|
|
|
|
|
|
const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
|
|
const expm1$1 = unaryKernelFuncFromImpl(Expm1, expm1Impl);
|
|
const expm1Config$1 = {
|
|
kernelName: Expm1,
|
|
backendName: 'cpu',
|
|
kernelFunc: expm1$1,
|
|
};
|
|
|
|
|
|
const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
|
|
const floor$1 = unaryKernelFuncFromImpl(Floor, floorImpl);
|
|
const floorConfig$1 = {
|
|
kernelName: Floor,
|
|
backendName: 'cpu',
|
|
kernelFunc: floor$1,
|
|
};
|
|
|
|
|
|
const floorDivImpl = createSimpleBinaryKernelImpl((a, b) => Math.floor(a / b));
|
|
const floorDiv$1 = binaryKernelFunc$1(FloorDiv, floorDivImpl, null , 'int32');
|
|
const floorDivConfig$1 = {
|
|
kernelName: FloorDiv,
|
|
backendName: 'cpu',
|
|
kernelFunc: floorDiv$1
|
|
};
|
|
|
|
|
|
function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
|
|
const outBuf = buffer([numSlices, sliceSize], dtype);
|
|
for (let i = 0; i < numSlices; i++) {
|
|
const index = [];
|
|
let flattenIndex = 0;
|
|
for (let j = 0; j < sliceRank; j++) {
|
|
const dim = indicesData[i * sliceRank + j];
|
|
flattenIndex += dim * strides[j];
|
|
index.push(dim);
|
|
}
|
|
if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
|
|
throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`);
|
|
}
|
|
for (let k = 0; k < sliceSize; k++) {
|
|
outBuf.values[i * sliceSize + k] =
|
|
paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
|
|
}
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
|
|
function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
|
|
const outBuf = buffer(flattenOutputShape, xBuf.dtype);
|
|
for (let i = 0; i < outBuf.size; ++i) {
|
|
const newLoc = outBuf.indexToLoc(i);
|
|
const originalLoc = newLoc.slice();
|
|
const batchIdx = originalLoc[0];
|
|
const indicesIdx = originalLoc[2];
|
|
const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
|
|
originalLoc[2] = indicesBuf.values[indicesIndex];
|
|
const originalIndex = xBuf.locToIndex(originalLoc);
|
|
if (0 <= originalIndex && originalIndex < xBuf.values.length) {
|
|
outBuf.values[i] = xBuf.values[originalIndex];
|
|
}
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
|
|
const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0);
|
|
const greater$1 = binaryKernelFunc$1(Greater, greaterImpl, null , 'bool');
|
|
const greaterConfig$1 = {
|
|
kernelName: Greater,
|
|
backendName: 'cpu',
|
|
kernelFunc: greater$1
|
|
};
|
|
|
|
|
|
const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0);
|
|
const greaterEqual$1 = binaryKernelFunc$1(GreaterEqual, greaterEqualImpl, null , 'bool');
|
|
const greaterEqualConfig$1 = {
|
|
kernelName: GreaterEqual,
|
|
backendName: 'cpu',
|
|
kernelFunc: greaterEqual$1
|
|
};
|
|
|
|
|
|
const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0);
|
|
const less$1 = binaryKernelFunc$1(Less, lessImpl, null , 'bool');
|
|
const lessConfig$1 = {
|
|
kernelName: Less,
|
|
backendName: 'cpu',
|
|
kernelFunc: less$1
|
|
};
|
|
|
|
|
|
const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0);
|
|
const lessEqual$1 = binaryKernelFunc$1(LessEqual, lessEqualImpl, null , 'bool');
|
|
const lessEqualConfig$1 = {
|
|
kernelName: LessEqual,
|
|
backendName: 'cpu',
|
|
kernelFunc: lessEqual$1
|
|
};
|
|
|
|
|
|
function linSpaceImpl(start, stop, num) {
|
|
const step = (stop - start) / (num - 1);
|
|
const values = makeZerosTypedArray(num, 'float32');
|
|
values[0] = start;
|
|
for (let i = 1; i < values.length; i++) {
|
|
values[i] = values[i - 1] + step;
|
|
}
|
|
return values;
|
|
}
|
|
|
|
|
|
const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
|
|
const log$1 = unaryKernelFuncFromImpl(Log, logImpl);
|
|
const logConfig$1 = {
|
|
kernelName: Log,
|
|
backendName: 'cpu',
|
|
kernelFunc: log$1,
|
|
};
|
|
|
|
|
|
function maxImpl$1(aVals, reduceSize, outShape, dtype) {
|
|
const vals = getTypedArrayFromDType(dtype, sizeFromShape(outShape));
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let max = aVals[offset];
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
if (Number.isNaN(value) ||
|
|
value > max) {
|
|
max = value;
|
|
}
|
|
}
|
|
vals[i] = max;
|
|
}
|
|
return vals;
|
|
}
|
|
|
|
|
|
const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue)));
|
|
const maximum$1 = binaryKernelFunc$1(Maximum, maximumImpl);
|
|
const maximumConfig$1 = {
|
|
kernelName: Maximum,
|
|
backendName: 'cpu',
|
|
kernelFunc: maximum$1
|
|
};
|
|
|
|
|
|
const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue)));
|
|
const minimum$1 = binaryKernelFunc$1(Minimum, minimumImpl);
|
|
const minimumConfig$1 = {
|
|
kernelName: Minimum,
|
|
backendName: 'cpu',
|
|
kernelFunc: minimum$1
|
|
};
|
|
|
|
|
|
const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
|
|
const multiplyComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
|
|
return {
|
|
real: aReal * bReal - aImag * bImag,
|
|
imag: aReal * bImag + aImag * bReal
|
|
};
|
|
}));
|
|
const multiply$1 = binaryKernelFunc$1(Multiply, multiplyImpl, multiplyComplexImpl);
|
|
const multiplyConfig$1 = {
|
|
kernelName: Multiply,
|
|
backendName: 'cpu',
|
|
kernelFunc: multiply$1
|
|
};
|
|
|
|
|
|
function negImpl(xVals, xShape, xDtype) {
|
|
const minusOne = createScalarValue(-1, xDtype);
|
|
return multiplyImpl([], xShape, minusOne, xVals, xDtype);
|
|
}
|
|
function neg$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
assertNotComplex(x, 'neg');
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const [res, newShape] = negImpl(xVals, x.shape, x.dtype);
|
|
return backend.makeTensorInfo(newShape, x.dtype, res);
|
|
}
|
|
const negConfig$1 = {
|
|
kernelName: Neg,
|
|
backendName: 'cpu',
|
|
kernelFunc: neg$1
|
|
};
|
|
|
|
|
|
const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
|
|
const notEqual$1 = binaryKernelFunc$1(NotEqual, notEqualImpl, null , 'bool');
|
|
const notEqualConfig$1 = {
|
|
kernelName: NotEqual,
|
|
backendName: 'cpu',
|
|
kernelFunc: notEqual$1
|
|
};
|
|
|
|
|
|
function transposeImpl$1(xVals, xShape, dtype, perm, newShape) {
|
|
const xRank = xShape.length;
|
|
const xSize = sizeFromShape(xShape);
|
|
const xStrides = computeStrides(xShape);
|
|
const newStrides = computeStrides(newShape);
|
|
const result = getTypedArrayFromDType(dtype, sizeFromShape(newShape));
|
|
for (let i = 0; i < xSize; ++i) {
|
|
const loc = indexToLoc(i, xRank, xStrides);
|
|
|
|
const newLoc = new Array(loc.length);
|
|
for (let i = 0; i < newLoc.length; i++) {
|
|
newLoc[i] = loc[perm[i]];
|
|
}
|
|
const newIndex = locToIndex(newLoc, xRank, newStrides);
|
|
result[newIndex] = xVals[i];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
function transpose$1(args) {
|
|
const { inputs, attrs, backend } = args;
|
|
const { x } = inputs;
|
|
const { perm } = attrs;
|
|
assertNotComplex(x, 'transpose');
|
|
const xRank = x.shape.length;
|
|
const newShape = new Array(xRank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[perm[i]];
|
|
}
|
|
const values = backend.data.get(x.dataId).values;
|
|
const result = transposeImpl$1(values, x.shape, x.dtype, perm, newShape);
|
|
const dataId = backend.write(result, newShape, x.dtype);
|
|
return { dataId, shape: newShape, dtype: x.dtype };
|
|
}
|
|
const transposeConfig$1 = {
|
|
kernelName: Transpose,
|
|
backendName: 'cpu',
|
|
kernelFunc: transpose$1
|
|
};
|
|
|
|
|
|
function prodImpl(xShape, xDtype, xVals, reductionAxes) {
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(xShape, reductionAxes);
|
|
const outDtype = upcastType(xDtype, 'int32');
|
|
const outVals = makeZerosTypedArray(sizeFromShape(outShape), outDtype);
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
for (let i = 0; i < outVals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let prod = 1;
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
prod *= xVals[offset + j];
|
|
}
|
|
outVals[i] = prod;
|
|
}
|
|
return { outVals, outShape, outDtype };
|
|
}
|
|
function prod$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
assertNotComplex(x, 'prod');
|
|
const xRank = x.shape.length;
|
|
const axes = parseAxisParam(axis, x.shape);
|
|
const permutation = getAxesPermutation(axes, xRank);
|
|
let reductionAxes = axes;
|
|
let permutedX = x;
|
|
const intermediateTensorInfos = [];
|
|
if (permutation != null) {
|
|
permutedX = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
|
|
intermediateTensorInfos.push(permutedX);
|
|
reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
|
|
}
|
|
const xVals = backend.data.get(permutedX.dataId).values;
|
|
const { outVals, outShape, outDtype } = prodImpl(permutedX.shape, permutedX.dtype, xVals, reductionAxes);
|
|
let resultShape = outShape;
|
|
if (keepDims) {
|
|
resultShape = expandShapeToKeepDim(outShape, axes);
|
|
}
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return backend.makeTensorInfo(resultShape, outDtype, outVals);
|
|
}
|
|
const prodConfig$1 = {
|
|
kernelName: Prod,
|
|
backendName: 'cpu',
|
|
kernelFunc: prod$1
|
|
};
|
|
|
|
|
|
function validateIndices(indices, indicesShape, numParams) {
|
|
indices.forEach((index, i) => {
|
|
if (index < 0 || index >= numParams) {
|
|
const locString = indexToLoc(i, indicesShape.length, computeStrides(indicesShape))
|
|
.join(',');
|
|
throw new Error(`indices[${locString}] = ${index} is not in [0, ${numParams})`);
|
|
}
|
|
});
|
|
}
|
|
function validateSplits(paramsNestedSplits, numParamsDenseValues) {
|
|
|
|
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
|
|
const splits = paramsNestedSplits[dim];
|
|
const lastSplit = (dim === paramsNestedSplits.length - 1) ?
|
|
numParamsDenseValues :
|
|
paramsNestedSplits[dim + 1].length;
|
|
if (splits.length === 0) {
|
|
throw new Error('Ragged splits may not be empty');
|
|
}
|
|
if (splits[0] < 0) {
|
|
throw new Error('Ragged splits must be non-negative');
|
|
}
|
|
if (splits[splits.length - 1] > lastSplit) {
|
|
throw new Error('Ragged splits must not point past values');
|
|
}
|
|
for (let i = 1; i < splits.length; ++i) {
|
|
if (splits[i - 1] > splits[i]) {
|
|
throw new Error('Ragged splits must be sorted in ascending order');
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
function makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues) {
|
|
const valueSlices = [];
|
|
let numValues = 0;
|
|
const numSplits = indicesShape.length - 1 + paramsNestedSplits.length;
|
|
const outSplits = new Array(numSplits).fill(null).map(() => [0]);
|
|
validateSplits(paramsNestedSplits, numParamsDenseValues);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let nrows = 1;
|
|
for (let dim = 0; dim < indicesShape.length - 1; ++dim) {
|
|
nrows *= indicesShape[dim];
|
|
const rowLength = indicesShape[dim + 1];
|
|
for (let i = 1; i < nrows + 1; ++i) {
|
|
outSplits[dim].push(i * rowLength);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (let i = 0; i < indices.length; ++i) {
|
|
let start = indices[i];
|
|
let limit = indices[i] + 1;
|
|
|
|
for (let dim = 0; dim < paramsNestedSplits.length; ++dim) {
|
|
const splits = paramsNestedSplits[dim];
|
|
const outDim = dim + indicesShape.length - 1;
|
|
if (outDim >= 0) {
|
|
const outSplitsOutDim = outSplits[outDim];
|
|
const delta = outSplitsOutDim[outSplitsOutDim.length - 1] - splits[start];
|
|
for (let j = start; j < limit; ++j) {
|
|
outSplits[outDim].push(splits[j + 1] + delta);
|
|
}
|
|
}
|
|
start = splits[start];
|
|
limit = splits[limit];
|
|
}
|
|
if (limit !== start) {
|
|
valueSlices.push([start, limit]);
|
|
numValues += limit - start;
|
|
}
|
|
}
|
|
return { outSplits, valueSlices, numValues };
|
|
}
|
|
function getSplits(outSplits) {
|
|
const splitsOut = [];
|
|
for (let i = 0; i < outSplits.length; ++i) {
|
|
const numSplits = outSplits[i].length;
|
|
const splits = getArrayFromDType('int32', numSplits);
|
|
splitsOut.push(splits);
|
|
outSplits[i].forEach((value, j) => splits[j] = value);
|
|
}
|
|
return splitsOut;
|
|
}
|
|
function computeFlatOuterDims(orig, numOutDims) {
|
|
const outDims = orig.slice(0, numOutDims);
|
|
while (outDims.length < numOutDims) {
|
|
outDims.push(1);
|
|
}
|
|
for (let inDim = numOutDims; inDim < orig.length; inDim++) {
|
|
outDims[numOutDims - 1] *= orig[inDim];
|
|
}
|
|
return outDims;
|
|
}
|
|
|
|
|
|
|
|
function writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, values, valuesShape) {
|
|
const denseM = computeFlatOuterDims(paramsDenseValuesShape, 2)[1];
|
|
const valuesM = computeFlatOuterDims(valuesShape, 2)[1];
|
|
let outPos = 0;
|
|
for (const slice of valueSlices) {
|
|
for (let i = slice[0]; i < slice[1]; ++i) {
|
|
for (let j = 0; j < valueSize; ++j) {
|
|
values[outPos * valuesM + j] = paramsDenseValues[i * denseM + j];
|
|
}
|
|
++outPos;
|
|
}
|
|
}
|
|
}
|
|
function getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues) {
|
|
const valuesShape = paramsDenseValuesShape.slice();
|
|
valuesShape[0] = numValues;
|
|
const valuesOut = getArrayFromDType(paramsDenseValuesDType, sizeFromShape(valuesShape));
|
|
const numElements = paramsDenseValues.length;
|
|
const valueSize = numElements === 0 ? 0 : (numElements / paramsDenseValuesShape[0]);
|
|
writeValueSlices(paramsDenseValues, paramsDenseValuesShape, valueSlices, valueSize, valuesOut, valuesShape);
|
|
return [valuesOut, valuesShape];
|
|
}
|
|
function raggedGatherImpl(paramsNestedSplits, paramsNestedSplitsShapes, paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, indices, indicesShape, outputRaggedRank) {
|
|
if (paramsNestedSplits.length === 0) {
|
|
throw new Error('paramsNestedSplits must be non empty');
|
|
}
|
|
if (paramsNestedSplitsShapes[0].length === 0) {
|
|
throw new Error('Split tensors must not be scalars');
|
|
}
|
|
const numParams = paramsNestedSplitsShapes[0][0] - 1;
|
|
validateIndices(indices, indicesShape, numParams);
|
|
if (paramsDenseValuesShape.length === 0) {
|
|
throw new Error('params.rank must be nonzero');
|
|
}
|
|
const numParamsDenseValues = paramsDenseValuesShape[0];
|
|
|
|
|
|
const { outSplits, valueSlices, numValues } = makeSplits(indices, indicesShape, paramsNestedSplits, numParamsDenseValues);
|
|
|
|
const outputNestedSplits = getSplits(outSplits);
|
|
const outputDenseValues = getValues(paramsDenseValues, paramsDenseValuesShape, paramsDenseValuesDType, valueSlices, numValues);
|
|
return [outputNestedSplits, outputDenseValues[0], outputDenseValues[1]];
|
|
}
|
|
|
|
|
|
const INT32_MAX = 2147483647;
|
|
function raggedRangeImpl(starts, startsShape, startsDType, limits, limitsShape, deltas, deltasShape) {
|
|
|
|
if (startsShape.length > 1) {
|
|
throw new Error('starts must be a scalar or vector');
|
|
}
|
|
if (limitsShape.length > 1) {
|
|
throw new Error('limits must be a scalar or vector');
|
|
}
|
|
if (deltasShape.length > 1) {
|
|
throw new Error('deltas must be a scalar or vector');
|
|
}
|
|
|
|
const broadcastStarts = startsShape.length === 0;
|
|
const broadcastLimits = limitsShape.length === 0;
|
|
const broadcastDeltas = deltasShape.length === 0;
|
|
|
|
|
|
const inSizes = [];
|
|
if (!broadcastStarts) {
|
|
inSizes.push(startsShape[0]);
|
|
}
|
|
if (!broadcastLimits) {
|
|
inSizes.push(limitsShape[0]);
|
|
}
|
|
if (!broadcastDeltas) {
|
|
inSizes.push(deltasShape[0]);
|
|
}
|
|
for (let i = 1; i < inSizes.length; ++i) {
|
|
if (inSizes[i] !== inSizes[i - 1]) {
|
|
throw new Error('starts, limits, and deltas must have the same shape');
|
|
}
|
|
}
|
|
const nRows = inSizes.length === 0 ? 1 : inSizes[0];
|
|
|
|
const rtNestedSplits = getArrayFromDType('int32', nRows + 1);
|
|
rtNestedSplits[0] = 0;
|
|
for (let row = 0; row < nRows; ++row) {
|
|
const start = broadcastStarts ? starts[0] : starts[row];
|
|
const limit = broadcastLimits ? limits[0] : limits[row];
|
|
const delta = broadcastDeltas ? deltas[0] : deltas[row];
|
|
if (delta === 0) {
|
|
throw new Error('Requires delta != 0');
|
|
}
|
|
let size;
|
|
if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
|
|
size = 0;
|
|
}
|
|
else {
|
|
size = Math.ceil(Math.abs((limit - start) / delta));
|
|
if (size > INT32_MAX) {
|
|
throw new Error(`Requires ((limit - start) / delta) <= ${INT32_MAX}`);
|
|
}
|
|
}
|
|
rtNestedSplits[row + 1] = rtNestedSplits[row] + size;
|
|
}
|
|
const nVals = rtNestedSplits[nRows];
|
|
|
|
const rtDenseValues = getArrayFromDType(startsDType, nVals);
|
|
let valueIndex = 0;
|
|
for (let row = 0; row < nRows; ++row) {
|
|
const rowSize = rtNestedSplits[row + 1] - rtNestedSplits[row];
|
|
let value = broadcastStarts ? starts[0] : starts[row];
|
|
const delta = broadcastDeltas ? deltas[0] : deltas[row];
|
|
for (let i = 0; i < rowSize; ++i) {
|
|
rtDenseValues[valueIndex++] = value;
|
|
value += delta;
|
|
}
|
|
}
|
|
return [rtNestedSplits, rtDenseValues];
|
|
}
|
|
|
|
|
|
var RowPartitionType = RowPartitionType$1;
|
|
|
|
|
|
class RaggedTensorToTensorOp {
|
|
constructor(shape, shapeShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypeStrings) {
|
|
this.shape = shape;
|
|
this.shapeShape = shapeShape;
|
|
this.values = values;
|
|
this.valuesShape = valuesShape;
|
|
this.valuesDType = valuesDType;
|
|
this.defaultValue = defaultValue;
|
|
this.defaultValueShape = defaultValueShape;
|
|
this.rowPartitionValues = rowPartitionValues;
|
|
this.rowPartitionValuesShapes = rowPartitionValuesShapes;
|
|
this.rowPartitionTypes =
|
|
getRowPartitionTypesHelper(rowPartitionTypeStrings);
|
|
this.raggedRank = getRaggedRank(this.rowPartitionTypes);
|
|
}
|
|
getRowPartitionTypeByDimension(dimension) {
|
|
if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
|
|
return this.rowPartitionTypes[dimension + 1];
|
|
}
|
|
else {
|
|
return this.rowPartitionTypes[dimension];
|
|
}
|
|
}
|
|
|
|
getRowPartitionTensor(dimension) {
|
|
if (this.rowPartitionTypes[0] === RowPartitionType.FIRST_DIM_SIZE) {
|
|
return this.rowPartitionValues[dimension + 1];
|
|
}
|
|
else {
|
|
return this.rowPartitionValues[dimension];
|
|
}
|
|
}
|
|
getMaxWidth(dimension) {
|
|
const rowPartitionTensor = this.getRowPartitionTensor(dimension - 1);
|
|
switch (this.getRowPartitionTypeByDimension(dimension - 1)) {
|
|
case RowPartitionType.VALUE_ROWIDS:
|
|
return RaggedTensorToTensorOp.getMaxWidthValueRowID(rowPartitionTensor);
|
|
case RowPartitionType.ROW_SPLITS:
|
|
return RaggedTensorToTensorOp.getMaxWidthRowSplit(rowPartitionTensor);
|
|
default:
|
|
throw new Error(`Cannot handle partition type ${RowPartitionType[this.getRowPartitionTypeByDimension(dimension - 1)]}`);
|
|
}
|
|
}
|
|
static getMaxWidthRowSplit(rowSplit) {
|
|
const tensorLength = rowSplit.length;
|
|
if (tensorLength === 0 || tensorLength === 1) {
|
|
return 0;
|
|
}
|
|
let maxWidth = 0;
|
|
for (let i = 0; i < tensorLength - 1; ++i) {
|
|
const currentWidth = rowSplit[i + 1] - rowSplit[i];
|
|
if (currentWidth > maxWidth) {
|
|
maxWidth = currentWidth;
|
|
}
|
|
}
|
|
return maxWidth;
|
|
}
|
|
static getMaxWidthValueRowID(valueRowIds) {
|
|
const indexLength = valueRowIds.length;
|
|
if (indexLength === 0) {
|
|
return 0;
|
|
}
|
|
let firstEqualIndex = 0;
|
|
let firstEqualIndexValue = valueRowIds[0];
|
|
let maxWidth = 0;
|
|
for (let i = 1; i < indexLength; ++i) {
|
|
const value = valueRowIds[i];
|
|
if (value !== firstEqualIndexValue) {
|
|
firstEqualIndexValue = value;
|
|
maxWidth = Math.max(i - firstEqualIndex, maxWidth);
|
|
firstEqualIndex = i;
|
|
}
|
|
}
|
|
return Math.max(indexLength - firstEqualIndex, maxWidth);
|
|
}
|
|
tensorShapeFromTensor(t, tShape, isPartial = true) {
|
|
if (tShape.length === 0) {
|
|
if (t[0] === -1) {
|
|
return [];
|
|
}
|
|
throw new Error(`The only valid scalar shape tensor is the fully unknown shape specified as -1.`);
|
|
}
|
|
|
|
return makeShape(t, isPartial);
|
|
}
|
|
calculateOutputSize(firstDim) {
|
|
const valueShape = this.valuesShape;
|
|
const defaultValueShape = this.defaultValueShape;
|
|
validateDefaultValueShape(defaultValueShape, valueShape);
|
|
const shape = this.tensorShapeFromTensor(this.shape, this.shapeShape);
|
|
const outputShape = combineRaggedTensorToTensorShapes(this.raggedRank, shape, valueShape);
|
|
const result = outputShape;
|
|
if (result[0] < 0) {
|
|
result[0] = firstDim;
|
|
}
|
|
for (let i = 1; i <= this.raggedRank; ++i) {
|
|
if (result[i] < 0) {
|
|
result[i] = this.getMaxWidth(i);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
calculateFirstParentOutputIndex(firstDimension, outputIndexMultiplier, firstDimensionOutput) {
|
|
const minDimension = Math.min(firstDimension, firstDimensionOutput);
|
|
const result = [];
|
|
let currentOutputIndex = 0;
|
|
for (let i = 0; i < minDimension; ++i, currentOutputIndex += outputIndexMultiplier) {
|
|
result.push(currentOutputIndex);
|
|
}
|
|
for (let i = minDimension; i < firstDimension; ++i) {
|
|
result.push(-1);
|
|
}
|
|
assert$1(result.length === firstDimension, () => 'Final length of result must be equal to firstDimension.');
|
|
return result;
|
|
}
|
|
calculateOutputIndexRowSplit(rowSplit, parentOutputIndex, outputIndexMultiplier, outputSize) {
|
|
const rowSplitSize = rowSplit.length;
|
|
const result = [];
|
|
for (let i = 0; i < rowSplitSize - 1; ++i) {
|
|
const rowLength = rowSplit[i + 1] - rowSplit[i];
|
|
let realLength = Math.min(outputSize, rowLength);
|
|
let parentOutputIndexCurrent = parentOutputIndex[i];
|
|
if (parentOutputIndexCurrent === -1) {
|
|
realLength = 0;
|
|
}
|
|
for (let j = 0; j < realLength; ++j) {
|
|
result.push(parentOutputIndexCurrent);
|
|
parentOutputIndexCurrent += outputIndexMultiplier;
|
|
}
|
|
for (let j = 0; j < rowLength - realLength; ++j) {
|
|
result.push(-1);
|
|
}
|
|
}
|
|
if (rowSplitSize > 0 && result.length !== rowSplit[rowSplitSize - 1]) {
|
|
throw new Error('Invalid row split size.');
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
calculateOutputIndexValueRowID(valueRowIds, parentOutputIndex, outputIndexMultiplier, outputSize) {
|
|
const indexSize = valueRowIds.length;
|
|
const result = [];
|
|
if (indexSize === 0) {
|
|
return [];
|
|
}
|
|
let currentOutputColumn = 0;
|
|
let currentValueRowId = valueRowIds[0];
|
|
if (currentValueRowId >= parentOutputIndex.length) {
|
|
throw new Error(`Got currentValueRowId=${currentValueRowId}, which is not less than ${parentOutputIndex.length}`);
|
|
}
|
|
let currentOutputIndex = parentOutputIndex[currentValueRowId];
|
|
result.push(currentOutputIndex);
|
|
for (let i = 1; i < indexSize; ++i) {
|
|
const nextValueRowId = valueRowIds[i];
|
|
if (nextValueRowId === currentValueRowId) {
|
|
if (currentOutputIndex >= 0) {
|
|
++currentOutputColumn;
|
|
if (currentOutputColumn < outputSize) {
|
|
currentOutputIndex += outputIndexMultiplier;
|
|
}
|
|
else {
|
|
currentOutputIndex = -1;
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
currentOutputColumn = 0;
|
|
currentValueRowId = nextValueRowId;
|
|
if (nextValueRowId >= parentOutputIndex.length) {
|
|
throw new Error(`Got nextValueRowId=${nextValueRowId} which is not less than ${parentOutputIndex.length}`);
|
|
}
|
|
currentOutputIndex = parentOutputIndex[nextValueRowId];
|
|
}
|
|
result.push(currentOutputIndex);
|
|
}
|
|
if (result.length !== valueRowIds.length) {
|
|
throw new Error('Invalid row ids.');
|
|
}
|
|
return result;
|
|
}
|
|
calculateOutputIndex(dimension, parentOutputIndex, outputIndexMultiplier, outputSize) {
|
|
const rowPartitionTensor = this.getRowPartitionTensor(dimension);
|
|
const partitionType = this.getRowPartitionTypeByDimension(dimension);
|
|
switch (partitionType) {
|
|
case RowPartitionType.VALUE_ROWIDS:
|
|
return this.calculateOutputIndexValueRowID(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
|
|
case RowPartitionType.ROW_SPLITS:
|
|
if (rowPartitionTensor.length - 1 > parentOutputIndex.length) {
|
|
throw new Error(`Row partition size is greater than output size: ${rowPartitionTensor.length - 1} > ${parentOutputIndex.length}`);
|
|
}
|
|
return this.calculateOutputIndexRowSplit(rowPartitionTensor, parentOutputIndex, outputIndexMultiplier, outputSize);
|
|
default:
|
|
throw new Error(`Unsupported partition type: ${RowPartitionType[partitionType]}`);
|
|
}
|
|
}
|
|
getFirstDimensionSize() {
|
|
const firstPartitionTensor = this.rowPartitionValues[0];
|
|
if (this.rowPartitionTypes.length === 0) {
|
|
throw new Error('No row_partition_types given.');
|
|
}
|
|
const firstPartitionType = this.rowPartitionTypes[0];
|
|
switch (firstPartitionType) {
|
|
case RowPartitionType.FIRST_DIM_SIZE:
|
|
return firstPartitionTensor[0];
|
|
case RowPartitionType.VALUE_ROWIDS:
|
|
throw new Error('Cannot handle VALUE_ROWIDS in first dimension.');
|
|
case RowPartitionType.ROW_SPLITS:
|
|
return this.rowPartitionValuesShapes[0][0] - 1;
|
|
default:
|
|
throw new Error(`Cannot handle type ${RowPartitionType[firstPartitionType]}`);
|
|
}
|
|
}
|
|
compute() {
|
|
const firstPartitionTensor = this.rowPartitionValues[0];
|
|
if (firstPartitionTensor.length <= 0) {
|
|
throw new Error('Invalid first partition input. ' +
|
|
'Tensor requires at least one element.');
|
|
}
|
|
const firstDimension = this.getFirstDimensionSize();
|
|
const outputSize = this.calculateOutputSize(firstDimension);
|
|
const multiplier = new Array(this.raggedRank + 1);
|
|
multiplier[multiplier.length - 1] = 1;
|
|
for (let i = multiplier.length - 2; i >= 0; --i) {
|
|
multiplier[i] = multiplier[i + 1] * outputSize[i + 1];
|
|
}
|
|
|
|
const outputShape = makeShape(outputSize, false);
|
|
const outputTensor = getArrayFromDType(this.valuesDType, sizeFromShape(outputShape));
|
|
const fullSize = multiplier[0] * outputSize[0];
|
|
if (fullSize > 0) {
|
|
let outputIndex = this.calculateFirstParentOutputIndex(firstDimension, multiplier[0], outputSize[0]);
|
|
for (let i = 1; i <= this.raggedRank; ++i) {
|
|
const newOutputIndex = this.calculateOutputIndex(i - 1, outputIndex, multiplier[i], outputSize[i]);
|
|
outputIndex = newOutputIndex;
|
|
}
|
|
this.setOutput(this.raggedRank, outputIndex, outputTensor, outputShape);
|
|
}
|
|
return [outputShape, outputTensor];
|
|
}
|
|
setOutput(raggedRank, outputIndex, outputTensor, outputShape) {
|
|
if (outputTensor.length === 0) {
|
|
return;
|
|
}
|
|
const valuesBase = this.values;
|
|
const outputBase = outputTensor;
|
|
let elementShape = outputShape.slice();
|
|
elementShape = elementShape.slice(raggedRank + 1);
|
|
const valueElementSize = sizeFromShape(elementShape);
|
|
const outputIndexSize = outputIndex.length;
|
|
|
|
|
|
let defaultValue = this.defaultValue;
|
|
if (defaultValue.length !== valueElementSize && defaultValue.length !== 1) {
|
|
const srcShape = this.defaultValueShape;
|
|
tidy(() => {
|
|
const defaultValueTensor = reshape$2(defaultValue, srcShape);
|
|
const bCastDefault = broadcastTo(defaultValueTensor, elementShape);
|
|
defaultValue = bCastDefault.dataSync();
|
|
});
|
|
}
|
|
|
|
|
|
|
|
let srcStart = 0;
|
|
let dstStart = 0;
|
|
let dstEnd = 0;
|
|
for (let srcI = 0; srcI <= outputIndexSize; ++srcI) {
|
|
|
|
let dstI = srcI < outputIndexSize ? outputIndex[srcI] : -1;
|
|
|
|
|
|
if (dstI === dstEnd) {
|
|
++dstEnd;
|
|
continue;
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dstStart < dstEnd) {
|
|
|
|
const src = valuesBase.subarray(srcStart * valueElementSize);
|
|
const dst = outputBase.subarray(dstStart * valueElementSize);
|
|
const nVals = (dstEnd - dstStart) * valueElementSize;
|
|
copyArray(dst, src, nVals);
|
|
}
|
|
|
|
if (srcI >= outputIndexSize) {
|
|
|
|
const outputSize = outputTensor.length;
|
|
dstI = Math.floor(outputSize / valueElementSize);
|
|
}
|
|
if (dstI > dstEnd) {
|
|
if (this.defaultValue.length === 1) {
|
|
outputBase
|
|
.subarray(dstEnd * valueElementSize, dstI * valueElementSize)
|
|
.fill(this.defaultValue[0]);
|
|
dstEnd = dstI;
|
|
}
|
|
else {
|
|
while (dstI > dstEnd) {
|
|
const dst = outputBase.slice(dstEnd * valueElementSize);
|
|
copyArray(dst, defaultValue, valueElementSize);
|
|
++dstEnd;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (dstI < 0) {
|
|
|
|
srcStart = srcI + 1;
|
|
dstStart = dstEnd;
|
|
}
|
|
else {
|
|
|
|
srcStart = srcI;
|
|
dstStart = dstEnd;
|
|
dstEnd = dstStart + 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
function copyArray(dst, src, size) {
|
|
for (let i = 0; i < size; i++) {
|
|
dst[i] = src[i];
|
|
}
|
|
}
|
|
function makeShape(shape, isPartial) {
|
|
const out = [];
|
|
for (let dim of shape) {
|
|
if (dim < 0) {
|
|
if (!isPartial) {
|
|
throw new Error(`Dimension ${dim} must be >= 0`);
|
|
}
|
|
if (dim < -1) {
|
|
throw new Error(`Dimension ${dim} must be >= -1`);
|
|
}
|
|
dim = -1;
|
|
}
|
|
out.push(dim);
|
|
}
|
|
return out;
|
|
}
|
|
function raggedTensorToTensorImpl(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes) {
|
|
return new RaggedTensorToTensorOp(shape, shapesShape, values, valuesShape, valuesDType, defaultValue, defaultValueShape, rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes)
|
|
.compute();
|
|
}
|
|
|
|
|
|
function rangeImpl(start, stop, step, dtype) {
|
|
const sameStartStop = start === stop;
|
|
const increasingRangeNegativeStep = start < stop && step < 0;
|
|
const decreasingRangePositiveStep = stop < start && step > 1;
|
|
if (sameStartStop || increasingRangeNegativeStep ||
|
|
decreasingRangePositiveStep) {
|
|
return makeZerosTypedArray(0, dtype);
|
|
}
|
|
const numElements = Math.abs(Math.ceil((stop - start) / step));
|
|
const values = makeZerosTypedArray(numElements, dtype);
|
|
if (stop < start && step === 1) {
|
|
|
|
|
|
step = -1;
|
|
}
|
|
values[0] = start;
|
|
for (let i = 1; i < values.length; i++) {
|
|
values[i] = values[i - 1] + step;
|
|
}
|
|
return values;
|
|
}
|
|
|
|
|
|
const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
|
|
const rsqrt$1 = unaryKernelFuncFromImpl(Rsqrt, rsqrtImpl);
|
|
const rsqrtConfig$1 = {
|
|
kernelName: Rsqrt,
|
|
backendName: 'cpu',
|
|
kernelFunc: rsqrt$1,
|
|
};
|
|
|
|
|
|
function scatterImpl(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) {
|
|
const flattenShape = [outputSize / sliceSize, sliceSize];
|
|
const indicesData = indices.values;
|
|
const updatesData = updates.values;
|
|
if (outputSize === 0) {
|
|
return buffer(shape, updates.dtype);
|
|
}
|
|
const outBuf = (defaultValue instanceof TensorBuffer) ?
|
|
defaultValue :
|
|
buffer(flattenShape, updates.dtype);
|
|
if (typeof defaultValue === 'string') {
|
|
outBuf.values.fill(defaultValue);
|
|
}
|
|
else if (typeof defaultValue === 'number') {
|
|
outBuf.values.fill(defaultValue);
|
|
}
|
|
else if (typeof defaultValue === 'boolean') {
|
|
outBuf.values.fill(+defaultValue);
|
|
}
|
|
for (let i = 0; i < numUpdates; i++) {
|
|
const index = [];
|
|
let flattenIndex = 0;
|
|
for (let j = 0; j < sliceRank; j++) {
|
|
const dim = indicesData[i * sliceRank + j];
|
|
index.push(dim);
|
|
flattenIndex += dim * strides[j];
|
|
}
|
|
if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) {
|
|
throw new Error(`Invalid indices: ${index} does not index into ${shape}`);
|
|
}
|
|
for (let k = 0; k < sliceSize; k++) {
|
|
if (sumDupeIndices) {
|
|
outBuf.values[flattenIndex * sliceSize + k] +=
|
|
updatesData[i * sliceSize + k];
|
|
}
|
|
else {
|
|
outBuf.values[flattenIndex * sliceSize + k] = updates.rank === 0 ?
|
|
updatesData[0] :
|
|
updatesData[i * sliceSize + k];
|
|
}
|
|
}
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
|
|
const sigmoidImpl = createSimpleUnaryImpl((xi) => 1 / (1 + Math.exp(-xi)));
|
|
const sigmoid$1 = unaryKernelFunc$1(Sigmoid$1, (xi) => 1 / (1 + Math.exp(-xi)));
|
|
const sigmoidConfig$1 = {
|
|
kernelName: Sigmoid$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: sigmoid$1,
|
|
};
|
|
|
|
|
|
function sliceImpl(vals, begin, size, shape, dtype) {
|
|
const isContinous = isSliceContinous(shape, begin, size);
|
|
const length = sizeFromShape(size);
|
|
const xStrides = computeStrides(shape);
|
|
if (isContinous) {
|
|
const flatOffset = computeFlatOffset(begin, xStrides);
|
|
if (dtype === 'string') {
|
|
return vals.slice(flatOffset, flatOffset + length);
|
|
}
|
|
return vals.subarray(flatOffset, flatOffset + length);
|
|
}
|
|
const decodedData = dtype === 'string' ?
|
|
fromUint8ToStringArray(vals) :
|
|
vals;
|
|
const inBuf = buffer(shape, dtype, decodedData);
|
|
const outBuf = buffer(size, dtype);
|
|
for (let i = 0; i < outBuf.size; ++i) {
|
|
const outLoc = outBuf.indexToLoc(i);
|
|
const inLoc = outLoc.map((idx, j) => idx + begin[j]);
|
|
outBuf.set(inBuf.get(...inLoc), ...outLoc);
|
|
}
|
|
if (dtype === 'string') {
|
|
return fromStringArrayToUint8(outBuf.values);
|
|
}
|
|
return outBuf.values;
|
|
}
|
|
function slice$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { begin, size } = attrs;
|
|
assertNotComplex(x, 'slice');
|
|
const [$begin, $size] = parseSliceParams(x, begin, size);
|
|
assertParamsValid(x, $begin, $size);
|
|
const vals = backend.data.get(x.dataId).values;
|
|
const outVals = sliceImpl(vals, $begin, $size, x.shape, x.dtype);
|
|
return backend.makeTensorInfo($size, x.dtype, outVals);
|
|
}
|
|
const sliceConfig$1 = {
|
|
kernelName: Slice,
|
|
backendName: 'cpu',
|
|
kernelFunc: slice$1
|
|
};
|
|
|
|
|
|
function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
|
|
const indicesCount = indicesShape[0];
|
|
const denseRows = denseShape[0];
|
|
const emptyRowIndicator = new Array(denseRows);
|
|
const reverseIndexMap = new Array(indicesCount);
|
|
const rank = indicesShape[1];
|
|
if (denseRows === 0) {
|
|
if (indicesCount !== 0) {
|
|
throw new Error(getSparseFillEmptyRowsIndicesDenseShapeMismatch(indicesCount));
|
|
}
|
|
const outputIndices = getArrayFromDType(indicesDType, 0);
|
|
const outputValues = getArrayFromDType(valuesDType, 0);
|
|
return [
|
|
outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
|
|
];
|
|
}
|
|
let rowsAreOrdered = true;
|
|
let lastIndicesRow = 0;
|
|
const csrOffset = new Array(denseRows).fill(0);
|
|
for (let i = 0; i < indicesCount; ++i) {
|
|
|
|
const row = indices[i * rank];
|
|
if (row < 0) {
|
|
throw new Error(getSparseFillEmptyRowsNegativeIndexErrorMessage(i, row));
|
|
}
|
|
if (row >= denseRows) {
|
|
throw new Error(getSparseFillEmptyRowsOutOfRangeIndexErrorMessage(i, row, denseRows));
|
|
}
|
|
++csrOffset[row];
|
|
rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
|
|
lastIndicesRow = row;
|
|
}
|
|
let allRowsFull = true;
|
|
for (let row = 0; row < denseRows; ++row) {
|
|
|
|
const rowEmpty = (csrOffset[row] === 0);
|
|
emptyRowIndicator[row] = rowEmpty;
|
|
allRowsFull = allRowsFull && !rowEmpty;
|
|
|
|
csrOffset[row] = Math.max(csrOffset[row], 1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (row > 0) {
|
|
csrOffset[row] += csrOffset[row - 1];
|
|
}
|
|
}
|
|
if (allRowsFull && rowsAreOrdered) {
|
|
const outputIndices = indices;
|
|
const outputValues = values;
|
|
for (let i = 0; i < indicesCount; ++i) {
|
|
reverseIndexMap[i] = i;
|
|
}
|
|
return [
|
|
outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
|
|
reverseIndexMap
|
|
];
|
|
}
|
|
else {
|
|
const fullIndicesCount = csrOffset[denseRows - 1];
|
|
const outputIndices = getArrayFromDType(indicesDType, fullIndicesCount * rank);
|
|
const outputValues = getArrayFromDType(valuesDType, fullIndicesCount);
|
|
const filledCount = new Array(denseRows).fill(0);
|
|
|
|
for (let i = 0; i < indicesCount; ++i) {
|
|
|
|
const row = indices[i * rank];
|
|
const offset = filledCount[row];
|
|
const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
|
|
filledCount[row]++;
|
|
for (let j = 0; j < rank; ++j) {
|
|
|
|
outputIndices[outputI * rank + j] = indices[i * rank + j];
|
|
}
|
|
outputValues[outputI] = values[i];
|
|
|
|
reverseIndexMap[i] = outputI;
|
|
}
|
|
|
|
for (let row = 0; row < denseRows; ++row) {
|
|
const rowCount = filledCount[row];
|
|
if (rowCount === 0) {
|
|
const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
|
|
|
|
|
|
|
|
outputIndices[startingIndex * rank + 0] = row;
|
|
for (let col = 1; col < rank; ++col) {
|
|
outputIndices[startingIndex * rank + col] = 0;
|
|
}
|
|
outputValues[startingIndex] = defaultValue;
|
|
}
|
|
}
|
|
return [
|
|
outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
|
|
reverseIndexMap
|
|
];
|
|
}
|
|
}
|
|
|
|
|
|
function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
|
|
const denseSize = sizeFromShape(inputShape);
|
|
const nnz = inputIndicesShape[0];
|
|
const outputRank = targetShape.length;
|
|
|
|
|
|
const outputShape = [];
|
|
let product = 1;
|
|
let unknownIndex = -1;
|
|
for (let d = 0; d < outputRank; ++d) {
|
|
const size = targetShape[d];
|
|
if (size === -1) {
|
|
if (unknownIndex !== -1) {
|
|
throw new Error(getSparseReshapeMultipleNegativeOneOutputDimErrorMessage(unknownIndex, d));
|
|
}
|
|
unknownIndex = d;
|
|
outputShape.push(1);
|
|
}
|
|
else {
|
|
if (size < 0) {
|
|
throw new Error(getSparseReshapeNegativeOutputDimErrorMessage(d, size));
|
|
}
|
|
product *= size;
|
|
outputShape.push(size);
|
|
}
|
|
}
|
|
if (unknownIndex !== -1) {
|
|
if (product <= 0) {
|
|
throw new Error(getSparseReshapeEmptyTensorZeroOutputDimErrorMessage());
|
|
}
|
|
const missing = Math.trunc(denseSize / product);
|
|
if (product * missing !== denseSize) {
|
|
throw new Error(getSparseReshapeInputOutputMultipleErrorMessage(inputShape, outputShape));
|
|
}
|
|
outputShape[unknownIndex] = missing;
|
|
}
|
|
const outputSize = sizeFromShape(outputShape);
|
|
if (outputSize !== denseSize) {
|
|
throw new Error(getSparseReshapeInputOutputMismatchErrorMessage(inputShape, outputShape));
|
|
}
|
|
const inputRank = inputShape.length;
|
|
const inputStrides = [];
|
|
if (inputRank > 0) {
|
|
inputStrides[inputRank - 1] = 1;
|
|
for (let d = inputRank - 2; d >= 0; --d) {
|
|
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
|
|
}
|
|
}
|
|
const outputStrides = [];
|
|
if (outputRank > 0) {
|
|
outputStrides[outputRank - 1] = 1;
|
|
for (let d = outputRank - 2; d >= 0; --d) {
|
|
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
|
|
}
|
|
}
|
|
const newIndices = getArrayFromDType(inputDType, nnz * outputRank);
|
|
for (let i = 0; i < nnz; ++i) {
|
|
let id = 0;
|
|
for (let j = 0; j < inputRank; ++j) {
|
|
|
|
id += inputIndices[i * inputRank + j] * inputStrides[j];
|
|
}
|
|
for (let j = 0; j < outputRank; ++j) {
|
|
|
|
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
|
|
id %= outputStrides[j];
|
|
}
|
|
}
|
|
return [newIndices, [nnz, outputRank], outputShape];
|
|
}
|
|
|
|
|
|
function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
|
|
const numIndices = indices.length;
|
|
|
|
const inputFlat = [inputShape[0], input.length / inputShape[0]];
|
|
const numCol = inputFlat[1];
|
|
|
|
|
|
const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
|
|
const outputRows = lastSegmentIdPlusOne;
|
|
if (outputRows < 0) {
|
|
throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
|
|
}
|
|
const outputShape = inputShape.slice();
|
|
outputShape[0] = outputRows;
|
|
const outputLength = outputShape.reduce((product, value) => product * value, 1);
|
|
|
|
const output = getArrayFromDType(inputDType, outputLength);
|
|
|
|
|
|
if (numIndices === 0) {
|
|
if (outputRows > 0) {
|
|
output.fill(defaultValue);
|
|
}
|
|
return [output, outputShape];
|
|
}
|
|
if (outputRows <= 0) {
|
|
throw new Error(getSparseSegmentReductionNegativeSegmentIdsErrorMessage());
|
|
}
|
|
let start = 0, end = 1;
|
|
|
|
let uninitializedIndex = 0;
|
|
let outIndex = segmentIds[start];
|
|
while (true) {
|
|
|
|
let nextIndex = 0;
|
|
if (end < numIndices) {
|
|
nextIndex = segmentIds[end];
|
|
if (outIndex === nextIndex) {
|
|
++end;
|
|
continue;
|
|
}
|
|
|
|
if (outIndex >= nextIndex) {
|
|
throw new Error(getSparseSegmentReductionNonIncreasingSegmentIdsErrorMessage());
|
|
}
|
|
}
|
|
if (outIndex < 0 || outIndex >= outputRows) {
|
|
throw new Error(getSparseSegmentReductionSegmentIdOutOfRangeErrorMessage(outIndex, outputRows));
|
|
}
|
|
|
|
|
|
if (outIndex > uninitializedIndex) {
|
|
output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
|
|
}
|
|
for (let i = start; i < end; ++i) {
|
|
const index = indices[i];
|
|
if (index < 0 || index >= inputFlat[0]) {
|
|
throw new Error(getSparseSegmentReductionIndicesOutOfRangeErrorMessage(i, indices[i], inputFlat[0]));
|
|
}
|
|
for (let j = 0; j < numCol; j++) {
|
|
output[outIndex * numCol + j] += input[index * numCol + j];
|
|
}
|
|
}
|
|
if (isMean) {
|
|
for (let j = 0; j < numCol; j++) {
|
|
output[outIndex * numCol + j] /= end - start;
|
|
}
|
|
}
|
|
start = end;
|
|
++end;
|
|
uninitializedIndex = outIndex + 1;
|
|
outIndex = nextIndex;
|
|
if (end > numIndices) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (uninitializedIndex < outputRows) {
|
|
output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
|
|
}
|
|
return [output, outputShape];
|
|
}
|
|
|
|
|
|
const sqrtImpl = createSimpleUnaryImpl((xi) => Math.sqrt(xi));
|
|
const sqrt$1 = unaryKernelFunc$1(Sqrt, (xi) => Math.sqrt(xi));
|
|
const sqrtConfig$1 = {
|
|
kernelName: Sqrt,
|
|
backendName: 'cpu',
|
|
kernelFunc: sqrt$1,
|
|
};
|
|
|
|
|
|
const squaredDifferenceImpl = createSimpleBinaryKernelImpl(((a, b) => {
|
|
const diff = a - b;
|
|
return diff * diff;
|
|
}));
|
|
const squaredDifference$1 = binaryKernelFunc$1(SquaredDifference, squaredDifferenceImpl);
|
|
const squaredDifferenceConfig$1 = {
|
|
kernelName: SquaredDifference,
|
|
backendName: 'cpu',
|
|
kernelFunc: squaredDifference$1
|
|
};
|
|
|
|
|
|
const staticRegexReplaceImpl = createSimpleUnaryImpl((x, attrs) => {
|
|
const { pattern, replaceGlobal, rewrite } = attrs;
|
|
|
|
return x.replace(new RegExp(pattern, replaceGlobal ? 'g' : ''), rewrite);
|
|
});
|
|
const staticRegexReplace$1 = unaryKernelFuncFromImpl(StaticRegexReplace, staticRegexReplaceImpl);
|
|
const staticRegexReplaceConfig$1 = {
|
|
kernelName: StaticRegexReplace,
|
|
backendName: 'cpu',
|
|
kernelFunc: staticRegexReplace$1,
|
|
};
|
|
|
|
|
|
function stridedSliceImpl(outShape, xBuf, strides, begin) {
|
|
const outBuf = buffer(outShape, xBuf.dtype);
|
|
for (let i = 0; i < outBuf.size; i++) {
|
|
const loc = outBuf.indexToLoc(i);
|
|
const newLoc = new Array(loc.length);
|
|
for (let j = 0; j < newLoc.length; j++) {
|
|
newLoc[j] = loc[j] * strides[j] + begin[j];
|
|
}
|
|
outBuf.set(xBuf.get(...newLoc), ...loc);
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
|
|
|
|
class StringNGramsOp {
|
|
constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
|
|
this.separator = encodeString(separator);
|
|
this.nGramWidths = nGramWidths;
|
|
this.leftPad = encodeString(leftPad);
|
|
this.rightPad = encodeString(rightPad);
|
|
this.padWidth = padWidth;
|
|
this.preserveShort = preserveShortSequences;
|
|
}
|
|
getPadWidth(nGramWidth) {
|
|
|
|
|
|
|
|
return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
|
|
}
|
|
getNumNGrams(length, nGramWidth) {
|
|
const padWidth = this.getPadWidth(nGramWidth);
|
|
return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1);
|
|
}
|
|
createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
|
|
for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
|
|
const padWidth = this.getPadWidth(nGramWidth);
|
|
const leftPadding = Math.max(0, padWidth - nGramIndex);
|
|
const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
|
|
const numTokens = nGramWidth - (leftPadding + rightPadding);
|
|
const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
|
|
|
|
|
|
let nGramSize = 0;
|
|
|
|
nGramSize += leftPadding * this.leftPad.length;
|
|
|
|
for (let n = 0; n < numTokens; ++n) {
|
|
nGramSize += data[dataStartIndex + n].length;
|
|
}
|
|
|
|
nGramSize += rightPadding * this.rightPad.length;
|
|
|
|
const numSeparators = leftPadding + rightPadding + numTokens - 1;
|
|
nGramSize += numSeparators * this.separator.length;
|
|
|
|
output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
|
|
const nGram = output[outputStartIndex + nGramIndex];
|
|
let nextNGramIndex = 0;
|
|
const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value);
|
|
for (let n = 0; n < leftPadding; ++n) {
|
|
appendToNGram(this.leftPad);
|
|
appendToNGram(this.separator);
|
|
}
|
|
|
|
for (let n = 0; n < numTokens - 1; ++n) {
|
|
appendToNGram(data[dataStartIndex + n]);
|
|
appendToNGram(this.separator);
|
|
}
|
|
|
|
|
|
if (numTokens > 0) {
|
|
|
|
|
|
|
|
appendToNGram(data[dataStartIndex + numTokens - 1]);
|
|
for (let n = 0; n < rightPadding; ++n) {
|
|
appendToNGram(this.separator);
|
|
appendToNGram(this.rightPad);
|
|
}
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
|
|
for (let n = 0; n < rightPadding - 1; ++n) {
|
|
appendToNGram(this.rightPad);
|
|
appendToNGram(this.separator);
|
|
}
|
|
appendToNGram(this.rightPad);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
compute(data, splits) {
|
|
|
|
|
|
const inputDataSize = data.length;
|
|
const splitsSize = splits.length;
|
|
if (splitsSize > 0) {
|
|
let prevSplit = splits[0];
|
|
if (prevSplit !== 0) {
|
|
throw new Error(`First split value must be 0, got ${prevSplit}`);
|
|
}
|
|
for (let i = 1; i < splitsSize; ++i) {
|
|
let validSplits = splits[i] >= prevSplit;
|
|
validSplits = validSplits && (splits[i] <= inputDataSize);
|
|
if (!validSplits) {
|
|
throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`);
|
|
}
|
|
prevSplit = splits[i];
|
|
}
|
|
if (prevSplit !== inputDataSize) {
|
|
throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`);
|
|
}
|
|
}
|
|
const numBatchItems = splitsSize - 1;
|
|
const nGramsSplits = getArrayFromDType('int32', splitsSize);
|
|
|
|
if (inputDataSize === 0 || splitsSize === 0) {
|
|
const empty = new Array(inputDataSize);
|
|
for (let i = 0; i <= numBatchItems; ++i) {
|
|
nGramsSplits[i] = 0;
|
|
}
|
|
return [empty, nGramsSplits];
|
|
}
|
|
nGramsSplits[0] = 0;
|
|
for (let i = 1; i <= numBatchItems; ++i) {
|
|
const length = splits[i] - splits[i - 1];
|
|
let numNGrams = 0;
|
|
this.nGramWidths.forEach((nGramWidth) => {
|
|
numNGrams += this.getNumNGrams(length, nGramWidth);
|
|
});
|
|
if (this.preserveShort && length > 0 && numNGrams === 0) {
|
|
numNGrams = 1;
|
|
}
|
|
nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams;
|
|
}
|
|
const nGrams = new Array(nGramsSplits[numBatchItems]);
|
|
for (let i = 0; i < numBatchItems; ++i) {
|
|
const splitIndex = splits[i];
|
|
let outputStartIdx = nGramsSplits[i];
|
|
this.nGramWidths.forEach((nGramWidth) => {
|
|
const length = splits[i + 1] - splits[i];
|
|
const numNGrams = this.getNumNGrams(length, nGramWidth);
|
|
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
|
|
outputStartIdx += numNGrams;
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
if (this.preserveShort && outputStartIdx === nGramsSplits[i]) {
|
|
const dataLength = splits[i + 1] - splits[i];
|
|
|
|
|
|
if (dataLength === 0) {
|
|
continue;
|
|
}
|
|
|
|
|
|
|
|
const nGramWidth = dataLength + 2 * this.padWidth;
|
|
const numNGrams = 1;
|
|
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
|
|
}
|
|
}
|
|
return [nGrams, nGramsSplits];
|
|
}
|
|
}
|
|
function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
|
|
return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences)
|
|
.compute(data, dataSplits);
|
|
}
|
|
|
|
|
|
function split(str, delimiters, skipEmpty, result) {
|
|
if (!str.length) {
|
|
return;
|
|
}
|
|
|
|
if (delimiters.length === 0) {
|
|
for (let i = 0; i < str.length; ++i) {
|
|
result.push(str.subarray(i, i + 1));
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (delimiters.length === 1) {
|
|
const delimiter = delimiters[0];
|
|
let f = str.indexOf(delimiter);
|
|
while (f !== -1) {
|
|
const token = str.subarray(0, f);
|
|
if (!skipEmpty || token.length !== 0) {
|
|
result.push(token);
|
|
}
|
|
str = str.subarray(f + 1);
|
|
f = str.indexOf(delimiter);
|
|
}
|
|
if (!skipEmpty || str.length !== 0) {
|
|
result.push(str);
|
|
}
|
|
return;
|
|
}
|
|
|
|
|
|
let tokenStart = 0;
|
|
for (let i = 0; i < str.length + 1; i++) {
|
|
if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) {
|
|
const token = str.subarray(tokenStart, i);
|
|
if (!skipEmpty || token.length !== 0) {
|
|
result.push(token);
|
|
}
|
|
tokenStart = i + 1;
|
|
}
|
|
}
|
|
}
|
|
function stringSplitImpl(input, delimiter, skipEmpty) {
|
|
const batchSize = input.length;
|
|
|
|
const tokens = [];
|
|
let outputSize = 0;
|
|
let maxNumEntries = 0;
|
|
const numIndices = new Array(batchSize);
|
|
for (let i = 0; i < batchSize; ++i) {
|
|
const prevTokensLength = tokens.length;
|
|
split(input[i], delimiter, skipEmpty, tokens);
|
|
const nEntries = tokens.length - prevTokensLength;
|
|
numIndices[i] = nEntries;
|
|
outputSize += nEntries;
|
|
maxNumEntries = Math.max(maxNumEntries, nEntries);
|
|
}
|
|
const indices = getArrayFromDType('int32', outputSize * 2);
|
|
const values = new Array(outputSize);
|
|
const shape = [batchSize, maxNumEntries];
|
|
let c = 0;
|
|
for (let i = 0; i < batchSize; ++i) {
|
|
for (let j = 0; j < numIndices[i]; ++j) {
|
|
|
|
indices[c * 2] = i;
|
|
indices[c * 2 + 1] = j;
|
|
values[c] = tokens[c];
|
|
++c;
|
|
}
|
|
}
|
|
return [indices, values, shape];
|
|
}
|
|
|
|
|
|
function stringToHashBucketFastImpl(input, numBuckets) {
|
|
const output = getArrayFromDType('int32', input.length);
|
|
for (let i = 0; i < input.length; ++i) {
|
|
output[i] =
|
|
fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
|
|
}
|
|
return output;
|
|
}
|
|
|
|
|
|
const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
|
|
const subComplexImpl = createComplexBinaryKernelImpl(((aReal, aImag, bReal, bImag) => {
|
|
return { real: aReal - bReal, imag: aImag - bImag };
|
|
}));
|
|
const sub$1 = binaryKernelFunc$1(Sub, subImpl, subComplexImpl);
|
|
const subConfig$1 = {
|
|
kernelName: Sub,
|
|
backendName: 'cpu',
|
|
kernelFunc: sub$1
|
|
};
|
|
|
|
|
|
|
|
function tileImpl(xBuf, reps) {
|
|
const newShape = new Array(xBuf.rank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = xBuf.shape[i] * reps[i];
|
|
}
|
|
const result = buffer(newShape, xBuf.dtype);
|
|
for (let i = 0; i < result.values.length; ++i) {
|
|
const newLoc = result.indexToLoc(i);
|
|
const originalLoc = new Array(xBuf.rank);
|
|
for (let j = 0; j < originalLoc.length; j++) {
|
|
originalLoc[j] = newLoc[j] % xBuf.shape[j];
|
|
}
|
|
const originalIndex = xBuf.locToIndex(originalLoc);
|
|
result.values[i] = xBuf.values[originalIndex];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
const comparePair = (a, b) => {
|
|
const valueDiff = b.value - a.value;
|
|
return valueDiff === 0 ? a.index - b.index : valueDiff;
|
|
};
|
|
|
|
function select$2(array, k, left = 0, right = array.length - 1) {
|
|
while (right > left) {
|
|
|
|
|
|
|
|
if (right - left > 600) {
|
|
const n = right - left + 1;
|
|
const i = k - left + 1;
|
|
const z = Math.log(n);
|
|
const s = 0.5 * Math.exp(2 * z / 3);
|
|
const sd = 0.5 * Math.sqrt(z * s * (n - s) / n) * Math.sign(i - n / 2);
|
|
const newLeft = Math.max(left, Math.floor(k - i * s / n + sd));
|
|
const newRight = Math.min(right, Math.floor(k + (n - i) * s / n + sd));
|
|
select$2(array, k, newLeft, newRight);
|
|
}
|
|
|
|
const t = array[k];
|
|
let i = left;
|
|
let j = right;
|
|
swap(array, left, k);
|
|
if (comparePair(array[right], t) > 0) {
|
|
swap(array, left, right);
|
|
}
|
|
while (i < j) {
|
|
swap(array, i, j);
|
|
i++;
|
|
j--;
|
|
while (comparePair(array[i], t) < 0) {
|
|
i = i + 1;
|
|
}
|
|
while (comparePair(array[j], t) > 0) {
|
|
j = j - 1;
|
|
}
|
|
}
|
|
if (comparePair(array[left], t) === 0) {
|
|
swap(array, left, j);
|
|
}
|
|
else {
|
|
j = j + 1;
|
|
swap(array, j, right);
|
|
}
|
|
|
|
|
|
if (j <= k) {
|
|
left = j + 1;
|
|
}
|
|
if (k <= j) {
|
|
right = j - 1;
|
|
}
|
|
}
|
|
}
|
|
function topKImpl(x, xShape, xDtype, k, sorted) {
|
|
|
|
const lastDim = xShape[xShape.length - 1];
|
|
const [batch, size] = [x.length / lastDim, lastDim];
|
|
const allTopKVals = getTypedArrayFromDType(xDtype, batch * k);
|
|
const allTopKIndices = getTypedArrayFromDType('int32', batch * k);
|
|
for (let b = 0; b < batch; b++) {
|
|
const offset = b * size;
|
|
const vals = x.subarray(offset, offset + size);
|
|
let valAndInd = new Array(vals.length);
|
|
vals.forEach((value, index) => valAndInd[index] = { value, index });
|
|
if (k < valAndInd.length) {
|
|
select$2(valAndInd, k);
|
|
valAndInd = valAndInd.slice(0, k);
|
|
}
|
|
if (sorted) {
|
|
valAndInd.sort(comparePair);
|
|
}
|
|
const outOffset = b * k;
|
|
const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
|
|
const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
|
|
for (let i = 0; i < k; i++) {
|
|
topKVals[i] = valAndInd[i].value;
|
|
topKIndices[i] = valAndInd[i].index;
|
|
}
|
|
}
|
|
|
|
|
|
const outputShape = xShape.slice();
|
|
outputShape[outputShape.length - 1] = k;
|
|
return [
|
|
buffer(outputShape, xDtype, allTopKVals),
|
|
buffer(outputShape, 'int32', allTopKIndices)
|
|
];
|
|
}
|
|
|
|
|
|
function uniqueImpl(values, axis, shape, dtype) {
|
|
|
|
const $axis = parseAxisParam(axis, shape)[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const newShape = [1, shape[0], 1];
|
|
for (let i = 0; i < $axis; i++) {
|
|
newShape[0] *= shape[i];
|
|
}
|
|
newShape[1] = shape[$axis];
|
|
for (let i = $axis + 1; i < shape.length; i++) {
|
|
newShape[2] *= shape[i];
|
|
}
|
|
|
|
|
|
const uniqueElements = new Map();
|
|
|
|
|
|
const indices = new Int32Array(shape[$axis]);
|
|
|
|
const inputBuffer = new TensorBuffer(newShape, dtype, values);
|
|
|
|
|
|
const uniqueIndices = [];
|
|
const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
|
|
for (let i = 0; i < shape[$axis]; i++) {
|
|
|
|
let element;
|
|
if (is1DTensor) {
|
|
|
|
element = values[i].toString();
|
|
}
|
|
else {
|
|
const axisValues = [];
|
|
for (let m = 0; m < newShape[0]; m++) {
|
|
for (let n = 0; n < newShape[2]; n++) {
|
|
axisValues.push(inputBuffer.get(m, i, n));
|
|
}
|
|
}
|
|
element = axisValues.join(',');
|
|
}
|
|
|
|
const existingIndex = uniqueElements.get(element);
|
|
if (existingIndex != null) {
|
|
indices[i] = existingIndex;
|
|
}
|
|
else {
|
|
const uniqueIndex = uniqueElements.size;
|
|
uniqueElements.set(element, uniqueIndex);
|
|
indices[i] = uniqueIndex;
|
|
uniqueIndices.push(i);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const outputTmpShape = newShape.slice();
|
|
outputTmpShape[1] = uniqueElements.size;
|
|
const outputBuffer = new TensorBuffer(outputTmpShape, dtype);
|
|
uniqueIndices.forEach((uniqueElementIndex, i) => {
|
|
for (let m = 0; m < newShape[0]; m++) {
|
|
for (let n = 0; n < newShape[2]; n++) {
|
|
outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
|
|
}
|
|
}
|
|
});
|
|
|
|
|
|
const outputShape = shape.slice();
|
|
outputShape[$axis] = outputTmpShape[1];
|
|
return {
|
|
outputValues: outputBuffer.values,
|
|
outputShape,
|
|
indices,
|
|
};
|
|
}
|
|
|
|
|
|
|
|
|
|
var shared = Object.freeze({
|
|
__proto__: null,
|
|
addImpl: addImpl,
|
|
bincountImpl: bincountImpl,
|
|
bincountReduceImpl: bincountReduceImpl,
|
|
bitwiseAndImpl: bitwiseAndImpl,
|
|
castImpl: castImpl,
|
|
ceilImpl: ceilImpl,
|
|
concatImpl: concatImpl$1,
|
|
equalImpl: equalImpl,
|
|
expImpl: expImpl,
|
|
expm1Impl: expm1Impl,
|
|
floorDivImpl: floorDivImpl,
|
|
floorImpl: floorImpl,
|
|
gatherNdImpl: gatherNdImpl,
|
|
gatherV2Impl: gatherV2Impl,
|
|
greaterEqualImpl: greaterEqualImpl,
|
|
greaterImpl: greaterImpl,
|
|
lessEqualImpl: lessEqualImpl,
|
|
lessImpl: lessImpl,
|
|
linSpaceImpl: linSpaceImpl,
|
|
logImpl: logImpl,
|
|
maxImpl: maxImpl$1,
|
|
maximumImpl: maximumImpl,
|
|
minimumImpl: minimumImpl,
|
|
multiplyImpl: multiplyImpl,
|
|
negImpl: negImpl,
|
|
notEqualImpl: notEqualImpl,
|
|
prodImpl: prodImpl,
|
|
raggedGatherImpl: raggedGatherImpl,
|
|
raggedRangeImpl: raggedRangeImpl,
|
|
raggedTensorToTensorImpl: raggedTensorToTensorImpl,
|
|
rangeImpl: rangeImpl,
|
|
rsqrtImpl: rsqrtImpl,
|
|
scatterImpl: scatterImpl,
|
|
sigmoidImpl: sigmoidImpl,
|
|
simpleAbsImpl: simpleAbsImpl,
|
|
sliceImpl: sliceImpl,
|
|
sparseFillEmptyRowsImpl: sparseFillEmptyRowsImpl,
|
|
sparseReshapeImpl: sparseReshapeImpl,
|
|
sparseSegmentReductionImpl: sparseSegmentReductionImpl,
|
|
sqrtImpl: sqrtImpl,
|
|
squaredDifferenceImpl: squaredDifferenceImpl,
|
|
staticRegexReplaceImpl: staticRegexReplaceImpl,
|
|
stridedSliceImpl: stridedSliceImpl,
|
|
stringNGramsImpl: stringNGramsImpl,
|
|
stringSplitImpl: stringSplitImpl,
|
|
stringToHashBucketFastImpl: stringToHashBucketFastImpl,
|
|
subImpl: subImpl,
|
|
tileImpl: tileImpl,
|
|
topKImpl: topKImpl,
|
|
transposeImpl: transposeImpl$1,
|
|
uniqueImpl: uniqueImpl
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
const { addImpl: addImplCPU, bincountImpl: bincountImplCPU, bincountReduceImpl: bincountReduceImplCPU, bitwiseAndImpl: bitwiseAndImplCPU, castImpl: castImplCPU, ceilImpl: ceilImplCPU, concatImpl: concatImplCPU, equalImpl: equalImplCPU, expImpl: expImplCPU, expm1Impl: expm1ImplCPU, floorImpl: floorImplCPU, gatherNdImpl: gatherNdImplCPU, gatherV2Impl: gatherV2ImplCPU, greaterImpl: greaterImplCPU, greaterEqualImpl: greaterEqualImplCPU, lessImpl: lessImplCPU, lessEqualImpl: lessEqualImplCPU, linSpaceImpl: linSpaceImplCPU, logImpl: logImplCPU, maxImpl: maxImplCPU, maximumImpl: maximumImplCPU, minimumImpl: minimumImplCPU, multiplyImpl: multiplyImplCPU, negImpl: negImplCPU, notEqualImpl: notEqualImplCPU, prodImpl: prodImplCPU, raggedGatherImpl: raggedGatherImplCPU, raggedRangeImpl: raggedRangeImplCPU, raggedTensorToTensorImpl: raggedTensorToTensorImplCPU, rangeImpl: rangeImplCPU, rsqrtImpl: rsqrtImplCPU, scatterImpl: scatterImplCPU, sigmoidImpl: sigmoidImplCPU, simpleAbsImpl: simpleAbsImplCPU, sliceImpl: sliceImplCPU, sparseFillEmptyRowsImpl: sparseFillEmptyRowsImplCPU, sparseReshapeImpl: sparseReshapeImplCPU, sparseSegmentReductionImpl: sparseSegmentReductionImplCPU, sqrtImpl: sqrtImplCPU, staticRegexReplaceImpl: staticRegexReplaceImplCPU, stridedSliceImpl: stridedSliceImplCPU, stringNGramsImpl: stringNGramsImplCPU, stringSplitImpl: stringSplitImplCPU, stringToHashBucketFastImpl: stringToHashBucketFastImplCPU, subImpl: subImplCPU, tileImpl: tileImplCPU, topKImpl: topKImplCPU, transposeImpl: transposeImplCPU, uniqueImpl: uniqueImplCPU, } = shared;
|
|
|
|
|
|
function getVecChannels(name, rank) {
|
|
return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
|
|
}
|
|
function getChannels(name, rank) {
|
|
if (rank === 1) {
|
|
return [name];
|
|
}
|
|
return getVecChannels(name, rank);
|
|
}
|
|
function getSourceCoords$2(rank, dims) {
|
|
if (rank === 1) {
|
|
return 'rc';
|
|
}
|
|
let coords = '';
|
|
for (let i = 0; i < rank; i++) {
|
|
coords += dims[i];
|
|
if (i < rank - 1) {
|
|
coords += ',';
|
|
}
|
|
}
|
|
return coords;
|
|
}
|
|
|
|
|
|
class PackProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
if (this.rank === 0) {
|
|
this.userCode = `
|
|
void main() {
|
|
setOutput(vec4(getA(), 0., 0., 0.));
|
|
}
|
|
`;
|
|
}
|
|
else {
|
|
const channels = getChannels('rc', this.rank);
|
|
const dtype = getCoordsDataType(this.rank);
|
|
const outOfBoundsCondition = this.getOutOfBoundsCondition(channels);
|
|
const setup = this.getSetup(channels);
|
|
const output = this.getOutput(channels);
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} rc = getOutputCoords();
|
|
|
|
if(${outOfBoundsCondition}) {
|
|
setOutput(vec4(0));
|
|
} else {
|
|
${setup}
|
|
|
|
setOutput(vec4(${output}));
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
getSourceCoordsArr(dims) {
|
|
const coords = [];
|
|
for (let row = 0; row <= 1; row++) {
|
|
for (let col = 0; col <= 1; col++) {
|
|
let coord = `${row === 0 ? 'r' : 'rp1'}, ${col === 0 ? 'c' : 'cp1'}`;
|
|
for (let d = 2; d < this.rank; d++) {
|
|
coord = `${dims[dims.length - 1 - d]},` + coord;
|
|
}
|
|
coords.push(coord);
|
|
}
|
|
}
|
|
return coords;
|
|
}
|
|
getOutOfBoundsCondition(dims) {
|
|
if (this.rank === 1) {
|
|
return `rc > ${this.enableShapeUniforms ? 'outShape' : this.outputShape[0]}`;
|
|
}
|
|
let cond = '';
|
|
for (let i = this.rank - 2; i < this.rank; i++) {
|
|
cond += `${dims[i]} >= ${this.enableShapeUniforms ? `outShape[${i}]` : this.outputShape[i]}`;
|
|
if (i < this.rank - 1) {
|
|
cond += '||';
|
|
}
|
|
}
|
|
return cond;
|
|
}
|
|
getSetup(dims) {
|
|
if (this.rank === 1) {
|
|
return '';
|
|
}
|
|
const innerDims = dims.slice(-2);
|
|
const col = this.enableShapeUniforms ? `outShape[${this.rank} - 1]` :
|
|
this.outputShape[this.rank - 1];
|
|
const row = this.enableShapeUniforms ? `outShape[${this.rank} - 2]` :
|
|
this.outputShape[this.rank - 2];
|
|
return `
|
|
int r = ${innerDims[0]};
|
|
int c = ${innerDims[1]};
|
|
int rp1 = r + 1;
|
|
int cp1 = c + 1;
|
|
|
|
bool cEdge = cp1 >= ${col};
|
|
bool rEdge = rp1 >= ${row};
|
|
`;
|
|
}
|
|
getOutput(dims) {
|
|
const sourceCoords = this.getSourceCoordsArr(dims);
|
|
if (this.rank === 1) {
|
|
const outShape = this.enableShapeUniforms ? 'outShape' : this.outputShape[0];
|
|
return `getA(rc), (rc + 1 >= ${outShape} ? 0. : getA(rc + 1)), 0, 0`;
|
|
}
|
|
return `getA(${sourceCoords[0]}),
|
|
cEdge ? 0. : getA(${sourceCoords[1]}),
|
|
rEdge ? 0. : getA(${sourceCoords[2]}),
|
|
rEdge || cEdge ? 0. : getA(${sourceCoords[3]})`;
|
|
}
|
|
}
|
|
|
|
|
|
class ReshapePackedProgram {
|
|
constructor(outputShape, inputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [{ name: 'inputShape', type: 'ivec3' }];
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
let mainLoop = ``;
|
|
for (let i = 0; i < 4; i++) {
|
|
let thisRC = `thisRC = rc;`;
|
|
if (i % 2 === 1) {
|
|
thisRC += `thisRC.z += 1;`;
|
|
}
|
|
if (i > 1) {
|
|
thisRC += `thisRC.y += 1;`;
|
|
}
|
|
mainLoop += `
|
|
${thisRC}
|
|
${i > 0 ? `if(thisRC.y < rows && thisRC.z < cols){` : ''}
|
|
int flatIndex = getFlatIndex(thisRC);
|
|
|
|
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
|
|
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
|
|
|
|
result[${i}] =
|
|
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
|
|
${i > 0 ? '}' : ''}
|
|
`;
|
|
}
|
|
this.userCode = `
|
|
${getReshapedInputCoords(inputShape, this.enableShapeUniforms)}
|
|
${this.enableShapeUniforms ? getFlatIndexFrom3DOutput() :
|
|
getFlatIndexFrom3D(outputShape)}
|
|
|
|
void main() {
|
|
ivec3 rc = getOutputCoords();
|
|
|
|
vec4 result = vec4(0.);
|
|
|
|
ivec3 thisRC;
|
|
int rows = ${this.enableShapeUniforms ? 'outShape[1]' : outputShape[1]};
|
|
int cols = ${this.enableShapeUniforms ? 'outShape[2]' : outputShape[2]};
|
|
|
|
${mainLoop}
|
|
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
function getReshapedInputCoords(shape, enableShapeUniforms) {
|
|
const coordsFromIndexSnippet = enableShapeUniforms ?
|
|
getLogicalCoordinatesFromFlatIndexByUniform(['r', 'c', 'd'], 'inputShape') :
|
|
getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
|
|
return `
|
|
ivec3 inputCoordsFromReshapedOutCoords(int index) {
|
|
${coordsFromIndexSnippet}
|
|
return ivec3(r, c, d);
|
|
}
|
|
`;
|
|
}
|
|
|
|
|
|
class TextureManager {
|
|
constructor(gpgpu) {
|
|
this.gpgpu = gpgpu;
|
|
this.numUsedTextures = 0;
|
|
this.numFreeTextures = 0;
|
|
this._numBytesAllocated = 0;
|
|
|
|
this._numBytesFree = 0;
|
|
this.freeTextures = {};
|
|
this.usedTextures = {};
|
|
this.logEnabled = false;
|
|
}
|
|
acquireTexture(shapeRC, usage, isPacked) {
|
|
const physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
|
|
const shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
|
|
if (!(shapeKey in this.freeTextures)) {
|
|
this.freeTextures[shapeKey] = [];
|
|
}
|
|
if (!(shapeKey in this.usedTextures)) {
|
|
this.usedTextures[shapeKey] = [];
|
|
}
|
|
const texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
|
|
if (this.freeTextures[shapeKey].length > 0) {
|
|
this.numFreeTextures--;
|
|
this.numUsedTextures++;
|
|
this._numBytesFree -= texBytes;
|
|
this.log();
|
|
const newTexture = this.freeTextures[shapeKey].pop();
|
|
this.usedTextures[shapeKey].push(newTexture);
|
|
return newTexture;
|
|
}
|
|
let newTexture;
|
|
if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
|
|
newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
|
|
newTexture =
|
|
this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
|
|
newTexture =
|
|
this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
|
|
newTexture =
|
|
this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
|
|
newTexture =
|
|
this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
this.usedTextures[shapeKey].push(newTexture);
|
|
this.numUsedTextures++;
|
|
this._numBytesAllocated += texBytes;
|
|
this.log();
|
|
return newTexture;
|
|
}
|
|
releaseTexture(texture, shape, logicalTexType, isPacked) {
|
|
if (this.freeTextures == null) {
|
|
|
|
return;
|
|
}
|
|
const physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
|
|
const shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
|
|
if (!(shapeKey in this.freeTextures)) {
|
|
this.freeTextures[shapeKey] = [];
|
|
}
|
|
const texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
|
|
const deleteTexThreshold = env()
|
|
.getNumber('WEBGL_DELETE_TEXTURE_THRESHOLD');
|
|
if (deleteTexThreshold !== -1 &&
|
|
this._numBytesAllocated > deleteTexThreshold) {
|
|
this.gpgpu.deleteMatrixTexture(texture.texture);
|
|
this._numBytesAllocated -= texBytes;
|
|
}
|
|
else {
|
|
this.freeTextures[shapeKey].push(texture);
|
|
this.numFreeTextures++;
|
|
this._numBytesFree += texBytes;
|
|
}
|
|
this.numUsedTextures--;
|
|
const texList = this.usedTextures[shapeKey];
|
|
const texIndex = texList && texList.indexOf(texture);
|
|
if (texIndex == null || texIndex < 0) {
|
|
throw new Error('Cannot release a texture that was never provided by this ' +
|
|
'texture manager');
|
|
}
|
|
texList[texIndex] = texList[texList.length - 1];
|
|
texList.pop();
|
|
this.log();
|
|
}
|
|
log() {
|
|
if (!this.logEnabled) {
|
|
return;
|
|
}
|
|
const total = this.numFreeTextures + this.numUsedTextures;
|
|
console.log('Free/Used', `${this.numFreeTextures} / ${this.numUsedTextures}`, `(${total})`);
|
|
const freeRatio = this._numBytesFree / this._numBytesAllocated;
|
|
console.log(`Bytes allocated: ${this._numBytesAllocated}`);
|
|
console.log(`Bytes unused: ${this._numBytesFree} (${Math.round(100 * freeRatio)}%)`);
|
|
}
|
|
get numBytesAllocated() {
|
|
return this._numBytesAllocated;
|
|
}
|
|
get numBytesFree() {
|
|
return this._numBytesFree;
|
|
}
|
|
getNumUsedTextures() {
|
|
return this.numUsedTextures;
|
|
}
|
|
getNumFreeTextures() {
|
|
return this.numFreeTextures;
|
|
}
|
|
dispose() {
|
|
if (this.freeTextures == null) {
|
|
|
|
return;
|
|
}
|
|
for (const texShape in this.freeTextures) {
|
|
this.freeTextures[texShape].forEach(tex => {
|
|
this.gpgpu.deleteMatrixTexture(tex.texture);
|
|
});
|
|
}
|
|
for (const texShape in this.usedTextures) {
|
|
this.usedTextures[texShape].forEach(tex => {
|
|
this.gpgpu.deleteMatrixTexture(tex.texture);
|
|
});
|
|
}
|
|
|
|
this.freeTextures = null;
|
|
this.usedTextures = null;
|
|
this.numUsedTextures = 0;
|
|
this.numFreeTextures = 0;
|
|
this._numBytesAllocated = 0;
|
|
this._numBytesFree = 0;
|
|
}
|
|
}
|
|
function numBytesForInternalFormat(gl, internalFormat) {
|
|
|
|
const glany = gl;
|
|
if (internalFormat === glany.R32F) {
|
|
return 4;
|
|
}
|
|
else if (internalFormat === glany.R16F) {
|
|
return 2;
|
|
}
|
|
else if (internalFormat === glany.RGBA32F) {
|
|
return 16;
|
|
}
|
|
else if (internalFormat === gl.RGBA) {
|
|
return 16;
|
|
}
|
|
else if (internalFormat === glany.RGBA16F) {
|
|
return 8;
|
|
}
|
|
else if (internalFormat === glany.RGBA8) {
|
|
return 4;
|
|
}
|
|
throw new Error(`Unknown internal format ${internalFormat}`);
|
|
}
|
|
function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
|
|
|
|
|
|
|
|
|
|
|
|
const internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
|
|
let numElements;
|
|
if (isPacked) {
|
|
const [packedWidth, packedHeight] = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
|
|
numElements = packedWidth * packedHeight;
|
|
}
|
|
else {
|
|
const [width, height] = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]);
|
|
numElements = width * height;
|
|
}
|
|
const bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
|
|
return numElements * bytesPerElement;
|
|
}
|
|
function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
|
|
switch (physicalTexType) {
|
|
case PhysicalTextureType.PACKED_2X2_FLOAT32:
|
|
return getInternalFormatForPackedMatrixTexture(textureConfig);
|
|
case PhysicalTextureType.PACKED_2X2_FLOAT16:
|
|
return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
|
|
case PhysicalTextureType.UNPACKED_FLOAT32:
|
|
return getInternalFormatForFloat32MatrixTexture(textureConfig);
|
|
case PhysicalTextureType.UNPACKED_FLOAT16:
|
|
return getInternalFormatForFloat16MatrixTexture(textureConfig);
|
|
case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
|
|
return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
|
|
default:
|
|
throw new Error(`Unknown physical texture type ${physicalTexType}`);
|
|
}
|
|
}
|
|
function getPhysicalTextureForRendering(isPacked) {
|
|
if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
|
|
if (isPacked) {
|
|
return PhysicalTextureType.PACKED_2X2_FLOAT32;
|
|
}
|
|
return PhysicalTextureType.UNPACKED_FLOAT32;
|
|
}
|
|
if (isPacked) {
|
|
return PhysicalTextureType.PACKED_2X2_FLOAT16;
|
|
}
|
|
return PhysicalTextureType.UNPACKED_FLOAT16;
|
|
}
|
|
function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
|
|
if (logicalTexType === TextureUsage.UPLOAD) {
|
|
return PhysicalTextureType.PACKED_2X2_FLOAT32;
|
|
}
|
|
else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
|
|
return getPhysicalTextureForRendering(isPacked);
|
|
}
|
|
else if (logicalTexType === TextureUsage.DOWNLOAD ||
|
|
logicalTexType === TextureUsage.PIXELS) {
|
|
return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
|
|
}
|
|
throw new Error(`Unknown logical texture type ${logicalTexType}`);
|
|
}
|
|
function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
|
|
return `${shapeRowsCol[0]}_${shapeRowsCol[1]}_${physicalTexType}_${isPacked}`;
|
|
}
|
|
|
|
|
|
class UnaryOpProgram {
|
|
constructor(aShape, opSnippet) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = aShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
this.userCode = `
|
|
float unaryOperation(float x) {
|
|
${opSnippet}
|
|
}
|
|
|
|
void main() {
|
|
float x = getAAtOutCoords();
|
|
float y = unaryOperation(x);
|
|
|
|
setOutput(y);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
const CHECK_NAN_SNIPPET$1 = `if (isnan(x)) return x;`;
|
|
const LINEAR$1 = `return x;`;
|
|
const ABS$1 = `return abs(x);`;
|
|
const ELU$2 = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
|
|
const RELU$2 = CHECK_NAN_SNIPPET$1 + `
|
|
return (x < 0.0) ? 0.0 : x;
|
|
`;
|
|
const RELU6$2 = CHECK_NAN_SNIPPET$1 + `
|
|
return (x < 0.0) ? 0.0 : min(6.0, x);
|
|
`;
|
|
const CLONE = 'return x;';
|
|
const SIGMOID$2 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
|
|
|
|
|
|
const LINEAR = `return x;`;
|
|
const ELU$1 = `
|
|
vec4 result;
|
|
|
|
result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
|
|
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
|
|
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
|
|
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
|
|
|
|
return result;
|
|
`;
|
|
const RELU$1 = `
|
|
vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
|
|
bvec4 isNaN = isnan(x);
|
|
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
const RELU6$1 = `
|
|
vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
|
|
bvec4 isNaN = isnan(x);
|
|
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
const SIGMOID$1 = `return 1.0 / (1.0 + exp(-1.0 * x));`;
|
|
class UnaryOpPackedProgram {
|
|
constructor(aShape, opSnippet) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = aShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
this.userCode = `
|
|
vec4 unaryOperation(vec4 x) {
|
|
${opSnippet}
|
|
}
|
|
|
|
void main() {
|
|
vec4 x = getAAtOutCoords();
|
|
vec4 y = unaryOperation(x);
|
|
|
|
setOutput(y);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class UnpackProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = false;
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const rank = outputShape.length;
|
|
const channels = getChannels('rc', rank);
|
|
const dtype = getCoordsDataType(rank);
|
|
const sourceCoords = getSourceCoords$2(rank, channels);
|
|
const innerDims = channels.slice(-2);
|
|
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} rc = getOutputCoords();
|
|
vec4 packedInput = getA(${sourceCoords});
|
|
|
|
setOutput(getChannel(packedInput, ${coords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const whereImpl$1 = whereImpl$2;
|
|
const EPSILON_FLOAT32 = 1e-7;
|
|
const EPSILON_FLOAT16 = 1e-4;
|
|
const binaryCaches = {};
|
|
function getBinaryCache(webGLVersion) {
|
|
if (webGLVersion in binaryCaches) {
|
|
return binaryCaches[webGLVersion];
|
|
}
|
|
binaryCaches[webGLVersion] = {};
|
|
return binaryCaches[webGLVersion];
|
|
}
|
|
|
|
|
|
const CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
|
|
|
|
|
|
|
|
const BEFORE_PAGING_CONSTANT = 600;
|
|
function numMBBeforeWarning() {
|
|
if (env().global.screen == null) {
|
|
return 1024;
|
|
}
|
|
return (env().global.screen.height * env().global.screen.width *
|
|
window.devicePixelRatio) *
|
|
BEFORE_PAGING_CONSTANT / 1024 / 1024;
|
|
}
|
|
class MathBackendWebGL extends KernelBackend {
|
|
nextDataId() {
|
|
return MathBackendWebGL.nextDataId++;
|
|
}
|
|
constructor(gpuResource) {
|
|
super();
|
|
|
|
this.pendingRead = new WeakMap();
|
|
|
|
|
|
this.pendingDisposal = new WeakSet();
|
|
|
|
|
|
this.dataRefCount = new WeakMap();
|
|
this.numBytesInGPU = 0;
|
|
|
|
this.uploadWaitMs = 0;
|
|
|
|
this.downloadWaitMs = 0;
|
|
|
|
this.lastGlFlushTime = 0;
|
|
this.warnedAboutMemory = false;
|
|
this.pendingDeletes = 0;
|
|
this.disposed = false;
|
|
if (!env().getBool('HAS_WEBGL')) {
|
|
throw new Error('WebGL is not supported on this device');
|
|
}
|
|
let newGPGPU;
|
|
if (gpuResource != null) {
|
|
if (gpuResource instanceof GPGPUContext) {
|
|
newGPGPU = gpuResource;
|
|
}
|
|
else {
|
|
const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'), gpuResource);
|
|
newGPGPU = new GPGPUContext(gl);
|
|
}
|
|
this.binaryCache = {};
|
|
this.gpgpuCreatedLocally = false;
|
|
}
|
|
else {
|
|
const gl = getWebGLContext(env().getNumber('WEBGL_VERSION'));
|
|
newGPGPU = new GPGPUContext(gl);
|
|
this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION'));
|
|
this.gpgpuCreatedLocally = true;
|
|
}
|
|
this.gpgpu = newGPGPU;
|
|
this.canvas = this.gpgpu.gl.canvas;
|
|
this.textureManager = new TextureManager(this.gpgpu);
|
|
this.numMBBeforeWarning = numMBBeforeWarning();
|
|
this.texData = new DataStorage(this, engine());
|
|
}
|
|
numDataIds() {
|
|
return this.texData.numDataIds() - this.pendingDeletes;
|
|
}
|
|
|
|
|
|
writeTexture(texture, shape, dtype, texHeight, texWidth, channels) {
|
|
|
|
|
|
const input = this.makeTensorInfo(shape, dtype);
|
|
const inData = this.texData.get(input.dataId);
|
|
|
|
|
|
inData.isPacked = false;
|
|
|
|
inData.texture = { texture, texShape: [texHeight, texWidth] };
|
|
inData.texShape = [texHeight, texWidth];
|
|
const shapeAs3D = getShapeAs3D(shape);
|
|
const program = new EncodeMatrixProgram(shapeAs3D, false , channels);
|
|
const output = this.runWebGLProgram(program, [input], dtype, [[texHeight, texWidth]]);
|
|
output.shape = shape;
|
|
|
|
|
|
inData.texture = null;
|
|
this.disposeIntermediateTensorInfo(input);
|
|
return output.dataId;
|
|
}
|
|
write(values, shape, dtype) {
|
|
if (env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') ||
|
|
env().getBool('DEBUG')) {
|
|
this.checkNumericalProblems(values);
|
|
}
|
|
if (dtype === 'complex64' && values != null) {
|
|
throw new Error(`Cannot write to a complex64 dtype. ` +
|
|
`Please use tf.complex(real, imag).`);
|
|
}
|
|
const dataId = { id: this.nextDataId() };
|
|
this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount: 1 });
|
|
return dataId;
|
|
}
|
|
|
|
refCount(dataId) {
|
|
if (this.texData.has(dataId)) {
|
|
const tensorData = this.texData.get(dataId);
|
|
return tensorData.refCount;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
incRef(dataId) {
|
|
const texData = this.texData.get(dataId);
|
|
texData.refCount++;
|
|
}
|
|
|
|
decRef(dataId) {
|
|
if (this.texData.has(dataId)) {
|
|
const texData = this.texData.get(dataId);
|
|
texData.refCount--;
|
|
}
|
|
}
|
|
move(dataId, values, shape, dtype, refCount) {
|
|
if (env().getBool('DEBUG')) {
|
|
this.checkNumericalProblems(values);
|
|
}
|
|
if (dtype === 'complex64') {
|
|
throw new Error(`Cannot write to a complex64 dtype. ` +
|
|
`Please use tf.complex(real, imag).`);
|
|
}
|
|
this.texData.set(dataId, { shape, dtype, values, usage: TextureUsage.UPLOAD, refCount });
|
|
}
|
|
disposeIntermediateTensorInfo(tensorInfo) {
|
|
this.disposeData(tensorInfo.dataId);
|
|
}
|
|
readSync(dataId) {
|
|
const texData = this.texData.get(dataId);
|
|
const { values, dtype, complexTensorInfos, slice, shape, isPacked } = texData;
|
|
|
|
|
|
|
|
if (slice != null) {
|
|
let program;
|
|
if (isPacked) {
|
|
program = new UnaryOpPackedProgram(shape, CLONE);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(shape, CLONE);
|
|
}
|
|
const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
|
|
const data = this.readSync(res.dataId);
|
|
this.disposeIntermediateTensorInfo(res);
|
|
return data;
|
|
}
|
|
if (values != null) {
|
|
return this.convertAndCacheOnCPU(dataId);
|
|
}
|
|
if (dtype === 'string') {
|
|
return values;
|
|
}
|
|
const shouldTimeProgram = this.activeTimers != null;
|
|
let start;
|
|
if (shouldTimeProgram) {
|
|
start = now();
|
|
}
|
|
let result;
|
|
if (dtype === 'complex64') {
|
|
const realValues = this.readSync(complexTensorInfos.real.dataId);
|
|
const imagValues = this.readSync(complexTensorInfos.imag.dataId);
|
|
result = mergeRealAndImagArrays(realValues, imagValues);
|
|
}
|
|
else {
|
|
result = this.getValuesFromTexture(dataId);
|
|
}
|
|
if (shouldTimeProgram) {
|
|
this.downloadWaitMs += now() - start;
|
|
}
|
|
return this.convertAndCacheOnCPU(dataId, result);
|
|
}
|
|
async read(dataId) {
|
|
if (this.pendingRead.has(dataId)) {
|
|
const subscribers = this.pendingRead.get(dataId);
|
|
return new Promise(resolve => subscribers.push(resolve));
|
|
}
|
|
const texData = this.texData.get(dataId);
|
|
const { values, shape, slice, dtype, complexTensorInfos, isPacked } = texData;
|
|
|
|
|
|
|
|
if (slice != null) {
|
|
let program;
|
|
if (isPacked) {
|
|
program = new UnaryOpPackedProgram(shape, CLONE);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(shape, CLONE);
|
|
}
|
|
const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
|
|
const data = this.read(res.dataId);
|
|
this.disposeIntermediateTensorInfo(res);
|
|
return data;
|
|
}
|
|
if (values != null) {
|
|
return this.convertAndCacheOnCPU(dataId);
|
|
}
|
|
if (env().getBool('DEBUG')) {
|
|
|
|
|
|
|
|
if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
|
|
env().getNumber('WEBGL_VERSION') === 2) {
|
|
throw new Error(`tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and ` +
|
|
`WEBGL_VERSION=2 not yet supported.`);
|
|
}
|
|
}
|
|
let buffer = null;
|
|
let tmpDownloadTarget;
|
|
if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) {
|
|
|
|
tmpDownloadTarget = this.decode(dataId);
|
|
const tmpData = this.texData.get(tmpDownloadTarget.dataId);
|
|
buffer = this.gpgpu.createBufferFromTexture(tmpData.texture.texture, ...getDenseTexShape(shape));
|
|
}
|
|
this.pendingRead.set(dataId, []);
|
|
if (dtype !== 'complex64') {
|
|
|
|
await this.gpgpu.createAndWaitForFence();
|
|
}
|
|
|
|
let vals;
|
|
if (dtype === 'complex64') {
|
|
const ps = await Promise.all([
|
|
this.read(complexTensorInfos.real.dataId),
|
|
this.read(complexTensorInfos.imag.dataId)
|
|
]);
|
|
const realValues = ps[0];
|
|
const imagValues = ps[1];
|
|
vals = mergeRealAndImagArrays(realValues, imagValues);
|
|
}
|
|
else if (buffer == null) {
|
|
vals = this.getValuesFromTexture(dataId);
|
|
}
|
|
else {
|
|
const size = sizeFromShape(shape);
|
|
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
|
|
}
|
|
if (tmpDownloadTarget != null) {
|
|
this.disposeIntermediateTensorInfo(tmpDownloadTarget);
|
|
}
|
|
if (buffer != null) {
|
|
const gl = this.gpgpu.gl;
|
|
callAndCheck(gl, () => gl.deleteBuffer(buffer));
|
|
}
|
|
const dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
|
|
const subscribers = this.pendingRead.get(dataId);
|
|
this.pendingRead.delete(dataId);
|
|
|
|
subscribers.forEach(resolve => resolve(dTypeVals));
|
|
if (this.pendingDisposal.has(dataId)) {
|
|
this.pendingDisposal.delete(dataId);
|
|
if (this.disposeData(dataId)) {
|
|
engine().removeDataId(dataId, this);
|
|
}
|
|
this.pendingDeletes--;
|
|
}
|
|
return dTypeVals;
|
|
}
|
|
|
|
readToGPU(dataId, options = {}) {
|
|
const texData = this.texData.get(dataId);
|
|
const { values, shape, slice, dtype, isPacked, texture } = texData;
|
|
if (dtype === 'complex64') {
|
|
throw new Error('Does not support reading texture for complex64 dtype.');
|
|
}
|
|
|
|
|
|
|
|
if (slice != null) {
|
|
let program;
|
|
if (isPacked) {
|
|
program = new UnaryOpPackedProgram(shape, CLONE);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(shape, CLONE);
|
|
}
|
|
const res = this.runWebGLProgram(program, [{ dataId, shape, dtype }], dtype);
|
|
const gpuResouorce = this.readToGPU(res, options);
|
|
this.disposeIntermediateTensorInfo(res);
|
|
return gpuResouorce;
|
|
}
|
|
if (texture == null) {
|
|
if (values != null) {
|
|
throw new Error('Data is not on GPU but on CPU.');
|
|
}
|
|
else {
|
|
throw new Error('There is no data on GPU or CPU.');
|
|
}
|
|
}
|
|
|
|
const tmpTarget = this.decode(dataId, options.customTexShape);
|
|
|
|
const tensorRef = engine().makeTensorFromTensorInfo(tmpTarget);
|
|
const tmpData = this.texData.get(tmpTarget.dataId);
|
|
return Object.assign({ tensorRef }, tmpData.texture);
|
|
}
|
|
bufferSync(t) {
|
|
const data = this.readSync(t.dataId);
|
|
if (t.dtype === 'string') {
|
|
try {
|
|
|
|
const strings = data.map(d => decodeString(d));
|
|
return buffer(t.shape, t.dtype, strings);
|
|
}
|
|
catch (_a) {
|
|
throw new Error('Failed to decode encoded string bytes into utf-8');
|
|
}
|
|
}
|
|
return buffer(t.shape, t.dtype, data);
|
|
}
|
|
checkNumericalProblems(values) {
|
|
if (values == null) {
|
|
return;
|
|
}
|
|
for (let i = 0; i < values.length; i++) {
|
|
const num = values[i];
|
|
if (!canBeRepresented(num)) {
|
|
if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
|
|
throw Error(`The value ${num} cannot be represented with your ` +
|
|
`current settings. Consider enabling float32 rendering: ` +
|
|
`'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'`);
|
|
}
|
|
throw Error(`The value ${num} cannot be represented on this device.`);
|
|
}
|
|
}
|
|
}
|
|
getValuesFromTexture(dataId) {
|
|
const { shape, dtype, isPacked } = this.texData.get(dataId);
|
|
const size = sizeFromShape(shape);
|
|
if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
|
|
const tmpTarget = this.decode(dataId);
|
|
const tmpData = this.texData.get(tmpTarget.dataId);
|
|
const vals = this.gpgpu
|
|
.downloadMatrixFromPackedTexture(tmpData.texture.texture, ...getDenseTexShape(shape))
|
|
.subarray(0, size);
|
|
this.disposeIntermediateTensorInfo(tmpTarget);
|
|
return vals;
|
|
}
|
|
const shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true;
|
|
const outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
|
|
const program = shouldUsePackedProgram ?
|
|
new EncodeFloatPackedProgram(outputShape) :
|
|
new EncodeFloatProgram(outputShape);
|
|
const output = this.runWebGLProgram(program, [{ shape: outputShape, dtype, dataId }], 'float32');
|
|
const tmpData = this.texData.get(output.dataId);
|
|
const vals = this.gpgpu
|
|
.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture.texture, tmpData.texShape[0], tmpData.texShape[1])
|
|
.subarray(0, size);
|
|
this.disposeIntermediateTensorInfo(output);
|
|
return vals;
|
|
}
|
|
timerAvailable() {
|
|
return env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
|
|
}
|
|
time(f) {
|
|
const oldActiveTimers = this.activeTimers;
|
|
const newActiveTimers = [];
|
|
let outerMostTime = false;
|
|
if (this.programTimersStack == null) {
|
|
this.programTimersStack = newActiveTimers;
|
|
outerMostTime = true;
|
|
}
|
|
else {
|
|
this.activeTimers.push(newActiveTimers);
|
|
}
|
|
this.activeTimers = newActiveTimers;
|
|
f();
|
|
|
|
const flattenedActiveTimerQueries = flatten$1(this.activeTimers.map((d) => d.query))
|
|
.filter(d => d != null);
|
|
const flattenedActiveTimerNames = flatten$1(this.activeTimers.map((d) => d.name))
|
|
.filter(d => d != null);
|
|
this.activeTimers = oldActiveTimers;
|
|
if (outerMostTime) {
|
|
this.programTimersStack = null;
|
|
}
|
|
const res = {
|
|
uploadWaitMs: this.uploadWaitMs,
|
|
downloadWaitMs: this.downloadWaitMs,
|
|
kernelMs: null,
|
|
wallMs: null
|
|
};
|
|
return (async () => {
|
|
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') >
|
|
0) {
|
|
const kernelMs = await Promise.all(flattenedActiveTimerQueries);
|
|
res['kernelMs'] = sum$3(kernelMs);
|
|
res['getExtraProfileInfo'] = () => kernelMs
|
|
.map((d, i) => ({ name: flattenedActiveTimerNames[i], ms: d }))
|
|
.map(d => `${d.name}: ${d.ms}`)
|
|
.join(', ');
|
|
}
|
|
else {
|
|
res['kernelMs'] = {
|
|
error: 'WebGL query timers are not supported in this environment.'
|
|
};
|
|
}
|
|
this.uploadWaitMs = 0;
|
|
this.downloadWaitMs = 0;
|
|
return res;
|
|
})();
|
|
}
|
|
memory() {
|
|
return {
|
|
unreliable: false,
|
|
numBytesInGPU: this.numBytesInGPU,
|
|
numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
|
|
numBytesInGPUFree: this.textureManager.numBytesFree
|
|
};
|
|
}
|
|
startTimer() {
|
|
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
|
|
return this.gpgpu.beginQuery();
|
|
}
|
|
return { startMs: now(), endMs: null };
|
|
}
|
|
endTimer(query) {
|
|
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
|
|
this.gpgpu.endQuery();
|
|
return query;
|
|
}
|
|
query.endMs = now();
|
|
return query;
|
|
}
|
|
async getQueryTime(query) {
|
|
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
|
|
return this.gpgpu.waitForQueryAndGetTime(query);
|
|
}
|
|
const timerQuery = query;
|
|
return timerQuery.endMs - timerQuery.startMs;
|
|
}
|
|
|
|
disposeData(dataId, force = false) {
|
|
if (this.pendingDisposal.has(dataId)) {
|
|
return false;
|
|
}
|
|
|
|
if (!this.texData.has(dataId)) {
|
|
return true;
|
|
}
|
|
|
|
|
|
|
|
if (force) {
|
|
this.texData.get(dataId).refCount = 0;
|
|
}
|
|
else {
|
|
this.texData.get(dataId).refCount--;
|
|
}
|
|
if (!force && this.texData.get(dataId).refCount > 0) {
|
|
return false;
|
|
}
|
|
if (this.pendingRead.has(dataId)) {
|
|
this.pendingDisposal.add(dataId);
|
|
this.pendingDeletes++;
|
|
return false;
|
|
}
|
|
this.releaseGPUData(dataId);
|
|
const { complexTensorInfos } = this.texData.get(dataId);
|
|
if (complexTensorInfos != null) {
|
|
this.disposeData(complexTensorInfos.real.dataId, force);
|
|
this.disposeData(complexTensorInfos.imag.dataId, force);
|
|
}
|
|
this.texData.delete(dataId);
|
|
return true;
|
|
}
|
|
releaseGPUData(dataId) {
|
|
const { texture, dtype, texShape, usage, isPacked, slice } = this.texData.get(dataId);
|
|
const key = slice && slice.origDataId || dataId;
|
|
const refCount = this.dataRefCount.get(key);
|
|
if (refCount > 1) {
|
|
this.dataRefCount.set(key, refCount - 1);
|
|
}
|
|
else {
|
|
this.dataRefCount.delete(key);
|
|
if (texture != null) {
|
|
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
|
|
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
|
|
}
|
|
}
|
|
const texData = this.texData.get(dataId);
|
|
texData.texture = null;
|
|
texData.texShape = null;
|
|
texData.isPacked = false;
|
|
texData.slice = null;
|
|
}
|
|
getTexture(dataId) {
|
|
this.uploadToGPU(dataId);
|
|
return this.texData.get(dataId).texture.texture;
|
|
}
|
|
|
|
getDataInfo(dataId) {
|
|
return this.texData.get(dataId);
|
|
}
|
|
|
|
shouldExecuteOnCPU(inputs, sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD) {
|
|
return env().getBool('WEBGL_CPU_FORWARD') &&
|
|
inputs.every(input => this.texData.get(input.dataId).texture == null &&
|
|
sizeFromShape(input.shape) < sizeThreshold);
|
|
}
|
|
getGPGPUContext() {
|
|
return this.gpgpu;
|
|
}
|
|
where(condition) {
|
|
warn('tf.where() in webgl locks the UI thread. ' +
|
|
'Call tf.whereAsync() instead');
|
|
const condVals = condition.dataSync();
|
|
return whereImpl$1(condition.shape, condVals);
|
|
}
|
|
packedUnaryOp(x, op, dtype) {
|
|
const program = new UnaryOpPackedProgram(x.shape, op);
|
|
const outInfo = this.compileAndRun(program, [x], dtype);
|
|
return engine().makeTensorFromTensorInfo(outInfo);
|
|
}
|
|
|
|
|
|
|
|
abs(x) {
|
|
|
|
if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
|
|
const outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
|
|
return this.makeOutput(x.shape, x.dtype, outValues);
|
|
}
|
|
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
|
|
return this.packedUnaryOp(x, ABS$1, x.dtype);
|
|
}
|
|
const program = new UnaryOpProgram(x.shape, ABS$1);
|
|
const outInfo = this.compileAndRun(program, [x]);
|
|
return engine().makeTensorFromTensorInfo(outInfo);
|
|
}
|
|
makeTensorInfo(shape, dtype, values) {
|
|
let dataId;
|
|
if (dtype === 'string' && values != null && values.length > 0 &&
|
|
isString(values[0])) {
|
|
const encodedValues = values.map(d => encodeString(d));
|
|
dataId = this.write(encodedValues, shape, dtype);
|
|
}
|
|
else {
|
|
dataId = this.write(values, shape, dtype);
|
|
}
|
|
this.texData.get(dataId).usage = null;
|
|
return { dataId, shape, dtype };
|
|
}
|
|
makeOutput(shape, dtype, values) {
|
|
return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
|
|
}
|
|
unpackTensor(input) {
|
|
const program = new UnpackProgram(input.shape);
|
|
return this.runWebGLProgram(program, [input], input.dtype);
|
|
}
|
|
packTensor(input) {
|
|
const program = new PackProgram(input.shape);
|
|
const preventEagerUnpackingOutput = true;
|
|
return this.runWebGLProgram(program, [input], input.dtype, null , preventEagerUnpackingOutput);
|
|
}
|
|
packedReshape(input, afterShape) {
|
|
const input3DShape = [
|
|
getBatchDim(input.shape),
|
|
...getRowsCols(input.shape)
|
|
];
|
|
const input3D = {
|
|
dtype: input.dtype,
|
|
shape: input3DShape,
|
|
dataId: input.dataId
|
|
};
|
|
const afterShapeAs3D = [
|
|
getBatchDim(afterShape), ...getRowsCols(afterShape)
|
|
];
|
|
const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
|
|
const preventEagerUnpackingOfOutput = true;
|
|
const customValues = [input3DShape];
|
|
const output = this.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
|
|
return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
|
|
}
|
|
decode(dataId, customTexShape) {
|
|
const texData = this.texData.get(dataId);
|
|
const { isPacked, shape, dtype } = texData;
|
|
if (customTexShape != null) {
|
|
const size = sizeFromShape(shape);
|
|
const texSize = customTexShape[0] * customTexShape[1] * 4;
|
|
assert$1(size <= texSize, () => 'customTexShape is too small. ' +
|
|
'Row * Column * 4 should be equal or larger than the ' +
|
|
'size of the tensor data.');
|
|
}
|
|
const shapeAs3D = getShapeAs3D(shape);
|
|
let program;
|
|
if (isPacked) {
|
|
program = new DecodeMatrixPackedProgram(shapeAs3D);
|
|
}
|
|
else {
|
|
program = new DecodeMatrixProgram(shapeAs3D);
|
|
}
|
|
const preventEagerUnpackingOfOutput = true;
|
|
const customValues = [customTexShape != null ? customTexShape :
|
|
getDenseTexShape(shapeAs3D)];
|
|
const out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype, dataId }], dtype, customValues, preventEagerUnpackingOfOutput, customTexShape);
|
|
return { dtype, shape, dataId: out.dataId };
|
|
}
|
|
runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false, customTexShape) {
|
|
const output = this.makeTensorInfo(program.outputShape, outputDtype);
|
|
const outData = this.texData.get(output.dataId);
|
|
if (program.packedOutput) {
|
|
outData.isPacked = true;
|
|
}
|
|
if (program.outPackingScheme === PackingScheme.DENSE) {
|
|
const texelShape = customTexShape != null ?
|
|
customTexShape :
|
|
getDenseTexShape(program.outputShape);
|
|
|
|
|
|
|
|
|
|
outData.texShape = texelShape.map(d => d * 2);
|
|
}
|
|
if (program.outTexUsage != null) {
|
|
outData.usage = program.outTexUsage;
|
|
}
|
|
if (sizeFromShape(output.shape) === 0) {
|
|
|
|
|
|
outData.values =
|
|
getTypedArrayFromDType(output.dtype, 0);
|
|
return output;
|
|
}
|
|
const dataToDispose = [];
|
|
const inputsData = inputs.map(input => {
|
|
if (input.dtype === 'complex64') {
|
|
throw new Error(`GPGPUProgram does not support complex64 input. For complex64 ` +
|
|
`dtypes, please separate the program into real and imaginary ` +
|
|
`parts.`);
|
|
}
|
|
let texData = this.texData.get(input.dataId);
|
|
if (texData.texture == null) {
|
|
if (!program.packedInputs &&
|
|
sizeFromShape(input.shape) <=
|
|
env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
shape: input.shape,
|
|
texData: null,
|
|
isUniform: true,
|
|
uniformValues: texData.values
|
|
};
|
|
}
|
|
|
|
|
|
if (program.packedInputs) {
|
|
texData.isPacked = true;
|
|
texData.shape = input.shape;
|
|
}
|
|
}
|
|
this.uploadToGPU(input.dataId);
|
|
if (!!texData.isPacked !== !!program.packedInputs) {
|
|
input = texData.isPacked ? this.unpackTensor(input) :
|
|
this.packTensor(input);
|
|
dataToDispose.push(input);
|
|
texData = this.texData.get(input.dataId);
|
|
}
|
|
else if (texData.isPacked &&
|
|
!isReshapeFree(texData.shape, input.shape)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const savedInput = input;
|
|
const targetShape = input.shape;
|
|
input.shape = texData.shape;
|
|
input = this.packedReshape(input, targetShape);
|
|
dataToDispose.push(input);
|
|
texData = this.texData.get(input.dataId);
|
|
savedInput.shape = targetShape;
|
|
}
|
|
return { shape: input.shape, texData, isUniform: false };
|
|
});
|
|
this.uploadToGPU(output.dataId);
|
|
const outputData = { shape: output.shape, texData: outData, isUniform: false };
|
|
const key = makeShaderKey(program, inputsData, outputData);
|
|
const binary = this.getAndSaveBinary(key, () => {
|
|
return compileProgram(this.gpgpu, program, inputsData, outputData);
|
|
});
|
|
const shouldTimeProgram = this.activeTimers != null;
|
|
let query;
|
|
if (shouldTimeProgram) {
|
|
query = this.startTimer();
|
|
}
|
|
if (!env().get('ENGINE_COMPILE_ONLY')) {
|
|
runProgram(this.gpgpu, binary, inputsData, outputData, customUniformValues);
|
|
}
|
|
dataToDispose.forEach(info => this.disposeIntermediateTensorInfo(info));
|
|
if (shouldTimeProgram) {
|
|
query = this.endTimer(query);
|
|
this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
|
|
}
|
|
const glFlushThreshold = env().getNumber('WEBGL_FLUSH_THRESHOLD');
|
|
|
|
if (glFlushThreshold > 0) {
|
|
const time = now();
|
|
if ((time - this.lastGlFlushTime) > glFlushThreshold) {
|
|
this.gpgpu.gl.flush();
|
|
this.lastGlFlushTime = time;
|
|
}
|
|
}
|
|
if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
|
|
preventEagerUnpackingOfOutput === false) {
|
|
const unpacked = this.unpackTensor(output);
|
|
this.disposeIntermediateTensorInfo(output);
|
|
return unpacked;
|
|
}
|
|
return output;
|
|
}
|
|
compileAndRun(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput = false) {
|
|
outputDtype = outputDtype || inputs[0].dtype;
|
|
const outInfo = this.runWebGLProgram(program, inputs, outputDtype, customUniformValues, preventEagerUnpackingOfOutput);
|
|
return outInfo;
|
|
}
|
|
getAndSaveBinary(key, getBinary) {
|
|
if (!(key in this.binaryCache)) {
|
|
this.binaryCache[key] = getBinary();
|
|
}
|
|
return this.binaryCache[key];
|
|
}
|
|
getTextureManager() {
|
|
return this.textureManager;
|
|
}
|
|
dispose() {
|
|
if (this.disposed) {
|
|
return;
|
|
}
|
|
|
|
|
|
if (!env().getBool('IS_TEST')) {
|
|
const allKeys = Object.keys(this.binaryCache);
|
|
allKeys.forEach(key => {
|
|
this.gpgpu.deleteProgram(this.binaryCache[key].webGLProgram);
|
|
delete this.binaryCache[key];
|
|
});
|
|
}
|
|
this.textureManager.dispose();
|
|
if (this.canvas != null &&
|
|
(typeof (HTMLCanvasElement) !== 'undefined' &&
|
|
this.canvas instanceof HTMLCanvasElement)) {
|
|
this.canvas.remove();
|
|
}
|
|
else {
|
|
this.canvas = null;
|
|
}
|
|
if (this.gpgpuCreatedLocally) {
|
|
this.gpgpu.program = null;
|
|
this.gpgpu.dispose();
|
|
}
|
|
this.disposed = true;
|
|
}
|
|
floatPrecision() {
|
|
if (this.floatPrecisionValue == null) {
|
|
this.floatPrecisionValue = tidy(() => {
|
|
if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
|
|
|
|
|
|
const debugFlag = env().getBool('DEBUG');
|
|
env().set('DEBUG', false);
|
|
const underflowCheckValue = this.abs(scalar(1e-8)).dataSync()[0];
|
|
env().set('DEBUG', debugFlag);
|
|
if (underflowCheckValue > 0) {
|
|
return 32;
|
|
}
|
|
}
|
|
return 16;
|
|
});
|
|
}
|
|
return this.floatPrecisionValue;
|
|
}
|
|
|
|
epsilon() {
|
|
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
|
|
}
|
|
uploadToGPU(dataId) {
|
|
const texData = this.texData.get(dataId);
|
|
const { shape, dtype, values, texture, usage, isPacked } = texData;
|
|
if (texture != null) {
|
|
|
|
return;
|
|
}
|
|
const shouldTimeProgram = this.activeTimers != null;
|
|
let start;
|
|
if (shouldTimeProgram) {
|
|
start = now();
|
|
}
|
|
let texShape = texData.texShape;
|
|
if (texShape == null) {
|
|
|
|
|
|
texShape = getTextureShapeFromLogicalShape(shape, isPacked);
|
|
texData.texShape = texShape;
|
|
}
|
|
if (values != null) {
|
|
const shapeAs3D = getShapeAs3D(shape);
|
|
let program;
|
|
let width = texShape[1], height = texShape[0];
|
|
const isByteArray = values instanceof Uint8Array || values instanceof Uint8ClampedArray;
|
|
|
|
|
|
if (isPacked || !isByteArray) {
|
|
[width, height] = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]);
|
|
}
|
|
if (isPacked) {
|
|
program = new EncodeMatrixPackedProgram(shapeAs3D, isByteArray);
|
|
}
|
|
else {
|
|
program = new EncodeMatrixProgram(shapeAs3D, isByteArray);
|
|
}
|
|
|
|
|
|
|
|
const tempDenseInputTexShape = isByteArray ? [height, width] : texShape;
|
|
const tempDenseInputHandle = this.makeTensorInfo(tempDenseInputTexShape, dtype);
|
|
const tempDenseInputTexData = this.texData.get(tempDenseInputHandle.dataId);
|
|
if (isByteArray) {
|
|
tempDenseInputTexData.usage = TextureUsage.PIXELS;
|
|
}
|
|
else {
|
|
tempDenseInputTexData.usage = TextureUsage.UPLOAD;
|
|
}
|
|
tempDenseInputTexData.texShape = tempDenseInputTexShape;
|
|
this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
|
|
const customValues = [[height, width]];
|
|
|
|
|
|
const preventEagerUnpacking = true;
|
|
const encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, customValues, preventEagerUnpacking);
|
|
|
|
const outputTexData = this.texData.get(encodedOutputTarget.dataId);
|
|
texData.texShape = outputTexData.texShape;
|
|
texData.isPacked = outputTexData.isPacked;
|
|
texData.usage = outputTexData.usage;
|
|
if (!env().get('ENGINE_COMPILE_ONLY')) {
|
|
texData.texture = outputTexData.texture;
|
|
|
|
texData.values = null;
|
|
this.texData.delete(encodedOutputTarget.dataId);
|
|
}
|
|
else {
|
|
this.disposeData(encodedOutputTarget.dataId);
|
|
}
|
|
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
|
|
if (shouldTimeProgram) {
|
|
this.uploadWaitMs += now() - start;
|
|
}
|
|
}
|
|
else {
|
|
const newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
|
|
texData.texture = newTexture;
|
|
}
|
|
}
|
|
convertAndCacheOnCPU(dataId, float32Values) {
|
|
const texData = this.texData.get(dataId);
|
|
const { dtype } = texData;
|
|
if (float32Values != null) {
|
|
texData.values = float32ToTypedArray(float32Values, dtype);
|
|
}
|
|
return texData.values;
|
|
}
|
|
acquireTexture(texShape, texType, dtype, isPacked) {
|
|
this.numBytesInGPU += this.computeBytes(texShape, dtype);
|
|
if (!this.warnedAboutMemory &&
|
|
this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
|
|
const mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
|
|
this.warnedAboutMemory = true;
|
|
console.warn(`High memory usage in GPU: ${mb} MB, ` +
|
|
`most likely due to a memory leak`);
|
|
}
|
|
return this.textureManager.acquireTexture(texShape, texType, isPacked);
|
|
}
|
|
computeBytes(shape, dtype) {
|
|
return shape[0] * shape[1] * bytesPerElement(dtype);
|
|
}
|
|
checkCompileCompletion() {
|
|
for (const [, binary] of Object.entries(this.binaryCache)) {
|
|
this.checkCompletion_(binary);
|
|
}
|
|
}
|
|
async checkCompileCompletionAsync() {
|
|
const ps = [];
|
|
if (this.gpgpu.parallelCompilationExtension) {
|
|
for (const [, binary] of Object.entries(this.binaryCache)) {
|
|
ps.push(this.checkCompletionAsync_(binary));
|
|
}
|
|
return Promise.all(ps);
|
|
}
|
|
else {
|
|
for (const [, binary] of Object.entries(this.binaryCache)) {
|
|
const p = new Promise((resolve) => {
|
|
try {
|
|
this.checkCompletion_(binary);
|
|
resolve(true);
|
|
}
|
|
catch (error) {
|
|
throw error;
|
|
}
|
|
});
|
|
ps.push(p);
|
|
}
|
|
return Promise.all(ps);
|
|
}
|
|
}
|
|
async checkCompletionAsync_(binary) {
|
|
if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.parallelCompilationExtension.COMPLETION_STATUS_KHR)) {
|
|
return this.checkCompletion_(binary);
|
|
}
|
|
else {
|
|
await nextFrame();
|
|
return this.checkCompletionAsync_(binary);
|
|
}
|
|
}
|
|
checkCompletion_(binary) {
|
|
if (this.gpgpu.gl.getProgramParameter(binary.webGLProgram, this.gpgpu.gl.LINK_STATUS) === false) {
|
|
console.log(this.gpgpu.gl.getProgramInfoLog(binary.webGLProgram));
|
|
if (this.gpgpu.gl.getShaderParameter(binary.fragmentShader, this.gpgpu.gl.COMPILE_STATUS) === false) {
|
|
logShaderSourceAndInfoLog(binary.source, this.gpgpu.gl.getShaderInfoLog(binary.fragmentShader));
|
|
throw new Error('Failed to compile fragment shader.');
|
|
}
|
|
throw new Error('Failed to link vertex and fragment shaders.');
|
|
}
|
|
return true;
|
|
}
|
|
getUniformLocations() {
|
|
for (const binary of Object.values(this.binaryCache)) {
|
|
|
|
|
|
|
|
|
|
this.gpgpu.buildVao(binary.webGLProgram);
|
|
const { variablesLocations, customUniformLocations, infLoc, nanLoc, outShapeLocation, outShapeStridesLocation, outTexShapeLocation } = getUniformLocations(this.gpgpu, binary.program, binary.webGLProgram);
|
|
binary.variablesLocations = variablesLocations;
|
|
binary.customUniformLocations = customUniformLocations;
|
|
binary.infLoc = infLoc;
|
|
binary.nanLoc = nanLoc;
|
|
binary.outShapeLocation = outShapeLocation;
|
|
binary.outShapeStridesLocation = outShapeStridesLocation;
|
|
binary.outTexShapeLocation = outTexShapeLocation;
|
|
}
|
|
}
|
|
|
|
createTensorFromGPUData(values, shape, dtype) {
|
|
values.channels = values.channels || 'RGBA';
|
|
const { texture, height, width, channels } = values;
|
|
const backend = engine().backend;
|
|
|
|
|
|
if (!backend.gpgpu.gl.isTexture(texture)) {
|
|
throw new Error(`The texture is invalid. Also, please make sure the texture and ` +
|
|
`the TFJS WebGL backend are using the same canvas. If you want to ` +
|
|
`use your own custom canvas, you have to create and use the custom ` +
|
|
`TFJS WebGL backend created from the canvas through ` +
|
|
`'new tf.MathBackendWebGL(customCanvas)'.`);
|
|
}
|
|
const dataId = backend.writeTexture(texture, shape, dtype, height, width, channels);
|
|
return engine().makeTensorFromDataId(dataId, shape, dtype, backend);
|
|
}
|
|
}
|
|
MathBackendWebGL.nextDataId = 0;
|
|
function float32ToTypedArray(a, dtype) {
|
|
if (dtype === 'float32' || dtype === 'complex64') {
|
|
return a;
|
|
}
|
|
else if (dtype === 'int32' || dtype === 'bool') {
|
|
const result = (dtype === 'int32') ? new Int32Array(a.length) :
|
|
new Uint8Array(a.length);
|
|
for (let i = 0; i < result.length; ++i) {
|
|
result[i] = Math.round(a[i]);
|
|
}
|
|
return result;
|
|
}
|
|
else {
|
|
throw new Error(`Unknown dtype ${dtype}`);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (isBrowser()) {
|
|
registerBackend('webgl', () => new MathBackendWebGL(), 2 );
|
|
}
|
|
|
|
|
|
const CHECK_NAN_SNIPPET = `
|
|
if (isnan(a)) return a;
|
|
if (isnan(b)) return b;
|
|
`;
|
|
class BinaryOpProgram {
|
|
constructor(op, aShape, bShape) {
|
|
this.variableNames = ['A', 'B'];
|
|
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
this.userCode = `
|
|
float binaryOperation(float a, float b) {
|
|
${op}
|
|
}
|
|
|
|
void main() {
|
|
float a = getAAtOutCoords();
|
|
float b = getBAtOutCoords();
|
|
setOutput(binaryOperation(a, b));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const CHECK_NAN_SNIPPET_PACKED = `
|
|
result.r = isNaN.r ? NAN : result.r;
|
|
result.g = isNaN.g ? NAN : result.g;
|
|
result.b = isNaN.b ? NAN : result.b;
|
|
result.a = isNaN.a ? NAN : result.a;
|
|
`;
|
|
class BinaryOpPackedProgram {
|
|
constructor(op, aShape, bShape, checkOutOfBounds = false) {
|
|
this.variableNames = ['A', 'B'];
|
|
this.supportsBroadcasting = true;
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
|
|
const rank = this.outputShape.length;
|
|
this.enableShapeUniforms = useShapeUniforms(rank);
|
|
let checkOutOfBoundsString = '';
|
|
if (checkOutOfBounds) {
|
|
if (rank === 0 || sizeFromShape(this.outputShape) === 1) {
|
|
checkOutOfBoundsString = `
|
|
result.y = 0.;
|
|
result.z = 0.;
|
|
result.w = 0.;
|
|
`;
|
|
}
|
|
else {
|
|
const dtype = getCoordsDataType(rank);
|
|
checkOutOfBoundsString = `
|
|
${dtype} coords = getOutputCoords();
|
|
`;
|
|
if (rank === 1) {
|
|
if (this.enableShapeUniforms) {
|
|
checkOutOfBoundsString += `
|
|
result.y = (coords + 1) >= outShape ? 0. : result.y;
|
|
result.z = 0.;
|
|
result.w = 0.;
|
|
`;
|
|
}
|
|
else {
|
|
checkOutOfBoundsString += `
|
|
result.y = (coords + 1) >= ${this.outputShape[0]} ? 0. : result.y;
|
|
result.z = 0.;
|
|
result.w = 0.;
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
const channels = getChannels('coords', rank);
|
|
if (this.enableShapeUniforms) {
|
|
checkOutOfBoundsString += `
|
|
bool nextRowOutOfBounds =
|
|
(${channels[rank - 2]} + 1) >= outShape[${rank} - 2];
|
|
bool nextColOutOfBounds =
|
|
(${channels[rank - 1]} + 1) >= outShape[${rank} - 1];
|
|
result.y = nextColOutOfBounds ? 0. : result.y;
|
|
result.z = nextRowOutOfBounds ? 0. : result.z;
|
|
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
|
`;
|
|
}
|
|
else {
|
|
checkOutOfBoundsString += `
|
|
bool nextRowOutOfBounds =
|
|
(${channels[rank - 2]} + 1) >= ${this.outputShape[rank - 2]};
|
|
bool nextColOutOfBounds =
|
|
(${channels[rank - 1]} + 1) >= ${this.outputShape[rank - 1]};
|
|
result.y = nextColOutOfBounds ? 0. : result.y;
|
|
result.z = nextRowOutOfBounds ? 0. : result.z;
|
|
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
this.userCode = `
|
|
vec4 binaryOperation(vec4 a, vec4 b) {
|
|
${op}
|
|
}
|
|
|
|
void main() {
|
|
vec4 a = getAAtOutCoords();
|
|
vec4 b = getBAtOutCoords();
|
|
|
|
vec4 result = binaryOperation(a, b);
|
|
${checkOutOfBoundsString}
|
|
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function identity(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
backend.incRef(x.dataId);
|
|
return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
|
|
}
|
|
const identityConfig = {
|
|
kernelName: Identity$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: identity
|
|
};
|
|
|
|
|
|
|
|
function complex(args) {
|
|
const { inputs, backend } = args;
|
|
const { real, imag } = inputs;
|
|
const complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
|
|
const complex = backend.texData.get(complexInfo.dataId);
|
|
const realTensorInfo = identity({ inputs: { x: real }, backend });
|
|
const imagTensorInfo = identity({ inputs: { x: imag }, backend });
|
|
complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo };
|
|
return complexInfo;
|
|
}
|
|
const complexConfig = {
|
|
kernelName: Complex,
|
|
backendName: 'webgl',
|
|
kernelFunc: complex
|
|
};
|
|
|
|
|
|
const LEAKYRELU = `return (a < 0.) ? b * a : a;`;
|
|
const LEAKYRELU_PACKED = `
|
|
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
|
|
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
|
|
`;
|
|
function leakyRelu$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { alpha } = attrs;
|
|
const $alpha = backend.makeTensorInfo([], 'float32', createScalarValue(alpha, 'float32'));
|
|
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
|
|
new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) :
|
|
new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
|
|
const result = backend.runWebGLProgram(program, [x, $alpha], 'float32');
|
|
backend.disposeIntermediateTensorInfo($alpha);
|
|
return result;
|
|
}
|
|
const leakyReluConfig$1 = {
|
|
kernelName: LeakyRelu,
|
|
backendName: 'webgl',
|
|
kernelFunc: leakyRelu$1
|
|
};
|
|
|
|
|
|
const PRELU = `return (a < 0.) ? b * a : a;`;
|
|
const PRELU_PACKED = `
|
|
vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));
|
|
return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);
|
|
`;
|
|
function prelu$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x, alpha } = inputs;
|
|
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
|
|
new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) :
|
|
new BinaryOpProgram(PRELU, x.shape, alpha.shape);
|
|
return backend.runWebGLProgram(program, [x, alpha], 'float32');
|
|
}
|
|
const preluConfig$1 = {
|
|
kernelName: Prelu,
|
|
backendName: 'webgl',
|
|
kernelFunc: prelu$1
|
|
};
|
|
|
|
|
|
const CHECK_NAN_SNIPPET_UNARY = `if (isnan(x)) return x;`;
|
|
|
|
function unaryKernelFunc({ opSnippet, packedOpSnippet, cpuKernelImpl, dtype }) {
|
|
return ({ inputs, backend }) => {
|
|
const { x } = inputs;
|
|
const webglBackend = backend;
|
|
const $dtype = dtype || x.dtype;
|
|
if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
|
|
const xData = webglBackend.texData.get(x.dataId);
|
|
const outValues = cpuKernelImpl(xData.values, $dtype);
|
|
return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
|
|
}
|
|
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
|
|
let program;
|
|
if (shouldUsePackedProgram) {
|
|
program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(x.shape, opSnippet);
|
|
}
|
|
return webglBackend.runWebGLProgram(program, [x], $dtype);
|
|
};
|
|
}
|
|
|
|
function binaryKernelFunc({ opSnippet, packedOpSnippet, checkOutOfBounds = false, supportsComplex = false, cpuKernelImpl, dtype }) {
|
|
return ({ inputs, backend }) => {
|
|
const { a, b } = inputs;
|
|
const webglBackend = backend;
|
|
if (supportsComplex && a.dtype === 'complex64') {
|
|
const aData = webglBackend.texData.get(a.dataId);
|
|
const bData = webglBackend.texData.get(b.dataId);
|
|
const [real, imag] = [
|
|
[aData.complexTensorInfos.real, bData.complexTensorInfos.real],
|
|
[aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
|
|
].map(complexParts => {
|
|
const [aPart, bPart] = complexParts;
|
|
const aHandle = {
|
|
dataId: aPart.dataId,
|
|
dtype: aPart.dtype,
|
|
shape: a.shape
|
|
};
|
|
const bHandle = {
|
|
dataId: bPart.dataId,
|
|
dtype: bPart.dtype,
|
|
shape: b.shape
|
|
};
|
|
const program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
|
|
return webglBackend.runWebGLProgram(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype));
|
|
});
|
|
const complexOutput = complex({ inputs: { real, imag }, backend: webglBackend });
|
|
webglBackend.disposeIntermediateTensorInfo(real);
|
|
webglBackend.disposeIntermediateTensorInfo(imag);
|
|
|
|
return complexOutput;
|
|
}
|
|
const $dtype = dtype || upcastType(a.dtype, b.dtype);
|
|
if ((a.dtype === 'string' || b.dtype === 'string' ||
|
|
webglBackend.shouldExecuteOnCPU([a, b])) &&
|
|
cpuKernelImpl != null) {
|
|
const aVals = webglBackend.texData.get(a.dataId).values;
|
|
const bVals = webglBackend.texData.get(b.dataId).values;
|
|
const decodedAVals = a.dtype === 'string' ?
|
|
|
|
fromUint8ToStringArray(aVals) :
|
|
aVals;
|
|
const decodedBVals = a.dtype === 'string' ?
|
|
|
|
fromUint8ToStringArray(bVals) :
|
|
bVals;
|
|
const [outValues, outShape] = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype);
|
|
const out = webglBackend.makeTensorInfo(outShape, $dtype);
|
|
const outData = webglBackend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
return out;
|
|
}
|
|
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') &&
|
|
packedOpSnippet != null;
|
|
let program;
|
|
if (shouldUsePackedProgram) {
|
|
program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
|
|
}
|
|
else {
|
|
program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
|
|
}
|
|
return webglBackend.runWebGLProgram(program, [a, b], $dtype);
|
|
};
|
|
}
|
|
function mapActivationToShaderProgram(activation, packed = false) {
|
|
if (activation === 'linear') {
|
|
if (packed) {
|
|
return LINEAR;
|
|
}
|
|
return LINEAR$1;
|
|
}
|
|
else if (activation === 'relu') {
|
|
if (packed) {
|
|
return RELU$1;
|
|
}
|
|
return RELU$2;
|
|
}
|
|
else if (activation === 'elu') {
|
|
if (packed) {
|
|
return ELU$1;
|
|
}
|
|
return ELU$2;
|
|
}
|
|
else if (activation === 'relu6') {
|
|
if (packed) {
|
|
return RELU6$1;
|
|
}
|
|
return RELU6$2;
|
|
}
|
|
else if (activation === 'prelu') {
|
|
if (packed) {
|
|
return PRELU_PACKED;
|
|
}
|
|
return PRELU;
|
|
}
|
|
else if (activation === 'leakyrelu') {
|
|
if (packed) {
|
|
return LEAKYRELU_PACKED;
|
|
}
|
|
return LEAKYRELU;
|
|
}
|
|
else if (activation === 'sigmoid') {
|
|
if (packed) {
|
|
return SIGMOID$1;
|
|
}
|
|
return SIGMOID$2;
|
|
}
|
|
throw new Error(`Activation ${activation} has not been implemented for the WebGL backend.`);
|
|
}
|
|
|
|
|
|
class MatMulPackedProgram {
|
|
constructor(aShape, bShape, outputShape, transposeA = false, transposeB = false, addBias = false, activation = null, hasPreluActivation = false, hasLeakyreluActivation = false) {
|
|
this.variableNames = ['matrixA', 'matrixB'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const sharedDim = transposeA ? aShape[1] : aShape[2];
|
|
const sharedDimensionPacked = Math.ceil(sharedDim / 2);
|
|
const aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
|
|
const bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
|
|
const aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
|
|
const bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
|
|
let activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = `vec4 activation(vec4 a) {
|
|
vec4 b = getPreluActivationWeightsAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else if (hasLeakyreluActivation) {
|
|
activationSnippet = `vec4 activation(vec4 a) {
|
|
vec4 b = getLeakyreluAlphaAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else {
|
|
activationSnippet = `vec4 activation(vec4 x) {
|
|
${activation}
|
|
}`;
|
|
}
|
|
applyActivationSnippet = `result = activation(result);`;
|
|
}
|
|
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyreluActivation) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
let batchASnippet = 'rc.x';
|
|
let batchBSnippet = 'rc.x';
|
|
if (aShape[0] < bShape[0]) {
|
|
batchASnippet = `imod(rc.x, ${aShape[0]})`;
|
|
}
|
|
else if (bShape[0] < aShape[0]) {
|
|
batchBSnippet = `imod(rc.x, ${bShape[0]})`;
|
|
}
|
|
this.userCode = `
|
|
${activationSnippet}
|
|
|
|
const float sharedDimension = ${sharedDimensionPacked}.0;
|
|
|
|
vec4 dot2x2ARowBCol(ivec3 rc) {
|
|
vec4 result = vec4(0);
|
|
int batchA = ${batchASnippet};
|
|
int batchB = ${batchBSnippet};
|
|
for (int i = 0; i < ${sharedDimensionPacked}; i++) {
|
|
vec4 a = getMatrixA(batchA, ${aSample});
|
|
vec4 b = getMatrixB(batchB, ${bSample});
|
|
|
|
|
|
|
|
result += (${aSwizzle[0]} * ${bSwizzle[0]});
|
|
result += (${aSwizzle[1]} * ${bSwizzle[1]});
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void main() {
|
|
ivec3 rc = getOutputCoords();
|
|
vec4 result = dot2x2ARowBCol(rc);
|
|
|
|
${addBiasSnippet}
|
|
|
|
${applyActivationSnippet}
|
|
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const COMPLEX_MULTIPLY = {
|
|
REAL: 'return areal * breal - aimag * bimag;',
|
|
IMAG: 'return areal * bimag + aimag * breal;'
|
|
};
|
|
class BinaryOpComplexProgram {
|
|
constructor(op, aShape, bShape) {
|
|
this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
|
|
this.outputShape = assertAndGetBroadcastShape(aShape, bShape);
|
|
this.userCode = `
|
|
float binaryOpComplex(
|
|
float areal, float aimag, float breal, float bimag) {
|
|
${op}
|
|
}
|
|
|
|
void main() {
|
|
float areal = getARealAtOutCoords();
|
|
float aimag = getAImagAtOutCoords();
|
|
float breal = getBRealAtOutCoords();
|
|
float bimag = getBImagAtOutCoords();
|
|
setOutput(binaryOpComplex(areal, aimag, breal, bimag));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const MUL = 'return a * b;';
|
|
function multiply(args) {
|
|
const { inputs, backend } = args;
|
|
const { a, b } = inputs;
|
|
const dtype = upcastType(a.dtype, b.dtype);
|
|
if (a.dtype === 'complex64') {
|
|
const aData = backend.texData.get(a.dataId);
|
|
const bData = backend.texData.get(b.dataId);
|
|
const realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
|
|
const imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
|
|
const inputs = [
|
|
{
|
|
dataId: aData.complexTensorInfos.real.dataId,
|
|
dtype: aData.complexTensorInfos.real.dtype,
|
|
shape: a.shape
|
|
},
|
|
{
|
|
dataId: aData.complexTensorInfos.imag.dataId,
|
|
dtype: aData.complexTensorInfos.imag.dtype,
|
|
shape: a.shape
|
|
},
|
|
{
|
|
dataId: bData.complexTensorInfos.real.dataId,
|
|
dtype: bData.complexTensorInfos.real.dtype,
|
|
shape: b.shape
|
|
},
|
|
{
|
|
dataId: bData.complexTensorInfos.imag.dataId,
|
|
dtype: bData.complexTensorInfos.imag.dtype,
|
|
shape: b.shape
|
|
}
|
|
];
|
|
const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
|
|
const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
|
|
const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
|
|
return complexOutput;
|
|
}
|
|
if (backend.shouldExecuteOnCPU([a, b])) {
|
|
const aData = backend.texData.get(a.dataId);
|
|
const bData = backend.texData.get(b.dataId);
|
|
const [outValues, outShape] = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype);
|
|
const out = backend.makeTensorInfo(outShape, dtype);
|
|
const outData = backend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
return out;
|
|
}
|
|
let program;
|
|
if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
|
|
program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
|
|
}
|
|
else {
|
|
program = new BinaryOpProgram(MUL, a.shape, b.shape);
|
|
}
|
|
return backend.runWebGLProgram(program, [a, b], dtype);
|
|
}
|
|
const multiplyConfig = {
|
|
kernelName: Multiply,
|
|
backendName: 'webgl',
|
|
kernelFunc: multiply
|
|
};
|
|
|
|
|
|
function packedReshape(input, afterShape, backend) {
|
|
const input3DShape = [getBatchDim(input.shape),
|
|
...getRowsCols(input.shape)];
|
|
const input3D = {
|
|
dtype: input.dtype,
|
|
shape: input3DShape,
|
|
dataId: input.dataId
|
|
};
|
|
const afterShapeAs3D = [getBatchDim(afterShape),
|
|
...getRowsCols(afterShape)];
|
|
const program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
|
|
const preventEagerUnpackingOfOutput = true;
|
|
const customValues = [input3DShape];
|
|
const output = backend.runWebGLProgram(program, [input3D], input.dtype, customValues, preventEagerUnpackingOfOutput);
|
|
return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
|
|
}
|
|
|
|
|
|
function reshape$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { shape } = attrs;
|
|
const webglBackend = backend;
|
|
const xSize = sizeFromShape(x.shape);
|
|
const $shape = inferFromImplicitShape(shape, xSize);
|
|
const $xSize = sizeFromShape($shape);
|
|
assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
|
|
`shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
|
|
`shape must have the same number of elements.`);
|
|
const xTexData = webglBackend.texData.get(x.dataId);
|
|
if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
|
|
!(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
|
|
return packedReshape(x, $shape, webglBackend);
|
|
}
|
|
webglBackend.incRef(x.dataId);
|
|
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
|
|
}
|
|
const reshapeConfig$1 = {
|
|
kernelName: Reshape$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: reshape$1
|
|
};
|
|
|
|
|
|
class MeanProgram {
|
|
constructor(reduceInfo, divisor) {
|
|
this.variableNames = ['x'];
|
|
const { windowSize, batchSize, inSize, outSize } = reduceInfo;
|
|
this.outputShape = [batchSize, outSize];
|
|
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
|
const windowSizeVec4Remainder = windowSize % 4;
|
|
let updateSnippet = `sumValue += dot(values, ones);`;
|
|
if (divisor != null) {
|
|
const denominator = 1 / divisor;
|
|
updateSnippet = `sumValue += dot(values * ${isInt(denominator) ? denominator.toPrecision(2) :
|
|
denominator}, ones);`;
|
|
}
|
|
let checkOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkOutOfBounds = `
|
|
if (inIdx < 0 || inIdx >= ${inSize}) {
|
|
return 0.0;
|
|
}
|
|
`;
|
|
}
|
|
this.userCode = `
|
|
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
|
|
float getValue(int batch, int inIdx) {
|
|
${checkOutOfBounds}
|
|
return getX(batch, inIdx);
|
|
}
|
|
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int outIdx = coords[1];
|
|
int inOffset = outIdx * ${windowSize};
|
|
|
|
float sumValue = 0.0;
|
|
|
|
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
|
|
int inIdx = inOffset + i;
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
getValue(batch, inIdx + 2),
|
|
getValue(batch, inIdx + 3)
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
|
|
int inIdx = inOffset + ${windowSizeNearestVec4};
|
|
if (${windowSizeVec4Remainder === 1}) {
|
|
vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
|
|
|
|
${updateSnippet}
|
|
} else if (${windowSizeVec4Remainder === 2}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1), 0.0, 0.0);
|
|
|
|
${updateSnippet}
|
|
} else if (${windowSizeVec4Remainder === 3}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
getValue(batch, inIdx + 2), 0.0);
|
|
|
|
${updateSnippet}
|
|
}
|
|
setOutput(sumValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ReduceProgram {
|
|
constructor(reduceInfo, reduceType) {
|
|
this.variableNames = ['x'];
|
|
const { windowSize, batchSize, inSize, outSize } = reduceInfo;
|
|
this.outputShape = [batchSize, outSize];
|
|
let initializationValue = '0.0';
|
|
let compareOp = ``;
|
|
if (reduceType === 'prod') {
|
|
initializationValue = '1.0';
|
|
}
|
|
else if (reduceType === 'min') {
|
|
|
|
initializationValue = '1.0 / 1e-20';
|
|
compareOp = `min`;
|
|
}
|
|
else if (reduceType === 'max') {
|
|
|
|
initializationValue = '-1.0 / 1e-20';
|
|
compareOp = `max`;
|
|
}
|
|
let returnValue = `${reduceType}(${reduceType}(${reduceType}(` +
|
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
|
if (reduceType === 'sum') {
|
|
returnValue = `sumValue`;
|
|
}
|
|
else if (reduceType === 'prod') {
|
|
returnValue = `prodValue`;
|
|
}
|
|
else if (reduceType === 'all') {
|
|
returnValue = `allValue`;
|
|
}
|
|
else if (reduceType === 'any') {
|
|
returnValue = `anyValue`;
|
|
}
|
|
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
|
const windowSizeVec4Remainder = windowSize % 4;
|
|
let updateSnippet = `
|
|
if (${reduceType === 'sum'}) {
|
|
sumValue += dot(values, ones);
|
|
} else if (${reduceType === 'prod'}) {
|
|
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
|
|
prodValue *= tmp[0] * tmp[1];
|
|
} else {
|
|
minMaxValue = ${compareOp}(values, minMaxValue);
|
|
if (${reduceType === 'min'} || ${reduceType === 'max'}) {
|
|
minMaxValue = ${compareOp}(values, minMaxValue);
|
|
bvec4 isNaN = isnan(values);
|
|
if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
|
|
minMaxValue = vec4(NAN);
|
|
}
|
|
}
|
|
}
|
|
`;
|
|
let vecType = `vec4`;
|
|
if (reduceType === 'all') {
|
|
initializationValue = '1.0';
|
|
updateSnippet = `
|
|
bool reducedAllValue = all(values);
|
|
float floatedReducedAllValue = float(reducedAllValue);
|
|
allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);
|
|
`;
|
|
vecType = `bvec4`;
|
|
}
|
|
else if (reduceType === 'any') {
|
|
initializationValue = '0.0';
|
|
updateSnippet = `
|
|
bool reducedAnyValue = any(values);
|
|
float floatedReducedAnyValue = float(reducedAnyValue);
|
|
anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);
|
|
`;
|
|
vecType = `bvec4`;
|
|
}
|
|
let checkOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkOutOfBounds = `
|
|
if (inIdx < 0 || inIdx >= ${inSize}) {
|
|
return initializationValue;
|
|
}
|
|
`;
|
|
}
|
|
this.userCode = `
|
|
const float initializationValue = ${initializationValue};
|
|
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
|
|
float getValue(int batch, int inIdx) {
|
|
${checkOutOfBounds}
|
|
return getX(batch, inIdx);
|
|
}
|
|
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int outIdx = coords[1];
|
|
int inOffset = outIdx * ${windowSize};
|
|
|
|
vec4 minMaxValue = vec4(${initializationValue});
|
|
float prodValue = 1.0;
|
|
float sumValue = 0.0;
|
|
float allValue = 1.0;
|
|
float anyValue = 0.0;
|
|
|
|
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
|
|
int inIdx = inOffset + i;
|
|
${vecType} values = ${vecType}(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
getValue(batch, inIdx + 2),
|
|
getValue(batch, inIdx + 3)
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
|
|
int inIdx = inOffset + ${windowSizeNearestVec4};
|
|
if (${windowSizeVec4Remainder === 1}) {
|
|
${vecType} values = ${vecType}(
|
|
getValue(batch, inIdx),
|
|
initializationValue,
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${windowSizeVec4Remainder === 2}) {
|
|
${vecType} values = ${vecType}(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${windowSizeVec4Remainder === 3}) {
|
|
${vecType} values = ${vecType}(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
getValue(batch, inIdx + 2),
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
setOutput(${returnValue});
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
function getReductionStages(inShape) {
|
|
const stages = [];
|
|
while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
|
|
const outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
|
|
const windowSize = computeOptimalWindowSize(outSize);
|
|
stages.push({
|
|
inSize: outSize,
|
|
windowSize,
|
|
outSize: Math.ceil(outSize / windowSize)
|
|
});
|
|
}
|
|
return stages;
|
|
}
|
|
function reduce(x, dtype, reductionType, backend) {
|
|
const reductionStages = getReductionStages(x.shape);
|
|
let result = x;
|
|
for (let i = 0; i < reductionStages.length; i++) {
|
|
const { inSize, windowSize, outSize } = reductionStages[i];
|
|
let program;
|
|
let previousResult;
|
|
if (reductionType === 'mean') {
|
|
program = i === 0 ?
|
|
new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, inSize) :
|
|
new MeanProgram({ windowSize, inSize, batchSize: x.shape[0], outSize });
|
|
}
|
|
else {
|
|
program = new ReduceProgram({ windowSize, inSize, batchSize: x.shape[0], outSize }, reductionType);
|
|
}
|
|
previousResult = result;
|
|
result = backend.runWebGLProgram(program, [result], dtype);
|
|
if (previousResult.dataId !== x.dataId) {
|
|
backend.disposeIntermediateTensorInfo(previousResult);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
class TransposeProgram {
|
|
constructor(aShape, newDim) {
|
|
this.variableNames = ['A'];
|
|
const outputShape = new Array(aShape.length);
|
|
for (let i = 0; i < outputShape.length; i++) {
|
|
outputShape[i] = aShape[newDim[i]];
|
|
}
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
const dtype = getCoordsDataType(this.rank);
|
|
const switched = getSwitchedCoords(newDim);
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} resRC = getOutputCoords();
|
|
setOutput(getA(${switched}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
function getSwitchedCoords(newDim) {
|
|
const rank = newDim.length;
|
|
if (rank > 6) {
|
|
throw Error(`Transpose for rank ${rank} is not yet supported`);
|
|
}
|
|
const originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
|
|
const switchedCoords = new Array(rank);
|
|
for (let i = 0; i < newDim.length; i++) {
|
|
switchedCoords[newDim[i]] = originalOrder[i];
|
|
}
|
|
return switchedCoords.join();
|
|
}
|
|
|
|
|
|
class TransposePackedProgram {
|
|
constructor(aShape, newDim) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
const outputShape = new Array(aShape.length);
|
|
for (let i = 0; i < outputShape.length; i++) {
|
|
outputShape[i] = aShape[newDim[i]];
|
|
}
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
if (this.rank > 6) {
|
|
throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
|
|
}
|
|
const dtype = getCoordsDataType(this.rank);
|
|
const outputOrder = getVecChannels('rc', this.rank);
|
|
const switchedOrder = new Array(this.rank);
|
|
for (let i = 0; i < newDim.length; i++) {
|
|
switchedOrder[newDim[i]] = outputOrder[i];
|
|
}
|
|
const innerDims = `vec2(${switchedOrder.slice(-2).join()})`;
|
|
const nextColumn = `++${outputOrder[this.rank - 1]} < ${outputShape[this.rank - 1]}`;
|
|
const getc = `getChannel(getA(${switchedOrder.join()}), ${innerDims})`;
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} rc = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
|
result[0] = ${getc};
|
|
if(${nextColumn}) {
|
|
result[1] = ${getc};
|
|
}
|
|
--${outputOrder[this.rank - 1]};
|
|
if(++${outputOrder[this.rank - 2]} < ${outputShape[this.rank - 2]}) {
|
|
result[2] = ${getc};
|
|
if(${nextColumn}) {
|
|
result[3] = ${getc};
|
|
}
|
|
}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function transposeImpl(x, perm, backend) {
|
|
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new TransposePackedProgram(x.shape, perm) :
|
|
new TransposeProgram(x.shape, perm);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
|
|
|
|
function sumImpl(x, axis, keepDims, backend) {
|
|
const reductionIndices = axis;
|
|
const xRank = x.shape.length;
|
|
const origAxes = parseAxisParam(reductionIndices, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
const sumInputIsTransposed = permutedAxes != null;
|
|
let sumInput = x;
|
|
if (sumInputIsTransposed) {
|
|
sumInput = transposeImpl(x, permutedAxes, backend);
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
assertAxesAreInnerMostDims('sum', axes, xRank);
|
|
const [sumOutShape, reduceShape] = computeOutAndReduceShapes(sumInput.shape, axes);
|
|
let outShape = sumOutShape;
|
|
if (keepDims) {
|
|
|
|
outShape = expandShapeToKeepDim(sumOutShape, origAxes);
|
|
}
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const xSize = sizeFromShape(x.shape);
|
|
const batchSize = xSize / inSize;
|
|
const reshapedInput = reshape$1({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend });
|
|
const outType = sumOutType(x.dtype);
|
|
const reduced = reduce(reshapedInput, outType, 'sum', backend);
|
|
const out = reshape$1({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (sumInputIsTransposed) {
|
|
backend.disposeIntermediateTensorInfo(sumInput);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
|
|
function sum$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
return sumImpl(x, axis, keepDims, backend);
|
|
}
|
|
const sumConfig$1 = {
|
|
kernelName: Sum,
|
|
backendName: 'webgl',
|
|
kernelFunc: sum$1
|
|
};
|
|
|
|
|
|
function transpose(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { perm } = attrs;
|
|
const webglBackend = backend;
|
|
const xRank = x.shape.length;
|
|
const newShape = new Array(xRank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[perm[i]];
|
|
}
|
|
let out;
|
|
if (webglBackend.shouldExecuteOnCPU([x])) {
|
|
const xTexData = webglBackend.texData.get(x.dataId);
|
|
const values = xTexData.values;
|
|
const outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
|
|
out = webglBackend.makeTensorInfo(newShape, x.dtype);
|
|
const outData = webglBackend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
}
|
|
else {
|
|
out = transposeImpl(x, perm, webglBackend);
|
|
}
|
|
return out;
|
|
}
|
|
const transposeConfig = {
|
|
kernelName: Transpose,
|
|
backendName: 'webgl',
|
|
kernelFunc: transpose
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
const MATMUL_SHARED_DIM_THRESHOLD = 1000;
|
|
function batchMatMulImpl({ a, b, transposeA, transposeB, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
|
|
const aRank = a.shape.length;
|
|
const bRank = b.shape.length;
|
|
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
|
|
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
|
|
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
|
|
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
|
|
const outerDimsA = a.shape.slice(0, -2);
|
|
const outerDimsB = b.shape.slice(0, -2);
|
|
const batchDimA = sizeFromShape(outerDimsA);
|
|
const batchDimB = sizeFromShape(outerDimsB);
|
|
const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
|
|
assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
|
|
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
|
|
`${b.shape} and transposeA=${transposeA}` +
|
|
` and transposeB=${transposeB} must match.`);
|
|
const a3dShape = transposeA ?
|
|
[batchDimA, innerShapeA, outerShapeA] :
|
|
[batchDimA, outerShapeA, innerShapeA];
|
|
const b3dShape = transposeB ?
|
|
[batchDimB, outerShapeB, innerShapeB] :
|
|
[batchDimB, innerShapeB, outerShapeB];
|
|
|
|
const a3d = reshape$1({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
|
|
const b3d = reshape$1({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
|
|
const intermediates = [a3d, b3d];
|
|
const batchDim = Math.max(batchDimA, batchDimB);
|
|
const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
|
|
const hasBias = bias != null;
|
|
const hasPreluActivationWeights = preluActivationWeights != null;
|
|
const hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
const fusedActivation = activation != null ?
|
|
mapActivationToShaderProgram(activation, true) :
|
|
null;
|
|
const containsFusedOps = hasBias || hasPreluActivationWeights ||
|
|
hasLeakyreluAlpha || fusedActivation != null;
|
|
let out;
|
|
|
|
|
|
if ((outerShapeA === 1 || outerShapeB === 1) &&
|
|
sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
|
|
let aVec = a3d;
|
|
let bVec = b3d;
|
|
if (transposeA) {
|
|
aVec = transpose({ inputs: { x: a3d }, backend, attrs: { perm: [0, 2, 1] } });
|
|
intermediates.push(aVec);
|
|
}
|
|
if (transposeB) {
|
|
bVec = transpose({ inputs: { x: b3d }, backend, attrs: { perm: [0, 2, 1] } });
|
|
intermediates.push(bVec);
|
|
}
|
|
const shouldReshapeA = outerShapeB !== 1;
|
|
const shouldReshapeB = outerShapeB === 1;
|
|
let aVec3d = aVec;
|
|
if (shouldReshapeA) {
|
|
aVec3d = reshape$1({
|
|
inputs: { x: aVec },
|
|
backend,
|
|
attrs: { shape: [batchDim, sharedDim, 1] }
|
|
});
|
|
intermediates.push(aVec3d);
|
|
}
|
|
const axis = outerShapeB === 1 ? 2 : 1;
|
|
let bVec3d = bVec;
|
|
if (shouldReshapeB) {
|
|
bVec3d = reshape$1({
|
|
inputs: { x: bVec },
|
|
backend,
|
|
attrs: { shape: [batchDim, 1, sharedDim] }
|
|
});
|
|
intermediates.push(bVec3d);
|
|
}
|
|
const product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend });
|
|
out = sum$1({ inputs: { x: product }, backend, attrs: { axis, keepDims: true } });
|
|
intermediates.push(product);
|
|
}
|
|
else {
|
|
const dtype = upcastType(a.dtype, b.dtype);
|
|
const program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
const inputs = [a3d, b3d];
|
|
if (bias != null) {
|
|
inputs.push(bias);
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
inputs.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
|
|
inputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
out = backend.runWebGLProgram(program, inputs, dtype);
|
|
}
|
|
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: outShape } });
|
|
intermediates.push(out);
|
|
for (const i of intermediates) {
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return outReshaped;
|
|
}
|
|
|
|
|
|
function _fusedMatMul$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { a, b, bias, preluActivationWeights } = inputs;
|
|
const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
|
|
return batchMatMulImpl({
|
|
a,
|
|
b,
|
|
transposeA,
|
|
transposeB,
|
|
backend,
|
|
bias,
|
|
preluActivationWeights,
|
|
leakyreluAlpha,
|
|
activation
|
|
});
|
|
}
|
|
const _fusedMatMulConfig$1 = {
|
|
kernelName: _FusedMatMul,
|
|
backendName: 'webgl',
|
|
kernelFunc: _fusedMatMul$1,
|
|
};
|
|
|
|
|
|
const ABS = `return abs(x);`;
|
|
function abs(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
|
|
|
|
if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
|
|
const xData = backend.texData.get(x.dataId);
|
|
const outValues = simpleAbsImplCPU(xData.values);
|
|
return backend.makeTensorInfo(x.shape, x.dtype, outValues);
|
|
}
|
|
let program;
|
|
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
|
|
program = new UnaryOpPackedProgram(x.shape, ABS);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(x.shape, ABS);
|
|
}
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
const absConfig = {
|
|
kernelName: Abs,
|
|
backendName: 'webgl',
|
|
kernelFunc: abs
|
|
};
|
|
|
|
|
|
const ACOS = CHECK_NAN_SNIPPET$1 + `
|
|
if (abs(x) > 1.) {
|
|
return NAN;
|
|
}
|
|
return acos(x);
|
|
`;
|
|
const acos$1 = unaryKernelFunc({ opSnippet: ACOS });
|
|
const acosConfig$1 = {
|
|
kernelName: Acos,
|
|
backendName: 'webgl',
|
|
kernelFunc: acos$1,
|
|
};
|
|
|
|
|
|
const ACOSH = CHECK_NAN_SNIPPET$1 + `
|
|
if (x < 1.0) return NAN;
|
|
return log(x + sqrt(x * x - 1.0));`;
|
|
const acosh$1 = unaryKernelFunc({ opSnippet: ACOSH });
|
|
const acoshConfig$1 = {
|
|
kernelName: Acosh,
|
|
backendName: 'webgl',
|
|
kernelFunc: acosh$1,
|
|
};
|
|
|
|
|
|
const ADD = 'return a + b;';
|
|
const addKernelFunc = binaryKernelFunc({
|
|
opSnippet: ADD,
|
|
packedOpSnippet: ADD,
|
|
supportsComplex: true,
|
|
cpuKernelImpl: addImplCPU
|
|
});
|
|
const addConfig = {
|
|
kernelName: Add,
|
|
backendName: 'webgl',
|
|
kernelFunc: addKernelFunc
|
|
};
|
|
|
|
|
|
class AddNProgram {
|
|
constructor(outputShape, shapes) {
|
|
this.outputShape = [];
|
|
this.outputShape = outputShape;
|
|
this.variableNames = shapes.map((_, i) => `T${i}`);
|
|
const snippets = [];
|
|
|
|
this.variableNames.forEach(variable => {
|
|
snippets.push(`float v${variable} = get${variable}AtOutCoords();`);
|
|
});
|
|
|
|
const operation = this.variableNames
|
|
.map(variable => {
|
|
return `v${variable}`;
|
|
})
|
|
.join(' + ');
|
|
this.userCode = `
|
|
void main() {
|
|
${snippets.join('\n ')}
|
|
|
|
float result = ${operation};
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class AddNPackedProgram {
|
|
constructor(outputShape, shapes) {
|
|
this.outputShape = [];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = outputShape;
|
|
this.variableNames = shapes.map((_, i) => `T${i}`);
|
|
const snippets = [];
|
|
|
|
this.variableNames.forEach(variable => {
|
|
snippets.push(`vec4 v${variable} = get${variable}AtOutCoords();`);
|
|
});
|
|
|
|
const operation = this.variableNames
|
|
.map(variable => {
|
|
return `v${variable}`;
|
|
})
|
|
.join(' + ');
|
|
this.userCode = `
|
|
void main() {
|
|
${snippets.join('\n ')}
|
|
|
|
vec4 result = ${operation};
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function addN$1(args) {
|
|
const { inputs, backend } = args;
|
|
const tensors = inputs;
|
|
if (tensors.length === 1) {
|
|
return identity({ inputs: { x: tensors[0] }, backend });
|
|
}
|
|
|
|
if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
|
|
const midIndex = Math.floor(tensors.length / 2);
|
|
const leftSide = addN$1({ inputs: tensors.slice(0, midIndex), backend });
|
|
const rightSide = addN$1({ inputs: tensors.slice(midIndex), backend });
|
|
return addN$1({ inputs: [leftSide, rightSide], backend });
|
|
}
|
|
const dtype = tensors.map(t => t.dtype).reduce((d1, d2) => upcastType(d1, d2));
|
|
const shapes = tensors.map(t => t.shape);
|
|
|
|
const usePackedOp = env().getBool('WEBGL_PACK');
|
|
const program = usePackedOp ?
|
|
new AddNPackedProgram(tensors[0].shape, shapes) :
|
|
new AddNProgram(tensors[0].shape, shapes);
|
|
return backend.runWebGLProgram(program, tensors, dtype);
|
|
}
|
|
const addNConfig$1 = {
|
|
kernelName: AddN,
|
|
backendName: 'webgl',
|
|
kernelFunc: addN$1
|
|
};
|
|
|
|
|
|
function all$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
const xRank = x.shape.length;
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
let permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
assertAxesAreInnerMostDims('all', axes, xRank);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
|
|
const reduced = reduce(a2D, a2D.dtype, 'all', backend);
|
|
let res;
|
|
if (keepDims) {
|
|
const newShape = expandShapeToKeepDim(outShape, origAxes);
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
|
|
}
|
|
else {
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a2D);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return res;
|
|
}
|
|
const allConfig$1 = {
|
|
kernelName: All,
|
|
backendName: 'webgl',
|
|
kernelFunc: all$1
|
|
};
|
|
|
|
|
|
function any$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
const xRank = x.shape.length;
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
let permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
assertAxesAreInnerMostDims('any', axes, xRank);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
|
|
const reduced = reduce(a2D, a2D.dtype, 'any', backend);
|
|
let res;
|
|
if (keepDims) {
|
|
const newShape = expandShapeToKeepDim(outShape, origAxes);
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
|
|
}
|
|
else {
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a2D);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return res;
|
|
}
|
|
const anyConfig$1 = {
|
|
kernelName: Any,
|
|
backendName: 'webgl',
|
|
kernelFunc: any$1
|
|
};
|
|
|
|
|
|
class ArgMinMaxProgram {
|
|
constructor(reduceInfo, op, firstPass) {
|
|
this.variableNames = ['A'];
|
|
const { windowSize, batchSize, outSize } = reduceInfo;
|
|
if (!firstPass) {
|
|
this.variableNames.push('bestIndicesA');
|
|
}
|
|
this.outputShape = [batchSize, outSize];
|
|
const compOp = (op === 'max') ? '>' : '<';
|
|
const indexSnippet = firstPass ?
|
|
'inOffset + i;' :
|
|
'round(getBestIndicesA(batch, inOffset + i));';
|
|
this.userCode = `
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int outIdx = coords[1];
|
|
int inOffset = outIdx * ${windowSize};
|
|
|
|
int bestIndex = inOffset;
|
|
float bestValue = getA(batch, bestIndex);
|
|
|
|
for (int i = 0; i < ${windowSize}; i++) {
|
|
int inIdx = ${indexSnippet};
|
|
float candidate = getA(batch, inIdx);
|
|
if (candidate ${compOp} bestValue) {
|
|
bestValue = candidate;
|
|
bestIndex = inIdx;
|
|
}
|
|
}
|
|
setOutput(float(bestIndex));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ArgMinMaxPackedProgram {
|
|
constructor(shape, windowSize, op, firstPass) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
assert$1(shape.length > 2, () => `Packed arg${op.charAt(0).toUpperCase() +
|
|
op.slice(1)} supports only inputs with rank above 2.`);
|
|
const inSize = shape[shape.length - 1];
|
|
const outSize = Math.ceil(inSize / windowSize);
|
|
this.outputShape = shape.slice(0, -1);
|
|
if (outSize > 1) {
|
|
this.outputShape.push(outSize);
|
|
}
|
|
if (!firstPass) {
|
|
this.variableNames.push('bestIndicesA');
|
|
}
|
|
const outShape = this.outputShape;
|
|
const rank = outShape.length;
|
|
const dtype = getCoordsDataType(rank);
|
|
const coords = getChannels('coords', rank);
|
|
let sourceLocSetup;
|
|
let sourceRank;
|
|
if (outSize === 1) {
|
|
sourceRank = rank + 1;
|
|
const sourceLocDType = getCoordsDataType(sourceRank);
|
|
sourceLocSetup = `
|
|
${sourceLocDType} sourceLocR = ${sourceLocDType}(${coords.join()}, 0);
|
|
++${coords[rank - 1]};
|
|
${sourceLocDType} sourceLocG = ${sourceLocDType}(${coords.join()}, 0);
|
|
++${coords[rank - 2]};
|
|
${sourceLocDType} sourceLocA = ${sourceLocDType}(${coords.join()}, 0);
|
|
--${coords[rank - 1]};
|
|
${sourceLocDType} sourceLocB = ${sourceLocDType}(${coords.join()}, 0);
|
|
--${coords[rank - 2]};`;
|
|
}
|
|
else {
|
|
sourceRank = rank;
|
|
sourceLocSetup = `
|
|
${dtype} sourceLocR = coords;
|
|
++${coords[rank - 1]};
|
|
${dtype} sourceLocG = coords;
|
|
++${coords[rank - 2]};
|
|
${dtype} sourceLocA = coords;
|
|
--${coords[rank - 1]};
|
|
${dtype} sourceLocB = coords;
|
|
--${coords[rank - 2]};`;
|
|
}
|
|
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
|
|
const inChannel = '.' + channels[sourceRank - 1];
|
|
const intChannels = channels.map(x => 'int ' + x);
|
|
const srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
|
|
const srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
|
|
const srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
|
|
const srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
|
|
const compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
|
|
const fetchCandidateIdx = firstPass ? '' : `
|
|
inIdx = round(vec4(getBestIndicesAChannel(${srcRCoords.join()}),
|
|
getBestIndicesAChannel(${srcGCoords.join()}),
|
|
getBestIndicesAChannel(${srcBCoords.join()}),
|
|
getBestIndicesAChannel(${srcACoords.join()})));`;
|
|
const fetchValue = `vec4(
|
|
getAChannel(${srcRCoords.join()}),
|
|
hasNextCol ? getAChannel(${srcGCoords.join()}) : 0.,
|
|
hasNextRow ? getAChannel(${srcBCoords.join()}) : 0.,
|
|
hasNextRow && hasNextCol ? getAChannel(${srcACoords.join()}) : 0.)`;
|
|
const getBestIndicesAChannelSnippet = firstPass ? '' : `
|
|
float getBestIndicesAChannel(${intChannels.join()}) {
|
|
return getChannel(getBestIndicesA(${channels.join()}),
|
|
vec2(${channels.slice(-2).join()}));
|
|
}`;
|
|
this.userCode = `
|
|
float getAChannel(${intChannels.join()}) {
|
|
return getChannel(getA(${channels.join()}),
|
|
vec2(${channels.slice(-2).join()}));
|
|
}
|
|
${getBestIndicesAChannelSnippet}
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
bool hasNextCol = ${coords[rank - 1]} < ${outShape[rank - 1] - 1};
|
|
bool hasNextRow = ${coords[rank - 2]} < ${outShape[rank - 2] - 1};
|
|
${sourceLocSetup}
|
|
ivec4 srcIdx = ivec4(sourceLocR${inChannel}, sourceLocG${inChannel},
|
|
sourceLocB${inChannel}, sourceLocA${inChannel}) * ${windowSize};
|
|
ivec4 inIdx = srcIdx;
|
|
vec4 bestIndex = vec4(inIdx);
|
|
vec4 bestValue = ${fetchValue};
|
|
|
|
for (int i = 0; i < ${windowSize}; i++) {
|
|
inIdx = srcIdx;
|
|
${fetchCandidateIdx}
|
|
vec4 candidate = ${fetchValue};
|
|
bvec4 nan = isnan(candidate);
|
|
bvec4 replace = bvec4(
|
|
vec4(${compOp}(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));
|
|
|
|
bestValue = vec4(replace.x ? candidate.x : bestValue.x,
|
|
replace.y ? candidate.y : bestValue.y,
|
|
replace.z ? candidate.z : bestValue.z,
|
|
replace.w ? candidate.w : bestValue.w);
|
|
bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));
|
|
srcIdx++;
|
|
}
|
|
setOutput(bestIndex);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function argReduce(backend, x, reduceType, bestIndicesA = null) {
|
|
let batchSize = x.shape[0];
|
|
let inSize = x.shape[1];
|
|
if (bestIndicesA != null) {
|
|
batchSize = bestIndicesA.shape[0];
|
|
inSize = bestIndicesA.shape[1];
|
|
}
|
|
const windowSize = computeOptimalWindowSize(inSize);
|
|
const reduceInfo = { windowSize, inSize, batchSize, outSize: Math.ceil(inSize / windowSize) };
|
|
const program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
|
|
const inputs = [x];
|
|
if (bestIndicesA != null) {
|
|
inputs.push(bestIndicesA);
|
|
}
|
|
const output = backend.runWebGLProgram(program, inputs, 'int32');
|
|
|
|
if (output.shape[1] === 1) {
|
|
return output;
|
|
}
|
|
const result = argReduce(backend, x, reduceType, output);
|
|
backend.disposeIntermediateTensorInfo(output);
|
|
return result;
|
|
}
|
|
function argReducePacked(backend, x, reduceType, bestIndicesA = null) {
|
|
const inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
|
|
const inSize = inShape[inShape.length - 1];
|
|
const windowSize = computeOptimalWindowSize(inSize);
|
|
const program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
|
|
const inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
|
|
const output = backend.runWebGLProgram(program, inputs, 'int32');
|
|
if (output.shape.length === x.shape.length) {
|
|
const result = argReducePacked(backend, x, reduceType, output);
|
|
backend.disposeIntermediateTensorInfo(output);
|
|
return result;
|
|
}
|
|
return output;
|
|
}
|
|
function argMinMaxReduce(backend, x, axis, reduceType) {
|
|
const axes = [axis];
|
|
assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
|
|
if (!env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
|
|
const intermediateTensorInfos = [];
|
|
|
|
|
|
const xtexData = backend.texData.get(x.dataId);
|
|
const xIsPacked = xtexData !== null && xtexData.isPacked;
|
|
let xUnPacked = x;
|
|
if (xIsPacked) {
|
|
xUnPacked = backend.unpackTensor(x);
|
|
intermediateTensorInfos.push(xUnPacked);
|
|
}
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(xUnPacked.shape, axes);
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const a2D = reshape$1({ inputs: { x: xUnPacked }, backend, attrs: { shape: [-1, inSize] } });
|
|
intermediateTensorInfos.push(a2D);
|
|
const reduced = argReduce(backend, a2D, reduceType);
|
|
intermediateTensorInfos.push(reduced);
|
|
const reshaped = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return reshaped;
|
|
}
|
|
return argReducePacked(backend, x, reduceType);
|
|
}
|
|
|
|
|
|
function argMax$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis } = attrs;
|
|
let axes = parseAxisParam(axis, x.shape);
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
const intermediateTensorInfos = [];
|
|
if (permutedAxes != null) {
|
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
intermediateTensorInfos.push($x);
|
|
axes = getInnerMostAxes(axes.length, $x.shape.length);
|
|
}
|
|
assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
|
|
const out = argMinMaxReduce(backend, $x, axes[0], 'max');
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return out;
|
|
}
|
|
const argMaxConfig$1 = {
|
|
kernelName: ArgMax,
|
|
backendName: 'webgl',
|
|
kernelFunc: argMax$1
|
|
};
|
|
|
|
|
|
function argMin$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis } = attrs;
|
|
let axes = parseAxisParam(axis, x.shape);
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
const intermediateTensorInfos = [];
|
|
if (permutedAxes != null) {
|
|
$x = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
intermediateTensorInfos.push($x);
|
|
axes = getInnerMostAxes(axes.length, $x.shape.length);
|
|
}
|
|
assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
|
|
const out = argMinMaxReduce(backend, $x, axes[0], 'min');
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return out;
|
|
}
|
|
const argMinConfig$1 = {
|
|
kernelName: ArgMin,
|
|
backendName: 'webgl',
|
|
kernelFunc: argMin$1
|
|
};
|
|
|
|
|
|
const ASIN = CHECK_NAN_SNIPPET$1 + `
|
|
if (abs(x) > 1.) {
|
|
return NAN;
|
|
}
|
|
return asin(x);
|
|
`;
|
|
const asin$1 = unaryKernelFunc({ opSnippet: ASIN });
|
|
const asinConfig$1 = {
|
|
kernelName: Asin,
|
|
backendName: 'webgl',
|
|
kernelFunc: asin$1,
|
|
};
|
|
|
|
|
|
const ASINH = CHECK_NAN_SNIPPET$1 + `return log(x + sqrt(x * x + 1.0));`;
|
|
const asinh$1 = unaryKernelFunc({ opSnippet: ASINH });
|
|
const asinhConfig$1 = {
|
|
kernelName: Asinh,
|
|
backendName: 'webgl',
|
|
kernelFunc: asinh$1,
|
|
};
|
|
|
|
|
|
const ATAN = CHECK_NAN_SNIPPET$1 + `
|
|
return atan(x);
|
|
`;
|
|
const atan$1 = unaryKernelFunc({ opSnippet: ATAN });
|
|
const atanConfig$1 = {
|
|
kernelName: Atan,
|
|
backendName: 'webgl',
|
|
kernelFunc: atan$1,
|
|
};
|
|
|
|
|
|
const ATAN2 = CHECK_NAN_SNIPPET + `
|
|
return atan(a, b);
|
|
`;
|
|
const ATAN2_PACKED = `
|
|
vec4 result = atan(a, b);
|
|
bvec4 isNaNA = isnan(a);
|
|
bvec4 isNaNB = isnan(b);
|
|
bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
|
|
` +
|
|
CHECK_NAN_SNIPPET_PACKED + `
|
|
return result;
|
|
`;
|
|
const atan2$1 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED });
|
|
const atan2Config$1 = {
|
|
kernelName: Atan2,
|
|
backendName: 'webgl',
|
|
kernelFunc: atan2$1,
|
|
};
|
|
|
|
|
|
const ATANH = CHECK_NAN_SNIPPET$1 + `
|
|
if ((x < -1.0) || (x > 1.0)) return NAN;
|
|
return (log(1.0 + x) - log(1.0 - x)) / 2.0;`;
|
|
const atanh$1 = unaryKernelFunc({ opSnippet: ATANH });
|
|
const atanhConfig$1 = {
|
|
kernelName: Atanh,
|
|
backendName: 'webgl',
|
|
kernelFunc: atanh$1,
|
|
};
|
|
|
|
|
|
class Pool2DProgram {
|
|
constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
|
|
this.variableNames = ['x'];
|
|
if (poolType === 'avg' && computePositions) {
|
|
throw new Error('Cannot compute positions for average pool.');
|
|
}
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
this.outputShape = convInfo.outShape;
|
|
const isAvgPool = poolType === 'avg';
|
|
const batchFlattenPositionStr = `((batch * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
|
|
const flattenPositionStr = `(xR * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + d`;
|
|
let initializationValue = '0.0';
|
|
if (!isAvgPool) {
|
|
|
|
initializationValue = '-1.0 / 1e-20';
|
|
}
|
|
if (computePositions) {
|
|
const compareOp = '>=';
|
|
this.userCode = `
|
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int d = coords[3];
|
|
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
|
int xRCorner = xRCCorner.x;
|
|
int xCCorner = xRCCorner.y;
|
|
|
|
|
|
|
|
float minMaxValue = 0.0;
|
|
float minMaxValueFound = 0.0;
|
|
int minMaxPosition = 0;
|
|
float avgValue = 0.0;
|
|
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
int xR = xRCorner + wR;
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
|
wC += ${dilationWidth}) {
|
|
int xC = xCCorner + wC;
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
float value = getX(batch, xR, xC, d);
|
|
|
|
|
|
|
|
float currMinMaxValue = mix(
|
|
value, minMaxValue, minMaxValueFound);
|
|
if (value ${compareOp} currMinMaxValue) {
|
|
minMaxValue = value;
|
|
minMaxValueFound = 1.0;
|
|
minMaxPosition = ${flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
|
|
flattenPositionStr) :
|
|
`wR * ${effectiveFilterWidth} + wC`};
|
|
}
|
|
}
|
|
}
|
|
setOutput(float(minMaxPosition));
|
|
}
|
|
`;
|
|
return;
|
|
}
|
|
const compareOp = 'max';
|
|
let returnValue = `${poolType}(${poolType}(${poolType}(` +
|
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
|
if (poolType === 'avg') {
|
|
returnValue = `avgValue / max(count, 1.0)`;
|
|
}
|
|
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
|
|
const filterWidthVec4Remainder = filterWidth % 4;
|
|
const updateSnippet = `
|
|
if (${isAvgPool}) {
|
|
avgValue += dot(values, ones);
|
|
} else {
|
|
minMaxValue = ${compareOp}(values, minMaxValue);
|
|
}
|
|
`;
|
|
this.userCode = `
|
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
const float initializationValue = ${initializationValue};
|
|
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
|
|
float count = 0.0;
|
|
|
|
float getValue(int batch, int xR, int xC, int d) {
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
return initializationValue;
|
|
}
|
|
count += 1.0;
|
|
return getX(batch, xR, xC, d);
|
|
}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int d = coords[3];
|
|
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
|
int xRCorner = xRCCorner.x;
|
|
int xCCorner = xRCCorner.y;
|
|
|
|
|
|
|
|
vec4 minMaxValue = vec4(${initializationValue});
|
|
float avgValue = 0.0;
|
|
count = 0.0;
|
|
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
int xR = xRCorner + wR;
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
|
|
int xC = xCCorner + wC * ${dilationWidth};
|
|
|
|
vec4 values = vec4(
|
|
getValue(batch, xR, xC, d),
|
|
getValue(batch, xR, xC + ${dilationWidth}, d),
|
|
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
|
|
getValue(batch, xR, xC + 3 * ${dilationWidth}, d)
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
|
|
int xC = xCCorner + ${filterWidthNearestVec4};
|
|
if (${filterWidthVec4Remainder === 1}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, xR, xC, d),
|
|
initializationValue,
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${filterWidthVec4Remainder === 2}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, xR, xC, d),
|
|
getValue(batch, xR, xC + ${dilationWidth}, d),
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${filterWidthVec4Remainder === 3}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, xR, xC, d),
|
|
getValue(batch, xR, xC + ${dilationWidth}, d),
|
|
getValue(batch, xR, xC + 2 * ${dilationWidth}, d),
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
}
|
|
setOutput(${returnValue});
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class Pool3DProgram {
|
|
constructor(convInfo, poolType, computePositions, flattenPositions = false, includeBatchInIndex = false) {
|
|
this.variableNames = ['x'];
|
|
if (poolType === 'avg' && computePositions) {
|
|
throw new Error('Cannot compute positions for average pool.');
|
|
}
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = convInfo.padInfo.front;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
this.outputShape = convInfo.outShape;
|
|
const isAvgPool = poolType === 'avg';
|
|
let initializationValue = '0.0';
|
|
if (!isAvgPool) {
|
|
|
|
initializationValue = '-1.0 / 1e-20';
|
|
}
|
|
if (computePositions) {
|
|
const compareOp = '>=';
|
|
this.userCode = `
|
|
const ivec3 strides =
|
|
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
|
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int ch = coords.u;
|
|
|
|
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
|
|
int xDCorner = xCorner.x;
|
|
int xRCorner = xCorner.y;
|
|
int xCCorner = xCorner.z;
|
|
|
|
|
|
|
|
float minMaxValue = 0.0;
|
|
float minMaxValueFound = 0.0;
|
|
int minMaxPosition = 0;
|
|
|
|
for (int wD = 0; wD < ${effectiveFilterDepth};
|
|
wD += ${dilationDepth}) {
|
|
int xD = xDCorner + wD;
|
|
|
|
if (xD < 0 || xD >= ${convInfo.inDepth}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
int xR = xRCorner + wR;
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
|
wC += ${dilationWidth}) {
|
|
int xC = xCCorner + wC;
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
float value = getX(batch, xD, xR, xC, ch);
|
|
|
|
|
|
|
|
float currMinMaxValue = mix(
|
|
value, minMaxValue, minMaxValueFound);
|
|
if (value ${compareOp} currMinMaxValue) {
|
|
minMaxValue = value;
|
|
minMaxValueFound = 1.0;
|
|
minMaxPosition = ${flattenPositions ?
|
|
(includeBatchInIndex ?
|
|
`(((batch * ${convInfo.inDepth} + xD) * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch` :
|
|
`((xD * ${convInfo.inHeight} + xR) * ${convInfo.inWidth} + xC) * ${convInfo.inChannels} + ch`) :
|
|
`wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
|
|
wR * ${effectiveFilterWidth} + wC`};
|
|
}
|
|
}
|
|
}
|
|
}
|
|
setOutput(float(minMaxPosition));
|
|
}
|
|
`;
|
|
return;
|
|
}
|
|
const compareOp = 'max';
|
|
let returnValue = `${poolType}(${poolType}(${poolType}(` +
|
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
|
if (poolType === 'avg') {
|
|
|
|
|
|
|
|
returnValue = `avgValue / max(count, 1.0)`;
|
|
}
|
|
const filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
|
|
const filterWidthVec4Remainder = filterWidth % 4;
|
|
const updateSnippet = `
|
|
if (${isAvgPool}) {
|
|
avgValue += dot(values, ones);
|
|
} else {
|
|
minMaxValue = ${compareOp}(values, minMaxValue);
|
|
}
|
|
`;
|
|
this.userCode = `
|
|
const ivec3 strides =
|
|
ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
|
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
const float initializationValue = ${initializationValue};
|
|
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
|
|
float count = 0.0;
|
|
|
|
float getValue(int batch, int xD, int xR, int xC, int ch) {
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
return initializationValue;
|
|
}
|
|
count += 1.0;
|
|
return getX(batch, xD, xR, xC, ch);
|
|
}
|
|
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int ch = coords.u;
|
|
|
|
ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
|
|
int xDCorner = xCorner.x;
|
|
int xRCorner = xCorner.y;
|
|
int xCCorner = xCorner.z;
|
|
|
|
|
|
|
|
vec4 minMaxValue = vec4(${initializationValue});
|
|
float avgValue = 0.0;
|
|
count = 0.0;
|
|
|
|
for (int wD = 0; wD < ${effectiveFilterDepth};
|
|
wD += ${dilationDepth}) {
|
|
int xD = xDCorner + wD;
|
|
|
|
if (xD < 0 || xD >= ${convInfo.inDepth}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
int xR = xRCorner + wR;
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${filterWidthNearestVec4}; wC += 4) {
|
|
int xC = xCCorner + wC * ${dilationWidth};
|
|
|
|
vec4 values = vec4(
|
|
getValue(batch, xD, xR, xC, ch),
|
|
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
|
|
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
|
|
getValue(batch, xD, xR, xC + 3 * ${dilationWidth}, ch)
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
|
|
int xC = xCCorner + ${filterWidthNearestVec4};
|
|
if (${filterWidthVec4Remainder === 1}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, xD, xR, xC, ch),
|
|
initializationValue,
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${filterWidthVec4Remainder === 2}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, xD, xR, xC, ch),
|
|
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${filterWidthVec4Remainder === 3}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, xD, xR, xC, ch),
|
|
getValue(batch, xD, xR, xC + ${dilationWidth}, ch),
|
|
getValue(batch, xD, xR, xC + 2 * ${dilationWidth}, ch),
|
|
initializationValue
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
}
|
|
}
|
|
setOutput(${returnValue});
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function avgPool$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
assertNotComplex$1(x, 'avgPool');
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const dilations = 1;
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
|
|
`Got strides ${strides} and dilations '${dilations}'`);
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
|
arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
|
return identity({ inputs: { x }, backend });
|
|
}
|
|
const avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
|
|
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
|
|
}
|
|
const avgPoolConfig$1 = {
|
|
kernelName: AvgPool,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPool$1
|
|
};
|
|
|
|
|
|
function avgPool3D$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
|
|
const dilations = [1, 1, 1];
|
|
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
|
|
const avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
|
|
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
|
|
}
|
|
const avgPool3DConfig$1 = {
|
|
kernelName: AvgPool3D,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPool3D$1
|
|
};
|
|
|
|
|
|
class AvgPool2DBackpropProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = convInfo.inShape;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const avgMultiplier = 1 / (filterHeight * filterWidth);
|
|
this.userCode = `
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
const float avgMultiplier = float(${avgMultiplier});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
|
|
ivec2 dyRCCorner = coords.yz - pads;
|
|
int dyRCorner = dyRCCorner.x;
|
|
int dyCCorner = dyRCCorner.y;
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
|
wC+= ${dilationWidth}) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
float dyValue = getDy(b, idyR, idyC, d);
|
|
|
|
dotProd += dyValue * avgMultiplier;
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class AvgPool3DBackpropProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = convInfo.inShape;
|
|
const filterDepth = convInfo.filterDepth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
|
|
this.userCode = `
|
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
const float avgMultiplier = float(${avgMultiplier});
|
|
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int ch = coords.u;
|
|
|
|
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
|
|
int dyDCorner = dyCorner.x;
|
|
int dyRCorner = dyCorner.y;
|
|
int dyCCorner = dyCorner.z;
|
|
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
|
|
for (int wD = 0; wD < ${effectiveFilterDepth};
|
|
wD += ${dilationDepth}) {
|
|
float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
|
|
|
|
if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyD = int(dyD);
|
|
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
|
|
fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
|
wC += ${dilationWidth}) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
float dyValue = getDy(batch, idyD, idyR, idyC, ch);
|
|
|
|
dotProd += dyValue * avgMultiplier;
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function avgPool3DGrad$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input } = inputs;
|
|
const x = input;
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const dilations = [1, 1, 1];
|
|
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
const avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
|
|
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
|
|
}
|
|
const avgPool3DGradConfig$2 = {
|
|
kernelName: AvgPool3DGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPool3DGrad$1
|
|
};
|
|
|
|
|
|
function avgPoolGrad$2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input } = inputs;
|
|
const x = input;
|
|
assertNotComplex$1([dy, input], 'avgPoolGrad');
|
|
const { filterSize, strides, pad } = attrs;
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad);
|
|
const avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
|
|
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
|
|
}
|
|
const avgPoolGradConfig$2 = {
|
|
kernelName: AvgPoolGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPoolGrad$2
|
|
};
|
|
|
|
|
|
function batchMatMul$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { a, b } = inputs;
|
|
const { transposeA, transposeB } = attrs;
|
|
return batchMatMulImpl({ a, b, transposeA, transposeB, backend });
|
|
}
|
|
const batchMatMulConfig$1 = {
|
|
kernelName: BatchMatMul,
|
|
backendName: 'webgl',
|
|
kernelFunc: batchMatMul$1,
|
|
};
|
|
|
|
|
|
class BatchNormProgram {
|
|
constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
|
|
this.outputShape = [];
|
|
this.variableNames = ['x', 'mean', 'variance'];
|
|
assertAndGetBroadcastShape(xShape, meanShape);
|
|
assertAndGetBroadcastShape(xShape, varianceShape);
|
|
let offsetSnippet = '0.0';
|
|
if (offsetShape != null) {
|
|
assertAndGetBroadcastShape(xShape, offsetShape);
|
|
this.variableNames.push('offset');
|
|
offsetSnippet = 'getOffsetAtOutCoords()';
|
|
}
|
|
let scaleSnippet = '1.0';
|
|
if (scaleShape != null) {
|
|
assertAndGetBroadcastShape(xShape, scaleShape);
|
|
this.variableNames.push('scale');
|
|
scaleSnippet = 'getScaleAtOutCoords()';
|
|
}
|
|
this.outputShape = xShape;
|
|
this.userCode = `
|
|
void main() {
|
|
float x = getXAtOutCoords();
|
|
float mean = getMeanAtOutCoords();
|
|
float variance = getVarianceAtOutCoords();
|
|
float offset = ${offsetSnippet};
|
|
float scale = ${scaleSnippet};
|
|
float inv = scale * inversesqrt(variance + float(${varianceEpsilon}));
|
|
setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class BatchNormPackedProgram {
|
|
constructor(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.variableNames = ['x', 'mean', 'variance'];
|
|
assertAndGetBroadcastShape(xShape, meanShape);
|
|
assertAndGetBroadcastShape(xShape, varianceShape);
|
|
let offsetSnippet = 'vec4(0.0)';
|
|
if (offsetShape != null) {
|
|
assertAndGetBroadcastShape(xShape, offsetShape);
|
|
this.variableNames.push('offset');
|
|
offsetSnippet = 'getOffsetAtOutCoords()';
|
|
}
|
|
let scaleSnippet = 'vec4(1.0)';
|
|
if (scaleShape != null) {
|
|
assertAndGetBroadcastShape(xShape, scaleShape);
|
|
this.variableNames.push('scale');
|
|
scaleSnippet = 'getScaleAtOutCoords()';
|
|
}
|
|
this.outputShape = xShape;
|
|
this.userCode = `
|
|
void main() {
|
|
vec4 offset = ${offsetSnippet};
|
|
vec4 scale = ${scaleSnippet};
|
|
|
|
vec4 x = getXAtOutCoords();
|
|
vec4 mean = getMeanAtOutCoords();
|
|
vec4 variance = getVarianceAtOutCoords();
|
|
|
|
vec4 inv = scale * inversesqrt(variance + vec4(${varianceEpsilon}));
|
|
|
|
setOutput((x - mean) * inv + offset);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const batchNorm$1 = ({ inputs, backend, attrs }) => {
|
|
const { x, mean, variance, offset, scale } = inputs;
|
|
assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
|
|
'equal ranks.');
|
|
assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
|
|
'equal ranks.');
|
|
assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
|
|
'equal ranks.');
|
|
let { varianceEpsilon } = attrs;
|
|
if (varianceEpsilon == null) {
|
|
varianceEpsilon = 0.001;
|
|
}
|
|
const finalInputs = [x, mean, variance];
|
|
let offsetShape = null;
|
|
if (offset != null) {
|
|
offsetShape = offset.shape;
|
|
finalInputs.push(offset);
|
|
}
|
|
let scaleShape = null;
|
|
if (scale != null) {
|
|
scaleShape = scale.shape;
|
|
finalInputs.push(scale);
|
|
}
|
|
const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
|
|
new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) :
|
|
new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
|
|
const output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
|
|
return output;
|
|
};
|
|
const batchNormConfig$1 = {
|
|
kernelName: FusedBatchNorm,
|
|
backendName: 'webgl',
|
|
kernelFunc: batchNorm$1,
|
|
};
|
|
|
|
|
|
class SliceProgram {
|
|
constructor(destSize) {
|
|
this.variableNames = ['source'];
|
|
this.outputShape = destSize;
|
|
this.rank = destSize.length;
|
|
const dtype = getCoordsDataType(this.rank);
|
|
this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
|
|
const sourceCoords = getCoords$1(this.rank);
|
|
let body;
|
|
const coordSum = destSize.map((_, i) => {
|
|
return `sourceLoc.${coords[i]} = start[${i}] + coords.${coords[i]};`;
|
|
});
|
|
body = `
|
|
${dtype} sourceLoc;
|
|
${dtype} coords = getOutputCoords();
|
|
${coordSum.join('\n')}
|
|
`;
|
|
this.userCode = `
|
|
void main() {
|
|
${body}
|
|
setOutput(getSource(${sourceCoords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
const coords = ['x', 'y', 'z', 'w', 'u', 'v'];
|
|
function getCoords$1(rank) {
|
|
if (rank === 1) {
|
|
return 'sourceLoc';
|
|
}
|
|
else if (rank <= 6) {
|
|
return coords.slice(0, rank).map(x => 'sourceLoc.' + x).join(',');
|
|
}
|
|
else {
|
|
throw Error(`Slicing for rank ${rank} is not yet supported`);
|
|
}
|
|
}
|
|
|
|
|
|
class SlicePackedProgram {
|
|
constructor(destSize) {
|
|
this.variableNames = ['source'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = destSize;
|
|
this.rank = destSize.length;
|
|
this.customUniforms = [{ name: 'start', arrayIndex: this.rank, type: 'int' }];
|
|
const dtype = getCoordsDataType(this.rank);
|
|
const coords = getChannels('coords', this.rank);
|
|
const sourceLoc = getChannels('sourceLoc', this.rank);
|
|
const innerDims = this.rank === 1 ? 'sourceLoc' : `vec2(${sourceLoc.slice(-2).join()})`;
|
|
const getChannel = `getChannel(getSource(${sourceLoc.join()}), ${innerDims})`;
|
|
const upperRow = `
|
|
result.x = ${getChannel};
|
|
if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
|
|
++${sourceLoc[this.rank - 1]};
|
|
result.y = ${getChannel};
|
|
--${sourceLoc[this.rank - 1]};
|
|
}
|
|
`;
|
|
const lowerRow = this.rank === 1 ? '' : `
|
|
--${coords[this.rank - 1]};
|
|
if (++${coords[this.rank - 2]} < ${destSize[this.rank - 2]}) {
|
|
++${sourceLoc[this.rank - 2]};
|
|
result.z = ${getChannel};
|
|
if (++${coords[this.rank - 1]} < ${destSize[this.rank - 1]}) {
|
|
++${sourceLoc[this.rank - 1]};
|
|
result.w = ${getChannel};
|
|
}
|
|
}
|
|
`;
|
|
const sourceLocSetup = this.rank <= 4 ?
|
|
`sourceLoc = coords +
|
|
${dtype}(${destSize.map((_, i) => `start[${i}]`).join()});` :
|
|
destSize.map((_, i) => `${sourceLoc[i]} = ${coords[i]} + start[${i}];`)
|
|
.join('\n');
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
${dtype} sourceLoc;
|
|
${sourceLocSetup}
|
|
vec4 result = vec4(0.);
|
|
${upperRow}
|
|
${lowerRow}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function shallowSlice(x, begin, size, backend) {
|
|
const xTexData = backend.texData.get(x.dataId);
|
|
const t = backend.makeTensorInfo(size, x.dtype);
|
|
const newTexData = backend.texData.get(t.dataId);
|
|
|
|
Object.assign(newTexData, xTexData);
|
|
newTexData.refCount = 1;
|
|
newTexData.shape = size;
|
|
newTexData.dtype = x.dtype;
|
|
let flatOffset = computeFlatOffset(begin, computeStrides(x.shape));
|
|
if (xTexData.slice) {
|
|
|
|
|
|
flatOffset += xTexData.slice.flatOffset;
|
|
}
|
|
newTexData.slice = {
|
|
flatOffset,
|
|
|
|
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
|
|
};
|
|
|
|
const refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
|
|
backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
|
|
return t;
|
|
}
|
|
function slice(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { begin, size } = attrs;
|
|
const [$begin, $size] = parseSliceParams(x, begin, size);
|
|
assertParamsValid(x, $begin, $size);
|
|
if (sizeFromShape($size) === 0) {
|
|
return backend.makeTensorInfo($size, x.dtype, []);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
|
|
const xTexData = backend.texData.get(x.dataId);
|
|
const outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
|
|
return backend.makeTensorInfo($size, x.dtype, outValues);
|
|
}
|
|
const { isPacked } = backend.texData.get(x.dataId);
|
|
const isContinous = isSliceContinous(x.shape, $begin, $size);
|
|
if (isPacked || !isContinous) {
|
|
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new SlicePackedProgram($size) :
|
|
new SliceProgram($size);
|
|
const customValues = [$begin];
|
|
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
|
|
}
|
|
backend.uploadToGPU(x.dataId);
|
|
return shallowSlice(x, $begin, $size, backend);
|
|
}
|
|
const sliceConfig = {
|
|
kernelName: Slice,
|
|
backendName: 'webgl',
|
|
kernelFunc: slice
|
|
};
|
|
|
|
|
|
const batchToSpaceND$1 = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { blockShape, crops } = attrs;
|
|
assert$1(x.shape.length <= 4, () => 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
|
|
'implemented yet');
|
|
const prod = blockShape.reduce((a, b) => a * b);
|
|
const reshaped = getReshaped(x.shape, blockShape, prod);
|
|
const permuted = getPermuted(reshaped.length, blockShape.length);
|
|
const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
|
|
const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
|
|
const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
|
|
const toDispose = [];
|
|
const reshapedIntermediate = reshape$1({ inputs: { x }, backend, attrs: { shape: reshaped } });
|
|
const transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend, attrs: { perm: permuted } });
|
|
const reshapedIntermediate2 = reshape$1({
|
|
inputs: { x: transposedIntermediate },
|
|
backend,
|
|
attrs: { shape: reshapedPermuted }
|
|
});
|
|
const sliced = slice({
|
|
inputs: { x: reshapedIntermediate2 },
|
|
backend,
|
|
attrs: { begin: sliceBeginCoords, size: sliceSize }
|
|
});
|
|
toDispose.push(reshapedIntermediate);
|
|
toDispose.push(transposedIntermediate);
|
|
toDispose.push(reshapedIntermediate2);
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return sliced;
|
|
};
|
|
const batchToSpaceNDConfig$1 = {
|
|
kernelName: BatchToSpaceND,
|
|
backendName: 'webgl',
|
|
kernelFunc: batchToSpaceND$1
|
|
};
|
|
|
|
|
|
function bincount$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, weights } = inputs;
|
|
const { size } = attrs;
|
|
const xVals = backend.readSync(x.dataId);
|
|
const weightsVals = backend.readSync(weights.dataId);
|
|
const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
|
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
|
}
|
|
const bincountConfig$1 = {
|
|
kernelName: Bincount,
|
|
backendName: 'webgl',
|
|
kernelFunc: bincount$1
|
|
};
|
|
|
|
|
|
const BITWISEAND = `
|
|
int r = int(a.r) & int(b.r);
|
|
int g = int(a.g) & int(b.g);
|
|
int rb = int(a.b) & int(b.b);
|
|
int ra = int(a.a) & int(b.a);
|
|
return vec4(r, g, rb, ra);
|
|
`;
|
|
const BITWISEAND_UNPACKED = `
|
|
return float(int(a.r) & int(b.r));
|
|
`;
|
|
function bitwiseAnd(args) {
|
|
const { inputs, backend } = args;
|
|
const { a, b } = inputs;
|
|
const shouldUsePackedProgram = env().getBool('WEBGL_PACK_BINARY_OPERATIONS');
|
|
const versionNumber = env().getNumber('WEBGL_VERSION');
|
|
|
|
|
|
if ((backend.shouldExecuteOnCPU([a, b])) || versionNumber === 1) {
|
|
const aVals = backend.texData.get(a.dataId).values;
|
|
const bVals = backend.texData.get(b.dataId).values;
|
|
const [outValues, outShape] = bitwiseAndImplCPU(a.shape, b.shape, aVals, bVals, a.dtype);
|
|
const out = backend.makeTensorInfo(outShape, a.dtype);
|
|
const outData = backend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
return out;
|
|
}
|
|
let program;
|
|
if (shouldUsePackedProgram) {
|
|
program = new BinaryOpPackedProgram(BITWISEAND, a.shape, b.shape, false);
|
|
}
|
|
else {
|
|
program = new BinaryOpProgram(BITWISEAND_UNPACKED, a.shape, b.shape);
|
|
}
|
|
return backend.runWebGLProgram(program, [a, b], a.dtype);
|
|
}
|
|
const bitwiseAndConfig = {
|
|
kernelName: BitwiseAnd,
|
|
backendName: 'webgl',
|
|
kernelFunc: bitwiseAnd
|
|
};
|
|
|
|
|
|
function broadcastArgs$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { s0, s1 } = inputs;
|
|
const s0Vals = backend.readSync(s0.dataId);
|
|
const s1Vals = backend.readSync(s1.dataId);
|
|
const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
|
|
return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
|
|
}
|
|
const broadcastArgsConfig$1 = {
|
|
kernelName: BroadcastArgs,
|
|
backendName: 'webgl',
|
|
kernelFunc: broadcastArgs$1
|
|
};
|
|
|
|
|
|
const NOT_EQUAL = `return float(a != b);`;
|
|
const notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' });
|
|
const notEqualConfig = {
|
|
kernelName: NotEqual,
|
|
backendName: 'webgl',
|
|
kernelFunc: notEqual,
|
|
};
|
|
|
|
|
|
function real(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
const inputData = backend.texData.get(input.dataId);
|
|
return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend });
|
|
}
|
|
const realConfig = {
|
|
kernelName: Real,
|
|
backendName: 'webgl',
|
|
kernelFunc: real
|
|
};
|
|
|
|
|
|
const TO_INT = `return float(int(x));`;
|
|
function int(input, backend) {
|
|
const program = new UnaryOpProgram(input.shape, TO_INT);
|
|
const output = backend.runWebGLProgram(program, [input], 'int32');
|
|
return { dataId: output.dataId, shape: output.shape, dtype: output.dtype };
|
|
}
|
|
|
|
|
|
function cast$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { dtype } = attrs;
|
|
|
|
if (dtype === 'complex64') {
|
|
if (x.dtype === 'complex64') {
|
|
return identity({ inputs: { x }, backend });
|
|
}
|
|
|
|
const zerosTensor = zeros$1(x.shape);
|
|
const floatX = cast$1({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
|
const result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend });
|
|
zerosTensor.dispose();
|
|
backend.disposeIntermediateTensorInfo(floatX);
|
|
return result;
|
|
}
|
|
|
|
if (x.dtype === 'complex64') {
|
|
const realPart = real({ inputs: { input: x }, backend });
|
|
const result = cast$1({ inputs: { x: realPart }, backend, attrs: { dtype } });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
return result;
|
|
}
|
|
if (!hasEncodingLoss(x.dtype, dtype)) {
|
|
|
|
|
|
const result = identity({ inputs: { x }, backend });
|
|
return { dataId: result.dataId, shape: result.shape, dtype };
|
|
}
|
|
if (backend.shouldExecuteOnCPU([x])) {
|
|
const values = backend.texData.get(x.dataId).values;
|
|
const [resultShape, resultType, resultData] = castImplCPU(values, x.shape, x.dtype, dtype);
|
|
return backend.makeTensorInfo(resultShape, resultType, resultData);
|
|
}
|
|
if (dtype === 'int32') {
|
|
return int(x, backend);
|
|
}
|
|
if (dtype === 'bool') {
|
|
const zerosTensorInfo = backend.makeTensorInfo([], 'bool', getTypedArrayFromDType('bool', 1));
|
|
const binaryInputs = { a: x, b: zerosTensorInfo };
|
|
const result = notEqual({ inputs: binaryInputs, backend });
|
|
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
|
|
return result;
|
|
}
|
|
throw new Error(`Error in Cast: failed to cast ${x.dtype} to ${dtype}`);
|
|
}
|
|
const castConfig = {
|
|
kernelName: Cast,
|
|
backendName: 'webgl',
|
|
kernelFunc: cast$1
|
|
};
|
|
|
|
|
|
const CEIL = `return ceil(x);`;
|
|
const ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU });
|
|
const ceilConfig = {
|
|
kernelName: Ceil,
|
|
backendName: 'webgl',
|
|
kernelFunc: ceil
|
|
};
|
|
|
|
|
|
class ClipProgram {
|
|
constructor(aShape) {
|
|
this.variableNames = ['A'];
|
|
this.customUniforms = [
|
|
{ name: 'minVal', type: 'float' },
|
|
{ name: 'maxVal', type: 'float' }
|
|
];
|
|
this.outputShape = aShape;
|
|
this.userCode = `
|
|
|
|
void main() {
|
|
float value = getAAtOutCoords();
|
|
if (isnan(value)) {
|
|
setOutput(value);
|
|
return;
|
|
}
|
|
|
|
setOutput(clamp(value, minVal, maxVal));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ClipPackedProgram {
|
|
constructor(aShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [
|
|
{ name: 'minVal', type: 'float' },
|
|
{ name: 'maxVal', type: 'float' }
|
|
];
|
|
this.outputShape = aShape;
|
|
this.userCode = `
|
|
void main() {
|
|
vec4 value = getAAtOutCoords();
|
|
|
|
if (any(isnan(value))) {
|
|
setOutput(value);
|
|
return;
|
|
}
|
|
|
|
setOutput(clamp(value, vec4(minVal), vec4(maxVal)));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function clipByValue$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { clipValueMin, clipValueMax } = attrs;
|
|
let program;
|
|
if (env().getBool('WEBGL_PACK_CLIP')) {
|
|
program = new ClipPackedProgram(x.shape);
|
|
}
|
|
else {
|
|
program = new ClipProgram(x.shape);
|
|
}
|
|
const customValues = [[clipValueMin], [clipValueMax]];
|
|
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
|
|
}
|
|
const clipByValueConfig$1 = {
|
|
kernelName: ClipByValue,
|
|
backendName: 'webgl',
|
|
kernelFunc: clipByValue$1
|
|
};
|
|
|
|
|
|
class ComplexAbsProgram {
|
|
constructor(shape) {
|
|
this.variableNames = ['real', 'imag'];
|
|
this.outputShape = shape;
|
|
this.userCode = `
|
|
void main() {
|
|
float re = abs(getRealAtOutCoords());
|
|
float im = abs(getImagAtOutCoords());
|
|
float mx = max(re, im);
|
|
|
|
|
|
|
|
|
|
setOutput(
|
|
mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))
|
|
);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
function makeComplexComponentTensorInfo(complexTensor, complexPart) {
|
|
return {
|
|
dataId: complexPart.dataId,
|
|
dtype: complexPart.dtype,
|
|
shape: complexTensor.shape
|
|
};
|
|
}
|
|
function complexAbs$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
const xData = backend.texData.get(x.dataId);
|
|
const program = new ComplexAbsProgram(x.shape);
|
|
const programInputs = [
|
|
makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
|
|
makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag),
|
|
];
|
|
return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
|
|
}
|
|
const complexAbsConfig$1 = {
|
|
kernelName: ComplexAbs,
|
|
backendName: 'webgl',
|
|
kernelFunc: complexAbs$1
|
|
};
|
|
|
|
|
|
class ConcatProgram {
|
|
|
|
constructor(shapes) {
|
|
this.outputShape = [];
|
|
this.outputShape = computeOutShape$1(shapes, 1 );
|
|
this.variableNames = shapes.map((_, i) => `T${i}`);
|
|
const offsets = new Array(shapes.length - 1);
|
|
offsets[0] = shapes[0][1];
|
|
for (let i = 1; i < offsets.length; i++) {
|
|
offsets[i] = offsets[i - 1] + shapes[i][1];
|
|
}
|
|
const snippets = [`if (yC < ${offsets[0]}) setOutput(getT0(yR, yC));`];
|
|
for (let i = 1; i < offsets.length; i++) {
|
|
const shift = offsets[i - 1];
|
|
snippets.push(`else if (yC < ${offsets[i]}) ` +
|
|
`setOutput(getT${i}(yR, yC-${shift}));`);
|
|
}
|
|
const lastIndex = offsets.length;
|
|
const lastShift = offsets[offsets.length - 1];
|
|
snippets.push(`else setOutput(getT${lastIndex}(yR, yC-${lastShift}));`);
|
|
this.userCode = `
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int yR = coords.x;
|
|
int yC = coords.y;
|
|
|
|
${snippets.join('\n ')}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ConcatPackedProgram {
|
|
constructor(shapes, axis) {
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = [];
|
|
this.outputShape = computeOutShape$1(shapes, axis);
|
|
const shape = this.outputShape;
|
|
const rank = shape.length;
|
|
const dtype = getCoordsDataType(rank);
|
|
const coords = getChannels('coords', rank);
|
|
const channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
|
|
this.variableNames = shapes.map((_, i) => `T${i}`);
|
|
const offsets = new Array(shapes.length - 1);
|
|
offsets[0] = shapes[0][axis];
|
|
for (let i = 1; i < offsets.length; i++) {
|
|
offsets[i] = offsets[i - 1] + shapes[i][axis];
|
|
}
|
|
const channel = channels[axis];
|
|
const lastChannels = channels.slice(-2);
|
|
const allChannels = channels.join();
|
|
let getValueSnippet = `if (${channel} < ${offsets[0]}) {
|
|
return getChannel(
|
|
getT0(${allChannels}), vec2(${lastChannels.join()}));
|
|
}`;
|
|
for (let i = 1; i < offsets.length; i++) {
|
|
const shift = offsets[i - 1];
|
|
|
|
|
|
|
|
|
|
getValueSnippet += `
|
|
if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {
|
|
return getChannel(
|
|
getT${i}(${shiftedChannels(channels, channel, shift)}),
|
|
vec2(${shiftedChannels(lastChannels, channel, shift)}));
|
|
}`;
|
|
}
|
|
const lastIndex = offsets.length;
|
|
const shift = offsets[offsets.length - 1];
|
|
getValueSnippet += `
|
|
return getChannel(
|
|
getT${lastIndex}(${shiftedChannels(channels, channel, shift)}),
|
|
vec2(${shiftedChannels(lastChannels, channel, shift)}));`;
|
|
this.userCode = `
|
|
float getValue(${channels.map(x => 'int ' + x)}) {
|
|
${getValueSnippet}
|
|
}
|
|
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
|
|
|
|
${coords[rank - 1]} = ${coords[rank - 1]} + 1;
|
|
if (${coords[rank - 1]} < ${shape[rank - 1]}) {
|
|
result.g = getValue(${coords});
|
|
}
|
|
|
|
${coords[rank - 2]} = ${coords[rank - 2]} + 1;
|
|
if (${coords[rank - 2]} < ${shape[rank - 2]}) {
|
|
result.a = getValue(${coords});
|
|
}
|
|
|
|
${coords[rank - 1]} = ${coords[rank - 1]} - 1;
|
|
if (${coords[rank - 2]} < ${shape[rank - 2]} &&
|
|
${coords[rank - 1]} < ${shape[rank - 1]}) {
|
|
result.b = getValue(${coords});
|
|
}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
function shiftedChannels(channels, channel, shift) {
|
|
const channelIdx = channels.indexOf(channel);
|
|
const res = channels.map((c, idx) => {
|
|
if (idx === channelIdx) {
|
|
return `${c} - ${shift}`;
|
|
}
|
|
else {
|
|
return c;
|
|
}
|
|
});
|
|
return res.join();
|
|
}
|
|
|
|
|
|
function imag$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
const inputData = backend.texData.get(input.dataId);
|
|
return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend });
|
|
}
|
|
const imagConfig$1 = {
|
|
kernelName: Imag,
|
|
backendName: 'webgl',
|
|
kernelFunc: imag$1
|
|
};
|
|
|
|
|
|
function concatImpl(inputs, axis, backend) {
|
|
const dtype = inputs[0].dtype;
|
|
if (dtype === 'complex64') {
|
|
const reals = inputs.map((t) => real({ inputs: { input: t }, backend }));
|
|
const imags = inputs.map((t) => imag$1({ inputs: { input: t }, backend }));
|
|
const realConcated = concatImpl(reals, axis, backend);
|
|
const imagConcated = concatImpl(imags, axis, backend);
|
|
const result = complex({ inputs: { real: realConcated, imag: imagConcated }, backend });
|
|
reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
|
|
imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
|
|
backend.disposeIntermediateTensorInfo(realConcated);
|
|
backend.disposeIntermediateTensorInfo(imagConcated);
|
|
return result;
|
|
}
|
|
let runOnCpu = backend.shouldExecuteOnCPU(inputs);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (dtype === 'string') {
|
|
runOnCpu = true;
|
|
}
|
|
if (runOnCpu) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const tensors2D = inputs.map(t => {
|
|
const innerSize = sizeFromShape(t.shape.slice(axis));
|
|
const shape = [-1, innerSize];
|
|
return reshape$1({ inputs: { x: t }, backend, attrs: { shape } });
|
|
});
|
|
const inputsValShapes = tensors2D.map(t => {
|
|
return { vals: backend.readSync(t.dataId), shape: t.shape };
|
|
});
|
|
|
|
const outShape = computeOutShape$1(tensors2D.map(t => t.shape), 1 );
|
|
const simplyConcat = tensors2D[0].shape[0] === 1;
|
|
const outVals = concatImplCPU(inputsValShapes, outShape, dtype, simplyConcat);
|
|
const finalOutShape = computeOutShape$1(inputs.map(t => t.shape), axis);
|
|
const outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
|
|
tensors2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return outInfo;
|
|
}
|
|
|
|
const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
|
|
const shouldPack = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') &&
|
|
$inputs[0].shape.length > 1;
|
|
if ($inputs.length === 1) {
|
|
|
|
const program = shouldPack ?
|
|
new UnaryOpProgram(inputs[0].shape, CLONE) :
|
|
new UnaryOpPackedProgram(inputs[0].shape, CLONE);
|
|
return backend.runWebGLProgram(program, inputs, dtype);
|
|
}
|
|
const maxTexturesInShader = env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER');
|
|
if ($inputs.length > maxTexturesInShader) {
|
|
const reducedInputs = [];
|
|
for (let i = 0; i < $inputs.length; i += maxTexturesInShader) {
|
|
const subArray = $inputs.slice(i, i + maxTexturesInShader);
|
|
reducedInputs.push(concatImpl(subArray, axis, backend));
|
|
}
|
|
const result = concatImpl(reducedInputs, axis, backend);
|
|
for (const i of reducedInputs) {
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return result;
|
|
}
|
|
if (shouldPack) {
|
|
const program = new ConcatPackedProgram($inputs.map(t => t.shape), axis);
|
|
return backend.runWebGLProgram(program, $inputs, dtype);
|
|
}
|
|
const { tensors2D, outShape } = computeTensors2D($inputs, axis, backend);
|
|
const program = new ConcatProgram(tensors2D.map(t => t.shape));
|
|
const result = backend.runWebGLProgram(program, tensors2D, dtype);
|
|
tensors2D.forEach(r => backend.disposeIntermediateTensorInfo(r));
|
|
const reshapedResult = reshape$1({ inputs: { x: result }, attrs: { shape: outShape }, backend });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return reshapedResult;
|
|
}
|
|
function computeTensors2D(inputs, axis, backend) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const outShape = computeOutShape$1(inputs.map(t => t.shape), axis);
|
|
const tensors2D = inputs.map(x => reshape$1({
|
|
inputs: { x },
|
|
attrs: { shape: [-1, sizeFromShape(x.shape.slice(axis))] },
|
|
backend
|
|
}));
|
|
return { tensors2D, outShape };
|
|
}
|
|
|
|
|
|
function concat$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { axis } = attrs;
|
|
const $axis = parseAxisParam(axis, inputs[0].shape)[0];
|
|
const shapes = inputs.map(t => t.shape);
|
|
assertParamsConsistent(shapes, $axis);
|
|
const outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
|
|
if (sizeFromShape(outShape) === 0) {
|
|
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
|
|
}
|
|
|
|
const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
|
|
if ($inputs.length === 1) {
|
|
return identity({ inputs: { x: $inputs[0] }, backend });
|
|
}
|
|
return concatImpl($inputs, $axis, backend);
|
|
}
|
|
const concatConfig$1 = {
|
|
kernelName: Concat,
|
|
backendName: 'webgl',
|
|
kernelFunc: concat$1
|
|
};
|
|
|
|
|
|
class Conv2DProgram {
|
|
constructor(convInfo, addBias = false, activation = null, hasPreluActivationWeights = false, hasLeakyreluAlpha = false) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
|
|
const inputDepthVec4Remainder = convInfo.inChannels % 4;
|
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
const rowDim = isChannelsLast ? 1 : 2;
|
|
const colDim = isChannelsLast ? 2 : 3;
|
|
const channelDim = isChannelsLast ? 3 : 1;
|
|
let activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivationWeights) {
|
|
activationSnippet = `float activation(float a) {
|
|
float b = getPreluActivationWeightsAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else if (hasLeakyreluAlpha) {
|
|
activationSnippet = `float activation(float a) {
|
|
float b = getLeakyreluAlphaAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else {
|
|
activationSnippet = `
|
|
float activation(float x) {
|
|
${activation}
|
|
}
|
|
`;
|
|
}
|
|
applyActivationSnippet = `result = activation(result);`;
|
|
}
|
|
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = `
|
|
${activationSnippet}
|
|
|
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int d2 = coords[${channelDim}];
|
|
|
|
ivec2 xRCCorner =
|
|
ivec2(coords[${rowDim}], coords[${colDim}]) * strides - pads;
|
|
int xRCorner = xRCCorner.x;
|
|
int xCCorner = xRCCorner.y;
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
int xR = xRCorner + wR * ${dilationHeight};
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
int xC = xCCorner + wC * ${dilationWidth};
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
|
|
vec4 wValues = vec4(
|
|
getW(wR, wC, d1, d2),
|
|
getW(wR, wC, d1 + 1, d2),
|
|
getW(wR, wC, d1 + 2, d2),
|
|
getW(wR, wC, d1 + 3, d2)
|
|
);
|
|
|
|
if (${isChannelsLast}) {
|
|
vec4 xValues = vec4(
|
|
getX(batch, xR, xC, d1),
|
|
getX(batch, xR, xC, d1 + 1),
|
|
getX(batch, xR, xC, d1 + 2),
|
|
getX(batch, xR, xC, d1 + 3)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
} else {
|
|
vec4 xValues = vec4(
|
|
getX(batch, d1, xR, xC),
|
|
getX(batch, d1 + 1, xR, xC),
|
|
getX(batch, d1 + 2, xR, xC),
|
|
getX(batch, d1 + 3, xR, xC)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
}
|
|
}
|
|
|
|
if (${inputDepthVec4Remainder === 1}) {
|
|
|
|
if (${isChannelsLast}) {
|
|
dotProd +=
|
|
getX(batch, xR, xC, ${inputDepthNearestVec4}) *
|
|
getW(wR, wC, ${inputDepthNearestVec4}, d2);
|
|
} else {
|
|
dotProd +=
|
|
getX(batch, ${inputDepthNearestVec4}, xR, xC) *
|
|
getW(wR, wC, ${inputDepthNearestVec4}, d2);
|
|
}
|
|
|
|
} else if (${inputDepthVec4Remainder === 2}) {
|
|
vec2 wValues = vec2(
|
|
getW(wR, wC, ${inputDepthNearestVec4}, d2),
|
|
getW(wR, wC, ${inputDepthNearestVec4} + 1, d2)
|
|
);
|
|
|
|
if (${isChannelsLast}) {
|
|
vec2 xValues = vec2(
|
|
getX(batch, xR, xC, ${inputDepthNearestVec4}),
|
|
getX(batch, xR, xC, ${inputDepthNearestVec4} + 1)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
} else {
|
|
vec2 xValues = vec2(
|
|
getX(batch, ${inputDepthNearestVec4}, xR, xC),
|
|
getX(batch, ${inputDepthNearestVec4} + 1, xR, xC)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
}
|
|
|
|
} else if (${inputDepthVec4Remainder === 3}) {
|
|
vec3 wValues = vec3(
|
|
getW(wR, wC, ${inputDepthNearestVec4}, d2),
|
|
getW(wR, wC, ${inputDepthNearestVec4} + 1, d2),
|
|
getW(wR, wC, ${inputDepthNearestVec4} + 2, d2)
|
|
);
|
|
|
|
if (${isChannelsLast}) {
|
|
vec3 xValues = vec3(
|
|
getX(batch, xR, xC, ${inputDepthNearestVec4}),
|
|
getX(batch, xR, xC, ${inputDepthNearestVec4} + 1),
|
|
getX(batch, xR, xC, ${inputDepthNearestVec4} + 2)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
} else {
|
|
vec3 xValues = vec3(
|
|
getX(batch, ${inputDepthNearestVec4}, xR, xC),
|
|
getX(batch, ${inputDepthNearestVec4} + 1, xR, xC),
|
|
getX(batch, ${inputDepthNearestVec4} + 2, xR, xC)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
float result = dotProd;
|
|
${addBiasSnippet}
|
|
${applyActivationSnippet}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class Conv3DProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
const padFront = convInfo.padInfo.front;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const filterDepth = convInfo.filterDepth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
|
|
const inputDepthVec4Remainder = convInfo.inChannels % 4;
|
|
this.userCode = `
|
|
const ivec3 strides = ivec3(${strideDepth}, ${strideHeight}, ${strideWidth});
|
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int d2 = coords.u;
|
|
|
|
ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;
|
|
int xFCorner = xFRCCorner.x;
|
|
int xRCorner = xFRCCorner.y;
|
|
int xCCorner = xFRCCorner.z;
|
|
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
for (int wF = 0; wF < ${filterDepth}; wF++) {
|
|
int xF = xFCorner + wF * ${dilationDepth};
|
|
|
|
if (xF < 0 || xF >= ${convInfo.inDepth}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
int xR = xRCorner + wR * ${dilationHeight};
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
int xC = xCCorner + wC * ${dilationWidth};
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
for (int d1 = 0; d1 < ${inputDepthNearestVec4}; d1 += 4) {
|
|
vec4 xValues = vec4(
|
|
getX(batch, xF, xR, xC, d1),
|
|
getX(batch, xF, xR, xC, d1 + 1),
|
|
getX(batch, xF, xR, xC, d1 + 2),
|
|
getX(batch, xF, xR, xC, d1 + 3)
|
|
);
|
|
vec4 wValues = vec4(
|
|
getW(wF, wR, wC, d1, d2),
|
|
getW(wF, wR, wC, d1 + 1, d2),
|
|
getW(wF, wR, wC, d1 + 2, d2),
|
|
getW(wF, wR, wC, d1 + 3, d2)
|
|
);
|
|
|
|
dotProd += dot(xValues, wValues);
|
|
}
|
|
|
|
if (${inputDepthVec4Remainder === 1}) {
|
|
dotProd +=
|
|
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}) *
|
|
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2);
|
|
} else if (${inputDepthVec4Remainder === 2}) {
|
|
vec2 xValues = vec2(
|
|
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
|
|
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1)
|
|
);
|
|
vec2 wValues = vec2(
|
|
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
|
|
getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
} else if (${inputDepthVec4Remainder === 3}) {
|
|
vec3 xValues = vec3(
|
|
getX(batch, xF, xR, xC, ${inputDepthNearestVec4}),
|
|
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 1),
|
|
getX(batch, xF, xR, xC, ${inputDepthNearestVec4} + 2)
|
|
);
|
|
vec3 wValues = vec3(
|
|
getW(wF, wR, wC, ${inputDepthNearestVec4}, d2),
|
|
getW(wF, wR, wC, ${inputDepthNearestVec4} + 1, d2),
|
|
getW(wF, wR, wC, ${inputDepthNearestVec4} + 2, d2)
|
|
);
|
|
dotProd += dot(xValues, wValues);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class Conv2DPackedProgram {
|
|
constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [
|
|
{ name: 'pads', type: 'ivec2' },
|
|
{ name: 'strides', type: 'ivec2' },
|
|
{ name: 'dilations', type: 'ivec2' },
|
|
{ name: 'inDims', type: 'ivec2' },
|
|
];
|
|
this.outputShape = convInfo.outShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const padLeft = convInfo.padInfo.left;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const texelsAcross = filterWidth;
|
|
let mainLoop = `
|
|
int xR; int xC; int xCOffset;
|
|
vec4 wTexel; vec4 previous; vec4 final;`;
|
|
for (let c = 0; c < filterWidth; c++) {
|
|
mainLoop += `
|
|
vec4 xTexelC${c * 2};
|
|
int xTexelC${c * 2}Ready;
|
|
vec4 xTexelC${c * 2 + 1};
|
|
int xTexelC${c * 2 + 1}Ready;
|
|
vec4 xC${c};`;
|
|
}
|
|
|
|
mainLoop += `
|
|
for (int r = 0; r < ${filterHeight}; r++) {
|
|
for (int d1 = 0; d1 < ${convInfo.inChannels}; d1 += 2) {
|
|
`;
|
|
for (let c = 0; c < filterWidth; c++) {
|
|
mainLoop += `
|
|
xTexelC${c * 2} = vec4(0.0);
|
|
xTexelC${c * 2}Ready = 0;
|
|
xTexelC${c * 2 + 1} = vec4(0.0);
|
|
xTexelC${c * 2 + 1}Ready = 0;
|
|
xC${c} = vec4(0.0);`;
|
|
}
|
|
mainLoop += `
|
|
xR = xRCorner + r * dilations[0];
|
|
if (xR >=0 && xR < inDims[0]) {
|
|
`;
|
|
for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
|
|
const colIndex = texelC * 2;
|
|
mainLoop += `
|
|
xC = xCCorner + ${colIndex * dilationWidth};
|
|
`;
|
|
if (strideWidth === 1) {
|
|
if (colIndex < filterWidth) {
|
|
|
|
if (padLeft % 2 === 1) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mainLoop += `
|
|
xCOffset = xC + 1;
|
|
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
`;
|
|
|
|
|
|
if (dilationWidth === 1 && colIndex > 0) {
|
|
mainLoop += `
|
|
xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
|
|
`;
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
xCOffset = xC + 1 - 2;
|
|
|
|
if (xCOffset >= 0 && xCOffset < inDims[1]) {
|
|
previous = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
previous.zw = vec2(0.0);
|
|
}
|
|
|
|
xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
|
|
} else {
|
|
xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
|
|
mainLoop += `
|
|
if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xC, d1);
|
|
if (xC + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex} = xTexelC${colIndex};
|
|
`;
|
|
}
|
|
if (colIndex + 1 < filterWidth) {
|
|
|
|
|
|
|
|
|
|
|
|
const nextTexelOffset = padLeft % 2 === 0 ?
|
|
nearestLargerEven(dilationWidth) :
|
|
dilationWidth;
|
|
if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
|
|
(dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
|
|
mainLoop += `
|
|
xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
|
|
|
|
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
`;
|
|
|
|
|
|
if (dilationWidth > 1) {
|
|
mainLoop += `
|
|
xCOffset -= 2;
|
|
if (xCOffset >= 0 && xCOffset < inDims[1]) {
|
|
previous = getX(batch, xR, xCOffset, d1);
|
|
xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy);
|
|
} else {
|
|
xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy);
|
|
}
|
|
`;
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
if (nextTexelOffset === 1) {
|
|
mainLoop += `
|
|
xC${colIndex + 1} = xTexelC${colIndex};
|
|
`;
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
xCOffset = xC + ${nextTexelOffset};
|
|
|
|
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex + 1} = xTexelC${colIndex + 1};
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
if (colIndex < filterWidth) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (padLeft % 2 === 1) {
|
|
mainLoop += `
|
|
xCOffset = xC + 1 - strides[1];
|
|
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
|
|
if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
|
|
|
|
|
|
if (xC + 2 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
|
|
`;
|
|
if (colIndex + 1 < filterWidth) {
|
|
mainLoop += `
|
|
final = vec4(0.0);
|
|
xCOffset = xC + 1 + strides[1];
|
|
if(xCOffset >= 0 && xCOffset < inDims[1]) {
|
|
final = getX(batch, xR, xCOffset, d1);
|
|
}
|
|
xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xC, d1);
|
|
if (xC + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
|
|
xCOffset = xC + strides[1];
|
|
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex} = vec4(
|
|
xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
|
|
`;
|
|
if (colIndex + 1 < filterWidth) {
|
|
mainLoop += `
|
|
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
if (colIndex < filterWidth) {
|
|
mainLoop += `
|
|
wTexel = getW(r, ${colIndex}, d1, d2);
|
|
dotProd += xC${colIndex}.xxzz * vec4(wTexel.xy, wTexel.xy);
|
|
if(d1 + 1 < ${convInfo.inChannels}) {
|
|
dotProd += xC${colIndex}.yyww * vec4(wTexel.zw, wTexel.zw);
|
|
}
|
|
`;
|
|
if (colIndex + 1 < filterWidth) {
|
|
mainLoop += `
|
|
wTexel = getW(r, ${colIndex + 1}, d1, d2);
|
|
dotProd += xC${colIndex + 1}.xxzz * vec4(wTexel.xy, wTexel.xy);
|
|
if(d1 + 1 < ${convInfo.inChannels}) {
|
|
dotProd += xC${colIndex + 1}.yyww * vec4(wTexel.zw, wTexel.zw);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
mainLoop += `
|
|
}
|
|
`;
|
|
mainLoop += `
|
|
}
|
|
`;
|
|
mainLoop += `
|
|
}
|
|
`;
|
|
let activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = `vec4 activation(vec4 a) {
|
|
vec4 b = getPreluActivationWeightsAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else if (hasLeakyReluAlpha) {
|
|
activationSnippet = `vec4 activation(vec4 a) {
|
|
vec4 b = getLeakyreluAlphaAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else {
|
|
activationSnippet = `vec4 activation(vec4 x) {
|
|
${activation}
|
|
}`;
|
|
}
|
|
applyActivationSnippet = `result = activation(result);`;
|
|
}
|
|
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyReluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = `
|
|
${activationSnippet}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
|
int d2 = coords.w;
|
|
int xRCorner = xRCCorner.x;
|
|
int xCCorner = xRCCorner.y;
|
|
|
|
|
|
vec4 dotProd = vec4(0.000000000000001);
|
|
|
|
${mainLoop}
|
|
|
|
vec4 result = dotProd - vec4(0.000000000000001);
|
|
${addBiasSnippet}
|
|
${applyActivationSnippet}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class Im2ColPackedProgram {
|
|
constructor(outputShape, convInfo) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [
|
|
{ name: 'inputShape', type: 'ivec4' },
|
|
{ name: 'pad', type: 'ivec2' },
|
|
{ name: 'stride', type: 'ivec2' },
|
|
{ name: 'dilation', type: 'ivec2' },
|
|
{ name: 'inChannels', type: 'int' },
|
|
{ name: 'itemsPerBlockRow', type: 'int' },
|
|
{ name: 'outWidth', type: 'int' },
|
|
];
|
|
this.outputShape = outputShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const { dataFormat } = convInfo;
|
|
const glsl = getGlslDifferences();
|
|
const isChannelsLast = dataFormat === 'channelsLast';
|
|
const rowDim = isChannelsLast ? 1 : 2;
|
|
const colDim = isChannelsLast ? 2 : 3;
|
|
const boundsCheckingSnippet = this.enableShapeUniforms ?
|
|
'if(blockIndex < outShape[2] && pos < outShape[1]) {' :
|
|
`if(blockIndex < ${outputShape[2]} && pos < ${outputShape[1]}) {`;
|
|
let unrolled = ``;
|
|
for (let row = 0; row <= 1; row++) {
|
|
for (let col = 0; col <= 1; col++) {
|
|
unrolled += `
|
|
blockIndex = rc.z + ${col};
|
|
pos = rc.y + ${row};
|
|
|
|
${boundsCheckingSnippet}
|
|
offsetY = int(blockIndex / outWidth) * stride[0] - pad[0];
|
|
d0 = offsetY + dilation[0] * (pos / itemsPerBlockRow);
|
|
|
|
if(d0 < inputShape[${rowDim}] && d0 >= 0) {
|
|
|
|
|
|
|
|
offsetX = imod(blockIndex, outWidth) * stride[1] - pad[1];
|
|
d1 = offsetX + dilation[1] * (imod(pos, itemsPerBlockRow) /
|
|
inChannels);
|
|
|
|
if(d1 < inputShape[${colDim}] && d1 >= 0) {
|
|
|
|
ch = imod(pos, inChannels);
|
|
|
|
if (${isChannelsLast}) {
|
|
innerDims = vec2(d1, ch);
|
|
result[${row * 2 + col}] = getChannel(
|
|
getA(rc.x, d0, int(innerDims.x),
|
|
int(innerDims.y)), innerDims);
|
|
} else {
|
|
innerDims = vec2(d0, d1);
|
|
result[${row * 2 + col}] = getChannel(
|
|
getA(rc.x, ch, int(innerDims.x),
|
|
int(innerDims.y)), innerDims);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
this.userCode = `
|
|
void main() {
|
|
ivec3 rc = getOutputCoords();
|
|
|
|
vec4 result = vec4(0);
|
|
|
|
int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
|
|
vec2 innerDims;
|
|
|
|
${unrolled}
|
|
|
|
${glsl.output} = result;
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function getShapeForBatchMatMul(shape, isChannelsLast) {
|
|
const length = shape.length;
|
|
if (length >= 3) {
|
|
return isChannelsLast ?
|
|
[
|
|
...shape.slice(0, -3) ,
|
|
shape[length - 3] * shape[length - 2] ,
|
|
shape[length - 1]
|
|
] :
|
|
[
|
|
...shape.slice(0, -3) , shape[length - 3] ,
|
|
shape[length - 2] * shape[length - 1]
|
|
];
|
|
}
|
|
else if (!isChannelsLast && length === 1 && shape[0] > 1) {
|
|
return [shape[0], 1];
|
|
}
|
|
else {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function conv2dByMatMul({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
|
|
|
|
|
|
const xShape = x.shape;
|
|
const xTexData = backend.texData.get(x.dataId);
|
|
const sharedMatMulDim = convInfo.inChannels;
|
|
const outerShapeX = xShape[0] * xShape[1] * xShape[2];
|
|
const outerShapeFilter = convInfo.outChannels;
|
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
const transposeA = false;
|
|
const transposeB = false;
|
|
let out;
|
|
const intermediates = [];
|
|
if (preluActivationWeights != null) {
|
|
const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
|
|
if (targetShape != null) {
|
|
preluActivationWeights = reshape$1({
|
|
inputs: { x: preluActivationWeights },
|
|
backend,
|
|
attrs: { shape: targetShape }
|
|
});
|
|
intermediates.push(preluActivationWeights);
|
|
}
|
|
}
|
|
if (bias != null) {
|
|
const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
|
|
if (targetShape != null) {
|
|
bias = reshape$1({ inputs: { x: bias }, backend, attrs: { shape: targetShape } });
|
|
intermediates.push(bias);
|
|
}
|
|
}
|
|
|
|
|
|
const batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) &&
|
|
sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
|
|
|
|
|
|
|
|
|
|
const canOptimize = !batchMatMulWillBeUnpacked && xTexData.isPacked &&
|
|
isChannelsLast && xTexData.texture != null && xShape[2] % 2 !== 0 &&
|
|
arraysEqual(xTexData.shape.slice(-3), xShape.slice(-3));
|
|
if (canOptimize) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const targetShape = xShape[0] * xShape[1] * (xShape[2] + 1);
|
|
const xReshaped = {
|
|
dataId: x.dataId,
|
|
shape: [1, targetShape, convInfo.inChannels],
|
|
dtype: x.dtype
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const originalXTexDataShape = xTexData.shape;
|
|
xTexData.shape = xTexData.shape.slice();
|
|
xTexData.shape[xTexData.shape.length - 2]++;
|
|
assert$1(isReshapeFree(xTexData.shape, xReshaped.shape), () => `packed reshape ${xTexData.shape} to ${xReshaped.shape} isn't free`);
|
|
const filterReshaped = reshape$1({
|
|
inputs: { x: filter },
|
|
backend,
|
|
attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
|
|
});
|
|
intermediates.push(filterReshaped);
|
|
const pointwiseConv = batchMatMulImpl({
|
|
a: xReshaped,
|
|
b: filterReshaped,
|
|
backend,
|
|
transposeA,
|
|
transposeB,
|
|
bias,
|
|
activation,
|
|
preluActivationWeights,
|
|
leakyreluAlpha
|
|
});
|
|
const pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
|
|
assert$1(pointwiseConvTexData.isPacked, () => 'batchMatMul result is expected to be packed');
|
|
|
|
xTexData.shape = originalXTexDataShape;
|
|
|
|
|
|
pointwiseConvTexData.shape = convInfo.outShape;
|
|
out = identity({ inputs: { x: pointwiseConv }, backend });
|
|
out.shape = convInfo.outShape;
|
|
intermediates.push(pointwiseConv);
|
|
}
|
|
else {
|
|
const numCols = convInfo.outHeight * convInfo.outWidth;
|
|
const xReshaped = reshape$1({
|
|
inputs: { x },
|
|
backend,
|
|
attrs: {
|
|
shape: isChannelsLast ?
|
|
[convInfo.batchSize, numCols, convInfo.inChannels] :
|
|
[convInfo.batchSize, convInfo.inChannels, numCols]
|
|
}
|
|
});
|
|
const filterReshaped = reshape$1({
|
|
inputs: { x: filter },
|
|
backend,
|
|
attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
|
|
});
|
|
const result = batchMatMulImpl({
|
|
a: isChannelsLast ? xReshaped : filterReshaped,
|
|
b: isChannelsLast ? filterReshaped : xReshaped,
|
|
transposeA: !isChannelsLast,
|
|
transposeB,
|
|
backend,
|
|
bias,
|
|
activation,
|
|
preluActivationWeights,
|
|
leakyreluAlpha
|
|
});
|
|
out = reshape$1({ inputs: { x: result }, backend, attrs: { shape: convInfo.outShape } });
|
|
intermediates.push(xReshaped);
|
|
intermediates.push(filterReshaped);
|
|
intermediates.push(result);
|
|
}
|
|
for (const i of intermediates) {
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
|
|
function conv2dWithIm2Row({ x, filter, convInfo, backend, bias = null, preluActivationWeights = null, leakyreluAlpha = 0, activation = null }) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const { filterWidth, filterHeight, inChannels, outWidth, outHeight, dataFormat } = convInfo;
|
|
const isChannelsLast = dataFormat === 'channelsLast';
|
|
const sharedDim = filterWidth * filterHeight * inChannels;
|
|
const numCols = outHeight * outWidth;
|
|
const x2ColShape = [convInfo.batchSize, sharedDim, numCols];
|
|
const transposeA = true;
|
|
const transposeB = false;
|
|
const intermediates = [];
|
|
if (preluActivationWeights != null) {
|
|
const targetShape = getShapeForBatchMatMul(preluActivationWeights.shape, isChannelsLast);
|
|
if (targetShape != null) {
|
|
preluActivationWeights = reshape$1({
|
|
inputs: { x: preluActivationWeights },
|
|
backend,
|
|
attrs: { shape: targetShape }
|
|
});
|
|
intermediates.push(preluActivationWeights);
|
|
}
|
|
}
|
|
if (bias != null) {
|
|
const targetShape = getShapeForBatchMatMul(bias.shape, isChannelsLast);
|
|
if (targetShape != null) {
|
|
bias = reshape$1({ inputs: { x: bias }, backend, attrs: { shape: targetShape } });
|
|
intermediates.push(bias);
|
|
}
|
|
}
|
|
const w2Row = reshape$1({
|
|
inputs: { x: filter },
|
|
backend,
|
|
attrs: { shape: [1, sharedDim, sizeFromShape(filter.shape) / sharedDim] }
|
|
});
|
|
intermediates.push(w2Row);
|
|
const im2ColProgram = new Im2ColPackedProgram(x2ColShape, convInfo);
|
|
const customValues = [
|
|
x.shape, [convInfo.padInfo.top, convInfo.padInfo.left],
|
|
[convInfo.strideHeight, convInfo.strideWidth],
|
|
[convInfo.dilationHeight, convInfo.dilationWidth], [convInfo.inChannels],
|
|
[convInfo.filterWidth * convInfo.inChannels], [convInfo.outWidth]
|
|
];
|
|
const im2Col = backend.runWebGLProgram(im2ColProgram, [x], 'float32', customValues);
|
|
const im2ColReshaped = reshape$1({ inputs: { x: im2Col }, backend, attrs: { shape: x2ColShape } });
|
|
intermediates.push(im2Col);
|
|
intermediates.push(im2ColReshaped);
|
|
const hasBias = bias != null;
|
|
const hasPreluActivationWeights = preluActivationWeights != null;
|
|
const hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
|
|
const matmulProgram = new MatMulPackedProgram(isChannelsLast ? im2ColReshaped.shape :
|
|
w2Row.shape, isChannelsLast ? w2Row.shape :
|
|
im2ColReshaped.shape, isChannelsLast ? [convInfo.batchSize, numCols, convInfo.outChannels] :
|
|
[convInfo.batchSize, convInfo.outChannels, numCols], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
const inputs = isChannelsLast ? [im2ColReshaped, w2Row] : [w2Row, im2ColReshaped];
|
|
if (bias) {
|
|
inputs.push(bias);
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
inputs.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
|
|
inputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
const product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
|
|
const out = reshape$1({ inputs: { x: product }, backend, attrs: { shape: convInfo.outShape } });
|
|
intermediates.push(product);
|
|
for (const i of intermediates) {
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
|
|
function conv2d(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat);
|
|
let out;
|
|
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
|
|
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
|
|
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
|
|
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
|
|
out = conv2dByMatMul({ x, filter, convInfo, backend });
|
|
}
|
|
else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'
|
|
&& env().getBool('WEBGL_EXP_CONV')) {
|
|
const program = new Conv2DPackedProgram(convInfo);
|
|
const customValues = [
|
|
[convInfo.padInfo.top, convInfo.padInfo.left],
|
|
[convInfo.strideHeight, convInfo.strideWidth],
|
|
[convInfo.dilationHeight, convInfo.dilationWidth],
|
|
[convInfo.inHeight, convInfo.inWidth]
|
|
];
|
|
out =
|
|
backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
|
|
}
|
|
else if (env().getBool('WEBGL_CONV_IM2COL')) {
|
|
out = conv2dWithIm2Row({ x, filter, convInfo, backend });
|
|
}
|
|
else {
|
|
const program = new Conv2DProgram(convInfo);
|
|
out = backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
}
|
|
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
|
|
backend.disposeIntermediateTensorInfo(out);
|
|
return outReshaped;
|
|
}
|
|
const conv2DConfig$1 = {
|
|
kernelName: Conv2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv2d,
|
|
};
|
|
|
|
|
|
class Conv2DDerFilterProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['x', 'dy'];
|
|
this.outputShape = convInfo.filterShape;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int wR = coords.x;
|
|
int wC = coords.y;
|
|
int d1 = coords.z;
|
|
int d2 = coords.w;
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
|
|
for (int b = 0; b < ${convInfo.batchSize}; b++) {
|
|
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
|
|
int xR = wR + yR * ${strideHeight} - ${padTop};
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
|
|
int xC = wC + yC * ${strideWidth} - ${padLeft};
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
${isChannelsLast ?
|
|
`float dyValue = getDy(b, yR, yC, d2);
|
|
float xValue = getX(b, xR, xC, d1);
|
|
dotProd += (xValue * dyValue);` :
|
|
`float dyValue = getDy(b, d2, yR, yC);
|
|
float xValue = getX(b, d1, xR, xC);
|
|
dotProd += (xValue * dyValue);`}
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class Conv2DDerInputProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.outputShape = convInfo.inShape;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
const padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
const rowDim = isChannelsLast ? 1 : 2;
|
|
const colDim = isChannelsLast ? 2 : 3;
|
|
const channelDim = isChannelsLast ? 3 : 1;
|
|
this.userCode = `
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int d1 = coords[${channelDim}];
|
|
|
|
ivec2 dyCorner = ivec2(coords[${rowDim}], coords[${colDim}]) - pads;
|
|
int dyRCorner = dyCorner.x;
|
|
int dyCCorner = dyCorner.y;
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
int wRPerm = ${filterHeight} - 1 - wR;
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
int wCPerm = ${filterWidth} - 1 - wC;
|
|
|
|
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
|
|
|
|
if (${isChannelsLast}) {
|
|
float xValue = getDy(batch, idyR, idyC, d2);
|
|
float wValue = getW(wRPerm, wCPerm, d1, d2);
|
|
dotProd += xValue * wValue;
|
|
} else {
|
|
float xValue = getDy(batch, d2, idyR, idyC);
|
|
float wValue = getW(wRPerm, wCPerm, d1, d2);
|
|
dotProd += xValue * wValue;
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class Conv3DDerFilterProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['x', 'dy'];
|
|
this.outputShape = convInfo.filterShape;
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const padFront = convInfo.padInfo.front;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int wF = coords.x;
|
|
int wR = coords.y;
|
|
int wC = coords.z;
|
|
int d1 = coords.w;
|
|
int d2 = coords.u;
|
|
|
|
float dotProd = 0.0;
|
|
|
|
for (int b = 0; b < ${convInfo.batchSize}; b++) {
|
|
for (int yF = 0; yF < ${convInfo.outDepth}; yF++) {
|
|
int xF = wF + yF * ${strideDepth} - ${padFront};
|
|
|
|
if (xF < 0 || xF >= ${convInfo.inDepth}) {
|
|
continue;
|
|
}
|
|
|
|
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
|
|
int xR = wR + yR * ${strideHeight} - ${padTop};
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
|
|
int xC = wC + yC * ${strideWidth} - ${padLeft};
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
float dyValue = getDy(b, yF, yR, yC, d2);
|
|
float xValue = getX(b, xF, xR, xC, d1);
|
|
dotProd += (xValue * dyValue);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class Conv3DDerInputProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.outputShape = convInfo.inShape;
|
|
const filterDepth = convInfo.filterDepth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const padFront = filterDepth - 1 - convInfo.padInfo.front;
|
|
const padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
this.userCode = `
|
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int d1 = coords.u;
|
|
|
|
|
|
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
|
|
int dyFCorner = dyCorner.x;
|
|
int dyRCorner = dyCorner.y;
|
|
int dyCCorner = dyCorner.z;
|
|
|
|
float dotProd = 0.0;
|
|
for (int wF = 0; wF < ${filterDepth}; wF++) {
|
|
float dyF = float(dyFCorner + wF) / ${strideDepth}.0;
|
|
|
|
if (dyF < 0.0 || dyF >= ${convInfo.outDepth}.0 || fract(dyF) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyF = int(dyF);
|
|
|
|
int wFPerm = ${filterDepth} - 1 - wF;
|
|
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
|
|
fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
int wRPerm = ${filterHeight} - 1 - wR;
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
int wCPerm = ${filterWidth} - 1 - wC;
|
|
|
|
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2++) {
|
|
float xValue = getDy(batch, idyF, idyR, idyC, d2);
|
|
float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);
|
|
dotProd += xValue * wValue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function conv2DBackpropFilter$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, dy } = inputs;
|
|
const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 , pad, dimRoundingMode, false , $dataFormat);
|
|
const program = new Conv2DDerFilterProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, dy], 'float32');
|
|
}
|
|
const conv2DBackpropFilterConfig$1 = {
|
|
kernelName: Conv2DBackpropFilter,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv2DBackpropFilter$1,
|
|
};
|
|
|
|
|
|
class Conv2DDerInputPackedProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [
|
|
{ name: 'strides', type: 'vec2' },
|
|
];
|
|
this.outputShape = convInfo.inShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
this.userCode = `
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int d1 = coords[3];
|
|
|
|
ivec2 dyCorner = ivec2(coords[1], coords[2]) - pads;
|
|
int dyRCorner = dyCorner.x;
|
|
int dyCCorner = dyCorner.y;
|
|
|
|
vec4 result = vec4(0.);
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
float dyR = float(dyRCorner + wR) / strides[0];
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
int wRPerm = ${filterHeight} - 1 - wR;
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
int wCPerm = ${filterWidth} - 1 - wC;
|
|
|
|
float dyC = float(dyCCorner + wC) / strides[1];
|
|
bool idyCVal = (dyC >= 0.0) && (dyC < ${convInfo.outWidth}.0)
|
|
&& (fract(dyC) == 0.0);
|
|
int idyC = int(dyC);
|
|
|
|
float dyC2 = float(dyCCorner + wC + 1) / strides[1];
|
|
bool idyCVal2 = (dyC2 >= 0.0) && (dyC2 < ${convInfo.outWidth}.0)
|
|
&& (fract(dyC2) == 0.0);
|
|
int idyC2 = int(dyC2);
|
|
|
|
if (idyCVal && idyCVal2) {
|
|
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
|
|
vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
|
|
vec4 dySample = getDy(batch, idyR, idyC, d2);
|
|
vec4 dySample2 = (idyC / 2 == idyC2 / 2) ?
|
|
dySample : getDy(batch, idyR, idyC2, d2);
|
|
|
|
vec2 dyValue = mod(float(idyC), 2.) == 0. ?
|
|
dySample.xy : dySample.zw;
|
|
result.xy += vec2(dot(dyValue, wValue.xy),
|
|
dot(dyValue, wValue.zw));
|
|
|
|
dyValue = mod(float(idyC2), 2.) == 0. ?
|
|
dySample2.xy : dySample2.zw;
|
|
result.zw += vec2(dot(dyValue, wValue.xy),
|
|
dot(dyValue, wValue.zw));
|
|
}
|
|
} else if (idyCVal) {
|
|
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
|
|
vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
|
|
vec4 dySample = getDy(batch, idyR, idyC, d2);
|
|
vec2 dyValue = mod(float(idyC), 2.) == 0. ?
|
|
dySample.xy : dySample.zw;
|
|
result.xy += vec2(dot(dyValue, wValue.xy),
|
|
dot(dyValue, wValue.zw));
|
|
}
|
|
} else if (idyCVal2) {
|
|
for (int d2 = 0; d2 < ${convInfo.outChannels}; d2 += 2) {
|
|
vec4 wValue = getW(wRPerm, wCPerm, d1, d2);
|
|
vec4 dySample = getDy(batch, idyR, idyC2, d2);
|
|
vec2 dyValue = mod(float(idyC2), 2.) == 0. ?
|
|
dySample.xy : dySample.zw;
|
|
result.zw += vec2(dot(dyValue, wValue.xy),
|
|
dot(dyValue, wValue.zw));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function conv2DBackpropInput$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, filter } = inputs;
|
|
const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 , pad, dimRoundingMode, false, $dataFormat);
|
|
if (env().getBool('WEBGL_PACK_CONV2DTRANSPOSE') &&
|
|
$dataFormat === 'channelsLast') {
|
|
const customValues = [
|
|
[convInfo.strideHeight, convInfo.strideWidth],
|
|
];
|
|
const program = new Conv2DDerInputPackedProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32', customValues);
|
|
}
|
|
else {
|
|
const program = new Conv2DDerInputProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32');
|
|
}
|
|
}
|
|
const conv2DBackpropInputConfig$1 = {
|
|
kernelName: Conv2DBackpropInput,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv2DBackpropInput$1,
|
|
};
|
|
|
|
|
|
function conv3D$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dilations } = attrs;
|
|
const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
|
|
const program = new Conv3DProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
}
|
|
const conv3DConfig$1 = {
|
|
kernelName: Conv3D,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv3D$1,
|
|
};
|
|
|
|
|
|
function conv3DBackpropFilterV2$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, dy } = inputs;
|
|
const { strides, pad, filterShape } = attrs;
|
|
const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 , pad);
|
|
const program = new Conv3DDerFilterProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, dy], 'float32');
|
|
}
|
|
const conv3DBackpropFilterV2Config$1 = {
|
|
kernelName: Conv3DBackpropFilterV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv3DBackpropFilterV2$1
|
|
};
|
|
|
|
|
|
function conv3DBackpropInput(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, filter } = inputs;
|
|
const { pad, strides, inputShape } = attrs;
|
|
const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 , pad);
|
|
const program = new Conv3DDerInputProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32');
|
|
}
|
|
const conv3DBackpropInputConfig = {
|
|
kernelName: Conv3DBackpropInputV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv3DBackpropInput,
|
|
};
|
|
|
|
|
|
const COS = CHECK_NAN_SNIPPET_UNARY + `
|
|
return cos(x);
|
|
`;
|
|
const COS_PACKED = `
|
|
vec4 result = cos(x);
|
|
bvec4 isNaN = isnan(x);
|
|
${CHECK_NAN_SNIPPET_PACKED}
|
|
return result;
|
|
`;
|
|
const cos$1 = unaryKernelFunc({ opSnippet: COS, packedOpSnippet: COS_PACKED });
|
|
const cosConfig$1 = {
|
|
kernelName: Cos,
|
|
backendName: 'webgl',
|
|
kernelFunc: cos$1,
|
|
};
|
|
|
|
|
|
const COSH = `
|
|
float e2x = exp(-x);
|
|
return (e2x + 1.0 / e2x) / 2.0;
|
|
`;
|
|
const cosh$1 = unaryKernelFunc({ opSnippet: COSH });
|
|
const coshConfig$1 = {
|
|
kernelName: Cosh,
|
|
backendName: 'webgl',
|
|
kernelFunc: cosh$1,
|
|
};
|
|
|
|
|
|
class CropAndResizeProgram {
|
|
constructor(imageShape, boxShape, cropSize, method, extrapolationValue) {
|
|
this.variableNames = ['Image', 'Boxes', 'BoxInd'];
|
|
this.outputShape = [];
|
|
const [batch, imageHeight, imageWidth, depth] = imageShape;
|
|
const [numBoxes,] = boxShape;
|
|
const [cropHeight, cropWidth] = cropSize;
|
|
this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
|
|
const methodId = method === 'bilinear' ? 1 : 0;
|
|
const [inputHeightFloat, inputWidthFloat] = [`${imageHeight - 1}.0`, `${imageWidth - 1}.0`];
|
|
const [heightRatio, heightScale, inY] = cropHeight > 1 ?
|
|
[
|
|
`${(imageHeight - 1) / (cropHeight - 1)}`,
|
|
'(y2-y1) * height_ratio',
|
|
`y1*${inputHeightFloat} + float(y)*(height_scale)`,
|
|
] :
|
|
[
|
|
'0.0',
|
|
'0.0',
|
|
`0.5 * (y1+y2) * ${inputHeightFloat}`,
|
|
];
|
|
const [widthRatio, widthScale, inX] = cropWidth > 1 ?
|
|
[
|
|
`${(imageWidth - 1) / (cropWidth - 1)}`,
|
|
'(x2-x1) * width_ratio',
|
|
`x1*${inputWidthFloat} + float(x)*(width_scale)`,
|
|
] :
|
|
[
|
|
'0.0',
|
|
'0.0',
|
|
`0.5 * (x1+x2) * ${inputWidthFloat}`,
|
|
];
|
|
|
|
|
|
|
|
this.userCode = `
|
|
const float height_ratio = float(${heightRatio});
|
|
const float width_ratio = float(${widthRatio});
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int y = coords[1];
|
|
int x = coords[2];
|
|
int d = coords[3];
|
|
|
|
|
|
float y1 = getBoxes(b,0);
|
|
float x1 = getBoxes(b,1);
|
|
float y2 = getBoxes(b,2);
|
|
float x2 = getBoxes(b,3);
|
|
|
|
|
|
int bInd = round(getBoxInd(b));
|
|
if(bInd < 0 || bInd >= ${batch}) {
|
|
return;
|
|
}
|
|
|
|
float height_scale = ${heightScale};
|
|
float width_scale = ${widthScale};
|
|
|
|
float in_y = ${inY};
|
|
if( in_y < 0.0 || in_y > ${inputHeightFloat} ) {
|
|
setOutput(float(${extrapolationValue}));
|
|
return;
|
|
}
|
|
float in_x = ${inX};
|
|
if( in_x < 0.0 || in_x > ${inputWidthFloat} ) {
|
|
setOutput(float(${extrapolationValue}));
|
|
return;
|
|
}
|
|
|
|
vec2 sourceFracIndexCR = vec2(in_x,in_y);
|
|
if(${methodId} == 1) {
|
|
|
|
ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);
|
|
ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));
|
|
|
|
float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);
|
|
float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);
|
|
float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);
|
|
float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);
|
|
|
|
vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);
|
|
|
|
float top = topLeft + (topRight - topLeft) * fracCR.x;
|
|
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;
|
|
float newValue = top + (bottom - top) * fracCR.y;
|
|
setOutput(newValue);
|
|
} else {
|
|
|
|
ivec2 sourceNearestCR = ivec2(floor(
|
|
sourceFracIndexCR + vec2(0.5,0.5)));
|
|
float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);
|
|
setOutput(newValue);
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const cropAndResize$1 = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { image, boxes, boxInd } = inputs;
|
|
const { cropSize, method, extrapolationValue } = attrs;
|
|
const program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
|
|
return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
|
|
};
|
|
const cropAndResizeConfig$1 = {
|
|
kernelName: CropAndResize,
|
|
backendName: 'webgl',
|
|
kernelFunc: cropAndResize$1
|
|
};
|
|
|
|
var CumOpType;
|
|
(function (CumOpType) {
|
|
CumOpType["Prod"] = "*";
|
|
CumOpType["Sum"] = "+";
|
|
})(CumOpType || (CumOpType = {}));
|
|
class CumProgram {
|
|
constructor(op, outputShape, exclusive, reverse) {
|
|
this.op = op;
|
|
this.outputShape = outputShape;
|
|
this.variableNames = ['x'];
|
|
this.customUniforms = [{ name: 'index', type: 'float' }];
|
|
const rank = this.outputShape.length;
|
|
const initVal = this.op === CumOpType.Prod ? '1.0' : '0.0';
|
|
const val = exclusive ? initVal : `getX(${getCoords(rank, 'coords', this.op)})`;
|
|
const length = this.outputShape[this.outputShape.length - 1];
|
|
let condition = '';
|
|
let idxString = '';
|
|
|
|
|
|
|
|
if (exclusive) {
|
|
condition = reverse ? `end != ${length - 1}` : 'end != 0';
|
|
idxString = reverse ? 'end + 1' : 'end - 1';
|
|
}
|
|
else {
|
|
condition = reverse ? `end + pow2 < ${length}` : 'end >= pow2';
|
|
idxString = (reverse ? 'end + pow2' : 'end - pow2');
|
|
}
|
|
this.userCode = `
|
|
void main() {
|
|
${getCoordsDataType(rank)} coords = getOutputCoords();
|
|
int end = ${getFinalCoord(rank, 'coords', this.op)};
|
|
float val = ${val};
|
|
int pow2 = int(pow(2.0, index));
|
|
if (${condition}) {
|
|
int idx = ${idxString};
|
|
${getFinalCoord(rank, 'coords', this.op)} = idx;
|
|
val ${this.op}= getX(${getCoords(rank, 'coords', this.op)});
|
|
}
|
|
setOutput(val);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
function getCoords(rank, name, op) {
|
|
if (rank === 1) {
|
|
return `${name}`;
|
|
}
|
|
else if (rank === 2) {
|
|
return `${name}.x, ${name}.y`;
|
|
}
|
|
else if (rank === 3) {
|
|
return `${name}.x, ${name}.y, ${name}.z`;
|
|
}
|
|
else if (rank === 4) {
|
|
return `${name}.x, ${name}.y, ${name}.z, ${name}.w`;
|
|
}
|
|
else {
|
|
throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
|
|
}
|
|
}
|
|
function getFinalCoord(rank, name, op) {
|
|
if (rank === 1) {
|
|
return `${name}`;
|
|
}
|
|
else if (rank === 2) {
|
|
return `${name}.y`;
|
|
}
|
|
else if (rank === 3) {
|
|
return `${name}.z`;
|
|
}
|
|
else if (rank === 4) {
|
|
return `${name}.w`;
|
|
}
|
|
else {
|
|
throw new Error(`Cumulative ${op} for rank ${rank} is not yet supported`);
|
|
}
|
|
}
|
|
|
|
|
|
function cumImpl(op, x, backend, axis, exclusive, reverse) {
|
|
const xRank = x.shape.length;
|
|
const permutation = getAxesPermutation([axis], xRank);
|
|
let permutedX = x;
|
|
if (permutation != null) {
|
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
|
|
}
|
|
const permutedAxis = getInnerMostAxes(1, xRank)[0];
|
|
if (permutedAxis !== xRank - 1) {
|
|
throw new Error(`WebGL cumprod shader expects an inner-most axis=${x.shape.length - 1} ` +
|
|
`but got axis=${axis}`);
|
|
}
|
|
const size = permutedX.shape[permutedAxis];
|
|
let result = identity({ inputs: { x: permutedX }, backend });
|
|
|
|
|
|
|
|
|
|
for (let i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
|
|
const program = new CumProgram(op, permutedX.shape, false, reverse);
|
|
const customValues = [[i]];
|
|
const prevResult = result;
|
|
result =
|
|
backend.runWebGLProgram(program, [result], result.dtype, customValues);
|
|
backend.disposeIntermediateTensorInfo(prevResult);
|
|
}
|
|
|
|
|
|
if (exclusive) {
|
|
const program = new CumProgram(op, permutedX.shape, exclusive, reverse);
|
|
const prevResult = result;
|
|
result = backend.runWebGLProgram(program, [result], result.dtype);
|
|
backend.disposeIntermediateTensorInfo(prevResult);
|
|
}
|
|
if (permutation != null) {
|
|
const reversePermutation = getUndoAxesPermutation(permutation);
|
|
const reverseTransposedResult = transpose({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
return reverseTransposedResult;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
function cumprod$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, exclusive, reverse } = attrs;
|
|
return cumImpl(CumOpType.Prod, x, backend, axis, exclusive, reverse);
|
|
}
|
|
const cumprodConfig$1 = {
|
|
kernelName: Cumprod,
|
|
backendName: 'webgl',
|
|
kernelFunc: cumprod$1
|
|
};
|
|
|
|
|
|
function cumsum$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, exclusive, reverse } = attrs;
|
|
return cumImpl(CumOpType.Sum, x, backend, axis, exclusive, reverse);
|
|
}
|
|
const cumsumConfig$1 = {
|
|
kernelName: Cumsum,
|
|
backendName: 'webgl',
|
|
kernelFunc: cumsum$1
|
|
};
|
|
|
|
|
|
function denseBincount$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, weights } = inputs;
|
|
const { size, binaryOutput } = attrs;
|
|
if (x.shape.length === 1) {
|
|
const xVals = backend.readSync(x.dataId);
|
|
const weightsVals = backend.readSync(weights.dataId);
|
|
const outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
|
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
|
}
|
|
else if (x.shape.length === 2) {
|
|
const xBuf = backend.bufferSync(x);
|
|
const weightsBuf = backend.bufferSync(weights);
|
|
const outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
|
|
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
|
|
}
|
|
throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
|
|
`${x.shape.length}.`);
|
|
}
|
|
const denseBincountConfig$1 = {
|
|
kernelName: DenseBincount,
|
|
backendName: 'webgl',
|
|
kernelFunc: denseBincount$1
|
|
};
|
|
|
|
|
|
class DepthToSpaceProgram {
|
|
constructor(outputShape, blockSize, dataFormat) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = [];
|
|
this.outputShape = outputShape;
|
|
this.blockSize = blockSize;
|
|
this.dataFormat = dataFormat;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int h = ${this.getHeightCoordString()};
|
|
int w = ${this.getWidthCoordString()};
|
|
int d = ${this.getDepthCoordString()};
|
|
|
|
int in_h = h / ${blockSize};
|
|
int offset_h = imod(h, ${blockSize});
|
|
int in_w = w / ${blockSize};
|
|
int offset_w = imod(w, ${blockSize});
|
|
int offset_d = (offset_h * ${blockSize} + offset_w) *
|
|
${this.getOutputDepthSize()};
|
|
int in_d = d + offset_d;
|
|
|
|
float result = ${this.getInputSamplingString()};
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
getHeightCoordString() {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return `coords[1]`;
|
|
}
|
|
else {
|
|
return `coords[2]`;
|
|
}
|
|
}
|
|
getWidthCoordString() {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return `coords[2]`;
|
|
}
|
|
else {
|
|
return `coords[3]`;
|
|
}
|
|
}
|
|
getDepthCoordString() {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return `coords[3]`;
|
|
}
|
|
else {
|
|
return `coords[1]`;
|
|
}
|
|
}
|
|
getOutputDepthSize() {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return this.outputShape[3];
|
|
}
|
|
else {
|
|
return this.outputShape[1];
|
|
}
|
|
}
|
|
getInputSamplingString() {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return `getX(b, in_h, in_w, in_d)`;
|
|
}
|
|
else {
|
|
return `getX(b, in_d, in_h, in_w)`;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
function depthToSpace$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { blockSize, dataFormat } = attrs;
|
|
const batchSize = x.shape[0];
|
|
const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
|
|
const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
|
|
const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
|
|
const outputHeight = inputHeight * blockSize;
|
|
const outputWidth = inputWidth * blockSize;
|
|
const outputDepth = inputDepth / (blockSize * blockSize);
|
|
const outputShape = (dataFormat === 'NHWC') ?
|
|
[batchSize, outputHeight, outputWidth, outputDepth] :
|
|
[batchSize, outputDepth, outputHeight, outputWidth];
|
|
const program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
const depthToSpaceConfig$1 = {
|
|
kernelName: DepthToSpace,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthToSpace$1
|
|
};
|
|
|
|
|
|
class DepthwiseConv2DProgram {
|
|
constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.customUniforms = [
|
|
{ name: 'pads', type: 'ivec2' },
|
|
{ name: 'strides', type: 'ivec2' },
|
|
{ name: 'dilations', type: 'ivec2' },
|
|
{ name: 'inDims', type: 'ivec2' },
|
|
];
|
|
this.outputShape = convInfo.outShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
let activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = `float activation(float a) {
|
|
float b = getPreluActivationWeightsAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else if (hasLeakyReluAlpha) {
|
|
activationSnippet = `float activation(float a) {
|
|
float b = getLeakyreluAlphaAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else {
|
|
activationSnippet = `
|
|
float activation(float x) {
|
|
${activation}
|
|
}
|
|
`;
|
|
}
|
|
applyActivationSnippet = `result = activation(result);`;
|
|
}
|
|
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyReluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = `
|
|
${activationSnippet}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
|
int d2 = coords.w;
|
|
int d1 = d2 / ${channelMul};
|
|
int q = d2 - d1 * ${channelMul};
|
|
|
|
int xRCorner = xRCCorner.x;
|
|
int xCCorner = xRCCorner.y;
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
int xR = xRCorner + wR * dilations[0];
|
|
|
|
if (xR < 0 || xR >= inDims[0]) {
|
|
continue;
|
|
}
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
int xC = xCCorner + wC * dilations[1];
|
|
|
|
if (xC < 0 || xC >= inDims[1]) {
|
|
continue;
|
|
}
|
|
|
|
float xVal = getX(batch, xR, xC, d1);
|
|
float wVal = getW(wR, wC, d1, q);
|
|
dotProd += xVal * wVal;
|
|
}
|
|
}
|
|
|
|
float result = dotProd;
|
|
${addBiasSnippet}
|
|
${applyActivationSnippet}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class DepthwiseConvPacked2DProgram {
|
|
constructor(convInfo, addBias = false, activation = null, hasPreluActivation = false, hasLeakyReluAlpha = false) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [
|
|
{ name: 'pads', type: 'ivec2' },
|
|
{ name: 'strides', type: 'ivec2' },
|
|
{ name: 'dilations', type: 'ivec2' },
|
|
{ name: 'inDims', type: 'ivec2' },
|
|
];
|
|
this.outputShape = convInfo.outShape;
|
|
this.enableShapeUniforms = useShapeUniforms(this.outputShape.length);
|
|
const channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const texelsAcross = filterWidth;
|
|
let mainLoop = `
|
|
int xR; int xC; int xCOffset;
|
|
vec4 wTexel; vec4 previous; vec4 final;`;
|
|
for (let c = 0; c < filterWidth; c++) {
|
|
mainLoop += `
|
|
vec4 xTexelC${c * 2};
|
|
int xTexelC${c * 2}Ready;
|
|
vec4 xTexelC${c * 2 + 1};
|
|
int xTexelC${c * 2 + 1}Ready;
|
|
vec4 xC${c};`;
|
|
}
|
|
|
|
mainLoop += `
|
|
for (int r = 0; r < ${filterHeight}; r++) {
|
|
`;
|
|
for (let c = 0; c < filterWidth; c++) {
|
|
mainLoop += `
|
|
xTexelC${c * 2} = vec4(0.0);
|
|
xTexelC${c * 2}Ready = 0;
|
|
xTexelC${c * 2 + 1} = vec4(0.0);
|
|
xTexelC${c * 2 + 1}Ready = 0;
|
|
xC${c} = vec4(0.0);`;
|
|
}
|
|
mainLoop += `
|
|
xR = xRCorner + r * dilations[0];
|
|
if (xR >=0 && xR < inDims[0]) {
|
|
`;
|
|
for (let texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
|
|
const colIndex = texelC * 2;
|
|
mainLoop += `
|
|
xC = xCCorner + ${colIndex * dilationWidth};
|
|
`;
|
|
if (strideWidth === 1) {
|
|
if (colIndex < filterWidth) {
|
|
|
|
if (padLeft % 2 === 1) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mainLoop += `
|
|
xCOffset = xC + 1;
|
|
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
`;
|
|
|
|
|
|
if (dilationWidth === 1 && colIndex > 0) {
|
|
mainLoop += `
|
|
xC${colIndex} = vec4(xTexelC${colIndex - 2}.zw, xTexelC${colIndex}.xy);
|
|
`;
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
xCOffset = xC + 1 - 2;
|
|
|
|
if (xCOffset >= 0 && xCOffset < inDims[1]) {
|
|
previous = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
previous.zw = vec2(0.0);
|
|
}
|
|
|
|
xC${colIndex} = vec4(previous.zw, xTexelC${colIndex}.xy);
|
|
} else {
|
|
xC${colIndex} = vec4(0.0, 0.0, xTexelC${colIndex}.xy);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
|
|
mainLoop += `
|
|
if (xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xC, d1);
|
|
if (xC + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex} = xTexelC${colIndex};
|
|
`;
|
|
}
|
|
if (colIndex + 1 < filterWidth) {
|
|
|
|
|
|
|
|
|
|
|
|
const nextTexelOffset = padLeft % 2 === 0 ?
|
|
nearestLargerEven(dilationWidth) :
|
|
dilationWidth;
|
|
if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
|
|
(dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
|
|
mainLoop += `
|
|
xCOffset = xC + imod(pads[1], 2) + ${nextTexelOffset};
|
|
|
|
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
`;
|
|
|
|
|
|
if (dilationWidth > 1) {
|
|
mainLoop += `
|
|
xCOffset -= 2;
|
|
if (xCOffset >= 0 && xCOffset < inDims[1]) {
|
|
previous = getX(batch, xR, xCOffset, d1);
|
|
xC${colIndex + 1} = vec4(previous.zw, xTexelC${colIndex + 1}.xy);
|
|
} else {
|
|
xC${colIndex + 1} = vec4(0.0, 0.0, xTexelC${colIndex + 1}.xy);
|
|
}
|
|
`;
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.xy);
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
if (nextTexelOffset === 1) {
|
|
mainLoop += `
|
|
xC${colIndex + 1} = xTexelC${colIndex};
|
|
`;
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
xCOffset = xC + ${nextTexelOffset};
|
|
|
|
if (xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex + 1} = xTexelC${colIndex + 1};
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
if (colIndex < filterWidth) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (padLeft % 2 === 1) {
|
|
mainLoop += `
|
|
xCOffset = xC + 1 - strides[1];
|
|
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xCOffset, d1);
|
|
|
|
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
|
|
if(xC + 1 >= 0 && xC + 1 < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xC + 1, d1);
|
|
|
|
|
|
if (xC + 2 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
|
|
`;
|
|
if (colIndex + 1 < filterWidth) {
|
|
mainLoop += `
|
|
final = vec4(0.0);
|
|
xCOffset = xC + 1 + strides[1];
|
|
if(xCOffset >= 0 && xCOffset < inDims[1]) {
|
|
final = getX(batch, xR, xCOffset, d1);
|
|
}
|
|
xC${colIndex + 1} = vec4(xTexelC${colIndex + 1}.xy, final.xy);
|
|
`;
|
|
}
|
|
}
|
|
else {
|
|
mainLoop += `
|
|
if(xC >= 0 && xC < inDims[1] && xTexelC${colIndex}Ready == 0) {
|
|
xTexelC${colIndex} = getX(batch, xR, xC, d1);
|
|
if (xC + 1 >= inDims[1]) {
|
|
xTexelC${colIndex}.zw = vec2(0.0);
|
|
}
|
|
xTexelC${colIndex}Ready = 1;
|
|
}
|
|
|
|
xCOffset = xC + strides[1];
|
|
if(xCOffset >= 0 && xCOffset < inDims[1] && xTexelC${colIndex + 1}Ready == 0) {
|
|
xTexelC${colIndex + 1} = getX(batch, xR, xCOffset, d1);
|
|
if (xCOffset + 1 >= inDims[1]) {
|
|
xTexelC${colIndex + 1}.zw = vec2(0.);
|
|
}
|
|
xTexelC${colIndex + 1}Ready = 1;
|
|
}
|
|
|
|
xC${colIndex} = vec4(
|
|
xTexelC${colIndex}.xy, xTexelC${colIndex + 1}.xy);
|
|
`;
|
|
if (colIndex + 1 < filterWidth) {
|
|
mainLoop += `
|
|
xC${colIndex + 1} = vec4(xTexelC${colIndex}.zw, xTexelC${colIndex + 1}.zw);
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
if (colIndex < filterWidth) {
|
|
mainLoop += `
|
|
wTexel = getW(r, ${colIndex}, d1, q);
|
|
dotProd += xC${colIndex} * vec4(wTexel.xz, wTexel.xz);
|
|
`;
|
|
if (colIndex + 1 < filterWidth) {
|
|
mainLoop += `
|
|
wTexel = getW(r, ${colIndex + 1}, d1, q);
|
|
dotProd += xC${colIndex + 1} * vec4(wTexel.xz, wTexel.xz);
|
|
`;
|
|
}
|
|
}
|
|
}
|
|
mainLoop += `
|
|
}
|
|
`;
|
|
mainLoop += `
|
|
}
|
|
`;
|
|
let activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = `vec4 activation(vec4 a) {
|
|
vec4 b = getPreluActivationWeightsAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else if (hasLeakyReluAlpha) {
|
|
activationSnippet = `vec4 activation(vec4 a) {
|
|
vec4 b = getLeakyreluAlphaAtOutCoords();
|
|
${activation}
|
|
}`;
|
|
}
|
|
else {
|
|
activationSnippet = `vec4 activation(vec4 x) {
|
|
${activation}
|
|
}`;
|
|
}
|
|
applyActivationSnippet = `result = activation(result);`;
|
|
}
|
|
const addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyReluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = `
|
|
${activationSnippet}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
ivec2 xRCCorner = coords.yz * strides - pads;
|
|
int d2 = coords.w;
|
|
int d1 = d2 / ${channelMul};
|
|
int q = d2 - d1 * ${channelMul};
|
|
int xRCorner = xRCCorner.x;
|
|
int xCCorner = xRCCorner.y;
|
|
|
|
|
|
vec4 dotProd = vec4(0.000000000000001);
|
|
|
|
${mainLoop}
|
|
|
|
vec4 result = dotProd - vec4(0.000000000000001);
|
|
${addBiasSnippet}
|
|
${applyActivationSnippet}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function depthwiseConv2dNative$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dilations, dimRoundingMode } = attrs;
|
|
let $dilations = dilations;
|
|
if ($dilations == null) {
|
|
$dilations = [1, 1];
|
|
}
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
|
`1. Got strides ${strides} and dilations '${$dilations}'`);
|
|
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true );
|
|
let program;
|
|
if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
|
|
convInfo.outChannels / convInfo.inChannels === 1) {
|
|
program = new DepthwiseConvPacked2DProgram(convInfo);
|
|
}
|
|
else {
|
|
program = new DepthwiseConv2DProgram(convInfo);
|
|
}
|
|
const customValues = [
|
|
[convInfo.padInfo.top, convInfo.padInfo.left],
|
|
[convInfo.strideHeight, convInfo.strideWidth],
|
|
[convInfo.dilationHeight, convInfo.dilationWidth],
|
|
[convInfo.inHeight, convInfo.inWidth]
|
|
];
|
|
return backend.runWebGLProgram(program, [x, filter], 'float32', customValues);
|
|
}
|
|
const depthwiseConv2dNativeConfig$1 = {
|
|
kernelName: DepthwiseConv2dNative,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthwiseConv2dNative$1,
|
|
};
|
|
|
|
|
|
class DepthwiseConv2DDerFilterProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['x', 'dy'];
|
|
this.outputShape = convInfo.filterShape;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int wR = coords.x;
|
|
int wC = coords.y;
|
|
int d1 = coords.z;
|
|
int dm = coords.w;
|
|
int d2 = d1 * ${channelMul} + dm;
|
|
|
|
float dotProd = 0.0;
|
|
|
|
|
|
for (int b = 0; b < ${convInfo.batchSize}; b++) {
|
|
for (int yR = 0; yR < ${convInfo.outHeight}; yR++) {
|
|
int xR = wR + yR * ${strideHeight} - ${padTop};
|
|
|
|
if (xR < 0 || xR >= ${convInfo.inHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int yC = 0; yC < ${convInfo.outWidth}; yC++) {
|
|
int xC = wC + yC * ${strideWidth} - ${padLeft};
|
|
|
|
if (xC < 0 || xC >= ${convInfo.inWidth}) {
|
|
continue;
|
|
}
|
|
|
|
float dyValue = getDy(b, yR, yC, d2);
|
|
float xValue = getX(b, xR, xC, d1);
|
|
dotProd += (xValue * dyValue);
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class DepthwiseConv2DDerInputProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.outputShape = convInfo.inShape;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
const channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
this.userCode = `
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int d1 = coords[3];
|
|
ivec2 dyCorner = coords.yz - pads;
|
|
int dyRCorner = dyCorner.x;
|
|
int dyCCorner = dyCorner.y;
|
|
|
|
float dotProd = 0.0;
|
|
|
|
for (int wR = 0; wR < ${filterHeight}; wR++) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
int wRPerm = ${filterHeight} - 1 - wR;
|
|
|
|
for (int wC = 0; wC < ${filterWidth}; wC++) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
int wCPerm = ${filterWidth} - 1 - wC;
|
|
|
|
|
|
for (int dm = 0; dm < ${channelMul}; dm++) {
|
|
int d2 = d1 * ${channelMul} + dm;
|
|
float xValue = getDy(batch, idyR, idyC, d2);
|
|
float wValue = getW(wRPerm, wCPerm, d1, dm);
|
|
dotProd += xValue * wValue;
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function depthwiseConv2dNativeBackpropFilter$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, dy } = inputs;
|
|
const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
|
|
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true );
|
|
const program = new DepthwiseConv2DDerFilterProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, dy], 'float32');
|
|
}
|
|
const depthwiseConv2dNativeBackpropFilterConfig$1 = {
|
|
kernelName: DepthwiseConv2dNativeBackpropFilter,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthwiseConv2dNativeBackpropFilter$1
|
|
};
|
|
|
|
|
|
function depthwiseConv2dNativeBackpropInput$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, filter } = inputs;
|
|
const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
|
|
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true );
|
|
const program = new DepthwiseConv2DDerInputProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32');
|
|
}
|
|
const depthwiseConv2dNativeBackpropInputConfig$1 = {
|
|
kernelName: DepthwiseConv2dNativeBackpropInput,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthwiseConv2dNativeBackpropInput$1
|
|
};
|
|
|
|
|
|
class DiagProgram {
|
|
constructor(size) {
|
|
this.variableNames = ['X'];
|
|
this.outputShape = [size, size];
|
|
this.userCode = `
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
|
|
setOutput(val);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function diag$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
const outShape = [...x.shape, ...x.shape];
|
|
const xSize = sizeFromShape(x.shape);
|
|
const flat = reshape$1({ inputs: { x }, backend, attrs: { shape: [xSize] } });
|
|
const program = new DiagProgram(xSize);
|
|
const res = backend.runWebGLProgram(program, [flat], flat.dtype);
|
|
const out = reshape$1({ inputs: { x: res }, backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo(flat);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return out;
|
|
}
|
|
const diagConfig$1 = {
|
|
kernelName: Diag,
|
|
backendName: 'webgl',
|
|
kernelFunc: diag$1
|
|
};
|
|
|
|
|
|
class Dilation2DProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
const { inHeight, inWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth } = convInfo;
|
|
const { top: padTop, left: padLeft } = padInfo;
|
|
this.userCode = `
|
|
const ivec2 strides = ivec2(${strideHeight}, ${strideWidth});
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
const float neg_infinity = -3.4e38;
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int d1 = coords.w;
|
|
ivec2 outTopLeftCorner =
|
|
coords.yz * strides - pads;
|
|
int hBeg = outTopLeftCorner.x;
|
|
int wBeg = outTopLeftCorner.y;
|
|
|
|
float curVal = neg_infinity;
|
|
for (int h = 0; h < ${filterHeight}; h++) {
|
|
int hIn = hBeg + h * ${dilationHeight};
|
|
|
|
if (hIn >= 0 && hIn < ${inHeight}) {
|
|
for (int w = 0; w < ${filterWidth}; w++) {
|
|
int wIn = wBeg + w * ${dilationWidth};
|
|
|
|
if (wIn >= 0 && wIn < ${inWidth}) {
|
|
float xVal = getX(batch, hIn, wIn, d1);
|
|
float wVal = getW(h, w, d1);
|
|
|
|
float val = xVal + wVal;
|
|
if (val > curVal) {
|
|
curVal = val;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
float result = curVal;
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function dilation2D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dilations } = attrs;
|
|
const convInfo = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
|
|
let out;
|
|
const program = new Dilation2DProgram(convInfo);
|
|
out = backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
|
|
backend.disposeIntermediateTensorInfo(out);
|
|
return outReshaped;
|
|
}
|
|
const dilation2DConfig$1 = {
|
|
kernelName: Dilation2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: dilation2D,
|
|
};
|
|
|
|
|
|
function einsum$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { equation } = attrs;
|
|
const tensors = inputs;
|
|
const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
|
|
checkEinsumDimSizes(allDims.length, idDims, tensors);
|
|
const { path, steps } = getEinsumComputePath(summedDims, idDims);
|
|
const nSteps = steps.length;
|
|
let out = null;
|
|
let numDimsRemaining = allDims.length;
|
|
const tensorsToDispose = [];
|
|
for (let i = 0; i < nSteps; ++i) {
|
|
for (const idTerm of steps[i]) {
|
|
const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
|
|
let x;
|
|
if (isIdentityPermutation(perm)) {
|
|
x = tensors[idTerm];
|
|
}
|
|
else {
|
|
x = transpose({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
|
|
tensorsToDispose.push(x);
|
|
}
|
|
const targetShape = x.shape.slice();
|
|
for (let k = 0; k < dimsToExpand.length; ++k) {
|
|
targetShape.splice(dimsToExpand[k], 0, 1);
|
|
}
|
|
if (!arraysEqual(x.shape, targetShape)) {
|
|
x = reshape$1({ inputs: { x }, backend, attrs: { shape: targetShape } });
|
|
tensorsToDispose.push(x);
|
|
}
|
|
if (out === null) {
|
|
out = x;
|
|
}
|
|
else {
|
|
|
|
out = multiply({ inputs: { a: x, b: out }, backend });
|
|
tensorsToDispose.push(out);
|
|
}
|
|
}
|
|
if (i < nSteps - 1) {
|
|
if (path[i] >= 0) {
|
|
out = sum$1({
|
|
inputs: { x: out },
|
|
backend,
|
|
attrs: {
|
|
axis: path[i] - (allDims.length - numDimsRemaining),
|
|
keepDims: false
|
|
}
|
|
});
|
|
tensorsToDispose.push(out);
|
|
}
|
|
numDimsRemaining--;
|
|
}
|
|
}
|
|
|
|
for (const tensorInfo of tensorsToDispose) {
|
|
if (tensorInfo === out) {
|
|
continue;
|
|
}
|
|
backend.disposeIntermediateTensorInfo(tensorInfo);
|
|
}
|
|
return out;
|
|
}
|
|
const einsumConfig$1 = {
|
|
kernelName: Einsum,
|
|
backendName: 'webgl',
|
|
kernelFunc: einsum$1
|
|
};
|
|
|
|
|
|
const ELU = `return (x >= 0.0) ? x : (exp(x) - 1.0);`;
|
|
const ELU_PACKED = `
|
|
vec4 result;
|
|
|
|
result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);
|
|
result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);
|
|
result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);
|
|
result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);
|
|
|
|
return result;
|
|
`;
|
|
const elu$2 = unaryKernelFunc({ opSnippet: ELU, packedOpSnippet: ELU_PACKED });
|
|
const eluConfig$1 = {
|
|
kernelName: Elu$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: elu$2
|
|
};
|
|
|
|
|
|
const ELU_DER = `return (b >= 0.0) ? a : a * (b + 1.0);`;
|
|
const ELU_DER_PACKED = `
|
|
vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));
|
|
return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));
|
|
`;
|
|
const eluGrad$1 = (args) => {
|
|
const { inputs, backend } = args;
|
|
const { dy, y } = inputs;
|
|
const program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
|
|
new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) :
|
|
new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
|
|
return backend.runWebGLProgram(program, [dy, y], dy.dtype);
|
|
};
|
|
const eluGradConfig$2 = {
|
|
kernelName: EluGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: eluGrad$1
|
|
};
|
|
|
|
|
|
const PACKED_EQUAL = `
|
|
return vec4(equal(a, b));
|
|
`;
|
|
const EQUAL = `return float(a == b);`;
|
|
const equal = binaryKernelFunc({
|
|
opSnippet: EQUAL,
|
|
packedOpSnippet: PACKED_EQUAL,
|
|
dtype: 'bool',
|
|
cpuKernelImpl: equalImplCPU,
|
|
});
|
|
const equalConfig = {
|
|
kernelName: Equal,
|
|
backendName: 'webgl',
|
|
kernelFunc: equal
|
|
};
|
|
|
|
|
|
const ERF = `
|
|
|
|
|
|
|
|
float p = ${ERF_P};
|
|
float a1 = ${ERF_A1};
|
|
float a2 = ${ERF_A2};
|
|
float a3 = ${ERF_A3};
|
|
float a4 = ${ERF_A4};
|
|
float a5 = ${ERF_A5};
|
|
|
|
float sign = sign(x);
|
|
x = abs(x);
|
|
float t = 1.0 / (1.0 + p * x);
|
|
return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));
|
|
`;
|
|
const erf$1 = unaryKernelFunc({ opSnippet: ERF });
|
|
const erfConfig$1 = {
|
|
kernelName: Erf,
|
|
backendName: 'webgl',
|
|
kernelFunc: erf$1,
|
|
};
|
|
|
|
|
|
const EXP = CHECK_NAN_SNIPPET_UNARY + `
|
|
return exp(x);
|
|
`;
|
|
const EXP_PACKED = `
|
|
vec4 result = exp(x);
|
|
bvec4 isNaN = isnan(x);
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
const exp = unaryKernelFunc({
|
|
opSnippet: EXP,
|
|
packedOpSnippet: EXP_PACKED,
|
|
cpuKernelImpl: expImplCPU,
|
|
dtype: 'float32',
|
|
});
|
|
const expConfig = {
|
|
kernelName: Exp,
|
|
backendName: 'webgl',
|
|
kernelFunc: exp
|
|
};
|
|
|
|
|
|
function expandDims$2(args) {
|
|
const { inputs, attrs, backend } = args;
|
|
const { dim } = attrs;
|
|
const { input } = inputs;
|
|
const inputRank = input.shape.length;
|
|
const newShape = input.shape.slice();
|
|
let $dim = dim;
|
|
if (dim < 0) {
|
|
|
|
assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
|
|
$dim = inputRank + dim + 1;
|
|
}
|
|
newShape.splice($dim, 0, 1);
|
|
return reshape$1({ inputs: { x: input }, backend, attrs: { shape: newShape } });
|
|
}
|
|
const expandDimsConfig$1 = {
|
|
kernelName: ExpandDims,
|
|
backendName: 'webgl',
|
|
kernelFunc: expandDims$2,
|
|
};
|
|
|
|
|
|
const EXPM1 = `return exp(x) - 1.0;`;
|
|
const expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU });
|
|
const expm1Config = {
|
|
kernelName: Expm1,
|
|
backendName: 'webgl',
|
|
kernelFunc: expm1
|
|
};
|
|
|
|
|
|
class FFTProgram {
|
|
constructor(component, inputShape, inverse) {
|
|
this.variableNames = ['real', 'imag'];
|
|
const innerDim = inputShape[1];
|
|
this.outputShape = inputShape;
|
|
const exponentMultiplierSnippet = inverse ? `2.0 * ${Math.PI}` : `-2.0 * ${Math.PI}`;
|
|
const resultDenominator = inverse ? `${innerDim}.0` : '1.0';
|
|
let opString;
|
|
if (component === 'real') {
|
|
opString = 'return real * expR - imag * expI;';
|
|
}
|
|
else if (component === 'imag') {
|
|
opString = 'return real * expI + imag * expR;';
|
|
}
|
|
else {
|
|
throw new Error(`FFT component must be either "real" or "imag", got ${component}.`);
|
|
}
|
|
this.userCode = `
|
|
const float exponentMultiplier = ${exponentMultiplierSnippet};
|
|
|
|
float unaryOpComplex(float real, float expR, float imag, float expI) {
|
|
${opString}
|
|
}
|
|
|
|
float mulMatDFT(int batch, int index) {
|
|
float indexRatio = float(index) / float(${innerDim});
|
|
float exponentMultiplierTimesIndexRatio =
|
|
exponentMultiplier * indexRatio;
|
|
|
|
float result = 0.0;
|
|
|
|
for (int i = 0; i < ${innerDim}; i++) {
|
|
|
|
float x = exponentMultiplierTimesIndexRatio * float(i);
|
|
float expR = cos(x);
|
|
float expI = sin(x);
|
|
float real = getReal(batch, i);
|
|
float imag = getImag(batch, i);
|
|
|
|
result +=
|
|
unaryOpComplex(real, expR, imag, expI) / ${resultDenominator};
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
setOutput(mulMatDFT(coords[0], coords[1]));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function fftImpl$1(x, inverse, backend) {
|
|
const xData = backend.texData.get(x.dataId);
|
|
const inputSize = sizeFromShape(x.shape);
|
|
|
|
const innerDimensionSize = x.shape[x.shape.length - 1];
|
|
const batch = inputSize / innerDimensionSize;
|
|
const input2D = reshape$1({ inputs: { x }, backend, attrs: { shape: [batch, innerDimensionSize] } });
|
|
const xShape = input2D.shape;
|
|
const realProgram = new FFTProgram('real', xShape, inverse);
|
|
const imagProgram = new FFTProgram('imag', xShape, inverse);
|
|
const inputs = [
|
|
{
|
|
dataId: xData.complexTensorInfos.real.dataId,
|
|
dtype: xData.complexTensorInfos.real.dtype,
|
|
shape: xShape
|
|
},
|
|
{
|
|
dataId: xData.complexTensorInfos.imag.dataId,
|
|
dtype: xData.complexTensorInfos.imag.dtype,
|
|
shape: xShape
|
|
}
|
|
];
|
|
const realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
|
|
const imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
|
|
const complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
const complexOutputReshaped = reshape$1({ inputs: { x: complexOutput }, backend, attrs: { shape: x.shape } });
|
|
backend.disposeIntermediateTensorInfo(input2D);
|
|
backend.disposeIntermediateTensorInfo(complexOutput);
|
|
return complexOutputReshaped;
|
|
}
|
|
|
|
|
|
function fft$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
return fftImpl$1(input, false , backend);
|
|
}
|
|
const fftConfig$1 = {
|
|
kernelName: FFT,
|
|
backendName: 'webgl',
|
|
kernelFunc: fft$1
|
|
};
|
|
|
|
|
|
class FillProgram {
|
|
constructor(shape, value) {
|
|
this.outputShape = [];
|
|
this.customUniforms = [{ name: 'value', type: 'float' }];
|
|
this.variableNames = ['x'];
|
|
this.outputShape = shape;
|
|
this.userCode = `
|
|
void main() {
|
|
|
|
setOutput(value);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function fill$1(args) {
|
|
const { backend, attrs } = args;
|
|
const { shape, value } = attrs;
|
|
let { dtype } = attrs;
|
|
dtype = dtype || inferDtype(value);
|
|
if (dtype === 'string') {
|
|
|
|
const values = getArrayFromDType(dtype, sizeFromShape(shape));
|
|
values.fill(value);
|
|
return backend.makeTensorInfo(shape, dtype, values);
|
|
}
|
|
else {
|
|
const program = new FillProgram(shape, value);
|
|
const customValues = [[value]];
|
|
return backend.runWebGLProgram(program, [], dtype, customValues);
|
|
}
|
|
}
|
|
const fillConfig$1 = {
|
|
kernelName: Fill,
|
|
backendName: 'webgl',
|
|
kernelFunc: fill$1
|
|
};
|
|
|
|
|
|
class FlipLeftRightProgram {
|
|
constructor(imageShape) {
|
|
this.variableNames = ['Image'];
|
|
this.outputShape = [];
|
|
const imageWidth = imageShape[2];
|
|
this.outputShape = imageShape;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int x = coords[2];
|
|
|
|
int coordX = ${imageWidth} - x - 1;
|
|
float outputValue;
|
|
if(coordX >= 0 && coordX < ${imageWidth}) {
|
|
outputValue = getImage(coords[0], coords[1], coordX, coords[3]);
|
|
} else {
|
|
outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);
|
|
}
|
|
setOutput(outputValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const flipLeftRightConfig$1 = {
|
|
kernelName: FlipLeftRight,
|
|
backendName: 'webgl',
|
|
kernelFunc: ({ inputs, backend }) => {
|
|
const { image } = inputs;
|
|
const webglBackend = backend;
|
|
const program = new FlipLeftRightProgram(image.shape);
|
|
const output = webglBackend.runWebGLProgram(program, [image], image.dtype);
|
|
return output;
|
|
}
|
|
};
|
|
|
|
|
|
const FLOOR = `return floor(x);`;
|
|
const floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU });
|
|
const floorConfig = {
|
|
kernelName: Floor,
|
|
backendName: 'webgl',
|
|
kernelFunc: floor,
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const INT_DIV = `
|
|
float s = sign(a) * sign(b);
|
|
int ia = round(a);
|
|
int ib = round(b);
|
|
if (ib != 0) {
|
|
|
|
return float(idiv(ia, ib, s));
|
|
} else {
|
|
return NAN;
|
|
}
|
|
`;
|
|
const INT_DIV_PACKED = `
|
|
ivec4 ia = round(a);
|
|
ivec4 ib = round(b);
|
|
bvec4 cond = notEqual(ib, ivec4(0));
|
|
ivec4 result = ivec4(0);
|
|
vec4 s = sign(a) * sign(b);
|
|
|
|
|
|
if (cond[0]) {
|
|
result[0] = idiv(ia[0], ib[0], s[0]);
|
|
}
|
|
if (cond[1]) {
|
|
result[1] = idiv(ia[1], ib[1], s[1]);
|
|
}
|
|
if (cond[2]) {
|
|
result[2] = idiv(ia[2], ib[2], s[2]);
|
|
}
|
|
if (cond[3]) {
|
|
result[3] = idiv(ia[3], ib[3], s[3]);
|
|
}
|
|
return vec4(result);
|
|
`;
|
|
const floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' });
|
|
const floorDivConfig = {
|
|
kernelName: FloorDiv,
|
|
backendName: 'webgl',
|
|
kernelFunc: floorDiv
|
|
};
|
|
|
|
|
|
class FromPixelsProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
const glsl = getGlslDifferences();
|
|
const [height, width,] = outputShape;
|
|
this.outputShape = outputShape;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec3 coords = getOutputCoords();
|
|
int texR = coords[0];
|
|
int texC = coords[1];
|
|
int depth = coords[2];
|
|
vec2 uv = (vec2(texC, texR) + halfCR) / vec2(${width}.0, ${height}.0);
|
|
|
|
vec4 values = ${glsl.texture2D}(A, uv);
|
|
float value;
|
|
if (depth == 0) {
|
|
value = values.r;
|
|
} else if (depth == 1) {
|
|
value = values.g;
|
|
} else if (depth == 2) {
|
|
value = values.b;
|
|
} else if (depth == 3) {
|
|
value = values.a;
|
|
}
|
|
|
|
setOutput(floor(value * 255.0 + 0.5));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class FromPixelsPackedProgram {
|
|
constructor(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
const glsl = getGlslDifferences();
|
|
const [height, width,] = outputShape;
|
|
this.outputShape = outputShape;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec3 coords = getOutputCoords();
|
|
int texR = coords[0];
|
|
int texC = coords[1];
|
|
int depth = coords[2];
|
|
|
|
vec4 result = vec4(0.);
|
|
|
|
for(int row=0; row<=1; row++) {
|
|
for(int col=0; col<=1; col++) {
|
|
texC = coords[1] + row;
|
|
depth = coords[2] + col;
|
|
|
|
vec2 uv = (vec2(texC, texR) + halfCR) /
|
|
vec2(${width}.0, ${height}.0);
|
|
vec4 values = ${glsl.texture2D}(A, uv);
|
|
float value;
|
|
if (depth == 0) {
|
|
value = values.r;
|
|
} else if (depth == 1) {
|
|
value = values.g;
|
|
} else if (depth == 2) {
|
|
value = values.b;
|
|
} else if (depth == 3) {
|
|
value = values.a;
|
|
}
|
|
|
|
result[row * 2 + col] = floor(value * 255.0 + 0.5);
|
|
}
|
|
}
|
|
|
|
${glsl.output} = result;
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const fromPixelsConfig = {
|
|
kernelName: FromPixels,
|
|
backendName: 'webgl',
|
|
kernelFunc: fromPixels,
|
|
};
|
|
let fromPixels2DContext;
|
|
let willReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
|
|
function fromPixels(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
let { pixels } = inputs;
|
|
const { numChannels } = attrs;
|
|
const isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
|
|
pixels instanceof HTMLVideoElement;
|
|
const isImage = typeof (HTMLImageElement) !== 'undefined' &&
|
|
pixels instanceof HTMLImageElement;
|
|
const [width, height] = isVideo ?
|
|
[
|
|
pixels.videoWidth,
|
|
pixels.videoHeight
|
|
] :
|
|
[pixels.width, pixels.height];
|
|
const texShape = [height, width];
|
|
const outShape = [height, width, numChannels];
|
|
if (isImage || isVideo) {
|
|
const newWillReadFrequently = env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
|
|
if (fromPixels2DContext == null ||
|
|
newWillReadFrequently !== willReadFrequently) {
|
|
willReadFrequently = newWillReadFrequently;
|
|
fromPixels2DContext =
|
|
document.createElement('canvas').getContext('2d', { willReadFrequently });
|
|
}
|
|
fromPixels2DContext.canvas.width = width;
|
|
fromPixels2DContext.canvas.height = height;
|
|
fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
|
|
pixels = fromPixels2DContext.canvas;
|
|
}
|
|
const tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
|
|
|
|
backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
|
|
backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
|
|
const program = env().getBool('WEBGL_PACK') ?
|
|
new FromPixelsPackedProgram(outShape) :
|
|
new FromPixelsProgram(outShape);
|
|
const res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
|
|
backend.disposeData(tempPixelHandle.dataId);
|
|
return res;
|
|
}
|
|
|
|
|
|
function fusedConv2d(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat);
|
|
let out;
|
|
const intermediates = [];
|
|
const hasBias = bias != null;
|
|
const hasPreluActivationWeights = preluActivationWeights != null;
|
|
const hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
const prepareInputs = () => {
|
|
const inputs = [x, filter];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const alignInputWithDataFormat = (input, dataFormat) => {
|
|
if (dataFormat === 'NCHW' && input.shape.length === 1 &&
|
|
input.shape[0] !== 1) {
|
|
const alignedInput = reshape$1({
|
|
inputs: { x: input },
|
|
backend,
|
|
attrs: { shape: [input.shape[0], 1, 1] }
|
|
});
|
|
intermediates.push(alignedInput);
|
|
return alignedInput;
|
|
}
|
|
return input;
|
|
};
|
|
if (hasBias) {
|
|
inputs.push(alignInputWithDataFormat(bias, dataFormat));
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
inputs.push(alignInputWithDataFormat(preluActivationWeights, dataFormat));
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
|
|
inputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
return inputs;
|
|
};
|
|
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
|
|
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
|
|
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
|
|
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
|
|
out = conv2dByMatMul({
|
|
x,
|
|
filter,
|
|
convInfo,
|
|
backend,
|
|
bias,
|
|
activation,
|
|
preluActivationWeights,
|
|
leakyreluAlpha
|
|
});
|
|
}
|
|
else if (convInfo.strideWidth <= 2 && $dataFormat === 'channelsLast'
|
|
&& env().getBool('WEBGL_EXP_CONV')) {
|
|
const fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
|
|
const program = new Conv2DPackedProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
const customValues = [
|
|
[convInfo.padInfo.top, convInfo.padInfo.left],
|
|
[convInfo.strideHeight, convInfo.strideWidth],
|
|
[convInfo.dilationHeight, convInfo.dilationWidth],
|
|
[convInfo.inHeight, convInfo.inWidth]
|
|
];
|
|
const inputs = prepareInputs();
|
|
out = backend.runWebGLProgram(program, inputs, 'float32', customValues);
|
|
}
|
|
else if (env().getBool('WEBGL_CONV_IM2COL')) {
|
|
out = conv2dWithIm2Row({
|
|
x,
|
|
filter,
|
|
convInfo,
|
|
backend,
|
|
bias,
|
|
activation,
|
|
preluActivationWeights,
|
|
leakyreluAlpha
|
|
});
|
|
}
|
|
else {
|
|
const fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
|
|
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
const inputs = prepareInputs();
|
|
out = backend.runWebGLProgram(program, inputs, 'float32');
|
|
}
|
|
const outReshaped = reshape$1({ inputs: { x: out }, backend, attrs: { shape: convInfo.outShape } });
|
|
intermediates.push(out);
|
|
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return outReshaped;
|
|
}
|
|
const fusedConv2DConfig$1 = {
|
|
kernelName: FusedConv2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: fusedConv2d,
|
|
};
|
|
|
|
|
|
function fusedDepthwiseConv2D$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
|
const { strides, pad, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
|
const intermediates = [];
|
|
let $dilations = dilations;
|
|
if ($dilations == null) {
|
|
$dilations = [1, 1];
|
|
}
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
|
`1. Got strides ${strides} and dilations '${$dilations}'`);
|
|
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true );
|
|
const shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
|
|
convInfo.strideWidth <= 2 &&
|
|
convInfo.outChannels / convInfo.inChannels === 1;
|
|
const fusedActivation = activation ?
|
|
mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
|
|
null;
|
|
const programInputs = [x, filter];
|
|
const hasBias = bias != null;
|
|
const hasPreluActivationWeights = preluActivationWeights != null;
|
|
const hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
if (hasBias) {
|
|
programInputs.push(bias);
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
programInputs.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
const $leakyreluAlpha = backend.makeTensorInfo([], 'float32', createScalarValue(leakyreluAlpha, 'float32'));
|
|
programInputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
let program;
|
|
if (shouldPackDepthwiseConv) {
|
|
program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
}
|
|
else {
|
|
program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
}
|
|
const customValues = [
|
|
[convInfo.padInfo.top, convInfo.padInfo.left],
|
|
[convInfo.strideHeight, convInfo.strideWidth],
|
|
[convInfo.dilationHeight, convInfo.dilationWidth],
|
|
[convInfo.inHeight, convInfo.inWidth]
|
|
];
|
|
const result = backend.runWebGLProgram(program, programInputs, 'float32', customValues);
|
|
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
}
|
|
const fusedDepthwiseConv2DConfig$1 = {
|
|
kernelName: FusedDepthwiseConv2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: fusedDepthwiseConv2D$1,
|
|
};
|
|
|
|
class GatherNDProgram {
|
|
constructor(sliceDim, strides, shape, paramsShape) {
|
|
this.sliceDim = sliceDim;
|
|
this.strides = strides;
|
|
this.paramsShape = paramsShape;
|
|
this.variableNames = ['x', 'indices'];
|
|
this.outputShape = shape;
|
|
const dtype = getCoordsDataType(shape.length);
|
|
let mainLoop = `
|
|
int index;`;
|
|
for (let j = 0; j < this.sliceDim; j++) {
|
|
mainLoop += `
|
|
index = round(getIndices(coords[0], ${j}));
|
|
out_of_bounds = out_of_bounds || index < 0;
|
|
out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]};
|
|
flattenIndex += index * ${this.strides[j]};`;
|
|
}
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
int flattenIndex = 0;
|
|
bool out_of_bounds = false;
|
|
|
|
${mainLoop}
|
|
|
|
setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function gatherNd$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { params, indices } = inputs;
|
|
const indicesShape = indices.shape;
|
|
const sliceRank = indicesShape[indicesShape.length - 1];
|
|
const paramsSize = sizeFromShape(params.shape);
|
|
const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
|
|
const flattenIndices = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [numSlices, sliceRank] } });
|
|
const flattenX = reshape$1({
|
|
inputs: { x: params },
|
|
backend,
|
|
attrs: { shape: [(sizeFromShape(params.shape) / sliceSize), sliceSize] }
|
|
});
|
|
if (backend.shouldExecuteOnCPU([params, indices]) ||
|
|
params.dtype === 'string') {
|
|
const indicesData = backend.readSync(indices.dataId);
|
|
const paramsBuf = backend.bufferSync(params);
|
|
const outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
|
|
return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
|
|
}
|
|
const program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize], params.shape);
|
|
const res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
|
|
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: resultShape } });
|
|
backend.disposeIntermediateTensorInfo(flattenIndices);
|
|
backend.disposeIntermediateTensorInfo(flattenX);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return reshaped;
|
|
}
|
|
const gatherNdConfig$1 = {
|
|
kernelName: GatherNd,
|
|
backendName: 'webgl',
|
|
kernelFunc: gatherNd$1
|
|
};
|
|
|
|
|
|
class GatherProgram {
|
|
constructor(aShape, outputShape) {
|
|
this.variableNames = ['A', 'indices'];
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
const dtype = getCoordsDataType(this.rank);
|
|
const sourceCoords = getSourceCoords$1(aShape);
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} resRC = getOutputCoords();
|
|
int index = int(getIndices(resRC.x, resRC.z));
|
|
float inBounds = (index >= 0) && (index < ${aShape[2]}) ? 1.0 : 0.0;
|
|
setOutput(inBounds * getA(${sourceCoords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
function getSourceCoords$1(aShape, axis) {
|
|
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
|
|
const sourceCoords = [];
|
|
for (let i = 0; i < aShape.length; i++) {
|
|
if (i === 2) {
|
|
sourceCoords.push('index');
|
|
}
|
|
else {
|
|
sourceCoords.push(`${currentCoords[i]}`);
|
|
}
|
|
}
|
|
return sourceCoords.join();
|
|
}
|
|
|
|
|
|
function gatherV2$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, indices } = inputs;
|
|
const { axis, batchDims } = attrs;
|
|
const parsedAxis = parseAxisParam(axis, x.shape)[0];
|
|
if (env().get('DEBUG')) {
|
|
|
|
|
|
const indicesVals = backend.readSync(indices.dataId);
|
|
const axisDim = x.shape[parsedAxis];
|
|
for (let i = 0; i < indicesVals.length; ++i) {
|
|
const index = indicesVals[i];
|
|
assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
|
|
}
|
|
}
|
|
const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
|
|
const indicesSize = sizeFromShape(indices.shape);
|
|
const toDispose = [];
|
|
const flattenX = reshape$1({
|
|
inputs: { x },
|
|
backend,
|
|
attrs: {
|
|
shape: [
|
|
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
|
|
shapeInfo.sliceSize
|
|
]
|
|
}
|
|
});
|
|
const flattenIndex = reshape$1({
|
|
inputs: { x: indices },
|
|
backend,
|
|
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
|
|
});
|
|
toDispose.push(flattenX);
|
|
toDispose.push(flattenIndex);
|
|
const flattenOutputShape = [
|
|
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
|
|
shapeInfo.sliceSize
|
|
];
|
|
if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
|
|
const indicesBuf = backend.bufferSync(flattenIndex);
|
|
const xBuf = backend.bufferSync(flattenX);
|
|
const outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const program = new GatherProgram(flattenX.shape, flattenOutputShape);
|
|
const res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
|
|
toDispose.push(res);
|
|
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: shapeInfo.outputShape } });
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return reshaped;
|
|
}
|
|
const gatherV2Config$1 = {
|
|
kernelName: GatherV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: gatherV2$1
|
|
};
|
|
|
|
|
|
const GREATER = `return float(a > b);`;
|
|
const GREATER_PACKED = `
|
|
return vec4(greaterThan(a, b));
|
|
`;
|
|
const greater = binaryKernelFunc({
|
|
opSnippet: GREATER,
|
|
packedOpSnippet: GREATER_PACKED,
|
|
cpuKernelImpl: greaterImplCPU,
|
|
dtype: 'bool'
|
|
});
|
|
const greaterConfig = {
|
|
kernelName: Greater,
|
|
backendName: 'webgl',
|
|
kernelFunc: greater
|
|
};
|
|
|
|
|
|
const GREATER_EQUAL = `return float(a >= b);`;
|
|
const GREATER_EQUAL_PACKED = `
|
|
return vec4(greaterThanEqual(a, b));
|
|
`;
|
|
const greaterEqual = binaryKernelFunc({
|
|
opSnippet: GREATER_EQUAL,
|
|
packedOpSnippet: GREATER_EQUAL_PACKED,
|
|
dtype: 'bool',
|
|
cpuKernelImpl: greaterEqualImplCPU
|
|
});
|
|
const greaterEqualConfig = {
|
|
kernelName: GreaterEqual,
|
|
backendName: 'webgl',
|
|
kernelFunc: greaterEqual
|
|
};
|
|
|
|
|
|
function ifft$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
return fftImpl$1(input, true , backend);
|
|
}
|
|
const ifftConfig$1 = {
|
|
kernelName: IFFT,
|
|
backendName: 'webgl',
|
|
kernelFunc: ifft$1
|
|
};
|
|
|
|
|
|
const IS_FINITE = `return float(!isnan(x) && !isinf(x));`;
|
|
const isFinite$2 = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' });
|
|
const isFiniteConfig$1 = {
|
|
kernelName: IsFinite,
|
|
backendName: 'webgl',
|
|
kernelFunc: isFinite$2,
|
|
};
|
|
|
|
|
|
const IS_INF = `return float(isinf(x));`;
|
|
const isInf$1 = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' });
|
|
const isInfConfig$1 = {
|
|
kernelName: IsInf,
|
|
backendName: 'webgl',
|
|
kernelFunc: isInf$1,
|
|
};
|
|
|
|
|
|
const IS_NAN = `return float(isnan(x));`;
|
|
const isNaN$2 = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' });
|
|
const isNaNConfig$1 = {
|
|
kernelName: IsNan,
|
|
backendName: 'webgl',
|
|
kernelFunc: isNaN$2,
|
|
};
|
|
|
|
|
|
const LESS = `return float(a < b);`;
|
|
const LESS_PACKED = `
|
|
return vec4(lessThan(a, b));
|
|
`;
|
|
const less = binaryKernelFunc({
|
|
opSnippet: LESS,
|
|
packedOpSnippet: LESS_PACKED,
|
|
cpuKernelImpl: lessImplCPU,
|
|
dtype: 'bool'
|
|
});
|
|
const lessConfig = {
|
|
kernelName: Less,
|
|
backendName: 'webgl',
|
|
kernelFunc: less
|
|
};
|
|
|
|
|
|
const LESS_EQUAL = `return float(a <= b);`;
|
|
const LESS_EQUAL_PACKED = `
|
|
return vec4(lessThanEqual(a, b));
|
|
`;
|
|
const lessEqual = binaryKernelFunc({
|
|
opSnippet: LESS_EQUAL,
|
|
packedOpSnippet: LESS_EQUAL_PACKED,
|
|
cpuKernelImpl: lessEqualImplCPU,
|
|
dtype: 'bool'
|
|
});
|
|
const lessEqualConfig = {
|
|
kernelName: LessEqual,
|
|
backendName: 'webgl',
|
|
kernelFunc: lessEqual
|
|
};
|
|
|
|
|
|
function linSpace$1(args) {
|
|
const { backend, attrs } = args;
|
|
const { start, stop, num } = attrs;
|
|
|
|
const outVals = linSpaceImplCPU(start, stop, num);
|
|
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
|
|
}
|
|
const linSpaceConfig$1 = {
|
|
kernelName: LinSpace,
|
|
backendName: 'webgl',
|
|
kernelFunc: linSpace$1
|
|
};
|
|
|
|
|
|
|
|
|
|
const LOG = CHECK_NAN_SNIPPET_UNARY + `
|
|
return x < 0.0 ? 0./0. : log(x);
|
|
`;
|
|
const LOG_PACKED = `
|
|
vec4 result = log(x);
|
|
bvec4 isNaN = isnan(x);
|
|
result.r = isNaN.r ? x.r : (x.r < 0.0 ? 0./0. : result.r);
|
|
result.g = isNaN.g ? x.g : (x.g < 0.0 ? 0./0. : result.g);
|
|
result.b = isNaN.b ? x.b : (x.b < 0.0 ? 0./0. : result.b);
|
|
result.a = isNaN.a ? x.a : (x.a < 0.0 ? 0./0. : result.a);
|
|
return result;
|
|
`;
|
|
const log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU });
|
|
const logConfig = {
|
|
kernelName: Log,
|
|
backendName: 'webgl',
|
|
kernelFunc: log
|
|
};
|
|
|
|
|
|
const LOG1P = CHECK_NAN_SNIPPET_UNARY + `
|
|
return log(1.0 + x);
|
|
`;
|
|
const log1p$1 = unaryKernelFunc({ opSnippet: LOG1P });
|
|
const log1pConfig$1 = {
|
|
kernelName: Log1p,
|
|
backendName: 'webgl',
|
|
kernelFunc: log1p$1,
|
|
};
|
|
|
|
|
|
const LOGICAL_AND = `return float(a >= 1.0 && b >= 1.0);`;
|
|
const LOGICAL_AND_PACKED = `
|
|
return vec4(
|
|
vec4(greaterThanEqual(a, vec4(1.0))) *
|
|
vec4(greaterThanEqual(b, vec4(1.0))));
|
|
`;
|
|
const logicalAnd$1 = binaryKernelFunc({
|
|
opSnippet: LOGICAL_AND,
|
|
packedOpSnippet: LOGICAL_AND_PACKED,
|
|
dtype: 'bool'
|
|
});
|
|
const logicalAndConfig$1 = {
|
|
kernelName: LogicalAnd,
|
|
backendName: 'webgl',
|
|
kernelFunc: logicalAnd$1
|
|
};
|
|
|
|
|
|
const LOGICAL_NOT = `return float(!(x >= 1.0));`;
|
|
const logicalNot$1 = unaryKernelFunc({ opSnippet: LOGICAL_NOT });
|
|
const logicalNotConfig$1 = {
|
|
kernelName: LogicalNot,
|
|
backendName: 'webgl',
|
|
kernelFunc: logicalNot$1,
|
|
};
|
|
|
|
|
|
const LOGICAL_OR = `return float(a >= 1.0 || b >= 1.0);`;
|
|
const LOGICAL_OR_PACKED = `
|
|
return min(
|
|
vec4(greaterThanEqual(a, vec4(1.0))) +
|
|
vec4(greaterThanEqual(b, vec4(1.0))),
|
|
vec4(1.0));
|
|
`;
|
|
const logicalOr$1 = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' });
|
|
const logicalOrConfig$1 = {
|
|
kernelName: LogicalOr,
|
|
backendName: 'webgl',
|
|
kernelFunc: logicalOr$1
|
|
};
|
|
|
|
|
|
class LRNProgram {
|
|
constructor(xShape, radius, bias, alpha, beta) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = [];
|
|
const rad = radius;
|
|
const maxD = xShape[3] - 1;
|
|
this.outputShape = xShape;
|
|
|
|
|
|
|
|
|
|
let powOperator;
|
|
const basis = `float(${bias}) + float(${alpha}) * sum`;
|
|
if (beta === 0.5) {
|
|
powOperator = `inversesqrt(${basis})`;
|
|
}
|
|
else if (beta === 1.0) {
|
|
powOperator = `1.0/(${basis})`;
|
|
}
|
|
else {
|
|
powOperator = `exp(log(${basis}) * float(-${beta}));`;
|
|
}
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int r = coords[1];
|
|
int c = coords[2];
|
|
int d = coords[3];
|
|
float x = getX(b, r, c, d);
|
|
float sum = 0.0;
|
|
for (int j = -${rad}; j <= ${rad}; j++) {
|
|
int idx = d + j;
|
|
if (idx >= 0 && idx <= ${maxD}) {
|
|
float z = getX(b, r, c, idx);
|
|
sum += z * z;
|
|
}
|
|
}
|
|
float val = x * ${powOperator};
|
|
setOutput(val);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class LRNPackedProgram {
|
|
constructor(xShape, radius, bias, alpha, beta) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = [];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
const rad = radius;
|
|
const maxD = xShape[3] - 1;
|
|
this.outputShape = xShape;
|
|
|
|
|
|
|
|
|
|
let powOperator;
|
|
const basis = `float(${bias}) + float(${alpha}) * sum`;
|
|
if (beta === 0.5) {
|
|
powOperator = `inversesqrt(${basis})`;
|
|
}
|
|
else if (beta === 1.0) {
|
|
powOperator = `1.0/(${basis})`;
|
|
}
|
|
else {
|
|
powOperator = `exp(log(${basis}) * float(-${beta}));`;
|
|
}
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords.x;
|
|
int r = coords.y;
|
|
int c = coords.z;
|
|
int d = coords.w;
|
|
|
|
bool hasNextCol = d < ${this.outputShape[3]};
|
|
bool hasNextRow = c < ${this.outputShape[2]};
|
|
|
|
vec4 sum = vec4(0.);
|
|
vec4 xFragAtOutputCoords = getX(b, r, c, d);
|
|
|
|
vec4 xAtOutputCoords = vec4(
|
|
getChannel(xFragAtOutputCoords, vec2(c, d)),
|
|
hasNextCol ?
|
|
getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,
|
|
hasNextRow ?
|
|
getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,
|
|
(hasNextRow && hasNextCol) ?
|
|
getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0
|
|
);
|
|
|
|
int firstChannel = d - ${rad};
|
|
vec2 cache = vec2(0.);
|
|
if(firstChannel >= 0){
|
|
vec4 firstChannelFrag = getX(b, r, c, firstChannel);
|
|
cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));
|
|
if(hasNextRow){
|
|
cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));
|
|
}
|
|
}
|
|
|
|
ivec2 depth = ivec2(d, d + 1);
|
|
for (int j = - ${rad}; j <= ${rad}; j++) {
|
|
ivec2 idx = depth + j;
|
|
bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));
|
|
bvec2 belowUpperBound = lessThanEqual(idx, ivec2(${maxD}));
|
|
|
|
bool depthInRange = aboveLowerBound.x && belowUpperBound.x;
|
|
bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;
|
|
|
|
if(depthInRange || depthPlusOneInRange){
|
|
vec4 z = vec4(0.);
|
|
vec4 xFragAtCurrentDepth;
|
|
z.xz = cache.xy;
|
|
if(depthPlusOneInRange && hasNextCol){
|
|
xFragAtCurrentDepth = idx.y != d ?
|
|
getX(b, r, c, idx.y) : xFragAtOutputCoords;
|
|
z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));
|
|
if(hasNextRow){
|
|
z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));
|
|
}
|
|
}
|
|
cache.xy = z.yw;
|
|
sum += z * z;
|
|
}
|
|
}
|
|
vec4 result = xAtOutputCoords * ${powOperator};
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const lrn = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { depthRadius, bias, alpha, beta } = attrs;
|
|
const program = env().getBool('WEBGL_PACK_NORMALIZATION') ?
|
|
new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) :
|
|
new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
};
|
|
|
|
const LRNConfig$1 = {
|
|
kernelName: LRN,
|
|
backendName: 'webgl',
|
|
kernelFunc: lrn
|
|
};
|
|
|
|
|
|
class LRNGradProgram {
|
|
constructor(inputShape, depthRadius, bias, alpha, beta) {
|
|
this.variableNames = ['inputImage', 'outputImage', 'dy'];
|
|
this.outputShape = [];
|
|
this.outputShape = inputShape;
|
|
this.depth = inputShape[3];
|
|
this.depthRadius = depthRadius;
|
|
this.bias = bias;
|
|
this.alpha = alpha;
|
|
this.beta = beta;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int r = coords[1];
|
|
int c = coords[2];
|
|
|
|
float result = 0.0;
|
|
for (int d = 0; d < ${this.depth}; ++d) {
|
|
int depthBegin = int(max(0.0, float(d - ${depthRadius})));
|
|
int depthEnd = int(min(float(${this.depth}),
|
|
float(d + ${depthRadius} + 1)));
|
|
|
|
const int MIN_DEPTH_BEGIN = 0;
|
|
const int MAX_DEPTH_END = ${this.depth};
|
|
|
|
float norm = 0.0;
|
|
for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {
|
|
if (k < depthBegin){
|
|
continue;
|
|
}
|
|
else if (k >= depthBegin && k < depthEnd) {
|
|
norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);
|
|
}
|
|
else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
norm = float(${alpha}) * norm + float(${bias});
|
|
|
|
for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){
|
|
if (k < depthBegin){
|
|
continue;
|
|
}
|
|
else if (k >= depthBegin && k < depthEnd){
|
|
float dyi = -2.0 * float(${alpha})
|
|
* float(${beta})
|
|
* getInputImage(b, r, c, k) * getOutputImage(b, r, c, d)
|
|
/ norm;
|
|
if (k == d) {
|
|
dyi += pow(norm, -1.0 * ${beta});
|
|
}
|
|
if (k == coords[3]) {
|
|
dyi *= getDy(b, r, c, d);
|
|
result += dyi;
|
|
}
|
|
}
|
|
else {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const lrnGrad = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, y, dy } = inputs;
|
|
const { depthRadius, bias, alpha, beta } = attrs;
|
|
const program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
|
|
return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
|
|
};
|
|
|
|
const LRNGradConfig$1 = {
|
|
kernelName: LRNGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: lrnGrad
|
|
};
|
|
|
|
|
|
function maxImpl(x, reduceShape, outShape, backend) {
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const xSize = sizeFromShape(x.shape);
|
|
const batchSize = xSize / inSize;
|
|
const reshapedInput = reshape$1({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
|
|
const reduced = reduce(reshapedInput, x.dtype, 'max', backend);
|
|
const reshapedOutput = reshape$1({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
return reshapedOutput;
|
|
}
|
|
|
|
|
|
function max$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { reductionIndices, keepDims } = attrs;
|
|
const xRank = x.shape.length;
|
|
const origAxes = parseAxisParam(reductionIndices, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
const maxInputIsTransposed = permutedAxes != null;
|
|
const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
|
|
let maxInput = x;
|
|
if (maxInputIsTransposed) {
|
|
if (shouldExecuteOnCPU) {
|
|
const xTexData = backend.texData.get(maxInput.dataId);
|
|
const values = xTexData.values;
|
|
const newShape = new Array(xRank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[permutedAxes[i]];
|
|
}
|
|
const maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
|
|
maxInput = backend.makeTensorInfo(newShape, x.dtype);
|
|
const maxInputData = backend.texData.get(maxInput.dataId);
|
|
maxInputData.values = maxInputValues;
|
|
}
|
|
else {
|
|
maxInput = transposeImpl(x, permutedAxes, backend);
|
|
}
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
assertAxesAreInnerMostDims('max', axes, xRank);
|
|
const [maxOutShape, reduceShape] = computeOutAndReduceShapes(maxInput.shape, axes);
|
|
let outShape = maxOutShape;
|
|
if (keepDims) {
|
|
|
|
outShape = expandShapeToKeepDim(maxOutShape, origAxes);
|
|
}
|
|
let out;
|
|
if (shouldExecuteOnCPU) {
|
|
const xTexData = backend.texData.get(maxInput.dataId);
|
|
const values = xTexData.values;
|
|
const outValues = maxImplCPU(values, sizeFromShape(reduceShape), outShape, x.dtype);
|
|
out = backend.makeTensorInfo(outShape, x.dtype);
|
|
const outData = backend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
}
|
|
else {
|
|
out = maxImpl(maxInput, reduceShape, outShape, backend);
|
|
}
|
|
if (maxInputIsTransposed) {
|
|
backend.disposeIntermediateTensorInfo(maxInput);
|
|
}
|
|
return out;
|
|
}
|
|
const maxConfig$1 = {
|
|
kernelName: Max,
|
|
backendName: 'webgl',
|
|
kernelFunc: max$1
|
|
};
|
|
|
|
|
|
const MAXIMUM = CHECK_NAN_SNIPPET + `
|
|
return max(a, b);
|
|
`;
|
|
const MAXIMUM_PACKED = `
|
|
vec4 result = vec4(max(a, b));
|
|
bvec4 isNaNA = isnan(a);
|
|
bvec4 isNaNB = isnan(b);
|
|
bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
|
|
` +
|
|
CHECK_NAN_SNIPPET_PACKED + `
|
|
return result;
|
|
`;
|
|
const maximum = binaryKernelFunc({
|
|
opSnippet: MAXIMUM,
|
|
packedOpSnippet: MAXIMUM_PACKED,
|
|
cpuKernelImpl: maximumImplCPU
|
|
});
|
|
const maximumConfig = {
|
|
kernelName: Maximum,
|
|
backendName: 'webgl',
|
|
kernelFunc: maximum
|
|
};
|
|
|
|
|
|
function maxPool$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
assertNotComplex$1(x, 'maxPool');
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const dilations = 1;
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
|
|
`Got strides ${strides} and dilations '${dilations}'`);
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
|
arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
|
return identity({ inputs: { x }, backend });
|
|
}
|
|
const maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
|
|
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
|
|
}
|
|
const maxPoolConfig$1 = {
|
|
kernelName: MaxPool,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPool$1
|
|
};
|
|
|
|
|
|
function maxPool3d(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { filterSize, strides, pad, dataFormat, dimRoundingMode } = attrs;
|
|
const dilations = [1, 1, 1];
|
|
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
|
|
const maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
|
|
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
|
|
}
|
|
const maxPool3DConfig$1 = {
|
|
kernelName: MaxPool3D,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPool3d
|
|
};
|
|
|
|
|
|
class MaxPool2DBackpropProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy', 'maxPos'];
|
|
this.outputShape = convInfo.inShape;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
|
|
this.userCode = `
|
|
const ivec2 pads = ivec2(${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
|
|
ivec2 dyRCCorner = coords.yz - pads;
|
|
int dyRCorner = dyRCCorner.x;
|
|
int dyCCorner = dyRCCorner.y;
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 || fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
for (int wC = 0; wC < ${effectiveFilterWidth}; wC++) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
float dyValue = getDy(b, idyR, idyC, d);
|
|
int maxPosValue = ${lastIndex} - int(getMaxPos(b, idyR, idyC, d));
|
|
|
|
|
|
|
|
int curPosValue = wR * ${effectiveFilterWidth} + wC;
|
|
float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
|
|
|
|
dotProd += dyValue * mask;
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class MaxPool3DBackpropProgram {
|
|
constructor(convInfo) {
|
|
this.variableNames = ['dy', 'maxPos'];
|
|
this.outputShape = convInfo.inShape;
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
|
|
this.userCode = `
|
|
const ivec3 pads = ivec3(${padFront}, ${padTop}, ${padLeft});
|
|
|
|
void main() {
|
|
ivec5 coords = getOutputCoords();
|
|
int batch = coords.x;
|
|
int ch = coords.u;
|
|
|
|
ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;
|
|
int dyDCorner = dyCorner.x;
|
|
int dyRCorner = dyCorner.y;
|
|
int dyCCorner = dyCorner.z;
|
|
|
|
|
|
|
|
|
|
float dotProd = 0.0;
|
|
|
|
for (int wD = 0; wD < ${effectiveFilterDepth};
|
|
wD += ${dilationDepth}) {
|
|
float dyD = float(dyDCorner + wD) / ${strideDepth}.0;
|
|
|
|
if (dyD < 0.0 || dyD >= ${convInfo.outDepth}.0 || fract(dyD) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyD = int(dyD);
|
|
|
|
for (int wR = 0; wR < ${effectiveFilterHeight};
|
|
wR += ${dilationHeight}) {
|
|
float dyR = float(dyRCorner + wR) / ${strideHeight}.0;
|
|
|
|
if (dyR < 0.0 || dyR >= ${convInfo.outHeight}.0 ||
|
|
fract(dyR) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyR = int(dyR);
|
|
|
|
for (int wC = 0; wC < ${effectiveFilterWidth};
|
|
wC += ${dilationWidth}) {
|
|
float dyC = float(dyCCorner + wC) / ${strideWidth}.0;
|
|
|
|
if (dyC < 0.0 || dyC >= ${convInfo.outWidth}.0 ||
|
|
fract(dyC) > 0.0) {
|
|
continue;
|
|
}
|
|
int idyC = int(dyC);
|
|
|
|
float dyValue = getDy(batch, idyD, idyR, idyC, ch);
|
|
int maxPosValue = ${lastIndex} -
|
|
int(getMaxPos(batch, idyD, idyR, idyC, ch));
|
|
|
|
|
|
|
|
int curPosValue =
|
|
wD * ${effectiveFilterHeight} * ${effectiveFilterWidth} +
|
|
wR * ${effectiveFilterWidth} + wC;
|
|
float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);
|
|
|
|
dotProd += dyValue * mask;
|
|
}
|
|
}
|
|
}
|
|
setOutput(dotProd);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function maxPool3DGrad$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input } = inputs;
|
|
const x = input;
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const dilations = [1, 1, 1];
|
|
const convInfo = computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
const maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true );
|
|
const maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
|
|
const maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
|
|
const result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
|
|
backend.disposeIntermediateTensorInfo(maxPool3dPositions);
|
|
return result;
|
|
}
|
|
const maxPool3DGradConfig$2 = {
|
|
kernelName: MaxPool3DGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPool3DGrad$1
|
|
};
|
|
|
|
|
|
function maxPoolGrad$2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input, output } = inputs;
|
|
const x = input;
|
|
assertNotComplex$1([input, output], 'maxPoolGrad');
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode);
|
|
const getPositions = true;
|
|
const maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
|
|
const maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
|
|
const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
|
|
const result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
|
|
backend.disposeIntermediateTensorInfo(maxPoolPositions);
|
|
return result;
|
|
}
|
|
const maxPoolGradConfig$2 = {
|
|
kernelName: MaxPoolGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPoolGrad$2
|
|
};
|
|
|
|
|
|
function maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, backend) {
|
|
let program = new Pool2DProgram(convInfo, 'max', false);
|
|
const poolOutput = backend.runWebGLProgram(program, [x], 'float32');
|
|
program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
|
|
const indexOutput = backend.runWebGLProgram(program, [x], 'float32');
|
|
return [poolOutput, indexOutput];
|
|
}
|
|
|
|
|
|
const maxPoolWithArgmaxConfig$1 = {
|
|
kernelName: MaxPoolWithArgmax,
|
|
backendName: 'webgl',
|
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
|
const { x } = inputs;
|
|
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
|
|
const webglBackend = backend;
|
|
assert$1(x.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${x.shape.length}.`);
|
|
const dilations = [1, 1];
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
|
|
`Got strides ${strides} and dilations '${dilations}'`);
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
|
|
const [result, indexes] = maxPoolWithArgmaxImpl$1(x, includeBatchInIndex, convInfo, webglBackend);
|
|
return [result, indexes];
|
|
}
|
|
};
|
|
|
|
|
|
function meanImpl(x, reduceShape, outShape, backend) {
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const xSize = sizeFromShape(x.shape);
|
|
const batchSize = xSize / inSize;
|
|
const reshapedInput = reshape$1({ inputs: { x }, attrs: { shape: [batchSize, inSize] }, backend });
|
|
const reduced = reduce(reshapedInput, 'float32', 'mean', backend);
|
|
const reshapedOutput = reshape$1({ inputs: { x: reduced }, attrs: { shape: outShape }, backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
return reshapedOutput;
|
|
}
|
|
|
|
|
|
const meanConfig$1 = {
|
|
kernelName: Mean,
|
|
backendName: 'webgl',
|
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
|
const { x } = inputs;
|
|
const { keepDims, axis } = attrs;
|
|
const webglBackend = backend;
|
|
const xRank = x.shape.length;
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
const meanInputIsTransposed = permutedAxes != null;
|
|
const shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
|
|
const intermediates = [];
|
|
let meanInput = x;
|
|
if (meanInputIsTransposed) {
|
|
if (shouldExecuteOnCPU) {
|
|
const xTexData = webglBackend.texData.get(meanInput.dataId);
|
|
const values = xTexData.values;
|
|
const newShape = new Array(xRank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[permutedAxes[i]];
|
|
}
|
|
const meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
|
|
meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
|
|
const meanInputData = webglBackend.texData.get(meanInput.dataId);
|
|
meanInputData.values = meanInputValues;
|
|
}
|
|
else {
|
|
meanInput = transposeImpl(x, permutedAxes, webglBackend);
|
|
}
|
|
intermediates.push(meanInput);
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
assertAxesAreInnerMostDims('sum', axes, xRank);
|
|
const [meanOutShape, reduceShape] = computeOutAndReduceShapes(meanInput.shape, axes);
|
|
let outShape = meanOutShape;
|
|
if (keepDims) {
|
|
|
|
outShape = expandShapeToKeepDim(meanOutShape, origAxes);
|
|
}
|
|
const out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
|
|
for (const i of intermediates) {
|
|
webglBackend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return out;
|
|
}
|
|
};
|
|
|
|
|
|
function min$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
const xRank = x.shape.length;
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
let permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, x.shape.length);
|
|
}
|
|
assertAxesAreInnerMostDims('min', axes, xRank);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
|
|
const reduced = reduce(a2D, a2D.dtype, 'min', backend);
|
|
let res;
|
|
if (keepDims) {
|
|
const newShape = expandShapeToKeepDim(outShape, origAxes);
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: newShape } });
|
|
}
|
|
else {
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a2D);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return res;
|
|
}
|
|
const minConfig$1 = {
|
|
kernelName: Min,
|
|
backendName: 'webgl',
|
|
kernelFunc: min$1
|
|
};
|
|
|
|
|
|
const MINIMUM = CHECK_NAN_SNIPPET + `
|
|
return min(a, b);
|
|
`;
|
|
const MINIMUM_PACKED = `
|
|
vec4 result = vec4(min(a, b));
|
|
bvec4 isNaNA = isnan(a);
|
|
bvec4 isNaNB = isnan(b);
|
|
bvec4 isNaN = bvec4(isNaNA.x || isNaNB.x, isNaNA.y || isNaNB.y, isNaNA.z || isNaNB.z, isNaNA.w || isNaNB.w);
|
|
` +
|
|
CHECK_NAN_SNIPPET_PACKED + `
|
|
return result;
|
|
`;
|
|
const minimum = binaryKernelFunc({
|
|
opSnippet: MINIMUM,
|
|
packedOpSnippet: MINIMUM_PACKED,
|
|
cpuKernelImpl: minimumImplCPU
|
|
});
|
|
const minimumConfig = {
|
|
kernelName: Minimum,
|
|
backendName: 'webgl',
|
|
kernelFunc: minimum
|
|
};
|
|
|
|
|
|
class MirrorPadProgram {
|
|
constructor(xShape, paddings, mode) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
|
|
const rank = xShape.length;
|
|
const dtype = getCoordsDataType(rank);
|
|
const start = paddings.map(p => p[0]).join(',');
|
|
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
|
|
const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
|
|
const offset = mode === 'reflect' ? 0 : 1;
|
|
if (rank === 1) {
|
|
this.userCode = `
|
|
int start = ${start};
|
|
int end = ${end};
|
|
|
|
void main() {
|
|
int outC = getOutputCoords();
|
|
if (outC < start) {
|
|
outC = start * 2 - outC - ${offset};
|
|
} else if(outC >= end) {
|
|
outC = (end - 1) * 2 - outC + ${offset};
|
|
}
|
|
setOutput(getX(outC - start));
|
|
}
|
|
`;
|
|
return;
|
|
}
|
|
this.userCode = `
|
|
${dtype} start = ${dtype}(${start});
|
|
${dtype} end = ${dtype}(${end});
|
|
|
|
void main() {
|
|
${dtype} outC = getOutputCoords();
|
|
for (int i = 0; i < ${rank}; i++) {
|
|
if (outC[i] < start[i]) {
|
|
outC[i] = start[i] * 2 - outC[i] - ${offset};
|
|
} else if(outC[i] >= end[i]) {
|
|
outC[i] = (end[i] - 1) * 2 - outC[i] + ${offset};
|
|
}
|
|
}
|
|
${dtype} coords = outC - start;
|
|
setOutput(getX(${unpackedCoords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class MirrorPadPackedProgram {
|
|
constructor(xShape, paddings, mode) {
|
|
this.variableNames = ['x'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
|
|
const rank = xShape.length;
|
|
const dtype = getCoordsDataType(rank);
|
|
const start = paddings.map(p => p[0]).join(',');
|
|
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
|
|
const coords = getChannels('rc', rank);
|
|
const source = getChannels('source', rank);
|
|
const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
|
|
const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
|
|
const offset = mode === 'reflect' ? 0 : 1;
|
|
let mainLoop = '';
|
|
if (rank === 1) {
|
|
const padSetup = `
|
|
${dtype} source = rc;
|
|
if (source < start) {
|
|
source = start * 2 - source - ${offset};
|
|
} else if (source >= end) {
|
|
source = (end - 1) * 2 - source + ${offset};
|
|
}
|
|
source -= start;
|
|
`;
|
|
mainLoop = `
|
|
${dtype} rc = outputLoc;
|
|
${padSetup}
|
|
result[0] = getChannel(getX(${source.join()}), ${innerDims});
|
|
${coords[rank - 1]} += 1;
|
|
if(${cLimit}) {
|
|
${padSetup}
|
|
result[1] = getChannel(getX(${source.join()}), ${innerDims});
|
|
}
|
|
`;
|
|
}
|
|
else {
|
|
const padSetup = `
|
|
${dtype} source = rc;
|
|
${dtype} lt = ${dtype}(lessThan(source, start));
|
|
${dtype} gte = ${dtype}(greaterThanEqual(source, end));
|
|
${dtype} orig = 1 - (lt + gte);
|
|
source = orig * source +
|
|
lt * (start * 2 - source - ${offset}) +
|
|
gte * ((end - 1) * 2 - source + ${offset});
|
|
source -= start;
|
|
`;
|
|
mainLoop = `
|
|
${dtype} rc = outputLoc;
|
|
${padSetup}
|
|
result[0] = getChannel(getX(${source.join()}), ${innerDims});
|
|
${coords[rank - 1]} += 1;
|
|
if(${cLimit}) {
|
|
${padSetup}
|
|
result[1] = getChannel(getX(${source.join()}), ${innerDims});
|
|
}
|
|
rc = outputLoc;
|
|
${coords[rank - 2]} += 1;
|
|
if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {
|
|
${padSetup}
|
|
result[2] = getChannel(getX(${source.join()}), ${innerDims});
|
|
${coords[rank - 1]} += 1;
|
|
if(${cLimit}) {
|
|
${padSetup}
|
|
result[3] = getChannel(getX(${source.join()}), ${innerDims});
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
this.userCode = `
|
|
const ${dtype} start = ${dtype}(${start});
|
|
const ${dtype} end = ${dtype}(${end});
|
|
|
|
void main() {
|
|
${dtype} outputLoc = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
|
${mainLoop}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const mirrorPadKernelFunc = ({ inputs, backend, attrs }) => {
|
|
const { x } = inputs;
|
|
const { paddings, mode } = attrs;
|
|
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new MirrorPadPackedProgram(x.shape, paddings, mode) :
|
|
new MirrorPadProgram(x.shape, paddings, mode);
|
|
const output = backend.runWebGLProgram(program, [x], x.dtype);
|
|
return output;
|
|
};
|
|
const mirrorPadConfig$1 = {
|
|
kernelName: MirrorPad,
|
|
backendName: 'webgl',
|
|
kernelFunc: mirrorPadKernelFunc,
|
|
};
|
|
|
|
|
|
const MOD = `if (b == 0.0) return NAN;
|
|
return mod(a, b);`;
|
|
const MOD_PACKED = `
|
|
vec4 result = mod(a, b);
|
|
bvec4 isNaN = equal(b, vec4(0.0));
|
|
` +
|
|
CHECK_NAN_SNIPPET_PACKED + `
|
|
return result;
|
|
`;
|
|
const mod$1 = binaryKernelFunc({
|
|
opSnippet: MOD,
|
|
packedOpSnippet: MOD_PACKED,
|
|
});
|
|
const modConfig$1 = {
|
|
kernelName: Mod,
|
|
backendName: 'webgl',
|
|
kernelFunc: mod$1
|
|
};
|
|
|
|
|
|
class MultinomialProgram {
|
|
constructor(batchSize, numOutcomes, numSamples) {
|
|
this.variableNames = ['probs'];
|
|
this.customUniforms = [{ name: 'seed', type: 'float' }];
|
|
this.outputShape = [batchSize, numSamples];
|
|
this.userCode = `
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
|
|
float r = random(seed);
|
|
float cdf = 0.0;
|
|
|
|
for (int i = 0; i < ${numOutcomes - 1}; i++) {
|
|
cdf += getProbs(batch, i);
|
|
|
|
if (r < cdf) {
|
|
setOutput(float(i));
|
|
return;
|
|
}
|
|
}
|
|
|
|
|
|
setOutput(float(${numOutcomes - 1}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
const DIV = `
|
|
if (a == b) {
|
|
return 1.0;
|
|
};
|
|
return a / b;`;
|
|
|
|
|
|
const DIV_PACKED = `
|
|
|
|
|
|
vec4 result = a / b;
|
|
if(a.x == b.x) {
|
|
result.x = 1.;
|
|
}
|
|
if(a.y == b.y) {
|
|
result.y = 1.;
|
|
}
|
|
if(a.z == b.z) {
|
|
result.z = 1.;
|
|
}
|
|
if(a.w == b.w) {
|
|
result.w = 1.;
|
|
}
|
|
|
|
return result;
|
|
`;
|
|
const realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true });
|
|
const realDivConfig$1 = {
|
|
kernelName: RealDiv,
|
|
backendName: 'webgl',
|
|
kernelFunc: realDiv,
|
|
};
|
|
|
|
|
|
const SUB = 'return a - b;';
|
|
const sub = binaryKernelFunc({
|
|
opSnippet: SUB,
|
|
packedOpSnippet: SUB,
|
|
supportsComplex: true,
|
|
cpuKernelImpl: subImplCPU
|
|
});
|
|
const subConfig = {
|
|
kernelName: Sub,
|
|
backendName: 'webgl',
|
|
kernelFunc: sub
|
|
};
|
|
|
|
|
|
function softmax$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { logits } = inputs;
|
|
const { dim } = attrs;
|
|
const axes = parseAxisParam([dim], logits.shape);
|
|
const maxLogit = max$1({
|
|
inputs: { x: logits },
|
|
backend,
|
|
attrs: { reductionIndices: axes, keepDims: false }
|
|
});
|
|
const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
|
|
const maxLogitsReshaped = reshape$1({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
|
|
const a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend });
|
|
const b = exp({ inputs: { x: a }, backend });
|
|
const sumExp = sum$1({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
|
|
const sumExpReshaped = reshape$1({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
|
|
const res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend });
|
|
backend.disposeIntermediateTensorInfo(maxLogit);
|
|
backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
|
|
backend.disposeIntermediateTensorInfo(a);
|
|
backend.disposeIntermediateTensorInfo(b);
|
|
backend.disposeIntermediateTensorInfo(sumExp);
|
|
backend.disposeIntermediateTensorInfo(sumExpReshaped);
|
|
return res;
|
|
}
|
|
const softmaxConfig$1 = {
|
|
kernelName: Softmax$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: softmax$1
|
|
};
|
|
|
|
|
|
function multinomial$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { logits } = inputs;
|
|
const { numSamples, seed, normalized } = attrs;
|
|
const probs = normalized ?
|
|
logits :
|
|
softmax$1({ inputs: { logits }, backend, attrs: { dim: logits.shape.length - 1 } });
|
|
const batchSize = probs.shape[0];
|
|
const numOutcomes = probs.shape[1];
|
|
const program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
|
|
const customValues = [[seed]];
|
|
const res = backend.runWebGLProgram(program, [probs], 'int32', customValues);
|
|
if (!normalized) {
|
|
backend.disposeIntermediateTensorInfo(probs);
|
|
}
|
|
return res;
|
|
}
|
|
const multinomialConfig$1 = {
|
|
kernelName: Multinomial,
|
|
backendName: 'webgl',
|
|
kernelFunc: multinomial$1
|
|
};
|
|
|
|
|
|
const NEG = CHECK_NAN_SNIPPET$1 + `
|
|
return -x;
|
|
`;
|
|
const NEG_PACKED = `
|
|
vec4 result = -x;
|
|
bvec4 isNaN = isnan(x);
|
|
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
|
|
|
|
function neg(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
if (backend.shouldExecuteOnCPU([x])) {
|
|
const xData = backend.texData.get(x.dataId);
|
|
const [outValues, newShape] = negImplCPU(xData.values, x.shape, x.dtype);
|
|
return backend.makeTensorInfo(newShape, x.dtype, outValues);
|
|
}
|
|
let program;
|
|
if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
|
|
program = new UnaryOpPackedProgram(x.shape, NEG_PACKED);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(x.shape, NEG);
|
|
}
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
const negConfig = {
|
|
kernelName: Neg,
|
|
backendName: 'webgl',
|
|
kernelFunc: neg
|
|
};
|
|
|
|
|
|
const nonMaxSuppressionV3Impl$1 = nonMaxSuppressionV3Impl$2;
|
|
function nonMaxSuppressionV3$1(args) {
|
|
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
|
|
'Call tf.nonMaxSuppressionAsync() instead');
|
|
const { inputs, backend, attrs } = args;
|
|
const { boxes, scores } = inputs;
|
|
const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
|
|
const boxesVals = backend.readSync(boxes.dataId);
|
|
const scoresVals = backend.readSync(scores.dataId);
|
|
const { selectedIndices } = nonMaxSuppressionV3Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
|
|
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
|
|
}
|
|
const nonMaxSuppressionV3Config$1 = {
|
|
kernelName: NonMaxSuppressionV3,
|
|
backendName: 'webgl',
|
|
kernelFunc: nonMaxSuppressionV3$1
|
|
};
|
|
|
|
|
|
const nonMaxSuppressionV4Impl$1 = nonMaxSuppressionV4Impl$2;
|
|
function nonMaxSuppressionV4$1(args) {
|
|
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
|
|
'Call tf.nonMaxSuppressionAsync() instead');
|
|
const { inputs, backend, attrs } = args;
|
|
const { boxes, scores } = inputs;
|
|
const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
|
|
const boxesVals = backend.readSync(boxes.dataId);
|
|
const scoresVals = backend.readSync(scores.dataId);
|
|
const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl$1(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
|
|
return [
|
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
|
backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
|
|
];
|
|
}
|
|
const nonMaxSuppressionV4Config$1 = {
|
|
kernelName: NonMaxSuppressionV4,
|
|
backendName: 'webgl',
|
|
kernelFunc: nonMaxSuppressionV4$1
|
|
};
|
|
|
|
|
|
const nonMaxSuppressionV5Impl$1 = nonMaxSuppressionV5Impl$2;
|
|
function nonMaxSuppressionV5$1(args) {
|
|
warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
|
|
'Call tf.nonMaxSuppressionAsync() instead');
|
|
const { inputs, backend, attrs } = args;
|
|
const { boxes, scores } = inputs;
|
|
const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
|
|
const boxesVals = backend.readSync(boxes.dataId);
|
|
const scoresVals = backend.readSync(scores.dataId);
|
|
const maxOutputSizeVal = maxOutputSize;
|
|
const iouThresholdVal = iouThreshold;
|
|
const scoreThresholdVal = scoreThreshold;
|
|
const softNmsSigmaVal = softNmsSigma;
|
|
const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl$1(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
|
|
return [
|
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
|
backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
|
|
];
|
|
}
|
|
const nonMaxSuppressionV5Config$1 = {
|
|
kernelName: NonMaxSuppressionV5,
|
|
backendName: 'webgl',
|
|
kernelFunc: nonMaxSuppressionV5$1
|
|
};
|
|
|
|
|
|
class OneHotProgram {
|
|
constructor(numIndices, depth, onValue, offValue) {
|
|
this.variableNames = ['indices'];
|
|
this.outputShape = [numIndices, depth];
|
|
this.userCode = `
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int index = round(getIndices(coords.x));
|
|
setOutput(mix(float(${offValue}), float(${onValue}),
|
|
float(index == coords.y)));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const oneHot$1 = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { indices } = inputs;
|
|
const { dtype, depth, onValue, offValue } = attrs;
|
|
const indicesSize = sizeFromShape(indices.shape);
|
|
const program = new OneHotProgram(indicesSize, depth, onValue, offValue);
|
|
const reshaped = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [indicesSize] } });
|
|
const result = backend.runWebGLProgram(program, [reshaped], dtype);
|
|
backend.disposeIntermediateTensorInfo(reshaped);
|
|
const outShape = [...indices.shape, depth];
|
|
const out = reshape$1({ inputs: { x: result }, backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return out;
|
|
};
|
|
const oneHotConfig$1 = {
|
|
kernelName: OneHot,
|
|
backendName: 'webgl',
|
|
kernelFunc: oneHot$1
|
|
};
|
|
|
|
|
|
function zerosLike$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
if (x.dtype === 'complex64') {
|
|
const realPart = real({ inputs: { input: x }, backend });
|
|
const r = zerosLike$1({ inputs: { x: realPart }, backend });
|
|
const imagPart = imag$1({ inputs: { input: x }, backend });
|
|
const i = zerosLike$1({ inputs: { x: imagPart }, backend });
|
|
const result = complex({ inputs: { real: r, imag: i }, backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(r);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
return result;
|
|
}
|
|
else {
|
|
return fill$1({
|
|
attrs: {
|
|
shape: x.shape,
|
|
dtype: x.dtype,
|
|
value: x.dtype === 'string' ? '' : 0
|
|
},
|
|
backend
|
|
});
|
|
}
|
|
}
|
|
const zerosLikeConfig$1 = {
|
|
kernelName: ZerosLike,
|
|
backendName: 'webgl',
|
|
kernelFunc: zerosLike$1
|
|
};
|
|
|
|
|
|
function onesLike$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
if (x.dtype === 'string') {
|
|
throw new Error('onesLike is not supported under string dtype');
|
|
}
|
|
else if (x.dtype === 'complex64') {
|
|
const realPart = real({ inputs: { input: x }, backend });
|
|
const r = onesLike$1({ inputs: { x: realPart }, backend });
|
|
const imagPart = imag$1({ inputs: { input: x }, backend });
|
|
const i = zerosLike$1({ inputs: { x: imagPart }, backend });
|
|
const result = complex({ inputs: { real: r, imag: i }, backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(r);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
return result;
|
|
}
|
|
else {
|
|
|
|
|
|
return fill$1({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend });
|
|
}
|
|
}
|
|
const onesLikeConfig$1 = {
|
|
kernelName: OnesLike,
|
|
backendName: 'webgl',
|
|
kernelFunc: onesLike$1
|
|
};
|
|
|
|
|
|
function pack$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { axis } = attrs;
|
|
if (inputs.length === 1) {
|
|
return expandDims$2({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
|
|
}
|
|
const shape = inputs[0].shape;
|
|
const dtype = inputs[0].dtype;
|
|
inputs.forEach(t => {
|
|
assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
|
|
assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
|
|
});
|
|
const intermediateTensorInfos = [];
|
|
const expandedTensors = inputs.map(t => {
|
|
const expandedT = expandDims$2({ inputs: { input: t }, backend, attrs: { dim: axis } });
|
|
intermediateTensorInfos.push(expandedT);
|
|
return expandedT;
|
|
});
|
|
const result = concat$1({ inputs: expandedTensors, backend, attrs: { axis } });
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
}
|
|
const packConfig$1 = {
|
|
kernelName: Pack,
|
|
backendName: 'webgl',
|
|
kernelFunc: pack$1
|
|
};
|
|
|
|
|
|
class PadProgram {
|
|
constructor(xShape, paddings, constantValue) {
|
|
this.variableNames = ['x'];
|
|
this.customUniforms = [{ name: 'value', type: 'float' }];
|
|
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
|
|
const rank = xShape.length;
|
|
const type = getCoordsDataType(rank);
|
|
const start = paddings.map(p => p[0]).join(',');
|
|
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
|
|
const unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
|
|
if (rank === 1) {
|
|
this.userCode = `
|
|
int start = ${start};
|
|
int end = ${end};
|
|
|
|
void main() {
|
|
int outC = getOutputCoords();
|
|
if (outC < start || outC >= end) {
|
|
setOutput(value);
|
|
} else {
|
|
setOutput(getX(outC - start));
|
|
}
|
|
}
|
|
`;
|
|
return;
|
|
}
|
|
this.userCode = `
|
|
${type} start = ${type}(${start});
|
|
${type} end = ${type}(${end});
|
|
|
|
void main() {
|
|
${type} outC = getOutputCoords();
|
|
if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {
|
|
setOutput(value);
|
|
} else {
|
|
${type} coords = outC - start;
|
|
setOutput(getX(${unpackedCoords}));
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class PadPackedProgram {
|
|
constructor(xShape, paddings, constantValue) {
|
|
this.variableNames = ['x'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.customUniforms = [{ name: 'value', type: 'float' }];
|
|
this.outputShape = paddings.map((p, i) => p[0] + xShape[i] + p[1] );
|
|
const rank = xShape.length;
|
|
const dtype = getCoordsDataType(rank);
|
|
const start = paddings.map(p => p[0]).join(',');
|
|
const end = paddings.map((p, i) => p[0] + xShape[i]).join(',');
|
|
const coords = getChannels('rc', rank);
|
|
const source = getChannels('source', rank);
|
|
const cLimit = `${coords[rank - 1]} < ${this.outputShape[rank - 1]}`;
|
|
const innerDims = rank === 1 ? 'source' : `vec2(${source.slice(-2).join()})`;
|
|
const componentSetup = [
|
|
`${dtype} rc = outputLoc;`, `${coords[rank - 1]} += 1;
|
|
if(${cLimit}) {
|
|
`,
|
|
rank === 1 ? '' : `}
|
|
rc = outputLoc;
|
|
${coords[rank - 2]} += 1;
|
|
if(${coords[rank - 2]} < ${this.outputShape[rank - 2]}) {`,
|
|
rank === 1 ? '' : ` ${coords[rank - 1]} += 1;
|
|
if(${cLimit}) {`
|
|
];
|
|
const paddingArea = rank === 1 ?
|
|
'rc < start || rc >= end' :
|
|
'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
|
|
let mainLoop = '';
|
|
for (let i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
|
|
mainLoop += `
|
|
${componentSetup[i]}
|
|
if (${paddingArea}) {
|
|
result[${i}] = float(value);
|
|
} else {
|
|
${dtype} source = rc - start;
|
|
result[${i}] = getChannel(getX(${source.join()}), ${innerDims});
|
|
}
|
|
`;
|
|
}
|
|
mainLoop += (rank === 1 ? `} ` : `}}`);
|
|
this.userCode = `
|
|
const ${dtype} start = ${dtype}(${start});
|
|
const ${dtype} end = ${dtype}(${end});
|
|
|
|
void main() {
|
|
${dtype} outputLoc = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
|
${mainLoop}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const padV2$1 = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { paddings, constantValue } = attrs;
|
|
if (sizeFromShape(x.shape) === 0) {
|
|
|
|
|
|
const outputShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] );
|
|
return fill$1({
|
|
backend,
|
|
attrs: { shape: outputShape, value: constantValue, dtype: x.dtype }
|
|
});
|
|
}
|
|
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new PadPackedProgram(x.shape, paddings, constantValue) :
|
|
new PadProgram(x.shape, paddings, constantValue);
|
|
const customValues = [[constantValue]];
|
|
return backend.runWebGLProgram(program, [x], x.dtype, customValues);
|
|
};
|
|
const padV2Config$1 = {
|
|
kernelName: PadV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: padV2$1
|
|
};
|
|
|
|
|
|
const POW = `
|
|
if(a < 0.0 && floor(b) < b){
|
|
return NAN;
|
|
}
|
|
if (b == 0.0) {
|
|
return 1.0;
|
|
}
|
|
return (round(mod(b, 2.0)) != 1) ?
|
|
pow(abs(a), b) : sign(a) * pow(abs(a), b);
|
|
`;
|
|
const POW_PACKED = `
|
|
|
|
vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));
|
|
vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);
|
|
vec4 result = multiplier * pow(abs(a), b);
|
|
|
|
|
|
bvec4 isExpZero = equal(b, vec4(0.0));
|
|
result.r = isExpZero.r ? 1.0 : result.r;
|
|
result.g = isExpZero.g ? 1.0 : result.g;
|
|
result.b = isExpZero.b ? 1.0 : result.b;
|
|
result.a = isExpZero.a ? 1.0 : result.a;
|
|
|
|
bvec4 isNaN1 = lessThan(a, vec4(0.0));
|
|
bvec4 isNaN2 = lessThan(floor(b), b);
|
|
bvec4 isNaN = bvec4(isNaN1.x && isNaN2.x, isNaN1.y && isNaN2.y, isNaN1.z && isNaN2.z, isNaN1.w && isNaN2.w);
|
|
` +
|
|
CHECK_NAN_SNIPPET_PACKED + `
|
|
return result;
|
|
`;
|
|
const pow$1 = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED });
|
|
const powConfig$1 = {
|
|
kernelName: Pow,
|
|
backendName: 'webgl',
|
|
kernelFunc: pow$1
|
|
};
|
|
|
|
|
|
function prod(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
const xRank = x.shape.length;
|
|
const toDispose = [];
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
let permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
toDispose.push(permutedX);
|
|
}
|
|
assertAxesAreInnerMostDims('prod', axes, xRank);
|
|
let res;
|
|
if (backend.shouldExecuteOnCPU([permutedX])) {
|
|
const xVals = backend.texData.get(permutedX.dataId).values;
|
|
const { outVals, outShape, outDtype } = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes);
|
|
res = backend.makeTensorInfo(outShape, outDtype, outVals);
|
|
}
|
|
else {
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, axes);
|
|
const inSize = sizeFromShape(reduceShape);
|
|
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
|
|
const outputDType = sumOutType(x.dtype);
|
|
const reduced = reduce(a2D, outputDType, 'prod', backend);
|
|
res = reshape$1({ inputs: { x: reduced }, backend, attrs: { shape: outShape } });
|
|
toDispose.push(a2D);
|
|
toDispose.push(reduced);
|
|
}
|
|
if (keepDims) {
|
|
toDispose.push(res);
|
|
const newShape = expandShapeToKeepDim(res.shape, origAxes);
|
|
res = reshape$1({ inputs: { x: res }, backend, attrs: { shape: newShape } });
|
|
}
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return res;
|
|
}
|
|
const prodConfig = {
|
|
kernelName: Prod,
|
|
backendName: 'webgl',
|
|
kernelFunc: prod
|
|
};
|
|
|
|
|
|
function raggedGather$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
|
|
const { outputRaggedRank } = attrs;
|
|
const $paramsNestedSplits = paramsNestedSplits.map(t => backend.readSync(t.dataId));
|
|
const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
|
|
const $paramsDenseValues = backend.readSync(paramsDenseValues.dataId);
|
|
const $indices = backend.readSync(indices.dataId);
|
|
const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImplCPU($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape, outputRaggedRank);
|
|
const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
|
|
const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
|
|
return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
|
|
}
|
|
const raggedGatherConfig$1 = {
|
|
kernelName: RaggedGather,
|
|
backendName: 'webgl',
|
|
kernelFunc: raggedGather$1,
|
|
};
|
|
|
|
|
|
function raggedRange$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { starts, limits, deltas } = inputs;
|
|
const $starts = backend.readSync(starts.dataId);
|
|
const $limits = backend.readSync(limits.dataId);
|
|
const $deltas = backend.readSync(deltas.dataId);
|
|
const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImplCPU($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
|
|
const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
|
|
const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
|
|
return [rtNestedSplits, rtDenseValues];
|
|
}
|
|
const raggedRangeConfig$1 = {
|
|
kernelName: RaggedRange,
|
|
backendName: 'webgl',
|
|
kernelFunc: raggedRange$1,
|
|
};
|
|
|
|
|
|
function raggedTensorToTensor$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { shape, values, defaultValue, rowPartitionTensors } = inputs;
|
|
const { rowPartitionTypes } = attrs;
|
|
const $shape = backend.readSync(shape.dataId);
|
|
const $values = backend.readSync(values.dataId);
|
|
const $defaultValue = backend.readSync(defaultValue.dataId);
|
|
const $rowPartitionValues = rowPartitionTensors.map(t => backend.readSync(t.dataId));
|
|
const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
|
|
const [outputShape, output] = raggedTensorToTensorImplCPU($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
|
|
return backend.makeTensorInfo(outputShape, values.dtype, output);
|
|
}
|
|
const raggedTensorToTensorConfig$1 = {
|
|
kernelName: RaggedTensorToTensor,
|
|
backendName: 'webgl',
|
|
kernelFunc: raggedTensorToTensor$1,
|
|
};
|
|
|
|
|
|
const range$2 = (args) => {
|
|
const { backend, attrs } = args;
|
|
const { start, stop, step, dtype } = attrs;
|
|
const values = rangeImplCPU(start, stop, step, dtype);
|
|
return backend.makeTensorInfo([values.length], dtype, values);
|
|
};
|
|
const rangeConfig$1 = {
|
|
kernelName: Range,
|
|
backendName: 'webgl',
|
|
kernelFunc: range$2
|
|
};
|
|
|
|
|
|
const RECIPROCAL = `return 1.0 / x;`;
|
|
const reciprocal$1 = unaryKernelFunc({ opSnippet: RECIPROCAL });
|
|
const reciprocalConfig$1 = {
|
|
kernelName: Reciprocal,
|
|
backendName: 'webgl',
|
|
kernelFunc: reciprocal$1,
|
|
};
|
|
|
|
|
|
const RELU = CHECK_NAN_SNIPPET$1 + `
|
|
return (x < 0.0) ? 0.0 : x;
|
|
`;
|
|
const RELU_PACKED = `
|
|
vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));
|
|
bvec4 isNaN = isnan(x);
|
|
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
const relu$1 = unaryKernelFunc({ opSnippet: RELU, packedOpSnippet: RELU_PACKED });
|
|
const reluConfig$1 = {
|
|
kernelName: Relu$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: relu$1
|
|
};
|
|
|
|
|
|
const RELU6 = CHECK_NAN_SNIPPET$1 + `
|
|
return (x < 0.0) ? 0.0 : min(6.0, x);
|
|
`;
|
|
const RELU6_PACKED = `
|
|
vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));
|
|
bvec4 isNaN = isnan(x);
|
|
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
const relu6$1 = unaryKernelFunc({ opSnippet: RELU6, packedOpSnippet: RELU6_PACKED });
|
|
const relu6Config$1 = {
|
|
kernelName: Relu6$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: relu6$1
|
|
};
|
|
|
|
|
|
class ResizeBilinearProgram {
|
|
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = [];
|
|
const [batch, oldHeight, oldWidth, depth] = inputShape;
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
const effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
const effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
let sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC =
|
|
`(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
|
|
` - vec2(0.5)`;
|
|
}
|
|
else {
|
|
sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
|
|
}
|
|
this.userCode = `
|
|
const vec2 effectiveInputOverOutputRatioRC = vec2(
|
|
${effectiveInSize[0] / effectiveOutSize[0]},
|
|
${effectiveInSize[1] / effectiveOutSize[1]});
|
|
const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
ivec2 yRC = coords.yz;
|
|
|
|
|
|
vec2 sourceFracIndexRC = ${sourceFracIndexRC};
|
|
|
|
|
|
ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));
|
|
ivec2 sourceCeilRC = ivec2(
|
|
min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
|
|
|
|
float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);
|
|
float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);
|
|
float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);
|
|
float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);
|
|
|
|
vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);
|
|
|
|
float top = topLeft + (topRight - topLeft) * fracRC.y;
|
|
float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;
|
|
float newValue = top + (bottom - top) * fracRC.x;
|
|
|
|
setOutput(newValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ResizeBilinearPackedProgram {
|
|
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = [];
|
|
const [batch, oldHeight, oldWidth, depth] = inputShape;
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
const effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
const effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
let sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC = `(vec3(yRC) + vec3(0.5)) * ` +
|
|
`effectiveInputOverOutputRatioRC - vec3(0.5)`;
|
|
}
|
|
else {
|
|
sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
|
|
}
|
|
this.userCode = `
|
|
const vec3 effectiveInputOverOutputRatioRC = vec3(
|
|
${effectiveInSize[0] / effectiveOutSize[0]},
|
|
${effectiveInSize[1] / effectiveOutSize[1]},
|
|
${effectiveInSize[1] / effectiveOutSize[1]});
|
|
const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
|
|
${oldWidth}.0);
|
|
|
|
float getAValue(int b, int r, int c, int d) {
|
|
return getChannel(getA(b, r, c, d), vec2(c, d));
|
|
}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
|
|
ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
|
|
|
|
|
|
vec3 sourceFracIndexRC = ${sourceFracIndexRC};
|
|
|
|
|
|
ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));
|
|
ivec3 sourceCeilRC = ivec3(
|
|
min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));
|
|
|
|
|
|
bool hasNextCol = d < ${depth - 1};
|
|
bool hasNextRow = coords.z < ${newWidth - 1};
|
|
|
|
|
|
|
|
vec4 topLeft = vec4(
|
|
getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),
|
|
hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)
|
|
: 0.0,
|
|
hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)
|
|
: 0.0,
|
|
(hasNextRow && hasNextCol) ?
|
|
getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);
|
|
|
|
vec4 bottomLeft = vec4(
|
|
getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),
|
|
hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)
|
|
: 0.0,
|
|
hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)
|
|
: 0.0,
|
|
(hasNextRow && hasNextCol) ?
|
|
getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);
|
|
|
|
vec4 topRight = vec4(
|
|
getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),
|
|
hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)
|
|
: 0.0,
|
|
hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)
|
|
: 0.0,
|
|
(hasNextRow && hasNextCol) ?
|
|
getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);
|
|
|
|
vec4 bottomRight = vec4(
|
|
getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),
|
|
hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)
|
|
: 0.0,
|
|
hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)
|
|
: 0.0,
|
|
(hasNextRow && hasNextCol) ?
|
|
getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);
|
|
|
|
vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);
|
|
|
|
vec4 top = mix(topLeft, topRight, fracRC.yyzz);
|
|
vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);
|
|
vec4 newValue = mix(top, bottom, fracRC.x);
|
|
|
|
setOutput(newValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function resizeBilinear$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images } = inputs;
|
|
const { alignCorners, halfPixelCenters, size } = attrs;
|
|
const [newHeight, newWidth] = size;
|
|
const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
|
|
new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
|
|
new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
|
|
return backend.runWebGLProgram(program, [images], 'float32');
|
|
}
|
|
const resizeBilinearConfig$1 = {
|
|
kernelName: ResizeBilinear,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeBilinear$1
|
|
};
|
|
|
|
|
|
class ResizeBilinearBackpropProgram {
|
|
constructor(dyShape, inputShape, alignCorners) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = [];
|
|
this.outputShape = inputShape;
|
|
const [, xHeight, xWidth,] = inputShape;
|
|
const [, yHeight, yWidth] = dyShape;
|
|
|
|
|
|
|
|
const effectiveXSize = [
|
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
|
];
|
|
const effectiveYSize = [
|
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
|
];
|
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
const invHeightScale = 1 / heightScale;
|
|
const invWidthScale = 1 / widthScale;
|
|
|
|
|
|
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
|
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
int r = coords[1];
|
|
int c = coords[2];
|
|
|
|
float accumulator = 0.0;
|
|
|
|
const float heightScale = float(${heightScale});
|
|
const float widthScale = float(${widthScale});
|
|
|
|
const float invHeightScale = float(${invHeightScale});
|
|
const float invWidthScale = float(${invWidthScale});
|
|
|
|
const int winHeight = int(${winHeight});
|
|
const int winWidth = int(${winWidth});
|
|
|
|
|
|
float startRLerp = floor(float(r) * invHeightScale);
|
|
int startDyR = int(startRLerp - float(winHeight / 2));
|
|
|
|
float startCLerp = floor(float(c) * invWidthScale);
|
|
int startDyC = int(startCLerp - float(winWidth / 2));
|
|
|
|
|
|
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
|
|
int dyR = dyROffset + startDyR;
|
|
|
|
|
|
if (dyR < 0 || dyR >= ${yHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
|
|
int dyC = dyCOffset + startDyC;
|
|
|
|
|
|
if (dyC < 0 || dyC >= ${yWidth}) {
|
|
continue;
|
|
}
|
|
|
|
float dxR = float(dyR) * heightScale;
|
|
int topDxRIndex = int(floor(dxR));
|
|
int bottomDxRIndex = int(min(ceil(dxR), ${xHeight - 1}.0));
|
|
float dxRLerp = dxR - float(topDxRIndex);
|
|
float inverseDxRLerp = 1.0 - dxRLerp;
|
|
|
|
float dxC = float(dyC) * widthScale;
|
|
int leftDxCIndex = int(floor(dxC));
|
|
int rightDxCIndex = int(min(ceil(dxC), ${xWidth - 1}.0));
|
|
float dxCLerp = dxC - float(leftDxCIndex);
|
|
float inverseDxCLerp = 1.0 - dxCLerp;
|
|
|
|
if (r == topDxRIndex && c == leftDxCIndex) {
|
|
|
|
accumulator +=
|
|
getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;
|
|
}
|
|
|
|
if (r == topDxRIndex && c == rightDxCIndex) {
|
|
|
|
accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;
|
|
}
|
|
|
|
if (r == bottomDxRIndex && c == leftDxCIndex) {
|
|
|
|
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;
|
|
}
|
|
|
|
if (r == bottomDxRIndex && c == rightDxCIndex) {
|
|
|
|
accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
setOutput(accumulator);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function resizeBilinearGrad$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images, dy } = inputs;
|
|
const { alignCorners } = attrs;
|
|
const program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
|
|
return backend.runWebGLProgram(program, [dy], dy.dtype);
|
|
}
|
|
const resizeBilinearGradConfig$2 = {
|
|
kernelName: ResizeBilinearGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeBilinearGrad$1
|
|
};
|
|
|
|
|
|
class ResizeNearestNeighborProgram {
|
|
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = [];
|
|
const [batch, oldHeight, oldWidth, depth] = inputShape;
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
const effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
const effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
|
|
const roundBase = alignCorners ? '0.5' : '0.0';
|
|
let sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC =
|
|
`max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC` +
|
|
`, vec2(0.0))`;
|
|
}
|
|
else {
|
|
sourceFracIndexRC = `vec2(yRC) * effectiveInputOverOutputRatioRC`;
|
|
}
|
|
this.userCode = `
|
|
const vec2 effectiveInputOverOutputRatioRC = vec2(
|
|
${effectiveInSize[0] / effectiveOutSize[0]},
|
|
${effectiveInSize[1] / effectiveOutSize[1]});
|
|
const vec2 inputShapeRC = vec2(${oldHeight}.0, ${oldWidth}.0);
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
ivec2 yRC = coords.yz;
|
|
|
|
|
|
vec2 sourceFracIndexRC = ${sourceFracIndexRC};
|
|
|
|
|
|
ivec2 sourceNearestRC = ivec2(
|
|
min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
|
|
float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);
|
|
|
|
setOutput(newValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ResizeNearestNeighborPackedProgram {
|
|
constructor(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = [];
|
|
const [batch, oldHeight, oldWidth, depth] = inputShape;
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
const effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
const effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
|
|
const roundBase = alignCorners ? '0.5' : '0.0';
|
|
let sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC = `max((vec3(yRC) + vec3(0.5)) * ` +
|
|
`effectiveInputOverOutputRatioRC, vec3(0.0))`;
|
|
}
|
|
else {
|
|
sourceFracIndexRC = `vec3(yRC) * effectiveInputOverOutputRatioRC`;
|
|
}
|
|
this.userCode = `
|
|
const vec3 effectiveInputOverOutputRatioRC = vec3(
|
|
${effectiveInSize[0] / effectiveOutSize[0]},
|
|
${effectiveInSize[1] / effectiveOutSize[1]},
|
|
${effectiveInSize[1] / effectiveOutSize[1]});
|
|
const vec3 inputShapeRC = vec3(${oldHeight}.0, ${oldWidth}.0,
|
|
${oldWidth}.0);
|
|
|
|
float getAValue(int b, int r, int c, int d) {
|
|
return getChannel(getA(b, r, c, d), vec2(c, d));
|
|
}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
|
|
ivec3 yRC = coords.yzz + ivec3(0, 0, 1);
|
|
|
|
|
|
vec3 sourceFracIndexRC = ${sourceFracIndexRC};
|
|
|
|
|
|
ivec3 sourceNearestRC = ivec3(
|
|
min(inputShapeRC - 1.0, floor(sourceFracIndexRC + ${roundBase})));
|
|
|
|
|
|
bool hasNextCol = d < ${depth - 1};
|
|
bool hasNextRow = coords.z < ${newWidth - 1};
|
|
|
|
vec4 newValue = vec4(
|
|
getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),
|
|
hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)
|
|
: 0.0,
|
|
hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)
|
|
: 0.0,
|
|
(hasNextRow && hasNextCol) ?
|
|
getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);
|
|
|
|
setOutput(newValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function resizeNearestNeighbor$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images } = inputs;
|
|
const { alignCorners, halfPixelCenters, size } = attrs;
|
|
const [newHeight, newWidth] = size;
|
|
const program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
|
|
new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
|
|
new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
|
|
return backend.runWebGLProgram(program, [images], images.dtype);
|
|
}
|
|
const resizeNearestNeighborConfig$1 = {
|
|
kernelName: ResizeNearestNeighbor,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeNearestNeighbor$1
|
|
};
|
|
|
|
|
|
class ResizeNearestNeigborBackpropProgram {
|
|
constructor(dyShape, inputShape, alignCorners) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = [];
|
|
this.outputShape = inputShape;
|
|
const [, xHeight, xWidth,] = inputShape;
|
|
const [, yHeight, yWidth] = dyShape;
|
|
|
|
|
|
|
|
const effectiveXSize = [
|
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
|
];
|
|
const effectiveYSize = [
|
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
|
];
|
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
const invHeightScale = 1 / heightScale;
|
|
const invWidthScale = 1 / widthScale;
|
|
|
|
|
|
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
|
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int b = coords[0];
|
|
int d = coords[3];
|
|
int r = coords[1];
|
|
int c = coords[2];
|
|
|
|
float accumulator = 0.0;
|
|
|
|
const float heightScale = float(${heightScale});
|
|
const float widthScale = float(${widthScale});
|
|
|
|
const float invHeightScale = float(${invHeightScale});
|
|
const float invWidthScale = float(${invWidthScale});
|
|
|
|
const int winHeight = int(${winHeight});
|
|
const int winWidth = int(${winWidth});
|
|
|
|
|
|
float startRLerp = floor(float(r) * invHeightScale);
|
|
int startDyR = int(floor(startRLerp - float(winHeight / 2)));
|
|
|
|
float startCLerp = floor(float(c) * invWidthScale);
|
|
int startDyC = int(floor(startCLerp - float(winWidth / 2)));
|
|
|
|
|
|
for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {
|
|
int dyR = dyROffset + startDyR;
|
|
|
|
|
|
if (dyR < 0 || dyR >= ${yHeight}) {
|
|
continue;
|
|
}
|
|
|
|
for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {
|
|
int dyC = dyCOffset + startDyC;
|
|
|
|
|
|
if (dyC < 0 || dyC >= ${yWidth}) {
|
|
continue;
|
|
}
|
|
|
|
float sourceFracRow =
|
|
float(${effectiveXSize[0]}) *
|
|
(float(dyR) / float(${effectiveYSize[0]}));
|
|
|
|
float sourceFracCol =
|
|
float(${effectiveXSize[1]}) *
|
|
(float(dyC) / float(${effectiveYSize[1]}));
|
|
|
|
int sourceNearestRow = int(min(
|
|
float(int(${xHeight}) - 1),
|
|
${alignCorners} ? float(round(sourceFracRow)) :
|
|
float(floor(sourceFracRow))));
|
|
|
|
int sourceNearestCol = int(min(
|
|
float(int(${xWidth}) - 1),
|
|
${alignCorners} ? float(round(sourceFracCol)) :
|
|
float(floor(sourceFracCol))));
|
|
|
|
if (r == sourceNearestRow && c == sourceNearestCol) {
|
|
accumulator += getDy(b, dyR, dyC, d);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
setOutput(accumulator);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function resizeNearestNeighborGrad$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images, dy } = inputs;
|
|
const { alignCorners } = attrs;
|
|
const program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
|
|
return backend.runWebGLProgram(program, [dy], dy.dtype);
|
|
}
|
|
const resizeNearestNeighborGradConfig$2 = {
|
|
kernelName: ResizeNearestNeighborGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeNearestNeighborGrad$1
|
|
};
|
|
|
|
|
|
class ReverseProgram {
|
|
constructor(xShape, axis) {
|
|
this.variableNames = ['x'];
|
|
const rank = xShape.length;
|
|
if (rank > 4) {
|
|
throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
|
|
}
|
|
this.outputShape = xShape;
|
|
if (rank === 1) {
|
|
this.userCode = `
|
|
void main() {
|
|
int coord = getOutputCoords();
|
|
setOutput(getX(${xShape[0]} - coord - 1));
|
|
}
|
|
`;
|
|
return;
|
|
}
|
|
const getInCoord = (i) => {
|
|
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
|
|
return `${xShape[i]} - coords[${i}] - 1`;
|
|
}
|
|
return `coords[${i}]`;
|
|
};
|
|
const inCoords = xShape.map((_, i) => getInCoord(i)).join(',');
|
|
const type = getCoordsDataType(rank);
|
|
this.userCode = `
|
|
void main() {
|
|
${type} coords = getOutputCoords();
|
|
setOutput(getX(${inCoords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ReversePackedProgram {
|
|
constructor(xShape, axis) {
|
|
this.variableNames = ['x'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
const rank = xShape.length;
|
|
if (rank > 4) {
|
|
throw new Error(`WebGL backend: Reverse of rank-${rank} tensor is not yet supported`);
|
|
}
|
|
this.outputShape = xShape;
|
|
const channels = getChannels('rc', rank);
|
|
const nextColumn = `${channels[rank - 1]} + 1 < ${this.outputShape[rank - 1]}`;
|
|
const nextRow = `${channels[rank - 2]} + 1 < ${this.outputShape[rank - 2]}`;
|
|
const type = getCoordsDataType(rank);
|
|
if (rank === 1) {
|
|
this.userCode = `
|
|
void main(){
|
|
int rc = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
|
result.r = getChannel(getX(${xShape[0]} - rc - 1),
|
|
${xShape[0]} - rc - 1);
|
|
if(${nextColumn}){
|
|
result.g = getChannel(getX(${xShape[0]} - (rc + 1) - 1),
|
|
${xShape[0]} - (rc + 1) - 1);
|
|
}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
else {
|
|
this.userCode = `
|
|
void main() {
|
|
${type} rc = getOutputCoords();
|
|
vec4 result = vec4(0.);
|
|
result.r = ${getR(channels.slice())};
|
|
if(${nextColumn}){
|
|
result.g = ${getG(channels.slice())};
|
|
}
|
|
if(${nextRow}) {
|
|
result.b = ${getB(channels.slice())};
|
|
if(${nextColumn}) {
|
|
result.a = ${getA(channels.slice())};
|
|
}
|
|
}
|
|
setOutput(result);
|
|
}
|
|
`;
|
|
}
|
|
function getR(channels) {
|
|
return getChannel(channels);
|
|
}
|
|
function getG(channels) {
|
|
channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
|
|
return getChannel(channels);
|
|
}
|
|
function getB(channels) {
|
|
channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
|
|
return getChannel(channels);
|
|
}
|
|
function getA(channels) {
|
|
channels[rank - 1] = '(' + channels[rank - 1] + ` + 1)`;
|
|
channels[rank - 2] = '(' + channels[rank - 2] + ` + 1)`;
|
|
return getChannel(channels);
|
|
}
|
|
function getChannel(channels) {
|
|
const inCoordsArray = xShape.map((_, i) => getInCoord(i, channels));
|
|
const inCoords = inCoordsArray.join(',');
|
|
const innerDims = inCoordsArray.slice(-2).join(',');
|
|
return `getChannel(getX(${inCoords}), vec2(${innerDims}))`;
|
|
}
|
|
function getInCoord(i, channels1) {
|
|
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
|
|
return `${xShape[i]} - ${channels1[i]} - 1`;
|
|
}
|
|
else {
|
|
return `${channels1[i]}`;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
function reverse$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { dims } = attrs;
|
|
const xRank = x.shape.length;
|
|
const $dims = parseAxisParam(dims, x.shape);
|
|
if (xRank === 0) {
|
|
return identity({ inputs: { x }, backend });
|
|
}
|
|
const program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new ReversePackedProgram(x.shape, $dims) :
|
|
new ReverseProgram(x.shape, $dims);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
const reverseConfig$1 = {
|
|
kernelName: Reverse,
|
|
backendName: 'webgl',
|
|
kernelFunc: reverse$1
|
|
};
|
|
|
|
|
|
class RotateProgram {
|
|
constructor(imageShape, fillValue) {
|
|
this.variableNames = ['Image'];
|
|
this.outputShape = [];
|
|
this.customUniforms = [{ name: 'params', type: 'vec4' }];
|
|
const imageHeight = imageShape[1];
|
|
const imageWidth = imageShape[2];
|
|
this.outputShape = imageShape;
|
|
let fillSnippet = '';
|
|
if (typeof fillValue === 'number') {
|
|
fillSnippet = `float outputValue = ${fillValue.toFixed(2)};`;
|
|
}
|
|
else {
|
|
fillSnippet = `
|
|
vec3 fill = vec3(${fillValue.join(',')});
|
|
float outputValue = fill[coords[3]];`;
|
|
}
|
|
this.userCode = `
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
int x = coords[2];
|
|
int y = coords[1];
|
|
float coordXFloat = (float(x) - params[0]) * params[3] -
|
|
(float(y) - params[1]) * params[2];
|
|
float coordYFloat = (float(x) - params[0]) * params[2] +
|
|
(float(y) - params[1]) * params[3];
|
|
int coordX = int(round(coordXFloat + params[0]));
|
|
int coordY = int(round(coordYFloat + params[1]));
|
|
${fillSnippet}
|
|
if(coordX >= 0 && coordX < ${imageWidth} && coordY >= 0 && coordY < ${imageHeight}) {
|
|
outputValue = getImage(coords[0], coordY, coordX, coords[3]);
|
|
}
|
|
setOutput(outputValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
const rotateWithOffsetConfig$1 = {
|
|
kernelName: RotateWithOffset,
|
|
backendName: 'webgl',
|
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
|
const { image } = inputs;
|
|
const { radians, fillValue, center } = attrs;
|
|
const webglBackend = backend;
|
|
const program = new RotateProgram(image.shape, fillValue);
|
|
const [centerX, centerY] = getImageCenter(center, image.shape[1], image.shape[2]);
|
|
const customValues = [[centerX, centerY, Math.sin(radians), Math.cos(radians)]];
|
|
const output = webglBackend.runWebGLProgram(program, [image], image.dtype, customValues);
|
|
return output;
|
|
}
|
|
};
|
|
|
|
|
|
const ROUND = `
|
|
|
|
|
|
float base = floor(x);
|
|
if ((x - base) < 0.5) {
|
|
return floor(x);
|
|
} else if ((x - base) > 0.5) {
|
|
return ceil(x);
|
|
} else {
|
|
if (mod(base, 2.0) == 0.0) {
|
|
return base;
|
|
} else {
|
|
return base + 1.0;
|
|
}
|
|
}
|
|
`;
|
|
const round$1 = unaryKernelFunc({ opSnippet: ROUND });
|
|
const roundConfig$1 = {
|
|
kernelName: Round,
|
|
backendName: 'webgl',
|
|
kernelFunc: round$1,
|
|
};
|
|
|
|
|
|
const RSQRT = `return inversesqrt(x);`;
|
|
const rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU });
|
|
const rsqrtConfig = {
|
|
kernelName: Rsqrt,
|
|
backendName: 'webgl',
|
|
kernelFunc: rsqrt
|
|
};
|
|
|
|
|
|
class ScatterProgram {
|
|
constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) {
|
|
this.variableNames = ['updates', 'indices', 'defaultValue'];
|
|
this.outputShape = shape;
|
|
const stridesType = getCoordsDataType(strides.length);
|
|
const dtype = getCoordsDataType(shape.length);
|
|
let indicesString = '';
|
|
if (indicesRank === 1) {
|
|
indicesString = 'i';
|
|
}
|
|
else if (indicesRank === 2) {
|
|
indicesString = 'i, j';
|
|
}
|
|
const indicesSnippet = `getIndices(${indicesString})`;
|
|
let updatesString = '';
|
|
if (updatesRank === 1) {
|
|
updatesString = 'i';
|
|
}
|
|
else if (updatesRank === 2) {
|
|
updatesString = 'i, coords[1]';
|
|
}
|
|
const updatesSnippet = `getUpdates(${updatesString})`;
|
|
let defaultValuesString = '';
|
|
if (defaultIsTensor) {
|
|
defaultValuesString = 'coords[0], coords[1]';
|
|
}
|
|
const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`;
|
|
const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
|
|
this.userCode = `
|
|
${stridesType} strides = ${stridesType}(${strides});
|
|
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
float sum = 0.0;
|
|
bool found = false;
|
|
for (int i = 0; i < ${updateSize}; i++) {
|
|
int flattenedIndex = 0;
|
|
for (int j = 0; j < ${sliceDim}; j++) {
|
|
int index = round(${indicesSnippet});
|
|
flattenedIndex += index * ${strideString};
|
|
}
|
|
if (flattenedIndex == coords[0]) {
|
|
sum += ${updatesSnippet};
|
|
found = true;
|
|
}
|
|
}
|
|
setOutput(mix(${defaultValueSnippet}, sum, float(found)));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
class ScatterPackedProgram {
|
|
constructor(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex = true, defaultIsTensor = false) {
|
|
this.variableNames = ['updates', 'indices', 'defaultValue'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = shape;
|
|
const stridesType = getCoordsDataType(strides.length);
|
|
const dtype = getCoordsDataType(shape.length);
|
|
let indicesString = '';
|
|
if (indicesRank === 1) {
|
|
indicesString = 'i';
|
|
}
|
|
else if (indicesRank === 2) {
|
|
indicesString = 'i, j';
|
|
}
|
|
const indicesSnippet = `getIndices(${indicesString})`;
|
|
let updatesString = '';
|
|
if (updatesRank === 1) {
|
|
updatesString = 'i';
|
|
}
|
|
else if (updatesRank === 2) {
|
|
updatesString = 'i, coords[1]';
|
|
}
|
|
const updatesSnippet = `getUpdates(${updatesString})`;
|
|
let defaultValuesString = '';
|
|
if (defaultIsTensor) {
|
|
defaultValuesString = 'coords[0], coords[1]';
|
|
}
|
|
const defaultValueSnippet = `getDefaultValue(${defaultValuesString})`;
|
|
const strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
|
|
const strideString2 = sliceDim > 1 ? 'strides[j + 1]' : 'strides';
|
|
this.userCode = `
|
|
${stridesType} strides = ${stridesType}(${strides});
|
|
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
vec4 sum = vec4(0.);
|
|
vec4 found = vec4(0.);
|
|
for (int i = 0; i < ${updateSize}; i+=2) {
|
|
ivec2 flattenedIndex = ivec2(0);
|
|
for (int j = 0; j < ${sliceDim}; j+=2) {
|
|
ivec4 index = round(${indicesSnippet});
|
|
flattenedIndex += index.xz * ${strideString};
|
|
if (j + 1 < ${sliceDim}) {
|
|
flattenedIndex += index.yw * ${strideString2};
|
|
}
|
|
}
|
|
if (flattenedIndex[0] == coords[0] || flattenedIndex[1] == coords[0] ||
|
|
flattenedIndex[0] == coords[0] + 1 || flattenedIndex[1] == coords[0] + 1) {
|
|
vec4 updVals = ${updatesSnippet};
|
|
if (flattenedIndex[0] == coords[0]) {
|
|
sum.xy += updVals.xy;
|
|
found.xy = vec2(1.);
|
|
} else if (flattenedIndex[0] == coords[0] + 1) {
|
|
sum.zw += updVals.xy;
|
|
found.zw = vec2(1.);
|
|
}
|
|
if (flattenedIndex[1] == coords[0]) {
|
|
sum.xy += updVals.zw;
|
|
found.xy = vec2(1.);
|
|
} else if (flattenedIndex[1] == coords[0] + 1) {
|
|
sum.zw += updVals.zw;
|
|
found.zw = vec2(1.);
|
|
}
|
|
}
|
|
}
|
|
setOutput(mix(${defaultValueSnippet}, sum, found));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function scatterNd$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { indices, updates } = inputs;
|
|
const { shape } = attrs;
|
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
|
|
const flattenShape = [outputSize / sliceSize, sliceSize];
|
|
if (outputSize === 0) {
|
|
return backend.makeTensorInfo(shape, indices.dtype);
|
|
}
|
|
const flattenIndices = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
|
|
const flattenX = reshape$1({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
|
|
const defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0]));
|
|
let program;
|
|
if (env().getBool('WEBGL_PACK')) {
|
|
program = new ScatterPackedProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
|
|
}
|
|
else {
|
|
program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
|
|
}
|
|
const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
|
|
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape } });
|
|
backend.disposeIntermediateTensorInfo(flattenIndices);
|
|
backend.disposeIntermediateTensorInfo(flattenX);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
backend.disposeIntermediateTensorInfo(defaultValue);
|
|
return reshaped;
|
|
}
|
|
const scatterNdConfig$1 = {
|
|
kernelName: ScatterNd,
|
|
backendName: 'webgl',
|
|
kernelFunc: scatterNd$1
|
|
};
|
|
|
|
|
|
class SearchSortedProgram {
|
|
constructor(batchSize, numInputs, numValues, side) {
|
|
this.variableNames = ['sortedSequence', 'values'];
|
|
this.customUniforms = [{ name: 'numInputs', type: 'int' }];
|
|
this.outputShape = [batchSize, numValues];
|
|
const webGL2LoopHead = 'while (left < right) {';
|
|
|
|
|
|
const webGL1LoopHead = `for (int i = 0; i < ${Math.ceil(Math.log2(numInputs + 1))}; ++i) { if (left >= right) break;`;
|
|
const loopHead = env().getNumber('WEBGL_VERSION') === 2 ? webGL2LoopHead :
|
|
webGL1LoopHead;
|
|
|
|
const boundComparator = side === 'left' ? '<' : '<=';
|
|
this.userCode = `
|
|
int findBound(int batch, float value) {
|
|
int left = 0;
|
|
int right = numInputs;
|
|
int mid;
|
|
${loopHead}
|
|
mid = (left + right) / 2;
|
|
if (getSortedSequence(batch, mid) ${boundComparator} value) {
|
|
left = mid + 1;
|
|
} else {
|
|
right = mid;
|
|
}
|
|
}
|
|
return right;
|
|
}
|
|
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int valueIndex = coords[1];
|
|
|
|
float value = getValues(batch, valueIndex);
|
|
|
|
setOutput(float(findBound(batch, value)));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function searchSorted$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { sortedSequence, values } = inputs;
|
|
const { side } = attrs;
|
|
const program = new SearchSortedProgram(sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
|
|
const customValues = [[sortedSequence.shape[1]]];
|
|
return backend.runWebGLProgram(program, [sortedSequence, values], 'int32', customValues);
|
|
}
|
|
const searchSortedConfig$1 = {
|
|
kernelName: SearchSorted,
|
|
backendName: 'webgl',
|
|
kernelFunc: searchSorted$1,
|
|
};
|
|
|
|
|
|
class SelectProgram {
|
|
constructor(cRank, shape, rank) {
|
|
this.variableNames = ['c', 'a', 'b'];
|
|
this.outputShape = shape;
|
|
let cCoords;
|
|
let abCoords;
|
|
if (rank > 4) {
|
|
throw Error(`Where for rank ${rank} is not yet supported`);
|
|
}
|
|
if (rank === 1) {
|
|
abCoords = `resRC`;
|
|
cCoords = `resRC`;
|
|
}
|
|
else {
|
|
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
|
|
const cCoordVars = [];
|
|
const abCoordVars = [];
|
|
for (let i = 0; i < shape.length; i++) {
|
|
abCoordVars.push(`${currentCoords[i]}`);
|
|
if (i < cRank) {
|
|
cCoordVars.push(`${currentCoords[i]}`);
|
|
}
|
|
}
|
|
cCoords = cCoordVars.join();
|
|
abCoords = abCoordVars.join();
|
|
}
|
|
const dtype = getCoordsDataType(rank);
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} resRC = getOutputCoords();
|
|
float cVal = getC(${cCoords});
|
|
if (cVal >= 1.0) {
|
|
setOutput(getA(${abCoords}));
|
|
} else {
|
|
setOutput(getB(${abCoords}));
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function select$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { condition, t, e } = inputs;
|
|
const program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
|
|
return backend.runWebGLProgram(program, [condition, t, e], upcastType(t.dtype, e.dtype));
|
|
}
|
|
const selectConfig$1 = {
|
|
kernelName: Select,
|
|
backendName: 'webgl',
|
|
kernelFunc: select$1
|
|
};
|
|
|
|
|
|
const SELU = `
|
|
|
|
|
|
float scaleAlpha = ${SELU_SCALEALPHA};
|
|
float scale = ${SELU_SCALE};
|
|
return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);
|
|
`;
|
|
const selu$1 = unaryKernelFunc({ opSnippet: SELU });
|
|
const seluConfig$1 = {
|
|
kernelName: Selu$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: selu$1,
|
|
};
|
|
|
|
|
|
const SIGMOID = CHECK_NAN_SNIPPET_UNARY + `
|
|
return 1.0 / (1.0 + exp(-1.0 * x));
|
|
`;
|
|
const SIGMOID_PACKED = `
|
|
vec4 result = 1.0 / (1.0 + exp(-1.0 * x));
|
|
bvec4 isNaN = isnan(x);
|
|
|
|
result.r = isNaN.r ? x.r : result.r;
|
|
result.g = isNaN.g ? x.g : result.g;
|
|
result.b = isNaN.b ? x.b : result.b;
|
|
result.a = isNaN.a ? x.a : result.a;
|
|
|
|
return result;
|
|
`;
|
|
const sigmoid = unaryKernelFunc({
|
|
opSnippet: SIGMOID,
|
|
packedOpSnippet: SIGMOID_PACKED,
|
|
cpuKernelImpl: sigmoidImplCPU
|
|
});
|
|
const sigmoidConfig = {
|
|
kernelName: Sigmoid$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: sigmoid,
|
|
};
|
|
|
|
|
|
|
|
const SIGN = `
|
|
if (isnan(x)) { return 0.0; }
|
|
return sign(x);
|
|
`;
|
|
const sign$1 = unaryKernelFunc({ opSnippet: SIGN });
|
|
const signConfig$1 = {
|
|
kernelName: Sign,
|
|
backendName: 'webgl',
|
|
kernelFunc: sign$1,
|
|
};
|
|
|
|
|
|
const SIN = CHECK_NAN_SNIPPET_UNARY + `
|
|
return sin(x);
|
|
`;
|
|
const SIN_PACKED = `
|
|
vec4 result = sin(x);
|
|
bvec4 isNaN = isnan(x);
|
|
${CHECK_NAN_SNIPPET_PACKED}
|
|
return result;
|
|
`;
|
|
const sin$1 = unaryKernelFunc({ opSnippet: SIN, packedOpSnippet: SIN_PACKED });
|
|
const sinConfig$1 = {
|
|
kernelName: Sin,
|
|
backendName: 'webgl',
|
|
kernelFunc: sin$1,
|
|
};
|
|
|
|
|
|
const SINH = `
|
|
float e2x = exp(x);
|
|
return (e2x - 1.0 / e2x) / 2.0;
|
|
`;
|
|
const sinh$1 = unaryKernelFunc({ opSnippet: SINH });
|
|
const sinhConfig$1 = {
|
|
kernelName: Sinh,
|
|
backendName: 'webgl',
|
|
kernelFunc: sinh$1,
|
|
};
|
|
|
|
|
|
const SOFTPLUS = `
|
|
float epsilon = 1.1920928955078125e-7;
|
|
float threshold = log(epsilon) + 2.0;
|
|
|
|
bool too_large = x > -threshold;
|
|
bool too_small = x < threshold;
|
|
|
|
float result;
|
|
float exp_x = exp(x);
|
|
|
|
if (too_large){
|
|
result = x;
|
|
}
|
|
else if (too_small){
|
|
result = exp_x;
|
|
}
|
|
else{
|
|
result = log(exp_x + 1.0);
|
|
}
|
|
return result;
|
|
`;
|
|
const softplus$1 = unaryKernelFunc({ opSnippet: SOFTPLUS });
|
|
const softplusConfig$1 = {
|
|
kernelName: Softplus$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: softplus$1,
|
|
};
|
|
|
|
|
|
const spaceToBatchND$1 = (args) => {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { blockShape, paddings } = attrs;
|
|
assert$1(x.shape.length <= 4, () => 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
|
|
'implemented yet');
|
|
const prod = blockShape.reduce((a, b) => a * b);
|
|
const completePaddings = [[0, 0]];
|
|
completePaddings.push(...paddings);
|
|
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
|
|
completePaddings.push([0, 0]);
|
|
}
|
|
const toDispose = [];
|
|
const paddedX = padV2$1({
|
|
inputs: { x },
|
|
backend,
|
|
attrs: { paddings: completePaddings, constantValue: 0 }
|
|
});
|
|
const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
|
|
const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
|
|
const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
|
|
const reshapedPaddedX = reshape$1({ inputs: { x: paddedX }, backend, attrs: { shape: reshapedPaddedShape } });
|
|
const paddedXT = transpose({
|
|
inputs: { x: reshapedPaddedX },
|
|
backend,
|
|
attrs: { perm: permutedReshapedPaddedPermutation }
|
|
});
|
|
const result = reshape$1({ inputs: { x: paddedXT }, backend, attrs: { shape: flattenShape } });
|
|
toDispose.push(paddedX);
|
|
toDispose.push(reshapedPaddedX);
|
|
toDispose.push(paddedXT);
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
};
|
|
const spaceToBatchNDConfig$1 = {
|
|
kernelName: SpaceToBatchND,
|
|
backendName: 'webgl',
|
|
kernelFunc: spaceToBatchND$1
|
|
};
|
|
|
|
|
|
function sparseFillEmptyRows$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { indices, values, denseShape, defaultValue } = inputs;
|
|
if (denseShape.shape.length !== 1) {
|
|
throw new Error(`Dense shape must be a vector, saw:
|
|
${denseShape.shape}`);
|
|
}
|
|
if (indices.shape.length !== 2) {
|
|
throw new Error(`Indices must be a matrix, saw:
|
|
${indices.shape}`);
|
|
}
|
|
if (values.shape.length !== 1) {
|
|
throw new Error(`Values must be a vector, saw:
|
|
${values.shape}`);
|
|
}
|
|
if (defaultValue.shape.length !== 0) {
|
|
throw new Error(`Default value must be a scalar, saw:
|
|
${defaultValue.shape}`);
|
|
}
|
|
const $indices = backend.readSync(indices.dataId);
|
|
const $values = backend.readSync(values.dataId);
|
|
const $denseShape = backend.readSync(denseShape.dataId);
|
|
const $defaultValue = backend.readSync(defaultValue.dataId)[0];
|
|
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
|
|
return [
|
|
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
|
|
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
|
|
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
|
|
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
|
|
];
|
|
}
|
|
const sparseFillEmptyRowsConfig$1 = {
|
|
kernelName: SparseFillEmptyRows,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseFillEmptyRows$1,
|
|
};
|
|
|
|
|
|
function sparseReshape$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { inputIndices, inputShape, newShape } = inputs;
|
|
if (inputIndices.shape.length !== 2) {
|
|
throw new Error(`Input indices should be a matrix but received shape ${inputIndices.shape}`);
|
|
}
|
|
if (inputShape.shape.length !== 1) {
|
|
throw new Error(`Input shape should be a vector but received shape ${inputShape.shape}`);
|
|
}
|
|
if (newShape.shape.length !== 1) {
|
|
throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
|
|
}
|
|
const $inputShape = Array.from(backend.readSync(inputShape.dataId));
|
|
const $inputIndices = backend.readSync(inputIndices.dataId);
|
|
const targetShape = Array.from(backend.readSync(newShape.dataId));
|
|
const [newIndices, indicesShape, outputShape] = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
|
|
return [
|
|
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
|
|
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
|
|
];
|
|
}
|
|
const sparseReshapeConfig$1 = {
|
|
kernelName: SparseReshape,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseReshape$1,
|
|
};
|
|
|
|
|
|
function sparseSegmentMean$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { data, indices, segmentIds } = inputs;
|
|
if (data.shape.length < 1) {
|
|
throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
|
}
|
|
if (indices.shape.length !== 1) {
|
|
throw new Error(`Indices should be a vector but received shape
|
|
${indices.shape}`);
|
|
}
|
|
if (segmentIds.shape.length !== 1) {
|
|
throw new Error(`Segment ids should be a vector but received shape
|
|
${segmentIds.shape}`);
|
|
}
|
|
const $data = backend.readSync(data.dataId);
|
|
const $indices = backend.readSync(indices.dataId);
|
|
const $segmentIds = backend.readSync(segmentIds.dataId);
|
|
const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true);
|
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
|
}
|
|
const sparseSegmentMeanConfig$1 = {
|
|
kernelName: SparseSegmentMean,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseSegmentMean$1,
|
|
};
|
|
|
|
|
|
function sparseSegmentSum$1(args) {
|
|
const { inputs, backend } = args;
|
|
const { data, indices, segmentIds } = inputs;
|
|
if (data.shape.length < 1) {
|
|
throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
|
}
|
|
if (indices.shape.length !== 1) {
|
|
throw new Error(`Indices should be a vector but received shape
|
|
${indices.shape}`);
|
|
}
|
|
if (segmentIds.shape.length !== 1) {
|
|
throw new Error(`Segment ids should be a vector but received shape
|
|
${segmentIds.shape}`);
|
|
}
|
|
const $data = backend.readSync(data.dataId);
|
|
const $indices = backend.readSync(indices.dataId);
|
|
const $segmentIds = backend.readSync(segmentIds.dataId);
|
|
const [outputData, outputDataShape] = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds);
|
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
|
}
|
|
const sparseSegmentSumConfig$1 = {
|
|
kernelName: SparseSegmentSum,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseSegmentSum$1,
|
|
};
|
|
|
|
|
|
function sparseToDense$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { sparseIndices, sparseValues, defaultValue } = inputs;
|
|
const { outputShape } = attrs;
|
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
|
|
const sumDupeIndices = false;
|
|
if (sparseValues.dtype === 'string') {
|
|
const indicesBuf = backend.bufferSync(sparseIndices);
|
|
const updatesBuf = backend.bufferSync(sparseValues);
|
|
const $defaultValue = decodeString(backend.readSync(defaultValue.dataId)[0]);
|
|
const outBuf = scatterImplCPU(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
|
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
|
|
const res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
|
|
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: outputShape } });
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return reshaped;
|
|
}
|
|
const sparseToDenseConfig$1 = {
|
|
kernelName: SparseToDense,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseToDense$1
|
|
};
|
|
|
|
|
|
function splitV$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { numOrSizeSplits, axis } = attrs;
|
|
const $axis = parseAxisParam(axis, x.shape)[0];
|
|
const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
|
|
const xRank = x.shape.length;
|
|
const begin = new Array(xRank).fill(0);
|
|
const size = x.shape.slice();
|
|
return splitSizes.map(s => {
|
|
const sliceSize = [...size];
|
|
sliceSize[$axis] = s;
|
|
const sliceT = slice({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
|
|
begin[$axis] += s;
|
|
return sliceT;
|
|
});
|
|
}
|
|
const splitVConfig$1 = {
|
|
kernelName: SplitV,
|
|
backendName: 'webgl',
|
|
kernelFunc: splitV$1
|
|
};
|
|
|
|
|
|
const SQRT = `return sqrt(x);`;
|
|
const sqrt = unaryKernelFunc({ opSnippet: SQRT, packedOpSnippet: SQRT, cpuKernelImpl: sqrtImplCPU });
|
|
const sqrtConfig = {
|
|
kernelName: Sqrt,
|
|
backendName: 'webgl',
|
|
kernelFunc: sqrt
|
|
};
|
|
|
|
|
|
const SQUARE = `return x * x;`;
|
|
const square$1 = unaryKernelFunc({ opSnippet: SQUARE });
|
|
const squareConfig$1 = {
|
|
kernelName: Square,
|
|
backendName: 'webgl',
|
|
kernelFunc: square$1,
|
|
};
|
|
|
|
|
|
const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
|
|
const squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE });
|
|
const squaredDifferenceConfig = {
|
|
kernelName: SquaredDifference,
|
|
backendName: 'webgl',
|
|
kernelFunc: squaredDifference,
|
|
};
|
|
|
|
|
|
function staticRegexReplace(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
if (x.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
const $x = backend.readSync(x.dataId);
|
|
const stringInput = fromUint8ToStringArray($x);
|
|
const output = staticRegexReplaceImplCPU(stringInput, 'string', attrs);
|
|
return backend.makeTensorInfo(x.shape, 'string', output);
|
|
}
|
|
const staticRegexReplaceConfig = {
|
|
kernelName: StaticRegexReplace,
|
|
backendName: 'webgl',
|
|
kernelFunc: staticRegexReplace,
|
|
};
|
|
|
|
|
|
function step$1({ inputs, attrs, backend }) {
|
|
const { x } = inputs;
|
|
const opSnippet = CHECK_NAN_SNIPPET$1 + `
|
|
return x > 0.0 ? 1.0 : float(${attrs.alpha});
|
|
`;
|
|
const program = new UnaryOpProgram(x.shape, opSnippet);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
const stepConfig$1 = {
|
|
kernelName: Step,
|
|
backendName: 'webgl',
|
|
kernelFunc: step$1,
|
|
};
|
|
|
|
|
|
class StridedSliceProgram {
|
|
constructor(begin, strides, size) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = size;
|
|
const rank = size.length;
|
|
const inputDtype = getCoordsDataType(size.length);
|
|
const dtype = getCoordsDataType(size.length);
|
|
let newCoords = '';
|
|
if (rank === 1) {
|
|
newCoords = 'coords * strides + begin';
|
|
}
|
|
else {
|
|
let outputAxis = 0;
|
|
newCoords =
|
|
size.map((_, i) => {
|
|
outputAxis++;
|
|
return size.length === 1 ?
|
|
`coords * strides[${i}] + begin[${i}]` :
|
|
`coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
|
|
})
|
|
.join(',');
|
|
}
|
|
this.userCode = `
|
|
${inputDtype} begin = ${inputDtype}(${begin});
|
|
${inputDtype} strides = ${inputDtype}(${strides});
|
|
|
|
void main() {
|
|
${dtype} coords = getOutputCoords();
|
|
setOutput(getX(${newCoords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function stridedSlice$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
|
|
const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
|
|
let result;
|
|
if (isIdentity) {
|
|
|
|
result = reshape$1({ inputs: { x }, backend, attrs: { shape: finalShape } });
|
|
}
|
|
else if (sliceDim0 || isSimpleSlice) {
|
|
|
|
assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
|
|
const size = computeOutShape$2($begin, $end, $strides);
|
|
|
|
const sliced = slice({ inputs: { x }, backend, attrs: { begin: $begin, size } });
|
|
result =
|
|
reshape$1({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
|
|
backend.disposeIntermediateTensorInfo(sliced);
|
|
}
|
|
else {
|
|
const shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
|
|
if (shouldExecuteOnCPU) {
|
|
|
|
const values = backend.readSync(x.dataId);
|
|
|
|
const xBuf = buffer(x.shape, x.dtype, values);
|
|
const resultValues = stridedSliceImplCPU(finalShapeSparse, xBuf, $strides, $begin);
|
|
result = backend.makeTensorInfo(finalShape, x.dtype, resultValues.values);
|
|
}
|
|
else {
|
|
const program = new StridedSliceProgram($begin, $strides, finalShapeSparse);
|
|
result = backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
}
|
|
const resultReshaped = reshape$1({ inputs: { x: result }, backend, attrs: { shape: finalShape } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return resultReshaped;
|
|
}
|
|
const stridedSliceConfig$1 = {
|
|
kernelName: StridedSlice,
|
|
backendName: 'webgl',
|
|
kernelFunc: stridedSlice$1
|
|
};
|
|
|
|
|
|
function stringNGrams$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
|
|
const { data, dataSplits } = inputs;
|
|
const $data = backend.readSync(data.dataId);
|
|
const $dataSplits = backend.readSync(dataSplits.dataId);
|
|
const [nGrams, nGramsSplits] = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
|
|
return [
|
|
backend.makeTensorInfo([nGrams.length], 'string', nGrams),
|
|
backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
|
|
];
|
|
}
|
|
const stringNGramsConfig$1 = {
|
|
kernelName: StringNGrams,
|
|
backendName: 'webgl',
|
|
kernelFunc: stringNGrams$1,
|
|
};
|
|
|
|
|
|
function stringSplit$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { skipEmpty } = attrs;
|
|
const { input, delimiter } = inputs;
|
|
if (input.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
if (input.shape.length !== 1) {
|
|
throw new Error(`Input must be a vector, got shape: ${input.shape}`);
|
|
}
|
|
if (delimiter.shape.length !== 0) {
|
|
throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
|
|
}
|
|
const $input = backend.readSync(input.dataId);
|
|
const $delimiter = backend.readSync(delimiter.dataId)[0];
|
|
const [indices, values, shape] = stringSplitImplCPU($input, $delimiter, skipEmpty);
|
|
const outputSize = values.length;
|
|
return [
|
|
backend.makeTensorInfo([outputSize, 2], 'int32', indices),
|
|
backend.makeTensorInfo([outputSize], 'string', values),
|
|
backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
|
|
];
|
|
}
|
|
const stringSplitConfig$1 = {
|
|
kernelName: StringSplit,
|
|
backendName: 'webgl',
|
|
kernelFunc: stringSplit$1,
|
|
};
|
|
|
|
|
|
function stringToHashBucketFast$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { numBuckets } = attrs;
|
|
const { input } = inputs;
|
|
if (input.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
if (numBuckets <= 0) {
|
|
throw new Error(`Number of buckets must be at least 1`);
|
|
}
|
|
const $input = backend.readSync(input.dataId);
|
|
const output = stringToHashBucketFastImplCPU($input, numBuckets);
|
|
return backend.makeTensorInfo(input.shape, 'int32', output);
|
|
}
|
|
const stringToHashBucketFastConfig$1 = {
|
|
kernelName: StringToHashBucketFast,
|
|
backendName: 'webgl',
|
|
kernelFunc: stringToHashBucketFast$1,
|
|
};
|
|
|
|
|
|
const TAN = `return tan(x);`;
|
|
const tan$1 = unaryKernelFunc({ opSnippet: TAN });
|
|
const tanConfig$1 = {
|
|
kernelName: Tan,
|
|
backendName: 'webgl',
|
|
kernelFunc: tan$1,
|
|
};
|
|
|
|
|
|
const TANH = `
|
|
float e2x = exp(-2.0 * abs(x));
|
|
return sign(x) * (1.0 - e2x) / (1.0 + e2x);
|
|
`;
|
|
const tanh$1 = unaryKernelFunc({ opSnippet: TANH });
|
|
const tanhConfig$1 = {
|
|
kernelName: Tanh$1,
|
|
backendName: 'webgl',
|
|
kernelFunc: tanh$1,
|
|
};
|
|
|
|
|
|
function tensorScatterUpdate$1(args) {
|
|
const { inputs, backend} = args;
|
|
const { tensor, indices, updates } = inputs;
|
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape);
|
|
const flattenShape = [outputSize / sliceSize, sliceSize];
|
|
if (outputSize === 0) {
|
|
return backend.makeTensorInfo(tensor.shape, indices.dtype);
|
|
}
|
|
const flattenIndices = reshape$1({ inputs: { x: indices }, backend, attrs: { shape: [numUpdates, sliceRank] } });
|
|
const flattenX = reshape$1({ inputs: { x: updates }, backend, attrs: { shape: [numUpdates, sliceSize] } });
|
|
const flattenTensor = reshape$1({ inputs: { x: tensor }, backend, attrs: { shape: flattenShape } });
|
|
const program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape, false, true);
|
|
const res = backend.runWebGLProgram(program, [flattenX, flattenIndices, flattenTensor], flattenTensor.dtype);
|
|
const reshaped = reshape$1({ inputs: { x: res }, backend, attrs: { shape: tensor.shape } });
|
|
backend.disposeIntermediateTensorInfo(flattenIndices);
|
|
backend.disposeIntermediateTensorInfo(flattenX);
|
|
backend.disposeIntermediateTensorInfo(flattenTensor);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return reshaped;
|
|
}
|
|
const tensorScatterUpdateConfig$1 = {
|
|
kernelName: TensorScatterUpdate,
|
|
backendName: 'webgl',
|
|
kernelFunc: tensorScatterUpdate$1
|
|
};
|
|
|
|
|
|
class TileProgram {
|
|
constructor(aShape, reps) {
|
|
this.variableNames = ['A'];
|
|
const outputShape = new Array(aShape.length);
|
|
for (let i = 0; i < outputShape.length; i++) {
|
|
outputShape[i] = aShape[i] * reps[i];
|
|
}
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
const dtype = getCoordsDataType(this.rank);
|
|
const sourceCoords = getSourceCoords(aShape);
|
|
this.userCode = `
|
|
void main() {
|
|
${dtype} resRC = getOutputCoords();
|
|
setOutput(getA(${sourceCoords}));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
function getSourceCoords(aShape) {
|
|
const rank = aShape.length;
|
|
if (rank > 5) {
|
|
throw Error(`Tile for rank ${rank} is not yet supported`);
|
|
}
|
|
if (rank === 1) {
|
|
return `imod(resRC, ${aShape[0]})`;
|
|
}
|
|
const currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
|
|
const sourceCoords = [];
|
|
for (let i = 0; i < aShape.length; i++) {
|
|
sourceCoords.push(`imod(${currentCoords[i]}, ${aShape[i]})`);
|
|
}
|
|
return sourceCoords.join();
|
|
}
|
|
|
|
|
|
function tile$2(params) {
|
|
const { inputs, backend, attrs } = params;
|
|
const { x } = inputs;
|
|
const { reps } = attrs;
|
|
|
|
if (x.dtype === 'string' || x.shape.length > 5) {
|
|
|
|
|
|
const data = backend.readSync(x.dataId);
|
|
const value = x.dtype === 'string' ?
|
|
data.map(d => decodeString(d)) :
|
|
data;
|
|
const buf = buffer(x.shape, x.dtype, value);
|
|
const outBuf = tileImplCPU(buf, reps);
|
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const program = new TileProgram(x.shape, reps);
|
|
const output = backend.runWebGLProgram(program, [x], x.dtype);
|
|
return output;
|
|
}
|
|
const tileConfig$1 = {
|
|
kernelName: Tile,
|
|
backendName: 'webgl',
|
|
kernelFunc: tile$2,
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwapProgram {
|
|
|
|
constructor(shape) {
|
|
this.variableNames = ['x', 'indices'];
|
|
|
|
|
|
|
|
|
|
this.customUniforms = [
|
|
{ name: 'n', type: 'int' },
|
|
{ name: 'firstPass', type: 'int' },
|
|
{ name: 'negativeInf', type: 'float' },
|
|
{ name: 'dir', type: 'int' },
|
|
{ name: 'inc', type: 'int' }
|
|
];
|
|
this.outputShape = shape;
|
|
this.userCode = `
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int elemIdx = coords[1];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool isFirstInPair = imod(elemIdx, 2 * inc) < inc;
|
|
int i = isFirstInPair ? elemIdx : elemIdx - inc;
|
|
|
|
int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
|
|
int i1 = firstPass == 1 ? i + inc : int(getIndices(batch, i + inc));
|
|
float x0 = i0 < n ? getX(batch, i0) : negativeInf;
|
|
float x1 = i1 < n ? getX(batch, i1) : negativeInf;
|
|
|
|
|
|
bool reverse = imod(elemIdx, 2 * dir) >= dir;
|
|
bool isGreater = x0 > x1 || (x0 == x1 && i1 > i0);
|
|
if (reverse == isGreater) {
|
|
int iTemp = i0;
|
|
i0 = i1;
|
|
i1 = iTemp;
|
|
}
|
|
if (isFirstInPair) {
|
|
setOutput(float(i0));
|
|
} else {
|
|
setOutput(float(i1));
|
|
}
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
class MergeProgram {
|
|
|
|
constructor(shape) {
|
|
this.variableNames = ['x', 'indices'];
|
|
|
|
|
|
|
|
|
|
this.customUniforms = [
|
|
{ name: 'n', type: 'int' },
|
|
{ name: 'firstPass', type: 'int' },
|
|
{ name: 'k', type: 'int' }
|
|
];
|
|
this.outputShape = shape;
|
|
this.userCode = `
|
|
void main() {
|
|
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int elemIdx = coords[1];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int i = elemIdx < k ? elemIdx : (elemIdx * 2 - imod(elemIdx, k));
|
|
int i0 = firstPass == 1 ? i : int(getIndices(batch, i));
|
|
int i1 = firstPass == 1 ? i + k : int(getIndices(batch, i + k));
|
|
|
|
float x0 = getX(batch, i0);
|
|
float x1 = i1 < n ? getX(batch, i1) : x0;
|
|
|
|
setOutput(x0 >= x1 ? float(i0) : float(i1));
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function disposeIntermediateTensorInfoOrNull(backend, tensorInfo) {
|
|
if (tensorInfo !== null) {
|
|
backend.disposeIntermediateTensorInfo(tensorInfo);
|
|
}
|
|
}
|
|
function roundUpToPow2(num) {
|
|
let pow2 = 1;
|
|
while (pow2 < num) {
|
|
pow2 *= 2;
|
|
}
|
|
return pow2;
|
|
}
|
|
|
|
|
|
function topK$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { k, sorted } = attrs;
|
|
|
|
|
|
const TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD = env().getNumber('TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD');
|
|
|
|
|
|
const TOPK_K_CPU_HANDOFF_THRESHOLD = env().getNumber('TOPK_K_CPU_HANDOFF_THRESHOLD');
|
|
const xShape = x.shape;
|
|
const lastDim = xShape[xShape.length - 1];
|
|
if (backend.shouldExecuteOnCPU([x]) ||
|
|
lastDim < TOPK_LAST_DIM_CPU_HANDOFF_SIZE_THRESHOLD ||
|
|
k > TOPK_K_CPU_HANDOFF_THRESHOLD) {
|
|
const xVals = backend.readSync(x.dataId);
|
|
const [allTopKVals, allTopKIndices] = topKImplCPU(xVals, xShape, x.dtype, k, sorted);
|
|
return [
|
|
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
|
|
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
|
|
];
|
|
}
|
|
if (k === 0) {
|
|
xShape[xShape.length - 1] = 0;
|
|
return [
|
|
backend.makeTensorInfo(xShape, x.dtype, []),
|
|
backend.makeTensorInfo(xShape, 'int32', [])
|
|
];
|
|
}
|
|
if (lastDim === 1 ) {
|
|
return [
|
|
x, fill$1({ attrs: { shape: xShape, dtype: 'int32', value: 0 }, backend })
|
|
];
|
|
}
|
|
|
|
|
|
const xtexData = backend.texData.get(x.dataId);
|
|
const xIsPacked = xtexData !== null && xtexData.isPacked;
|
|
const xUnPacked = xIsPacked ? backend.unpackTensor(x) : x;
|
|
|
|
const xSize = sizeFromShape(xShape);
|
|
const batch = xSize / lastDim;
|
|
const x2D = reshape$1({ inputs: { x: xUnPacked }, attrs: { shape: [batch, lastDim] }, backend });
|
|
if (xIsPacked) {
|
|
disposeIntermediateTensorInfoOrNull(backend, xUnPacked);
|
|
}
|
|
const kPow2 = roundUpToPow2(k);
|
|
const lastDimPow2 = roundUpToPow2(lastDim);
|
|
|
|
|
|
|
|
|
|
let indices = null;
|
|
|
|
|
|
|
|
const getInputs = () => indices === null ? [x2D, x2D] : [x2D, indices];
|
|
const runSwap = (dir, inc, shape) => {
|
|
const inputs = getInputs();
|
|
const program = new SwapProgram(shape);
|
|
const fistPass = indices === null ? 1 : 0;
|
|
const customValues = [[lastDim], [fistPass], [Number.NEGATIVE_INFINITY], [dir], [inc]];
|
|
const prevIndices = indices;
|
|
indices = backend.runWebGLProgram(program, inputs, 'int32', customValues);
|
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
|
};
|
|
|
|
for (let len = 1; len < kPow2; len *= 2) {
|
|
const dir = len * 2;
|
|
for (let inc = len; inc >= 1; inc /= 2) {
|
|
runSwap(dir, inc, [batch, lastDimPow2]);
|
|
}
|
|
}
|
|
|
|
for (let indicesSize = lastDimPow2; indicesSize > kPow2; indicesSize /= 2) {
|
|
const inputs = getInputs();
|
|
const mergeProgram = new MergeProgram([batch, indicesSize / 2]);
|
|
const firstPass = indices === null ? 1 : 0;
|
|
const customValues = [[lastDim], [firstPass], [kPow2]];
|
|
const prevIndices = indices;
|
|
indices =
|
|
backend.runWebGLProgram(mergeProgram, inputs, 'int32', customValues);
|
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
|
|
|
const len = kPow2 / 2;
|
|
const dir = len * 2;
|
|
for (let inc = len; inc >= 1; inc /= 2) {
|
|
runSwap(dir, inc, indices.shape);
|
|
}
|
|
}
|
|
|
|
let prevIndices = indices;
|
|
indices = slice({ inputs: { x: indices }, backend, attrs: { begin: 0, size: [batch, k] } });
|
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
|
|
|
let values = gatherV2$1({ inputs: { x: x2D, indices }, backend, attrs: { axis: 1, batchDims: 1 } });
|
|
disposeIntermediateTensorInfoOrNull(backend, x2D);
|
|
|
|
|
|
const newShape = xShape.slice(0, -1);
|
|
newShape.push(k);
|
|
prevIndices = indices;
|
|
indices = reshape$1({ inputs: { x: indices }, attrs: { shape: newShape }, backend });
|
|
disposeIntermediateTensorInfoOrNull(backend, prevIndices);
|
|
const prevValues = values;
|
|
values = reshape$1({ inputs: { x: values }, attrs: { shape: newShape }, backend });
|
|
disposeIntermediateTensorInfoOrNull(backend, prevValues);
|
|
return [values, indices];
|
|
}
|
|
const topKConfig$1 = {
|
|
kernelName: TopK,
|
|
backendName: 'webgl',
|
|
kernelFunc: topK$1
|
|
};
|
|
|
|
|
|
class TransformProgram {
|
|
constructor(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
|
|
this.variableNames = ['Image', 'Transforms'];
|
|
this.outputShape = outShape;
|
|
const interpolationModeId = interpolation === 'nearest' ? 1 : 2;
|
|
let fillModeId;
|
|
switch (fillMode) {
|
|
case 'constant':
|
|
fillModeId = 1;
|
|
break;
|
|
case 'reflect':
|
|
fillModeId = 2;
|
|
break;
|
|
case 'wrap':
|
|
fillModeId = 3;
|
|
break;
|
|
case 'nearest':
|
|
fillModeId = 4;
|
|
break;
|
|
default:
|
|
fillModeId = 1;
|
|
break;
|
|
}
|
|
this.userCode = `
|
|
float mapCoord(float outCoord, float len) {
|
|
float inCoord = outCoord;
|
|
if(${fillModeId} == 2) {
|
|
if (inCoord < 0.0) {
|
|
if (len <= 1.0) {
|
|
inCoord = 0.0;
|
|
} else {
|
|
float sz2 = 2.0 * len;
|
|
if (inCoord < sz2) {
|
|
inCoord = sz2 * float(int(float(-inCoord / sz2))) +
|
|
inCoord;
|
|
}
|
|
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;
|
|
}
|
|
} else if (inCoord > len - 1.0) {
|
|
if (len <= 1.0) {
|
|
inCoord = 0.0;
|
|
} else {
|
|
float sz2 = 2.0 * len;
|
|
inCoord -= sz2 * float(int(float(inCoord / sz2)));
|
|
if (inCoord >= len) {
|
|
inCoord = sz2 - inCoord - 1.0;
|
|
}
|
|
}
|
|
}
|
|
return clamp(inCoord, 0.0, len - 1.0);
|
|
} else if (${fillModeId} == 3) {
|
|
if (inCoord < 0.0) {
|
|
if (len <= 1.0) {
|
|
inCoord = 0.0;
|
|
} else {
|
|
float sz = len - 1.0;
|
|
inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);
|
|
}
|
|
} else if (inCoord > len - 1.0) {
|
|
if (len <= 1.0) {
|
|
inCoord = 0.0;
|
|
} else {
|
|
float sz = len - 1.0;
|
|
inCoord -= len * float(int(float(inCoord / sz)));
|
|
}
|
|
}
|
|
return clamp(inCoord, 0.0, len - 1.0);
|
|
} else if (${fillModeId} == 4) {
|
|
return clamp(outCoord, 0.0, len - 1.0);
|
|
} else {
|
|
return outCoord;
|
|
}
|
|
}
|
|
|
|
float readWithFillValue(int batch, int coordY, int coordX,
|
|
int channel) {
|
|
float outputValue;
|
|
if (0 <= coordY && coordY < ${imageHeight} && 0 <= coordX && coordX < ${imageWidth}) {
|
|
outputValue = getImage(batch, coordY, coordX, channel);
|
|
} else {
|
|
outputValue = float(${fillValue});
|
|
}
|
|
return outputValue;
|
|
}
|
|
|
|
void main() {
|
|
ivec4 coords = getOutputCoords();
|
|
float outputValue;
|
|
int batch = coords[0];
|
|
int x = coords[2];
|
|
int y = coords[1];
|
|
int channel = coords[3];
|
|
float xf = float(x);
|
|
float yf = float(y);
|
|
float a1 = getTransforms(batch, 0);
|
|
float a2 = getTransforms(batch, 1);
|
|
float a3 = getTransforms(batch, 2);
|
|
float b1 = getTransforms(batch, 3);
|
|
float b2 = getTransforms(batch, 4);
|
|
float b3 = getTransforms(batch, 5);
|
|
float c1 = getTransforms(batch, 6);
|
|
float c2 = getTransforms(batch, 7);
|
|
float projection = c1 * xf + c2 * yf + 1.0;
|
|
if (projection == 0.0) {
|
|
outputValue = float(${fillValue});
|
|
} else {
|
|
float inX = (a1 * xf + a2 * yf + a3) / projection;
|
|
float inY = (b1 * xf + b2 * yf + b3) / projection;
|
|
float mapX = mapCoord(inX, float(${imageWidth}));
|
|
float mapY = mapCoord(inY, float(${imageHeight}));
|
|
|
|
if (${interpolationModeId} == 1) {
|
|
int coordY = int(round(mapY));
|
|
int coordX = int(round(mapX));
|
|
outputValue = readWithFillValue(batch, coordY, coordX,
|
|
channel);
|
|
} else {
|
|
float yFloor = floor(mapY);
|
|
float xFloor = floor(mapX);
|
|
float yCeil = yFloor + 1.0;
|
|
float xCeil = xFloor + 1.0;
|
|
float valueYFloor = (xCeil - mapX) *
|
|
readWithFillValue(batch, int(yFloor), int(xFloor), channel) +
|
|
(mapX - xFloor) *
|
|
readWithFillValue(batch, int(yFloor), int(xCeil), channel);
|
|
float valueYCeil = (xCeil - mapX) *
|
|
readWithFillValue(batch, int(yCeil), int(xFloor), channel) +
|
|
(mapX - xFloor) *
|
|
readWithFillValue(batch, int(yCeil), int(xCeil), channel);
|
|
outputValue = (yCeil - mapY) * valueYFloor +
|
|
(mapY - yFloor) * valueYCeil;
|
|
}
|
|
}
|
|
setOutput(outputValue);
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function transform$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { image, transforms } = inputs;
|
|
const { interpolation, fillMode, fillValue, outputShape } = attrs;
|
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
|
const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
|
|
const outShape = [batch, outHeight, outWidth,
|
|
numChannels];
|
|
const program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
|
|
return backend.runWebGLProgram(program, [image, transforms], 'float32');
|
|
}
|
|
const transformConfig$1 = {
|
|
kernelName: Transform,
|
|
backendName: 'webgl',
|
|
kernelFunc: transform$1
|
|
};
|
|
|
|
|
|
function unique$2(args) {
|
|
const { inputs, attrs, backend } = args;
|
|
const { axis } = attrs;
|
|
const { x } = inputs;
|
|
assertNotComplex$1(x, 'unique');
|
|
|
|
console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
|
|
const values = backend.readSync(x.dataId);
|
|
const { outputValues, outputShape, indices } = uniqueImplCPU(values, axis, x.shape, x.dtype);
|
|
return [
|
|
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
|
|
backend.makeTensorInfo([indices.length], 'int32', indices),
|
|
];
|
|
}
|
|
const uniqueConfig$1 = {
|
|
kernelName: Unique,
|
|
backendName: 'webgl',
|
|
kernelFunc: unique$2,
|
|
};
|
|
|
|
|
|
function unpack$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { value } = inputs;
|
|
let { axis } = attrs;
|
|
if (axis < 0) {
|
|
axis += value.shape.length;
|
|
}
|
|
const x = value;
|
|
const xRank = x.shape.length;
|
|
const num = value.shape[axis];
|
|
const outShape = new Array(xRank - 1);
|
|
let outIndex = 0;
|
|
for (let i = 0; i < xRank; i++) {
|
|
if (i !== axis) {
|
|
outShape[outIndex++] = x.shape[i];
|
|
}
|
|
}
|
|
const toDispose = [];
|
|
const begin = new Array(xRank).fill(0);
|
|
const size = x.shape.slice();
|
|
size[axis] = 1;
|
|
const res = new Array(num);
|
|
for (let i = 0; i < res.length; i++) {
|
|
begin[axis] = i;
|
|
const sliced = slice({ inputs: { x }, backend, attrs: { begin, size } });
|
|
const reshaped = reshape$1({ inputs: { x: sliced }, backend, attrs: { shape: outShape } });
|
|
res[i] = reshaped;
|
|
toDispose.push(sliced);
|
|
}
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return res;
|
|
}
|
|
const unpackConfig$1 = {
|
|
kernelName: Unpack,
|
|
backendName: 'webgl',
|
|
kernelFunc: unpack$1
|
|
};
|
|
|
|
|
|
class SegmentOpProgram {
|
|
constructor(segOpInfo, segOpType) {
|
|
this.variableNames = ['x', 'segmentIds'];
|
|
const windowSize = segOpInfo.windowSize;
|
|
const batchSize = segOpInfo.batchSize;
|
|
const inSize = segOpInfo.inSize;
|
|
const numSegments = segOpInfo.numSegments;
|
|
const outSize = numSegments * Math.ceil(inSize / windowSize);
|
|
this.outputShape = [batchSize, outSize];
|
|
const initializationValue = '0.0';
|
|
const returnValue = `sumValue`;
|
|
const windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
|
const windowSizeVec4Remainder = windowSize % 4;
|
|
const updateSnippet = `
|
|
sumValue += dot(values, segFilter);
|
|
`;
|
|
let checkValueOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkValueOutOfBounds = `
|
|
if (inIdx < 0 || inIdx >= ${inSize}) {
|
|
return initializationValue;
|
|
}
|
|
`;
|
|
}
|
|
let checkSegmentIdOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkSegmentIdOutOfBounds = `
|
|
if (inIdx < 0 || inIdx >= ${inSize}) {
|
|
return -1.0;
|
|
}
|
|
`;
|
|
}
|
|
this.userCode = `
|
|
const float initializationValue = ${initializationValue};
|
|
|
|
float getValue(int batch, int inIdx) {
|
|
${checkValueOutOfBounds}
|
|
return getX(batch, inIdx);
|
|
}
|
|
|
|
float getSegmentIdAtIndex(int inIdx) {
|
|
${checkSegmentIdOutOfBounds}
|
|
return getSegmentIds(inIdx);
|
|
}
|
|
|
|
void main() {
|
|
ivec2 coords = getOutputCoords();
|
|
int batch = coords[0];
|
|
int outIdx = coords[1];
|
|
int inOffset = int(floor(float(outIdx) / float(
|
|
${numSegments})) * float(${windowSize}));
|
|
int currentSeg = int(mod(float(outIdx), float(${numSegments})));
|
|
|
|
float sumValue = 0.0;
|
|
|
|
for (int i = 0; i < ${windowSizeNearestVec4}; i += 4) {
|
|
int inIdx = inOffset + i;
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
getValue(batch, inIdx + 2),
|
|
getValue(batch, inIdx + 3)
|
|
);
|
|
|
|
vec4 segFilter = vec4(
|
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
|
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
|
|
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
|
|
int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
|
|
int inIdx = inOffset + ${windowSizeNearestVec4};
|
|
if (${windowSizeVec4Remainder === 1}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
initializationValue,
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
int inIdxSeg = int(getSegmentIdAtIndex(inIdx));
|
|
|
|
vec4 segFilter = vec4(
|
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
|
0,
|
|
0,
|
|
0
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${windowSizeVec4Remainder === 2}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
initializationValue,
|
|
initializationValue
|
|
);
|
|
|
|
vec4 segFilter = vec4(
|
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
|
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
|
|
0,
|
|
0
|
|
);
|
|
|
|
${updateSnippet}
|
|
} else if (${windowSizeVec4Remainder === 3}) {
|
|
vec4 values = vec4(
|
|
getValue(batch, inIdx),
|
|
getValue(batch, inIdx + 1),
|
|
getValue(batch, inIdx + 2),
|
|
initializationValue
|
|
);
|
|
|
|
vec4 segFilter = vec4(
|
|
int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,
|
|
int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,
|
|
int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,
|
|
0
|
|
);
|
|
|
|
${updateSnippet}
|
|
}
|
|
setOutput(${returnValue});
|
|
}
|
|
`;
|
|
}
|
|
}
|
|
|
|
|
|
function unsortedSegmentSum$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, segmentIds } = inputs;
|
|
const { numSegments } = attrs;
|
|
const xRank = x.shape.length;
|
|
const toDispose = [];
|
|
let axis = 0;
|
|
const permutation = getAxesPermutation([axis], xRank);
|
|
let permutedX = x;
|
|
if (permutation != null) {
|
|
permutedX = transpose({ inputs: { x }, backend, attrs: { perm: permutation } });
|
|
toDispose.push(permutedX);
|
|
axis = getInnerMostAxes(1, xRank)[0];
|
|
}
|
|
const outShape = computeOutShape(permutedX.shape, axis, numSegments);
|
|
const inSize = sizeFromShape([permutedX.shape[axis]]);
|
|
const a2D = reshape$1({ inputs: { x: permutedX }, backend, attrs: { shape: [-1, inSize] } });
|
|
toDispose.push(a2D);
|
|
const outputDType = sumOutType(x.dtype);
|
|
const segOpCompute = (x, segOpType, segmentIds, dtype, numSegments) => {
|
|
const batchSize = x.shape[0];
|
|
const inSize = x.shape[1];
|
|
const windowSize = segOpComputeOptimalWindowSize(inSize, numSegments);
|
|
const segOpInfo = { windowSize, inSize, batchSize, numSegments };
|
|
const program = new SegmentOpProgram(segOpInfo, segOpType);
|
|
const output = backend.compileAndRun(program, [x, segmentIds], dtype);
|
|
toDispose.push(output);
|
|
|
|
if (output.shape[1] === numSegments) {
|
|
return output;
|
|
}
|
|
const rangeInfo = range$2({
|
|
backend,
|
|
attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' }
|
|
});
|
|
const tileInfo = tile$2({
|
|
inputs: { x: rangeInfo },
|
|
backend,
|
|
attrs: { reps: [inSize / windowSize] }
|
|
});
|
|
toDispose.push(rangeInfo);
|
|
toDispose.push(tileInfo);
|
|
const result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
|
|
return result;
|
|
};
|
|
const segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
|
|
const reshaped = reshape$1({ inputs: { x: segOpResult }, backend, attrs: { shape: outShape } });
|
|
let result = reshaped;
|
|
if (permutation != null) {
|
|
toDispose.push(reshaped);
|
|
const perm = getUndoAxesPermutation(permutation);
|
|
result = transpose({ inputs: { x: result }, backend, attrs: { perm } });
|
|
}
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
}
|
|
const unsortedSegmentSumConfig$1 = {
|
|
kernelName: UnsortedSegmentSum,
|
|
backendName: 'webgl',
|
|
kernelFunc: unsortedSegmentSum$1
|
|
};
|
|
|
|
|
|
|
|
const kernelConfigs$1 = [
|
|
_fusedMatMulConfig$1,
|
|
absConfig,
|
|
acosConfig$1,
|
|
acoshConfig$1,
|
|
addConfig,
|
|
addNConfig$1,
|
|
allConfig$1,
|
|
anyConfig$1,
|
|
argMaxConfig$1,
|
|
argMinConfig$1,
|
|
asinConfig$1,
|
|
asinhConfig$1,
|
|
atanConfig$1,
|
|
atan2Config$1,
|
|
atanhConfig$1,
|
|
avgPoolConfig$1,
|
|
avgPool3DConfig$1,
|
|
avgPool3DGradConfig$2,
|
|
avgPoolGradConfig$2,
|
|
batchMatMulConfig$1,
|
|
batchNormConfig$1,
|
|
batchToSpaceNDConfig$1,
|
|
bincountConfig$1,
|
|
bitwiseAndConfig,
|
|
broadcastArgsConfig$1,
|
|
castConfig,
|
|
ceilConfig,
|
|
clipByValueConfig$1,
|
|
complexConfig,
|
|
complexAbsConfig$1,
|
|
concatConfig$1,
|
|
conv2DConfig$1,
|
|
conv2DBackpropFilterConfig$1,
|
|
conv2DBackpropInputConfig$1,
|
|
conv3DConfig$1,
|
|
conv3DBackpropFilterV2Config$1,
|
|
conv3DBackpropInputConfig,
|
|
cosConfig$1,
|
|
coshConfig$1,
|
|
cropAndResizeConfig$1,
|
|
cumprodConfig$1,
|
|
cumsumConfig$1,
|
|
denseBincountConfig$1,
|
|
depthToSpaceConfig$1,
|
|
depthwiseConv2dNativeConfig$1,
|
|
depthwiseConv2dNativeBackpropFilterConfig$1,
|
|
depthwiseConv2dNativeBackpropInputConfig$1,
|
|
diagConfig$1,
|
|
dilation2DConfig$1,
|
|
einsumConfig$1,
|
|
eluConfig$1,
|
|
eluGradConfig$2,
|
|
equalConfig,
|
|
erfConfig$1,
|
|
expConfig,
|
|
expandDimsConfig$1,
|
|
expm1Config,
|
|
fftConfig$1,
|
|
fillConfig$1,
|
|
flipLeftRightConfig$1,
|
|
floorConfig,
|
|
floorDivConfig,
|
|
fromPixelsConfig,
|
|
fusedConv2DConfig$1,
|
|
fusedDepthwiseConv2DConfig$1,
|
|
gatherNdConfig$1,
|
|
gatherV2Config$1,
|
|
greaterConfig,
|
|
greaterEqualConfig,
|
|
identityConfig,
|
|
ifftConfig$1,
|
|
imagConfig$1,
|
|
isFiniteConfig$1,
|
|
isInfConfig$1,
|
|
isNaNConfig$1,
|
|
leakyReluConfig$1,
|
|
lessConfig,
|
|
lessEqualConfig,
|
|
linSpaceConfig$1,
|
|
logConfig,
|
|
log1pConfig$1,
|
|
logicalAndConfig$1,
|
|
logicalNotConfig$1,
|
|
logicalOrConfig$1,
|
|
LRNConfig$1,
|
|
LRNGradConfig$1,
|
|
maxConfig$1,
|
|
maximumConfig,
|
|
maxPoolConfig$1,
|
|
maxPool3DConfig$1,
|
|
maxPool3DGradConfig$2,
|
|
maxPoolGradConfig$2,
|
|
maxPoolWithArgmaxConfig$1,
|
|
meanConfig$1,
|
|
minConfig$1,
|
|
minimumConfig,
|
|
mirrorPadConfig$1,
|
|
modConfig$1,
|
|
multinomialConfig$1,
|
|
multiplyConfig,
|
|
negConfig,
|
|
nonMaxSuppressionV3Config$1,
|
|
nonMaxSuppressionV4Config$1,
|
|
nonMaxSuppressionV5Config$1,
|
|
notEqualConfig,
|
|
oneHotConfig$1,
|
|
onesLikeConfig$1,
|
|
packConfig$1,
|
|
padV2Config$1,
|
|
powConfig$1,
|
|
preluConfig$1,
|
|
prodConfig,
|
|
raggedGatherConfig$1,
|
|
raggedRangeConfig$1,
|
|
raggedTensorToTensorConfig$1,
|
|
rangeConfig$1,
|
|
realConfig,
|
|
realDivConfig$1,
|
|
reciprocalConfig$1,
|
|
reluConfig$1,
|
|
relu6Config$1,
|
|
reshapeConfig$1,
|
|
resizeBilinearConfig$1,
|
|
resizeBilinearGradConfig$2,
|
|
resizeNearestNeighborConfig$1,
|
|
resizeNearestNeighborGradConfig$2,
|
|
reverseConfig$1,
|
|
rotateWithOffsetConfig$1,
|
|
roundConfig$1,
|
|
rsqrtConfig,
|
|
scatterNdConfig$1,
|
|
searchSortedConfig$1,
|
|
selectConfig$1,
|
|
seluConfig$1,
|
|
sigmoidConfig,
|
|
signConfig$1,
|
|
sinConfig$1,
|
|
sinhConfig$1,
|
|
sliceConfig,
|
|
softmaxConfig$1,
|
|
softplusConfig$1,
|
|
spaceToBatchNDConfig$1,
|
|
sparseFillEmptyRowsConfig$1,
|
|
sparseReshapeConfig$1,
|
|
sparseSegmentMeanConfig$1,
|
|
sparseSegmentSumConfig$1,
|
|
sparseToDenseConfig$1,
|
|
splitVConfig$1,
|
|
sqrtConfig,
|
|
squareConfig$1,
|
|
squaredDifferenceConfig,
|
|
staticRegexReplaceConfig,
|
|
stepConfig$1,
|
|
stridedSliceConfig$1,
|
|
stringNGramsConfig$1,
|
|
stringSplitConfig$1,
|
|
stringToHashBucketFastConfig$1,
|
|
subConfig,
|
|
sumConfig$1,
|
|
tanConfig$1,
|
|
tanhConfig$1,
|
|
tensorScatterUpdateConfig$1,
|
|
tileConfig$1,
|
|
topKConfig$1,
|
|
transformConfig$1,
|
|
transposeConfig,
|
|
uniqueConfig$1,
|
|
unpackConfig$1,
|
|
unsortedSegmentSumConfig$1,
|
|
zerosLikeConfig$1
|
|
];
|
|
for (const kernelConfig of kernelConfigs$1) {
|
|
registerKernel(kernelConfig);
|
|
}
|
|
|
|
|
|
const whereImpl = whereImpl$2;
|
|
class MathBackendCPU extends KernelBackend {
|
|
nextDataId() {
|
|
return MathBackendCPU.nextDataId++;
|
|
}
|
|
constructor() {
|
|
super();
|
|
this.blockSize = 48;
|
|
this.firstUse = true;
|
|
this.data = new DataStorage(this, engine());
|
|
}
|
|
write(values, shape, dtype) {
|
|
if (this.firstUse) {
|
|
this.firstUse = false;
|
|
if (env().get('IS_NODE')) {
|
|
warn('\n============================\n' +
|
|
'Hi, looks like you are running TensorFlow.js in ' +
|
|
'Node.js. To speed things up dramatically, install our node ' +
|
|
'backend, visit https://github.com/tensorflow/tfjs-node for more details. ' +
|
|
'\n============================');
|
|
}
|
|
}
|
|
const dataId = { id: this.nextDataId() };
|
|
this.data.set(dataId, { values, dtype, refCount: 1 });
|
|
return dataId;
|
|
}
|
|
|
|
makeTensorInfo(shape, dtype, values) {
|
|
let outId;
|
|
if (dtype === 'string' && values != null && values.length > 0 &&
|
|
isString(values[0])) {
|
|
const encodedValues = values.map(d => encodeString(d));
|
|
outId = this.write(encodedValues, shape, dtype);
|
|
}
|
|
else {
|
|
outId = this.write(values, shape, dtype);
|
|
}
|
|
return { dataId: outId, shape, dtype };
|
|
}
|
|
|
|
refCount(dataId) {
|
|
if (this.data.has(dataId)) {
|
|
const tensorData = this.data.get(dataId);
|
|
return tensorData.refCount;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
incRef(dataId) {
|
|
const tensorData = this.data.get(dataId);
|
|
tensorData.refCount++;
|
|
}
|
|
|
|
decRef(dataId) {
|
|
if (this.data.has(dataId)) {
|
|
const tensorData = this.data.get(dataId);
|
|
tensorData.refCount--;
|
|
}
|
|
}
|
|
move(dataId, values, shape, dtype, refCount) {
|
|
this.data.set(dataId, { values, dtype, refCount });
|
|
}
|
|
numDataIds() {
|
|
return this.data.numDataIds();
|
|
}
|
|
async read(dataId) {
|
|
return this.readSync(dataId);
|
|
}
|
|
readSync(dataId) {
|
|
const { dtype, complexTensorInfos } = this.data.get(dataId);
|
|
if (dtype === 'complex64') {
|
|
const realValues = this.readSync(complexTensorInfos.real.dataId);
|
|
const imagValues = this.readSync(complexTensorInfos.imag.dataId);
|
|
return mergeRealAndImagArrays(realValues, imagValues);
|
|
}
|
|
return convertBackendValuesAndArrayBuffer(this.data.get(dataId).values, dtype);
|
|
}
|
|
bufferSync(t) {
|
|
const data = this.readSync(t.dataId);
|
|
if (t.dtype === 'string') {
|
|
try {
|
|
|
|
const strings = data.map(d => decodeString(d));
|
|
return buffer(t.shape, t.dtype, strings);
|
|
}
|
|
catch (_a) {
|
|
throw new Error('Failed to decode encoded string bytes into utf-8');
|
|
}
|
|
}
|
|
return buffer(t.shape, t.dtype, data);
|
|
}
|
|
makeOutput(values, shape, dtype) {
|
|
return engine().makeTensorFromTensorInfo(this.makeTensorInfo(shape, dtype, values), this);
|
|
}
|
|
|
|
disposeData(dataId, force = false) {
|
|
if (this.data.has(dataId)) {
|
|
this.data.get(dataId).refCount--;
|
|
if (!force && this.data.get(dataId).refCount > 0) {
|
|
return false;
|
|
}
|
|
const { complexTensorInfos } = this.data.get(dataId);
|
|
if (complexTensorInfos != null) {
|
|
this.disposeData(complexTensorInfos.real.dataId, true);
|
|
this.disposeData(complexTensorInfos.imag.dataId, true);
|
|
}
|
|
this.data.delete(dataId);
|
|
}
|
|
return true;
|
|
}
|
|
disposeIntermediateTensorInfo(tensorInfo) {
|
|
this.disposeData(tensorInfo.dataId);
|
|
}
|
|
async time(f) {
|
|
const start = now();
|
|
f();
|
|
const kernelMs = now() - start;
|
|
return { kernelMs };
|
|
}
|
|
memory() {
|
|
return {
|
|
|
|
unreliable: true,
|
|
reasons: ['The reported memory is an upper bound. Due to automatic garbage ' +
|
|
'collection, the true allocated memory may be less.']
|
|
};
|
|
}
|
|
where(condition) {
|
|
assertNotComplex([condition], 'where');
|
|
const condVals = this.readSync(condition.dataId);
|
|
return whereImpl(condition.shape, condVals);
|
|
}
|
|
dispose() { }
|
|
floatPrecision() {
|
|
return 32;
|
|
}
|
|
|
|
epsilon() {
|
|
return super.epsilon();
|
|
}
|
|
}
|
|
MathBackendCPU.nextDataId = 0;
|
|
|
|
|
|
|
|
|
|
registerBackend('cpu', () => new MathBackendCPU(), 1 );
|
|
|
|
|
|
const elu$1 = unaryKernelFunc$1(Elu$1, (xi) => xi >= 0 ? xi : (Math.exp(xi) - 1));
|
|
const eluConfig = {
|
|
kernelName: Elu$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: elu$1,
|
|
};
|
|
|
|
|
|
function leakyRelu(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { alpha } = attrs;
|
|
assertNotComplex([x], 'leakyRelu');
|
|
const xSize = sizeFromShape(x.shape);
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const outVals = getTypedArrayFromDType('float32', xSize);
|
|
for (let i = 0; i < xVals.length; i++) {
|
|
outVals[i] = xVals[i] < 0 ? alpha * xVals[i] : xVals[i];
|
|
}
|
|
return backend.makeTensorInfo(x.shape, 'float32', outVals);
|
|
}
|
|
const leakyReluConfig = {
|
|
kernelName: LeakyRelu,
|
|
backendName: 'cpu',
|
|
kernelFunc: leakyRelu
|
|
};
|
|
|
|
|
|
const preluImpl = createSimpleBinaryKernelImpl((xValue, aValue) => xValue < 0 ? aValue * xValue : xValue);
|
|
function prelu(args) {
|
|
const { inputs, backend } = args;
|
|
const { x, alpha } = inputs;
|
|
assertNotComplex([x, alpha], 'prelu');
|
|
const aVals = backend.data.get(x.dataId).values;
|
|
const bVals = backend.data.get(alpha.dataId).values;
|
|
const [resultData, resultShape] = preluImpl(x.shape, alpha.shape, aVals, bVals, 'float32');
|
|
return backend.makeTensorInfo(resultShape, 'float32', resultData);
|
|
}
|
|
const preluConfig = {
|
|
kernelName: Prelu,
|
|
backendName: 'cpu',
|
|
kernelFunc: prelu,
|
|
};
|
|
|
|
|
|
const relu = unaryKernelFunc$1(Relu$1, (xi) => Math.max(0, xi));
|
|
const reluConfig = {
|
|
kernelName: Relu$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: relu,
|
|
};
|
|
|
|
|
|
const relu6 = unaryKernelFunc$1(Relu6$1, (xi) => Math.min(Math.max(0, xi), 6));
|
|
const relu6Config = {
|
|
kernelName: Relu6$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: relu6,
|
|
};
|
|
|
|
|
|
function applyActivation(backend, x, activation, preluActivationWeights, leakyreluAlpha) {
|
|
if (activation === 'linear') {
|
|
return identity$1({ inputs: { x }, backend });
|
|
}
|
|
else if (activation === 'relu') {
|
|
return relu({ inputs: { x }, backend });
|
|
}
|
|
else if (activation === 'elu') {
|
|
return elu$1({ inputs: { x }, backend });
|
|
}
|
|
else if (activation === 'relu6') {
|
|
return relu6({ inputs: { x }, backend });
|
|
}
|
|
else if (activation === 'prelu') {
|
|
return prelu({ inputs: { x, alpha: preluActivationWeights }, backend });
|
|
}
|
|
else if (activation === 'leakyrelu') {
|
|
return leakyRelu({ inputs: { x }, backend, attrs: { alpha: leakyreluAlpha } });
|
|
}
|
|
else if (activation === 'sigmoid') {
|
|
return sigmoid$1({ inputs: { x }, backend });
|
|
}
|
|
throw new Error(`Activation ${activation} has not been implemented for the CPU backend.`);
|
|
}
|
|
|
|
|
|
function reshape(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { shape } = attrs;
|
|
const xSize = sizeFromShape(x.shape);
|
|
const $shape = inferFromImplicitShape(shape, xSize);
|
|
const $xSize = sizeFromShape($shape);
|
|
assert$1(xSize === $xSize, () => `The new shape (${$shape}) has ${$xSize} elements and the old ` +
|
|
`shape (${x.shape}) has ${xSize} elements. The new shape and old ` +
|
|
`shape must have the same number of elements.`);
|
|
backend.incRef(x.dataId);
|
|
const xData = backend.data.get(x.dataId);
|
|
if (xData.complexTensorInfos != null) {
|
|
const real = xData.complexTensorInfos.real;
|
|
const imag = xData.complexTensorInfos.imag;
|
|
real.shape = $shape;
|
|
imag.shape = $shape;
|
|
}
|
|
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
|
|
}
|
|
const reshapeConfig = {
|
|
kernelName: Reshape$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: reshape
|
|
};
|
|
|
|
|
|
function batchMatMul(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { a, b } = inputs;
|
|
const { transposeA, transposeB } = attrs;
|
|
assertNotComplex([a, b], 'matMul');
|
|
const aRank = a.shape.length;
|
|
const bRank = b.shape.length;
|
|
const innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
|
|
const innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
|
|
const outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
|
|
const outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
|
|
const outerDimsA = a.shape.slice(0, -2);
|
|
const outerDimsB = b.shape.slice(0, -2);
|
|
const batchDimA = sizeFromShape(outerDimsA);
|
|
const batchDimB = sizeFromShape(outerDimsB);
|
|
const outShapeOuterDims = assertAndGetBroadcastShape(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
const outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
|
|
assert$1(innerShapeA === innerShapeB, () => `Error in matMul: inner shapes (${innerShapeA}) and (` +
|
|
`${innerShapeB}) of Tensors with shapes ${a.shape} and ` +
|
|
`${b.shape} and transposeA=${transposeA}` +
|
|
` and transposeB=${transposeB} must match.`);
|
|
const a3dShape = transposeA ? [batchDimA, innerShapeA, outerShapeA] :
|
|
[batchDimA, outerShapeA, innerShapeA];
|
|
const b3dShape = transposeB ? [batchDimB, outerShapeB, innerShapeB] :
|
|
[batchDimB, innerShapeB, outerShapeB];
|
|
|
|
const a3d = reshape({ inputs: { x: a }, backend, attrs: { shape: a3dShape } });
|
|
const b3d = reshape({ inputs: { x: b }, backend, attrs: { shape: b3dShape } });
|
|
const sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
|
|
const leftDim = transposeA ? a3d.shape[2] : a3d.shape[1];
|
|
const rightDim = transposeB ? b3d.shape[1] : b3d.shape[2];
|
|
const batchDim = Math.max(batchDimA, batchDimB);
|
|
const a3dValues = backend.data.get(a3d.dataId).values;
|
|
const b3dValues = backend.data.get(b3d.dataId).values;
|
|
const a3dStrides = computeStrides(a3d.shape);
|
|
const b3dStrides = computeStrides(b3d.shape);
|
|
const [aBatch, aOuterStep, aInnerStep] = transposeA ?
|
|
[a3dStrides[0], 1, a3dStrides[1]] :
|
|
[a3dStrides[0], a3dStrides[1], 1];
|
|
const [bInnerStep, bOuterStep, bBatch] = transposeB ?
|
|
[1, b3dStrides[1], b3dStrides[0]] :
|
|
[b3dStrides[1], 1, b3dStrides[0]];
|
|
const size = leftDim * rightDim;
|
|
const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
|
|
const resVals = result.values;
|
|
const blockSize = backend.blockSize;
|
|
for (let bi = 0; bi < batchDim; bi++) {
|
|
const batchIndexA = bi % batchDimA;
|
|
const batchIndexB = bi % batchDimB;
|
|
for (let i0 = 0; i0 < leftDim; i0 += blockSize) {
|
|
|
|
const iBlock = Math.min(i0 + blockSize, leftDim);
|
|
for (let j0 = 0; j0 < rightDim; j0 += blockSize) {
|
|
const jBlock = Math.min(j0 + blockSize, rightDim);
|
|
for (let k0 = 0; k0 < sharedDim; k0 += blockSize) {
|
|
const kBlock = Math.min(k0 + blockSize, sharedDim);
|
|
for (let i = i0; i < iBlock; i++) {
|
|
for (let j = j0; j < jBlock; j++) {
|
|
let sum = 0.0;
|
|
for (let k = k0; k < kBlock; k++) {
|
|
const aVal =
|
|
|
|
a3dValues[batchIndexA * aBatch + i * aOuterStep + k * aInnerStep];
|
|
const bVal =
|
|
|
|
b3dValues[k * bInnerStep + j * bOuterStep + batchIndexB * bBatch];
|
|
sum += aVal * bVal;
|
|
}
|
|
resVals[bi * size + (i * rightDim + j)] += sum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a3d);
|
|
backend.disposeIntermediateTensorInfo(b3d);
|
|
|
|
return backend.makeTensorInfo(outShape, result.dtype, result.values);
|
|
}
|
|
const batchMatMulConfig = {
|
|
kernelName: BatchMatMul,
|
|
backendName: 'cpu',
|
|
kernelFunc: batchMatMul,
|
|
};
|
|
|
|
|
|
function _fusedMatMul(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { a, b, bias, preluActivationWeights } = inputs;
|
|
const { transposeA, transposeB, activation, leakyreluAlpha } = attrs;
|
|
let current;
|
|
let addRes;
|
|
let activationRes;
|
|
const intermediates = [];
|
|
const matMulRes = batchMatMul({ inputs: { a, b }, attrs: { transposeA, transposeB }, backend });
|
|
current = matMulRes;
|
|
if (bias) {
|
|
addRes = add({ inputs: { a: current, b: bias }, backend });
|
|
intermediates.push(current);
|
|
current = addRes;
|
|
}
|
|
if (activation) {
|
|
activationRes = applyActivation(backend, current, activation, preluActivationWeights, leakyreluAlpha);
|
|
intermediates.push(current);
|
|
current = activationRes;
|
|
}
|
|
for (const i of intermediates) {
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return current;
|
|
}
|
|
const _fusedMatMulConfig = {
|
|
kernelName: _FusedMatMul,
|
|
backendName: 'cpu',
|
|
kernelFunc: _fusedMatMul,
|
|
};
|
|
|
|
|
|
const acos = unaryKernelFunc$1(Acos, (xi) => Math.acos(xi));
|
|
const acosConfig = {
|
|
kernelName: Acos,
|
|
backendName: 'cpu',
|
|
kernelFunc: acos,
|
|
};
|
|
|
|
|
|
const acosh = unaryKernelFunc$1(Acosh, (xi) => Math.acosh(xi));
|
|
const acoshConfig = {
|
|
kernelName: Acosh,
|
|
backendName: 'cpu',
|
|
kernelFunc: acosh,
|
|
};
|
|
|
|
|
|
function addN(args) {
|
|
const { inputs, backend } = args;
|
|
const tensors = inputs;
|
|
assertNotComplex(inputs, 'addN');
|
|
const vals = tensors.map(t => backend.data.get(t.dataId).values);
|
|
const outBuf = buffer(tensors[0].shape, tensors[0].dtype);
|
|
const outVals = outBuf.values;
|
|
for (let i = 0; i < tensors.length; i++) {
|
|
const currVals = vals[i];
|
|
for (let j = 0; j < outVals.length; j++) {
|
|
outVals[j] += currVals[j];
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const addNConfig = {
|
|
kernelName: AddN,
|
|
backendName: 'cpu',
|
|
kernelFunc: addN
|
|
};
|
|
|
|
|
|
function all(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
assertNotComplex(x, 'all');
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
if (permutedAxes != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, x.shape.length);
|
|
}
|
|
assertAxesAreInnerMostDims('all', axes, $x.shape.length);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let all = aVals[offset];
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
all = all && value;
|
|
}
|
|
vals[i] = all;
|
|
}
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
}
|
|
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
|
|
if (keepDims) {
|
|
const expandedShape = expandShapeToKeepDim(outShape, origAxes);
|
|
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return reshapedResult;
|
|
}
|
|
return result;
|
|
}
|
|
const allConfig = {
|
|
kernelName: All,
|
|
backendName: 'cpu',
|
|
kernelFunc: all
|
|
};
|
|
|
|
|
|
function any(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
assertNotComplex(x, 'any');
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
if (permutedAxes != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, x.shape.length);
|
|
}
|
|
assertAxesAreInnerMostDims('any', axes, $x.shape.length);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let anyVal = aVals[offset];
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
anyVal = anyVal || value;
|
|
}
|
|
vals[i] = anyVal;
|
|
}
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
}
|
|
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
|
|
if (keepDims) {
|
|
const expandedShape = expandShapeToKeepDim(outShape, origAxes);
|
|
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return reshapedResult;
|
|
}
|
|
return result;
|
|
}
|
|
const anyConfig = {
|
|
kernelName: Any,
|
|
backendName: 'cpu',
|
|
kernelFunc: any
|
|
};
|
|
|
|
|
|
function argMax(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis } = attrs;
|
|
assertNotComplex(x, 'argMax');
|
|
let axes = parseAxisParam(axis, x.shape);
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
const intermediateTensorInfos = [];
|
|
if (permutedAxes != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
intermediateTensorInfos.push($x);
|
|
axes = getInnerMostAxes(axes.length, $x.shape.length);
|
|
}
|
|
axes = [axes[0]];
|
|
assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
|
|
const outSize = sizeFromShape(outShape);
|
|
const vals = makeZerosTypedArray(outSize, 'int32');
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let max = aVals[offset];
|
|
let maxIndex = 0;
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
if (value > max) {
|
|
max = value;
|
|
maxIndex = j;
|
|
}
|
|
}
|
|
vals[i] = maxIndex;
|
|
}
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return backend.makeTensorInfo(outShape, 'int32', vals);
|
|
}
|
|
const argMaxConfig = {
|
|
kernelName: ArgMax,
|
|
backendName: 'cpu',
|
|
kernelFunc: argMax
|
|
};
|
|
|
|
|
|
function argMin(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis } = attrs;
|
|
assertNotComplex(x, 'argMin');
|
|
let axes = parseAxisParam(axis, x.shape);
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
const intermediateTensorInfos = [];
|
|
if (permutedAxes != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
intermediateTensorInfos.push($x);
|
|
axes = getInnerMostAxes(axes.length, $x.shape.length);
|
|
}
|
|
axes = [axes[0]];
|
|
assertAxesAreInnerMostDims('argMin', axes, $x.shape.length);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
|
|
const outSize = sizeFromShape(outShape);
|
|
const vals = makeZerosTypedArray(outSize, 'int32');
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let min = aVals[offset];
|
|
let minIndex = 0;
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
if (value < min) {
|
|
min = value;
|
|
minIndex = j;
|
|
}
|
|
}
|
|
vals[i] = minIndex;
|
|
}
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return backend.makeTensorInfo(outShape, 'int32', vals);
|
|
}
|
|
const argMinConfig = {
|
|
kernelName: ArgMin,
|
|
backendName: 'cpu',
|
|
kernelFunc: argMin
|
|
};
|
|
|
|
|
|
const asin = unaryKernelFunc$1(Asin, (xi) => Math.asin(xi));
|
|
const asinConfig = {
|
|
kernelName: Asin,
|
|
backendName: 'cpu',
|
|
kernelFunc: asin,
|
|
};
|
|
|
|
|
|
const asinh = unaryKernelFunc$1(Asinh, (xi) => Math.asinh(xi));
|
|
const asinhConfig = {
|
|
kernelName: Asinh,
|
|
backendName: 'cpu',
|
|
kernelFunc: asinh,
|
|
};
|
|
|
|
|
|
const atan = unaryKernelFunc$1(Atan, (xi) => Math.atan(xi));
|
|
const atanConfig = {
|
|
kernelName: Atan,
|
|
backendName: 'cpu',
|
|
kernelFunc: atan,
|
|
};
|
|
|
|
|
|
const atan2Impl = createSimpleBinaryKernelImpl((aValue, bValue) => Math.atan2(aValue, bValue));
|
|
const atan2 = binaryKernelFunc$1(Atan2, atan2Impl);
|
|
const atan2Config = {
|
|
kernelName: Atan2,
|
|
backendName: 'cpu',
|
|
kernelFunc: atan2,
|
|
};
|
|
|
|
|
|
const atanh = unaryKernelFunc$1(Atanh, (xi) => Math.atanh(xi));
|
|
const atanhConfig = {
|
|
kernelName: Atanh,
|
|
backendName: 'cpu',
|
|
kernelFunc: atanh,
|
|
};
|
|
|
|
|
|
function pool(xValues, xShape, dtype, strides, convInfo, poolType) {
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
|
|
Number.POSITIVE_INFINITY);
|
|
const output = buffer(convInfo.outShape, dtype);
|
|
const outputVals = output.values;
|
|
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
|
|
const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
|
|
const outputColStrides = convInfo.outShape[3];
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
const outputBatchOffset = b * outputBatchStrides;
|
|
const inputBatchOffset = b * strides[0];
|
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
|
const xRCorner = yR * strideHeight - padTop;
|
|
const xRMin = Math.max(0, xRCorner);
|
|
const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
|
|
const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
|
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
|
const xCCorner = yC * strideWidth - padLeft;
|
|
const xCMin = Math.max(0, xCCorner);
|
|
const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
|
|
let minMaxValue = initialValue;
|
|
let avgValue = 0;
|
|
let count = 0;
|
|
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
|
|
const xROffset = inputBatchOffset + xR * strides[1];
|
|
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
|
|
const xCOffset = xROffset + xC * strides[2];
|
|
const pixel = xValues[xCOffset + d];
|
|
if ((poolType === 'max' && pixel > minMaxValue)) {
|
|
minMaxValue = pixel;
|
|
}
|
|
else if (poolType === 'avg') {
|
|
avgValue += pixel;
|
|
count++;
|
|
}
|
|
}
|
|
if (isNaN(minMaxValue)) {
|
|
break;
|
|
}
|
|
}
|
|
const outputOffset = outputRowOffset + yC * outputColStrides + d;
|
|
outputVals[outputOffset] =
|
|
poolType === 'avg' ? avgValue / count : minMaxValue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
function maxPoolPositions(xValues, xShape, dtype, convInfo, flattenPositions = false, includeBatchInIndex = false) {
|
|
const maxPositions = buffer(convInfo.outShape, 'int32');
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const xBuf = buffer(xShape, dtype, xValues);
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
|
const xRCorner = yR * strideHeight - padTop;
|
|
let xRMin = xRCorner;
|
|
while (xRMin < 0) {
|
|
xRMin += dilationHeight;
|
|
}
|
|
|
|
const xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
|
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
|
const xCCorner = yC * strideWidth - padLeft;
|
|
let xCMin = xCCorner;
|
|
while (xCMin < 0) {
|
|
xCMin += dilationWidth;
|
|
}
|
|
const xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
|
|
let maxValue = Number.NEGATIVE_INFINITY;
|
|
let maxPosition = -1;
|
|
for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
|
|
const wR = xR - xRCorner;
|
|
for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
|
|
const wC = xC - xCCorner;
|
|
|
|
|
|
|
|
const pixel = xBuf.get(b, xR, xC, d);
|
|
if (pixel > maxValue) {
|
|
maxValue = pixel;
|
|
if (flattenPositions) {
|
|
maxPosition = includeBatchInIndex ?
|
|
((b * convInfo.inHeight + xR) * convInfo.inWidth + xC) *
|
|
convInfo.inChannels +
|
|
d :
|
|
(xR * convInfo.inWidth + xC) * convInfo.inChannels + d;
|
|
}
|
|
else {
|
|
maxPosition = wR * effectiveFilterWidth + wC;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
maxPositions.set(maxPosition, b, yR, yC, d);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return maxPositions;
|
|
}
|
|
function pool3d(xValues, xShape, dtype, strides, convInfo, poolType) {
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = convInfo.padInfo.front;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY :
|
|
Number.POSITIVE_INFINITY);
|
|
const output = buffer(convInfo.outShape, dtype);
|
|
const outputVals = output.values;
|
|
const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] *
|
|
convInfo.outShape[3] * convInfo.outShape[4];
|
|
const outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4];
|
|
const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4];
|
|
const outputColStrides = convInfo.outShape[4];
|
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
|
const outputBatchOffset = batch * outputBatchStrides;
|
|
const inputBatchOffset = batch * strides[0];
|
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
|
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
|
|
const xDepthCorner = yDepth * strideDepth - padFront;
|
|
let xDepthMin = xDepthCorner;
|
|
while (xDepthMin < 0) {
|
|
xDepthMin += dilationDepth;
|
|
}
|
|
const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
|
|
const outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides;
|
|
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
|
|
const xRowCorner = yRow * strideHeight - padTop;
|
|
let xRowMin = xRowCorner;
|
|
while (xRowMin < 0) {
|
|
xRowMin += dilationHeight;
|
|
}
|
|
const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
|
|
const outputRowOffset = outputDepthOffset + yRow * outputRowStrides;
|
|
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
|
|
const xColCorner = yCol * strideWidth - padLeft;
|
|
let xColMin = xColCorner;
|
|
while (xColMin < 0) {
|
|
xColMin += dilationWidth;
|
|
}
|
|
const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
|
|
|
|
const outputColOffset = outputRowOffset + yCol * outputColStrides;
|
|
let minMaxValue = initialValue;
|
|
let avgValue = 0;
|
|
let count = 0;
|
|
for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
|
|
const xDepthOffset = inputBatchOffset + xDepth * strides[1];
|
|
for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
|
|
const xRowOffset = xDepthOffset + xRow * strides[2];
|
|
for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
|
|
const xColOffset = xRowOffset + xCol * strides[3];
|
|
const pixel = xValues[xColOffset + channel];
|
|
if ((poolType === 'max' && pixel > minMaxValue)) {
|
|
minMaxValue = pixel;
|
|
}
|
|
else if (poolType === 'avg') {
|
|
avgValue += pixel;
|
|
count++;
|
|
}
|
|
if (isNaN(minMaxValue)) {
|
|
break;
|
|
}
|
|
}
|
|
if (isNaN(minMaxValue)) {
|
|
break;
|
|
}
|
|
}
|
|
if (isNaN(minMaxValue)) {
|
|
break;
|
|
}
|
|
}
|
|
const outputOffset = outputColOffset + channel;
|
|
outputVals[outputOffset] = poolType === 'avg' ?
|
|
avgValue / Math.max(count, 1) :
|
|
minMaxValue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
function maxPool3dPositions(xBuf, convInfo) {
|
|
const maxPositions = buffer(convInfo.outShape, 'int32');
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = convInfo.padInfo.front;
|
|
const padTop = convInfo.padInfo.top;
|
|
const padLeft = convInfo.padInfo.left;
|
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
|
for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) {
|
|
const xDepthCorner = yDepth * strideDepth - padFront;
|
|
let xDepthMin = xDepthCorner;
|
|
while (xDepthMin < 0) {
|
|
xDepthMin += dilationDepth;
|
|
}
|
|
const xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner);
|
|
for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) {
|
|
const xRowCorner = yRow * strideHeight - padTop;
|
|
let xRowMin = xRowCorner;
|
|
while (xRowMin < 0) {
|
|
xRowMin += dilationHeight;
|
|
}
|
|
const xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner);
|
|
for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) {
|
|
const xColCorner = yCol * strideWidth - padLeft;
|
|
let xColMin = xColCorner;
|
|
while (xColMin < 0) {
|
|
xColMin += dilationWidth;
|
|
}
|
|
const xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner);
|
|
|
|
let maxValue = Number.NEGATIVE_INFINITY;
|
|
let maxPosition = -1;
|
|
for (let xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) {
|
|
const wDepth = xDepth - xDepthCorner;
|
|
for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) {
|
|
const wRow = xRow - xRowCorner;
|
|
for (let xCol = xColMin; xCol < xColMax; xCol += dilationWidth) {
|
|
const wCol = xCol - xColCorner;
|
|
const pixel = xBuf.get(batch, xDepth, xRow, xCol, channel);
|
|
if (pixel >= maxValue) {
|
|
maxValue = pixel;
|
|
maxPosition =
|
|
wDepth * effectiveFilterHeight * effectiveFilterWidth +
|
|
wRow * effectiveFilterHeight + wCol;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return maxPositions;
|
|
}
|
|
|
|
|
|
function avgPool(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
assertNotComplex(x, 'avgPool');
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const dilations = 1;
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in avgPool: Either strides or dilations must be 1. ' +
|
|
`Got strides ${strides} and dilations '${dilations}'`);
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
let res;
|
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
|
arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
|
res = identity$1({ inputs: { x }, backend });
|
|
}
|
|
else {
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const strides = computeStrides(x.shape);
|
|
const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'avg');
|
|
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
|
|
}
|
|
return res;
|
|
}
|
|
const avgPoolConfig = {
|
|
kernelName: AvgPool,
|
|
backendName: 'cpu',
|
|
kernelFunc: avgPool
|
|
};
|
|
|
|
|
|
function avgPool3D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
|
|
assertNotComplex(x, 'avgPool3d');
|
|
const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode, dataFormat);
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'avg');
|
|
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
|
|
}
|
|
const avgPool3DConfig = {
|
|
kernelName: AvgPool3D,
|
|
backendName: 'cpu',
|
|
kernelFunc: avgPool3D
|
|
};
|
|
|
|
|
|
function avgPool3DGrad(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input } = inputs;
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
assertNotComplex([dy, input], 'avgPool3DGrad');
|
|
const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 , pad, dimRoundingMode);
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const filterDepth = convInfo.filterDepth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const dx = buffer(input.shape, 'float32');
|
|
const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
|
|
const dyBuf = backend.bufferSync(dy);
|
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
|
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
|
|
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
|
|
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
|
|
|
|
const dyDepthCorner = dxDepth - padFront;
|
|
const dyRowCorner = dxRow - padTop;
|
|
const dyColCorner = dxCol - padLeft;
|
|
let dotProd = 0;
|
|
for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
|
|
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
|
|
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
|
|
Math.floor(dyDepth) !== dyDepth) {
|
|
continue;
|
|
}
|
|
for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
|
|
const dyRow = (dyRowCorner + wRow) / strideHeight;
|
|
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
|
|
Math.floor(dyRow) !== dyRow) {
|
|
continue;
|
|
}
|
|
for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
|
|
const dyCol = (dyColCorner + wCol) / strideWidth;
|
|
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
|
|
Math.floor(dyCol) !== dyCol) {
|
|
continue;
|
|
}
|
|
const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
|
|
dotProd += pixel;
|
|
}
|
|
}
|
|
}
|
|
dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const avgPool3DGradConfig$1 = {
|
|
kernelName: AvgPool3DGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: avgPool3DGrad
|
|
};
|
|
|
|
|
|
function avgPoolGrad$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input } = inputs;
|
|
const x = input;
|
|
assertNotComplex([dy, input], 'avgPoolGrad');
|
|
const { filterSize, strides, pad } = attrs;
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad);
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const dx = buffer(x.shape, 'float32');
|
|
const avgMultiplier = 1 / (filterHeight * filterWidth);
|
|
const dyData = backend.data.get(dy.dataId).values;
|
|
const dyBuf = buffer(dy.shape, 'float32', dyData);
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
|
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
|
|
for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
|
|
|
|
const dyRCorner = dxR - padTop;
|
|
const dyCCorner = dxC - padLeft;
|
|
let dotProd = 0;
|
|
for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
|
|
const dyR = (dyRCorner + wR) / strideHeight;
|
|
if (dyR < 0 || dyR >= convInfo.outHeight ||
|
|
Math.floor(dyR) !== dyR) {
|
|
continue;
|
|
}
|
|
for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
|
|
const dyC = (dyCCorner + wC) / strideWidth;
|
|
if (dyC < 0 || dyC >= convInfo.outWidth ||
|
|
Math.floor(dyC) !== dyC) {
|
|
continue;
|
|
}
|
|
const pixel = dyBuf.get(b, dyR, dyC, d);
|
|
dotProd += pixel;
|
|
}
|
|
}
|
|
dx.set(dotProd * avgMultiplier, b, dxR, dxC, d);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const avgPoolGradConfig$1 = {
|
|
kernelName: AvgPoolGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: avgPoolGrad$1
|
|
};
|
|
|
|
|
|
function batchNorm(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, scale, offset, mean, variance } = inputs;
|
|
assert$1(mean.shape.length === variance.shape.length, () => 'Batch normalization gradient requires mean and variance to have ' +
|
|
'equal ranks.');
|
|
assert$1(offset == null || mean.shape.length === offset.shape.length, () => 'Batch normalization gradient requires mean and offset to have ' +
|
|
'equal ranks.');
|
|
assert$1(scale == null || mean.shape.length === scale.shape.length, () => 'Batch normalization gradient requires mean and scale to have ' +
|
|
'equal ranks.');
|
|
assertNotComplex([x, mean, variance, scale, offset], 'batchNorm');
|
|
let { varianceEpsilon } = attrs;
|
|
if (varianceEpsilon == null) {
|
|
varianceEpsilon = 0.001;
|
|
}
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const mVals = backend.data.get(mean.dataId).values;
|
|
const varVals = backend.data.get(variance.dataId).values;
|
|
const sVals = scale ? backend.data.get(scale.dataId).values :
|
|
new Float32Array([1]);
|
|
const offVals = offset ?
|
|
backend.data.get(offset.dataId).values :
|
|
new Float32Array([0]);
|
|
const outVals = new Float32Array(xVals.length);
|
|
const offValsLength = offVals.length;
|
|
const sValsLength = sVals.length;
|
|
const varValsLength = varVals.length;
|
|
const mValsLength = mVals.length;
|
|
let offi = 0;
|
|
let mi = 0;
|
|
let si = 0;
|
|
let vi = 0;
|
|
for (let i = 0; i < xVals.length; ++i) {
|
|
outVals[i] = offVals[offi++] +
|
|
(xVals[i] - mVals[mi++]) * sVals[si++] /
|
|
Math.sqrt(varVals[vi++] + varianceEpsilon);
|
|
if (offi >= offValsLength) {
|
|
offi = 0;
|
|
}
|
|
if (mi >= mValsLength) {
|
|
mi = 0;
|
|
}
|
|
if (si >= sValsLength) {
|
|
si = 0;
|
|
}
|
|
if (vi >= varValsLength) {
|
|
vi = 0;
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(x.shape, x.dtype, outVals);
|
|
}
|
|
const batchNormConfig = {
|
|
kernelName: FusedBatchNorm,
|
|
backendName: 'cpu',
|
|
kernelFunc: batchNorm,
|
|
};
|
|
|
|
|
|
function batchToSpaceND(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { blockShape, crops } = attrs;
|
|
assertNotComplex([x], 'batchToSpaceND');
|
|
const prod = blockShape.reduce((a, b) => a * b);
|
|
const reshaped = getReshaped(x.shape, blockShape, prod);
|
|
const permuted = getPermuted(reshaped.length, blockShape.length);
|
|
const reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod);
|
|
const sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length);
|
|
const sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length);
|
|
const xReshaped = reshape({ inputs: { x }, backend, attrs: { shape: reshaped } });
|
|
const xTransposed = transpose$1({ inputs: { x: xReshaped }, backend, attrs: { perm: permuted } });
|
|
const xTransposedReshaped = reshape({ inputs: { x: xTransposed }, backend, attrs: { shape: reshapedPermuted } });
|
|
const result = slice$1({
|
|
inputs: { x: xTransposedReshaped },
|
|
backend,
|
|
attrs: { begin: sliceBeginCoords, size: sliceSize }
|
|
});
|
|
backend.disposeIntermediateTensorInfo(xReshaped);
|
|
backend.disposeIntermediateTensorInfo(xTransposed);
|
|
backend.disposeIntermediateTensorInfo(xTransposedReshaped);
|
|
return result;
|
|
}
|
|
const batchToSpaceNDConfig = {
|
|
kernelName: BatchToSpaceND,
|
|
backendName: 'cpu',
|
|
kernelFunc: batchToSpaceND
|
|
};
|
|
|
|
|
|
function bincount(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, weights } = inputs;
|
|
const { size } = attrs;
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const weightsVals = backend.data.get(weights.dataId).values;
|
|
const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
|
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
|
}
|
|
const bincountConfig = {
|
|
kernelName: Bincount,
|
|
backendName: 'cpu',
|
|
kernelFunc: bincount
|
|
};
|
|
|
|
|
|
function broadcastArgs(args) {
|
|
const { inputs, backend } = args;
|
|
const { s0, s1 } = inputs;
|
|
const s0Vals = backend.data.get(s0.dataId).values;
|
|
const s1Vals = backend.data.get(s1.dataId).values;
|
|
const broadcastShape = assertAndGetBroadcastShape(Array.from(s0Vals), Array.from(s1Vals));
|
|
return backend.makeTensorInfo([broadcastShape.length], 'int32', Int32Array.from(broadcastShape));
|
|
}
|
|
const broadcastArgsConfig = {
|
|
kernelName: BroadcastArgs,
|
|
backendName: 'cpu',
|
|
kernelFunc: broadcastArgs
|
|
};
|
|
|
|
|
|
const clipByValue = unaryKernelFunc$1(ClipByValue, (xi, attrs) => {
|
|
const clipAttrs = attrs;
|
|
if (xi > clipAttrs.clipValueMax) {
|
|
return clipAttrs.clipValueMax;
|
|
}
|
|
return xi < clipAttrs.clipValueMin ? clipAttrs.clipValueMin : xi;
|
|
});
|
|
const clipByValueConfig = {
|
|
kernelName: ClipByValue,
|
|
backendName: 'cpu',
|
|
kernelFunc: clipByValue,
|
|
};
|
|
|
|
|
|
const complexAbs = (args) => {
|
|
const { x } = args.inputs;
|
|
const cpuBackend = args.backend;
|
|
const resultValues = new Float32Array(sizeFromShape(x.shape));
|
|
const complexVals = cpuBackend.data.get(x.dataId);
|
|
const real = complexVals.complexTensorInfos.real;
|
|
const imag = complexVals.complexTensorInfos.imag;
|
|
const realVals = cpuBackend.data.get(real.dataId).values;
|
|
const imagVals = cpuBackend.data.get(imag.dataId).values;
|
|
for (let i = 0; i < realVals.length; i++) {
|
|
const real = realVals[i];
|
|
const imag = imagVals[i];
|
|
resultValues[i] = Math.hypot(real, imag);
|
|
}
|
|
return cpuBackend.makeOutput(resultValues, x.shape, 'float32');
|
|
};
|
|
const complexAbsConfig = {
|
|
kernelName: ComplexAbs,
|
|
backendName: 'cpu',
|
|
kernelFunc: complexAbs,
|
|
};
|
|
|
|
|
|
function imag(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
const imag = backend.data.get(input.dataId).complexTensorInfos.imag;
|
|
const imagVal = backend.data.get(imag.dataId).values;
|
|
|
|
|
|
|
|
return backend.makeTensorInfo(imag.shape, imag.dtype, imagVal);
|
|
}
|
|
const imagConfig = {
|
|
kernelName: Imag,
|
|
backendName: 'cpu',
|
|
kernelFunc: imag
|
|
};
|
|
|
|
|
|
function concat(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { axis } = attrs;
|
|
const $axis = parseAxisParam(axis, inputs[0].shape)[0];
|
|
const shapes = inputs.map(t => t.shape);
|
|
assertParamsConsistent(shapes, $axis);
|
|
let outShape = computeOutShape$1(inputs.map(t => t.shape), $axis);
|
|
if (sizeFromShape(outShape) === 0) {
|
|
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
|
|
}
|
|
|
|
const $inputs = inputs.filter(t => sizeFromShape(t.shape) > 0);
|
|
if ($inputs.length === 1) {
|
|
return identity$1({ inputs: { x: $inputs[0] }, backend });
|
|
}
|
|
if ($inputs[0].dtype === 'complex64') {
|
|
const reals = $inputs.map((t) => real$1({ inputs: { input: t }, backend }));
|
|
const imags = $inputs.map((t) => imag({ inputs: { input: t }, backend }));
|
|
const realConcated = concat({ inputs: reals, backend, attrs: { axis: $axis } });
|
|
const imagConcated = concat({ inputs: imags, backend, attrs: { axis: $axis } });
|
|
const result = complex$1({ inputs: { real: realConcated, imag: imagConcated }, backend });
|
|
reals.forEach(r => backend.disposeIntermediateTensorInfo(r));
|
|
imags.forEach(i => backend.disposeIntermediateTensorInfo(i));
|
|
backend.disposeIntermediateTensorInfo(realConcated);
|
|
backend.disposeIntermediateTensorInfo(imagConcated);
|
|
return result;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const inputs2D = $inputs.map(t => {
|
|
const innerSize = sizeFromShape(t.shape.slice($axis));
|
|
const shape = [-1, innerSize];
|
|
return reshape({ inputs: { x: t }, backend, attrs: { shape } });
|
|
});
|
|
const inputsValShapes = inputs2D.map(t => {
|
|
return { vals: backend.data.get(t.dataId).values, shape: t.shape };
|
|
});
|
|
|
|
outShape =
|
|
computeOutShape$1(inputs2D.map(t => t.shape), 1 );
|
|
const simplyConcat = inputs2D[0].shape[0] === 1;
|
|
const outVals = concatImpl$1(inputsValShapes, outShape, inputs[0].dtype, simplyConcat);
|
|
const finalOutShape = computeOutShape$1($inputs.map(t => t.shape), $axis);
|
|
const outInfo = backend.makeTensorInfo(finalOutShape, inputs[0].dtype, outVals);
|
|
inputs2D.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return outInfo;
|
|
}
|
|
const concatConfig = {
|
|
kernelName: Concat,
|
|
backendName: 'cpu',
|
|
kernelFunc: concat
|
|
};
|
|
|
|
|
|
function conv2D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dataFormat, dilations, dimRoundingMode } = attrs;
|
|
assertNotComplex([x, filter], 'conv2d');
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false , $dataFormat);
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const padLeft = convInfo.padInfo.left;
|
|
const padTop = convInfo.padInfo.top;
|
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
const y = new TensorBuffer(convInfo.outShape, x.dtype);
|
|
const xStrides = computeStrides(x.shape);
|
|
const filterStrides = computeStrides(filter.shape);
|
|
const xBatchStride = xStrides[0];
|
|
const xRowStride = isChannelsLast ? xStrides[1] : xStrides[2];
|
|
const xColStride = isChannelsLast ? xStrides[2] : 1;
|
|
const xChannelStride = isChannelsLast ? 1 : xStrides[1];
|
|
const yBatchStride = y.strides[0];
|
|
const yRowStride = isChannelsLast ? y.strides[1] : y.strides[2];
|
|
const yColStride = isChannelsLast ? y.strides[2] : 1;
|
|
const yChannelStride = isChannelsLast ? 1 : y.strides[1];
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const wVals = backend.data.get(filter.dataId).values;
|
|
const yVals = y.values;
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
const xOffset1 = b * xBatchStride;
|
|
const yOffset1 = b * yBatchStride;
|
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
|
const yOffset2 = yOffset1 + yR * yRowStride;
|
|
const xRCorner = yR * convInfo.strideHeight - padTop;
|
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
|
const xR = xRCorner + wR * dilationHeight;
|
|
if (xR < 0 || xR >= convInfo.inHeight) {
|
|
continue;
|
|
}
|
|
const wOffset1 = wR * filterStrides[0];
|
|
const xOffset2 = xOffset1 + xR * xRowStride;
|
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
|
const yOffset3 = yOffset2 + yC * yColStride;
|
|
const xCCorner = yC * convInfo.strideWidth - padLeft;
|
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
|
const xC = xCCorner + wC * dilationWidth;
|
|
if (xC < 0 || xC >= convInfo.inWidth) {
|
|
continue;
|
|
}
|
|
const wOffset2 = wOffset1 + wC * filterStrides[1];
|
|
const xOffset3 = xOffset2 + xC * xColStride;
|
|
let wOffset3 = wOffset2;
|
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
|
const xVal = xVals[xOffset3 + d1 * xChannelStride];
|
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
|
yVals[yOffset3 + d2 * yChannelStride] +=
|
|
xVal * wVals[wOffset3 + d2];
|
|
}
|
|
wOffset3 += convInfo.outChannels;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(y.shape, y.dtype, yVals);
|
|
}
|
|
const conv2DConfig = {
|
|
kernelName: Conv2D,
|
|
backendName: 'cpu',
|
|
kernelFunc: conv2D
|
|
};
|
|
|
|
|
|
function conv2DBackpropFilter(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, dy } = inputs;
|
|
const { strides, pad, dataFormat, dimRoundingMode, filterShape } = attrs;
|
|
assertNotComplex([x, dy], 'conv2dBackpropFilter');
|
|
const $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, 1 , pad, dimRoundingMode, false , $dataFormat);
|
|
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
|
|
const isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
|
|
const leftPad = convInfo.padInfo.left;
|
|
const topPad = convInfo.padInfo.top;
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const dyVals = backend.data.get(dy.dataId).values;
|
|
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
|
|
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
|
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
|
let dotProd = 0;
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
|
const xR = wR + yR * strideHeight - topPad;
|
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
|
const xC = wC + yC * strideWidth - leftPad;
|
|
if (isChannelsLast) {
|
|
dotProd += xBuf.get(b, xR, xC, d1) *
|
|
dyBuf.get(b, yR, yC, d2);
|
|
}
|
|
else {
|
|
dotProd += xBuf.get(b, d1, xR, xC) *
|
|
dyBuf.get(b, d2, yR, yC);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
dW.set(dotProd, wR, wC, d1, d2);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
|
|
}
|
|
const conv2DBackpropFilterConfig = {
|
|
kernelName: Conv2DBackpropFilter,
|
|
backendName: 'cpu',
|
|
kernelFunc: conv2DBackpropFilter
|
|
};
|
|
|
|
|
|
function conv2DBackpropInput(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, filter } = inputs;
|
|
const { inputShape, strides, pad, dataFormat, dimRoundingMode } = attrs;
|
|
assertNotComplex([dy, filter], 'conv2dBackpropInput');
|
|
const filterStrides = computeStrides(filter.shape);
|
|
const dyStrides = computeStrides(dy.shape);
|
|
let $dataFormat = convertConv2DDataFormat(dataFormat);
|
|
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, 1 , pad, dimRoundingMode, false, $dataFormat);
|
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
|
const dxValues = dx.values;
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
const fltValues = backend.data.get(filter.dataId).values;
|
|
const [fltS0, fltS1, fltS2] = filterStrides;
|
|
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
|
|
$dataFormat = convInfo.dataFormat;
|
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
|
const isChannelsLast = $dataFormat === 'channelsLast';
|
|
const xBatchStride = dx.strides[0];
|
|
const xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2];
|
|
const xColStride = isChannelsLast ? dx.strides[2] : 1;
|
|
const xChannelStride = isChannelsLast ? 1 : dx.strides[1];
|
|
const yBatchStride = dyStrides[0];
|
|
const yRowStride = isChannelsLast ? dyStrides[1] : dyStrides[2];
|
|
const yColStride = isChannelsLast ? dyStrides[2] : 1;
|
|
const yChannelStride = isChannelsLast ? 1 : dyStrides[1];
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
|
for (let xR = 0; xR < inHeight; ++xR) {
|
|
const xRCorner = xR - topPad;
|
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
|
for (let xC = 0; xC < inWidth; ++xC) {
|
|
const xCCorner = xC - leftPad;
|
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
|
let dotProd = 0;
|
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
|
const wR = yR * strideHeight - xRCorner;
|
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
|
const wC = yC * strideWidth - xCCorner;
|
|
const dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC;
|
|
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
|
|
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
|
|
for (let d2 = 0; d2 < outChannels; ++d2) {
|
|
const pixel = dyValues[dyOffset + yChannelStride * d2];
|
|
const weight = fltValues[fltOffset + d2];
|
|
dotProd += pixel * weight;
|
|
}
|
|
}
|
|
}
|
|
const dxOffset = xBatchStride * b + xRowStride * xR +
|
|
xColStride * xC + xChannelStride * d1;
|
|
dxValues[dxOffset] = dotProd;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const conv2DBackpropInputConfig = {
|
|
kernelName: Conv2DBackpropInput,
|
|
backendName: 'cpu',
|
|
kernelFunc: conv2DBackpropInput
|
|
};
|
|
|
|
|
|
function conv3D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dilations } = attrs;
|
|
assertNotComplex([x, filter], 'conv3d');
|
|
const convInfo = computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
|
|
const { filterDepth, filterHeight, filterWidth, dilationDepth, dilationHeight, dilationWidth, padInfo } = convInfo;
|
|
const padFront = padInfo.front;
|
|
const padLeft = padInfo.left;
|
|
const padTop = padInfo.top;
|
|
const y = new TensorBuffer(convInfo.outShape, x.dtype);
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const wVals = backend.data.get(filter.dataId).values;
|
|
const yVals = y.values;
|
|
const xStrides = computeStrides(x.shape);
|
|
const filterStrides = computeStrides(filter.shape);
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
const xOffset1 = b * xStrides[0];
|
|
const yOffset1 = b * y.strides[0];
|
|
for (let yF = 0; yF < convInfo.outDepth; ++yF) {
|
|
const yOffset2 = yOffset1 + yF * y.strides[1];
|
|
const xFCorner = yF * convInfo.strideDepth - padFront;
|
|
for (let wF = 0; wF < filterDepth; ++wF) {
|
|
const xF = xFCorner + wF * dilationDepth;
|
|
if (xF < 0 || xF >= convInfo.inDepth) {
|
|
continue;
|
|
}
|
|
const wOffset1 = wF * filterStrides[0];
|
|
const xOffset2 = xOffset1 + xF * xStrides[1];
|
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
|
const yOffset3 = yOffset2 + yR * y.strides[2];
|
|
const xRCorner = yR * convInfo.strideHeight - padTop;
|
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
|
const xR = xRCorner + wR * dilationHeight;
|
|
if (xR < 0 || xR >= convInfo.inHeight) {
|
|
continue;
|
|
}
|
|
const wOffset2 = wOffset1 + wR * filterStrides[1];
|
|
const xOffset3 = xOffset2 + xR * xStrides[2];
|
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
|
const yOffset4 = yOffset3 + yC * convInfo.outChannels;
|
|
const xCCorner = yC * convInfo.strideWidth - padLeft;
|
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
|
const xC = xCCorner + wC * dilationWidth;
|
|
if (xC < 0 || xC >= convInfo.inWidth) {
|
|
continue;
|
|
}
|
|
const wOffset3 = wOffset2 + wC * filterStrides[2];
|
|
const xOffset4 = xOffset3 + xC * convInfo.inChannels;
|
|
let wOffset4 = wOffset3;
|
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
|
const xVal = xVals[xOffset4 + d1];
|
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
|
yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2];
|
|
}
|
|
wOffset4 += convInfo.outChannels;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
|
|
}
|
|
const conv3DConfig = {
|
|
kernelName: Conv3D,
|
|
backendName: 'cpu',
|
|
kernelFunc: conv3D
|
|
};
|
|
|
|
|
|
function conv3DBackpropFilterV2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, dy } = inputs;
|
|
const { strides, pad, filterShape } = attrs;
|
|
assertNotComplex([x, dy], 'conv3dBackpropFilterV2');
|
|
const xStrides = computeStrides(x.shape);
|
|
const dyStrides = computeStrides(dy.shape);
|
|
const convInfo = computeConv3DInfo(x.shape, filterShape, strides, 1 , pad);
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const filterDepth = convInfo.filterDepth;
|
|
const filterHeight = convInfo.filterHeight;
|
|
const filterWidth = convInfo.filterWidth;
|
|
const dw = new TensorBuffer(convInfo.filterShape, 'float32');
|
|
const dwValues = dw.values;
|
|
const [dwS0, dwS1, dwS2, dwS3] = dw.strides;
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const [xS0, xS1, xS2, xS3] = xStrides;
|
|
const frontPad = convInfo.padInfo.front;
|
|
const leftPad = convInfo.padInfo.left;
|
|
const topPad = convInfo.padInfo.top;
|
|
for (let wF = 0; wF < filterDepth; ++wF) {
|
|
const yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth));
|
|
const yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth);
|
|
const wOffset1 = wF * dwS0;
|
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
|
const wOffset2 = wR * dwS1 + wOffset1;
|
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
|
const wOffset3 = wC * dwS2 + wOffset2;
|
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
|
const wOffset4 = d1 * dwS3 + wOffset3;
|
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
|
let dotProd = 0;
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
const xOffset1 = b * xS0;
|
|
const yOffset1 = b * dyS0;
|
|
for (let yF = yFMin; yF < yFMax; ++yF) {
|
|
const xF = wF + yF * strideDepth - frontPad;
|
|
const xOffset2 = xF * xS1 + xOffset1;
|
|
const yOffset2 = yF * dyS1 + yOffset1;
|
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
|
const xR = wR + yR * strideHeight - topPad;
|
|
const xOffset3 = xR * xS2 + xOffset2;
|
|
const yOffset3 = yR * dyS2 + yOffset2;
|
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
|
const xC = wC + yC * strideWidth - leftPad;
|
|
const xOffset4 = xC * xS3 + xOffset3;
|
|
const yOffset4 = yC * dyS3 + yOffset3;
|
|
dotProd += xValues[xOffset4 + d1] * dyValues[yOffset4 + d2];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
dwValues[wOffset4 + d2] = dotProd;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dw.shape, dw.dtype, dw.values);
|
|
}
|
|
const conv3DBackpropFilterV2Config = {
|
|
kernelName: Conv3DBackpropFilterV2,
|
|
backendName: 'cpu',
|
|
kernelFunc: conv3DBackpropFilterV2
|
|
};
|
|
|
|
|
|
function conv3DBackpropInputV2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, filter } = inputs;
|
|
const { pad, strides, inputShape } = attrs;
|
|
assertNotComplex([dy], 'conv3dBackpropInputV2');
|
|
const dyStrides = computeStrides(dy.shape);
|
|
const filterStrides = computeStrides(filter.shape);
|
|
const convInfo = computeConv3DInfo(inputShape, filter.shape, strides, 1 , pad);
|
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
|
const dxValues = dx.values;
|
|
const [dxS0, dxS1, dxS2, dxS3] = dx.strides;
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
const [dyS0, dyS1, dyS2, dyS3] = dyStrides;
|
|
const fltValues = backend.data.get(filter.dataId).values;
|
|
const [fltS0, fltS1, fltS2, fltS3] = filterStrides;
|
|
const { batchSize, filterDepth, filterHeight, filterWidth, inChannels, inDepth, inHeight, inWidth, outChannels, outDepth, outHeight, outWidth, strideDepth, strideHeight, strideWidth } = convInfo;
|
|
const frontPad = filterDepth - 1 - convInfo.padInfo.front;
|
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
|
|
|
for (let xF = 0; xF < inDepth; ++xF) {
|
|
const xFCorner = xF - frontPad;
|
|
const xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth));
|
|
const yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth);
|
|
|
|
for (let xR = 0; xR < inHeight; ++xR) {
|
|
const xRCorner = xR - topPad;
|
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
|
|
|
for (let xC = 0; xC < inWidth; ++xC) {
|
|
const xCCorner = xC - leftPad;
|
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
|
let dotProd = 0;
|
|
for (let yF = xFMin; yF < yFMax; ++yF) {
|
|
const wF = yF * strideDepth - xFCorner;
|
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
|
const wR = yR * strideHeight - xRCorner;
|
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
|
const wC = yC * strideWidth - xCCorner;
|
|
const dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC;
|
|
const fltOffset = fltS0 * (filterDepth - 1 - wF) +
|
|
fltS1 * (filterHeight - 1 - wR) +
|
|
fltS2 * (filterWidth - 1 - wC) + fltS3 * d1;
|
|
for (let d2 = 0; d2 < outChannels; ++d2) {
|
|
const pixel = dyValues[dyOffset + d2];
|
|
const weight = fltValues[fltOffset + d2];
|
|
dotProd += pixel * weight;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] =
|
|
dotProd;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const conv3DBackpropInputV2Config = {
|
|
kernelName: Conv3DBackpropInputV2,
|
|
backendName: 'cpu',
|
|
kernelFunc: conv3DBackpropInputV2
|
|
};
|
|
|
|
|
|
const cos = unaryKernelFunc$1(Cos, (xi) => Math.cos(xi));
|
|
const cosConfig = {
|
|
kernelName: Cos,
|
|
backendName: 'cpu',
|
|
kernelFunc: cos,
|
|
};
|
|
|
|
|
|
const cosh = unaryKernelFunc$1(Cosh, (xi) => Math.cosh(xi));
|
|
const coshConfig = {
|
|
kernelName: Cosh,
|
|
backendName: 'cpu',
|
|
kernelFunc: cosh,
|
|
};
|
|
|
|
|
|
function cropAndResize(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { image, boxes, boxInd } = inputs;
|
|
const { cropSize, method, extrapolationValue } = attrs;
|
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
|
const numBoxes = boxes.shape[0];
|
|
const [cropHeight, cropWidth] = cropSize;
|
|
const output = buffer([numBoxes, cropHeight, cropWidth, numChannels], 'float32');
|
|
const boxVals = backend.data.get(boxes.dataId).values;
|
|
const boxIndVals = backend.data.get(boxInd.dataId).values;
|
|
const imageVals = backend.data.get(image.dataId).values;
|
|
const inStride = computeStrides(image.shape);
|
|
const outStride = computeStrides(output.shape);
|
|
|
|
|
|
|
|
for (let b = 0; b < numBoxes; b++) {
|
|
const startInd = b * 4;
|
|
const y1 = boxVals[startInd];
|
|
const x1 = boxVals[startInd + 1];
|
|
const y2 = boxVals[startInd + 2];
|
|
const x2 = boxVals[startInd + 3];
|
|
const bInd = boxIndVals[b];
|
|
if (bInd >= batch) {
|
|
continue;
|
|
}
|
|
const heightScale = (cropHeight > 1) ? (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : 0;
|
|
const widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0;
|
|
for (let y = 0; y < cropHeight; y++) {
|
|
const yInd = (cropHeight > 1) ?
|
|
y1 * (imageHeight - 1) + y * (heightScale) :
|
|
0.5 * (y1 + y2) * (imageHeight - 1);
|
|
if (yInd < 0 || yInd > imageHeight - 1) {
|
|
for (let x = 0; x < cropWidth; x++) {
|
|
for (let c = 0; c < numChannels; c++) {
|
|
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
|
output.values[ind] = extrapolationValue;
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
if (method === 'bilinear') {
|
|
const topInd = Math.floor(yInd);
|
|
const bottomInd = Math.ceil(yInd);
|
|
const yLerp = yInd - topInd;
|
|
for (let x = 0; x < cropWidth; x++) {
|
|
const xInd = (cropWidth > 1) ?
|
|
x1 * (imageWidth - 1) + x * widthScale :
|
|
0.5 * (x1 + x2) * (imageWidth - 1);
|
|
if (xInd < 0 || xInd > imageWidth - 1) {
|
|
for (let c = 0; c < numChannels; c++) {
|
|
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
|
output.values[ind] = extrapolationValue;
|
|
}
|
|
continue;
|
|
}
|
|
const leftInd = Math.floor(xInd);
|
|
const rightInd = Math.ceil(xInd);
|
|
const xLerp = xInd - leftInd;
|
|
for (let c = 0; c < numChannels; c++) {
|
|
let ind = c + leftInd * inStride[2] + topInd * inStride[1] +
|
|
bInd * inStride[0];
|
|
const topLeft = imageVals[ind];
|
|
ind = c + rightInd * inStride[2] + topInd * inStride[1] +
|
|
bInd * inStride[0];
|
|
const topRight = imageVals[ind];
|
|
ind = c + leftInd * inStride[2] + bottomInd * inStride[1] +
|
|
bInd * inStride[0];
|
|
const bottomLeft = imageVals[ind];
|
|
ind = c + rightInd * inStride[2] + bottomInd * inStride[1] +
|
|
bInd * inStride[0];
|
|
const bottomRight = imageVals[ind];
|
|
const top = topLeft + (topRight - topLeft) * xLerp;
|
|
const bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp;
|
|
ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
|
output.values[ind] = top + ((bottom - top) * yLerp);
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
for (let x = 0; x < cropWidth; ++x) {
|
|
const xInd = (cropWidth > 1) ?
|
|
x1 * (imageWidth - 1) + x * widthScale :
|
|
0.5 * (x1 + x2) * (imageWidth - 1);
|
|
if (xInd < 0 || xInd > imageWidth - 1) {
|
|
for (let c = 0; c < numChannels; c++) {
|
|
const ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
|
output.values[ind] = extrapolationValue;
|
|
}
|
|
continue;
|
|
}
|
|
const closestX = Math.round(xInd);
|
|
const closestY = Math.round(yInd);
|
|
for (let c = 0; c < numChannels; c++) {
|
|
const inInd = c + closestX * inStride[2] + closestY * inStride[1] +
|
|
bInd * inStride[0];
|
|
const outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0];
|
|
output.values[outInd] = imageVals[inInd];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(output.shape, output.dtype, output.values);
|
|
}
|
|
const cropAndResizeConfig = {
|
|
kernelName: CropAndResize,
|
|
backendName: 'cpu',
|
|
kernelFunc: cropAndResize
|
|
};
|
|
|
|
|
|
function cumprod(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, exclusive, reverse } = attrs;
|
|
assertNotComplex(x, 'cumprod');
|
|
const permutation = getAxesPermutation([axis], x.shape.length);
|
|
let $x = x;
|
|
if (permutation != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
|
|
}
|
|
const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
|
|
if (permutedAxis !== $x.shape.length - 1) {
|
|
throw new Error(`backend.cumprod in CPU expects an inner-most ` +
|
|
`axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
|
|
}
|
|
const resultDtype = upcastType($x.dtype, 'int32');
|
|
const vals = makeOnesTypedArray(sizeFromShape($x.shape), resultDtype);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
const finalDim = $x.shape[$x.shape.length - 1];
|
|
const indexAdjuster = reverse ?
|
|
(i, j) => i + finalDim - j - 1 :
|
|
(i, j) => i + j;
|
|
for (let i = 0; i < aVals.length; i += finalDim) {
|
|
for (let j = 0; j < finalDim; j++) {
|
|
const idx = indexAdjuster(i, j);
|
|
if (j === 0) {
|
|
vals[idx] = exclusive ? 1 : aVals[idx];
|
|
}
|
|
else {
|
|
const prevIdx = indexAdjuster(i, j - 1);
|
|
vals[idx] = exclusive ? aVals[prevIdx] * vals[prevIdx] :
|
|
aVals[idx] * vals[prevIdx];
|
|
}
|
|
}
|
|
}
|
|
const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
|
|
if (permutation != null) {
|
|
const reversePermutation = getUndoAxesPermutation(permutation);
|
|
const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
return reverseTransposedResult;
|
|
}
|
|
return result;
|
|
}
|
|
const cumprodConfig = {
|
|
kernelName: Cumprod,
|
|
backendName: 'cpu',
|
|
kernelFunc: cumprod
|
|
};
|
|
|
|
|
|
function cumsum(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, exclusive, reverse } = attrs;
|
|
assertNotComplex(x, 'cumsum');
|
|
const permutation = getAxesPermutation([axis], x.shape.length);
|
|
let $x = x;
|
|
if (permutation != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutation } });
|
|
}
|
|
const permutedAxis = getInnerMostAxes(1, x.shape.length)[0];
|
|
if (permutedAxis !== $x.shape.length - 1) {
|
|
throw new Error(`backend.cumsum in CPU expects an inner-most ` +
|
|
`axis=${$x.shape.length - 1} but got axis=${permutedAxis}`);
|
|
}
|
|
const resultDtype = upcastType($x.dtype, 'int32');
|
|
const vals = makeZerosTypedArray(sizeFromShape($x.shape), resultDtype);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
const finalDim = $x.shape[$x.shape.length - 1];
|
|
const indexAdjuster = reverse ?
|
|
(i, j) => i + finalDim - j - 1 :
|
|
(i, j) => i + j;
|
|
for (let i = 0; i < aVals.length; i += finalDim) {
|
|
for (let j = 0; j < finalDim; j++) {
|
|
const idx = indexAdjuster(i, j);
|
|
if (j === 0) {
|
|
vals[idx] = exclusive ? 0 : aVals[idx];
|
|
}
|
|
else {
|
|
const prevIdx = indexAdjuster(i, j - 1);
|
|
vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] :
|
|
aVals[idx] + vals[prevIdx];
|
|
}
|
|
}
|
|
}
|
|
const result = backend.makeTensorInfo($x.shape, resultDtype, vals);
|
|
if (permutation != null) {
|
|
const reversePermutation = getUndoAxesPermutation(permutation);
|
|
const reverseTransposedResult = transpose$1({ inputs: { x: result }, backend, attrs: { perm: reversePermutation } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
return reverseTransposedResult;
|
|
}
|
|
return result;
|
|
}
|
|
const cumsumConfig = {
|
|
kernelName: Cumsum,
|
|
backendName: 'cpu',
|
|
kernelFunc: cumsum
|
|
};
|
|
|
|
|
|
function denseBincount(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, weights } = inputs;
|
|
const { size, binaryOutput } = attrs;
|
|
if (x.shape.length === 1) {
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const weightsVals = backend.data.get(weights.dataId).values;
|
|
const outVals = bincountImpl(xVals, weightsVals, weights.dtype, weights.shape, size);
|
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
|
}
|
|
else if (x.shape.length === 2) {
|
|
const xBuf = backend.bufferSync(x);
|
|
const weightsBuf = backend.bufferSync(weights);
|
|
const outBuf = bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput);
|
|
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
|
|
}
|
|
throw new Error(`Error in denseBincount: input must be at most rank 2, but got rank` +
|
|
`${x.shape.length}.`);
|
|
}
|
|
const denseBincountConfig = {
|
|
kernelName: DenseBincount,
|
|
backendName: 'cpu',
|
|
kernelFunc: denseBincount
|
|
};
|
|
|
|
|
|
function depthToSpace(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { blockSize, dataFormat } = attrs;
|
|
assert$1(dataFormat === 'NHWC', () => `Only NHWC dataFormat supported on CPU for depthToSpace. Got ${dataFormat}`);
|
|
const batchSize = x.shape[0];
|
|
const inputHeight = x.shape[1];
|
|
const inputWidth = x.shape[2];
|
|
const inputDepth = x.shape[3];
|
|
const outputHeight = inputHeight * blockSize;
|
|
const outputWidth = inputWidth * blockSize;
|
|
const outputDepth = inputDepth / (blockSize * blockSize);
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth);
|
|
let outputIdx = 0;
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let h = 0; h < outputHeight; ++h) {
|
|
const inH = Math.floor(h / blockSize);
|
|
const offsetH = (h % blockSize);
|
|
for (let w = 0; w < outputWidth; ++w) {
|
|
const inW = Math.floor(w / blockSize);
|
|
const offsetW = (w % blockSize);
|
|
const offsetD = (offsetH * blockSize + offsetW) * outputDepth;
|
|
for (let d = 0; d < outputDepth; ++d) {
|
|
const inD = d + offsetD;
|
|
const inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b));
|
|
result[outputIdx++] = xValues[inputIdx];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo([batchSize, outputHeight, outputWidth, outputDepth], x.dtype, result);
|
|
}
|
|
const depthToSpaceConfig = {
|
|
kernelName: DepthToSpace,
|
|
backendName: 'cpu',
|
|
kernelFunc: depthToSpace
|
|
};
|
|
|
|
|
|
function depthwiseConv2dNative(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dilations, dimRoundingMode } = attrs;
|
|
assertNotComplex([x, filter], 'depthwiseConv2DNative');
|
|
const xStrides = computeStrides(x.shape);
|
|
const filterStrides = computeStrides(filter.shape);
|
|
let $dilations = dilations;
|
|
if ($dilations == null) {
|
|
$dilations = [1, 1];
|
|
}
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
|
`1. Got strides ${strides} and dilations '${$dilations}'`);
|
|
const convInfo = computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true );
|
|
const { filterHeight, filterWidth, dilationHeight, dilationWidth, padInfo } = convInfo;
|
|
const padLeft = padInfo.left;
|
|
const padTop = padInfo.top;
|
|
const chMul = convInfo.outChannels / convInfo.inChannels;
|
|
const y = new TensorBuffer(convInfo.outShape, x.dtype);
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const wVals = backend.data.get(filter.dataId).values;
|
|
const yVals = y.values;
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
const xOffset1 = b * xStrides[0];
|
|
const yOffset1 = b * y.strides[0];
|
|
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
|
|
const yOffset2 = yOffset1 + yR * y.strides[1];
|
|
const xRCorner = yR * convInfo.strideHeight - padTop;
|
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
|
const xR = xRCorner + wR * dilationHeight;
|
|
if (xR < 0 || xR >= convInfo.inHeight) {
|
|
continue;
|
|
}
|
|
const wOffset1 = wR * filterStrides[0];
|
|
const xOffset2 = xOffset1 + xR * xStrides[1];
|
|
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
|
|
const yOffset3 = yOffset2 + yC * y.strides[2];
|
|
const xCCorner = yC * convInfo.strideWidth - padLeft;
|
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
|
const xC = xCCorner + wC * dilationWidth;
|
|
if (xC < 0 || xC >= convInfo.inWidth) {
|
|
continue;
|
|
}
|
|
const wOffset2 = wOffset1 + wC * filterStrides[1];
|
|
const xOffset3 = xOffset2 + xC * convInfo.inChannels;
|
|
let yOffset4 = yOffset3;
|
|
let wOffset3 = wOffset2;
|
|
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
|
|
const xVal = xVals[xOffset3 + d1];
|
|
for (let q = 0; q < chMul; ++q) {
|
|
yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q];
|
|
}
|
|
yOffset4 += chMul;
|
|
wOffset3 += chMul;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(y.shape, y.dtype, y.values);
|
|
}
|
|
const depthwiseConv2dNativeConfig = {
|
|
kernelName: DepthwiseConv2dNative,
|
|
backendName: 'cpu',
|
|
kernelFunc: depthwiseConv2dNative
|
|
};
|
|
|
|
|
|
function depthwiseConv2dNativeBackpropFilter(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, dy } = inputs;
|
|
const { strides, dilations, pad, dimRoundingMode, filterShape } = attrs;
|
|
assertNotComplex([x, dy], 'depthwiseConv2dNativeBackpropFilter');
|
|
const convInfo = computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true );
|
|
const { strideHeight, strideWidth, filterHeight, filterWidth } = convInfo;
|
|
const dW = new TensorBuffer(convInfo.filterShape, 'float32');
|
|
const leftPad = convInfo.padInfo.left;
|
|
const topPad = convInfo.padInfo.top;
|
|
const chMul = convInfo.outChannels / convInfo.inChannels;
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const xBuf = new TensorBuffer(x.shape, x.dtype, xVals);
|
|
const dyVals = backend.data.get(dy.dataId).values;
|
|
const dyBuf = new TensorBuffer(dy.shape, dy.dtype, dyVals);
|
|
for (let wR = 0; wR < filterHeight; ++wR) {
|
|
const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight));
|
|
const yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight);
|
|
for (let wC = 0; wC < filterWidth; ++wC) {
|
|
const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth));
|
|
const yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth);
|
|
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
|
|
const d1 = Math.trunc(d2 / chMul);
|
|
const dm = d2 % chMul;
|
|
let dotProd = 0;
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
for (let yR = yRMin; yR < yRMax; ++yR) {
|
|
const xR = wR + yR * strideHeight - topPad;
|
|
for (let yC = yCMin; yC < yCMax; ++yC) {
|
|
const xC = wC + yC * strideWidth - leftPad;
|
|
dotProd += xBuf.get(b, xR, xC, d1) *
|
|
dyBuf.get(b, yR, yC, d2);
|
|
}
|
|
}
|
|
}
|
|
dW.set(dotProd, wR, wC, d1, dm);
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dW.shape, dW.dtype, dW.values);
|
|
}
|
|
const depthwiseConv2dNativeBackpropFilterConfig = {
|
|
kernelName: DepthwiseConv2dNativeBackpropFilter,
|
|
backendName: 'cpu',
|
|
kernelFunc: depthwiseConv2dNativeBackpropFilter
|
|
};
|
|
|
|
|
|
function depthwiseConv2dNativeBackpropInput(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, filter } = inputs;
|
|
const { strides, dilations, pad, dimRoundingMode, inputShape } = attrs;
|
|
assertNotComplex([dy, filter], 'depthwiseConv2DNativeBackpropInput');
|
|
const dyStrides = computeStrides(dy.shape);
|
|
const filterStrides = computeStrides(filter.shape);
|
|
const convInfo = computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true );
|
|
const dx = new TensorBuffer(convInfo.inShape, 'float32');
|
|
const dxValues = dx.values;
|
|
const [dxS0, dxS1, dxS2] = dx.strides;
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
const [dyS0, dyS1, dyS2] = dyStrides;
|
|
const fltValues = backend.data.get(filter.dataId).values;
|
|
const [fltS0, fltS1, fltS2] = filterStrides;
|
|
const { batchSize, filterHeight, filterWidth, inChannels, inHeight, inWidth, outChannels, outHeight, outWidth, strideHeight, strideWidth } = convInfo;
|
|
const topPad = filterHeight - 1 - convInfo.padInfo.top;
|
|
const leftPad = filterWidth - 1 - convInfo.padInfo.left;
|
|
const chMul = outChannels / inChannels;
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let d1 = 0; d1 < inChannels; ++d1) {
|
|
for (let xR = 0; xR < inHeight; ++xR) {
|
|
const xRCorner = xR - topPad;
|
|
const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight));
|
|
const yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight);
|
|
for (let xC = 0; xC < inWidth; ++xC) {
|
|
const xCCorner = xC - leftPad;
|
|
const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth));
|
|
const yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth);
|
|
let dotProd = 0;
|
|
for (let yR = xRMin; yR < yRMax; ++yR) {
|
|
const wR = yR * strideHeight - xRCorner;
|
|
for (let yC = xCMin; yC < yCMax; ++yC) {
|
|
const wC = yC * strideWidth - xCCorner;
|
|
const dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC;
|
|
const fltOffset = fltS0 * (filterHeight - 1 - wR) +
|
|
fltS1 * (filterWidth - 1 - wC) + fltS2 * d1;
|
|
for (let dm = 0; dm < chMul; ++dm) {
|
|
const d2 = d1 * chMul + dm;
|
|
const pixel = dyValues[dyOffset + d2];
|
|
const weight = fltValues[fltOffset + dm];
|
|
dotProd += pixel * weight;
|
|
}
|
|
}
|
|
}
|
|
dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const depthwiseConv2dNativeBackpropInputConfig = {
|
|
kernelName: DepthwiseConv2dNativeBackpropInput,
|
|
backendName: 'cpu',
|
|
kernelFunc: depthwiseConv2dNativeBackpropInput
|
|
};
|
|
|
|
|
|
function diag(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
const xSize = sizeFromShape(x.shape);
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const outBuf = buffer([xSize, xSize], x.dtype);
|
|
const vals = outBuf.values;
|
|
for (let i = 0; i < xVals.length; i++) {
|
|
vals[i * xSize + i] = xVals[i];
|
|
}
|
|
const outShape = [...x.shape, ...x.shape];
|
|
return backend.makeTensorInfo(outShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const diagConfig = {
|
|
kernelName: Diag,
|
|
backendName: 'cpu',
|
|
kernelFunc: diag
|
|
};
|
|
|
|
|
|
const dilation2DConfig = {
|
|
kernelName: Dilation2D,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, backend, attrs }) => {
|
|
const { x, filter } = inputs;
|
|
const { strides, pad, dilations } = attrs;
|
|
const cpuBackend = backend;
|
|
const xVals = cpuBackend.data.get(x.dataId).values;
|
|
const xRank = x.shape.length;
|
|
const filterVals = cpuBackend.data.get(filter.dataId).values;
|
|
const filterRank = filter.shape.length;
|
|
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
|
|
const outSize = sizeFromShape(outShape);
|
|
const outRank = outShape.length;
|
|
const outputVals = getArrayFromDType(x.dtype, outSize);
|
|
|
|
|
|
|
|
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let hOut = 0; hOut < outHeight; ++hOut) {
|
|
const hBeg = hOut * strideHeight - padInfo.top;
|
|
for (let wOut = 0; wOut < outWidth; ++wOut) {
|
|
const wBeg = wOut * strideWidth - padInfo.left;
|
|
for (let d = 0; d < inChannels; ++d) {
|
|
let curVal = Number.MIN_SAFE_INTEGER;
|
|
for (let h = 0; h < filterHeight; ++h) {
|
|
const hIn = hBeg + h * dilationHeight;
|
|
if (hIn >= 0 && hIn < inHeight) {
|
|
for (let w = 0; w < filterWidth; ++w) {
|
|
const wIn = wBeg + w * dilationWidth;
|
|
if (wIn >= 0 && wIn < inWidth) {
|
|
const xIndex = locToIndex([b, hIn, wIn, d], xRank, computeStrides(x.shape));
|
|
const filterIndex = locToIndex([h, w, d], filterRank, computeStrides(filter.shape));
|
|
const val = xVals[xIndex] + filterVals[filterIndex];
|
|
if (val > curVal) {
|
|
curVal = val;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const outputIndex = locToIndex([b, hOut, wOut, d], outRank, computeStrides(outShape));
|
|
outputVals[outputIndex] = curVal;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const dataId = cpuBackend.write(toTypedArray(outputVals, x.dtype), outShape, x.dtype);
|
|
return { dataId, shape: outShape, dtype: x.dtype };
|
|
}
|
|
};
|
|
|
|
|
|
const dilation2DBackpropFilterConfig = {
|
|
kernelName: Dilation2DBackpropFilter,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, backend, attrs }) => {
|
|
const { x, filter, dy } = inputs;
|
|
const { strides, pad, dilations } = attrs;
|
|
const cpuBackend = backend;
|
|
const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
|
|
const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
|
|
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
|
|
assert$1(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropFilter}, dy ` +
|
|
`must have the same rank as output ${outShape.length}, but got ` +
|
|
`${dy.rank}`);
|
|
const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
|
|
|
|
|
|
const gradients = makeZerosNestedTypedArray(filter.shape, filter.dtype);
|
|
|
|
|
|
|
|
|
|
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let hOut = 0; hOut < outHeight; ++hOut) {
|
|
const hBeg = hOut * strideHeight - padInfo.top;
|
|
for (let wOut = 0; wOut < outWidth; ++wOut) {
|
|
const wBeg = wOut * strideWidth - padInfo.left;
|
|
for (let d = 0; d < inChannels; ++d) {
|
|
let curVal = Number.MIN_SAFE_INTEGER;
|
|
let hMax = 0;
|
|
let wMax = 0;
|
|
for (let h = 0; h < filterHeight; ++h) {
|
|
const hIn = hBeg + h * dilationHeight;
|
|
if (hIn >= 0 && hIn < inHeight) {
|
|
for (let w = 0; w < filterWidth; ++w) {
|
|
const wIn = wBeg + w * dilationWidth;
|
|
if (wIn >= 0 && wIn < inWidth) {
|
|
const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
|
|
if (val > curVal) {
|
|
curVal = val;
|
|
hMax = h;
|
|
wMax = w;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), filter.shape, filter.dtype);
|
|
return { dataId, shape: filter.shape, dtype: filter.dtype };
|
|
}
|
|
};
|
|
|
|
|
|
const dilation2DBackpropInputConfig = {
|
|
kernelName: Dilation2DBackpropInput,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, backend, attrs }) => {
|
|
const { x, filter, dy } = inputs;
|
|
const { strides, pad, dilations } = attrs;
|
|
const cpuBackend = backend;
|
|
const $x = toNestedArray(x.shape, cpuBackend.data.get(x.dataId).values);
|
|
const $filter = toNestedArray(filter.shape, cpuBackend.data.get(filter.dataId).values);
|
|
const { batchSize, inHeight, inWidth, inChannels, outHeight, outWidth, padInfo, strideHeight, strideWidth, filterHeight, filterWidth, dilationHeight, dilationWidth, outShape } = computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' , dilations);
|
|
assert$1(dy.rank === outShape.length, () => `Error in ${Dilation2DBackpropInput}, dy ` +
|
|
`must have the same rank as output ${outShape.length}, but got ` +
|
|
`${dy.rank}`);
|
|
const $dy = toNestedArray(outShape, cpuBackend.data.get(dy.dataId).values);
|
|
|
|
|
|
const gradients = makeZerosNestedTypedArray(x.shape, x.dtype);
|
|
|
|
|
|
|
|
|
|
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
for (let hOut = 0; hOut < outHeight; ++hOut) {
|
|
const hBeg = hOut * strideHeight - padInfo.top;
|
|
for (let wOut = 0; wOut < outWidth; ++wOut) {
|
|
const wBeg = wOut * strideWidth - padInfo.left;
|
|
for (let d = 0; d < inChannels; ++d) {
|
|
let curVal = Number.MIN_SAFE_INTEGER;
|
|
let hInMax = (hBeg < 0) ? 0 : hBeg;
|
|
let wInMax = (wBeg < 0) ? 0 : wBeg;
|
|
for (let h = 0; h < filterHeight; ++h) {
|
|
const hIn = hBeg + h * dilationHeight;
|
|
if (hIn >= 0 && hIn < inHeight) {
|
|
for (let w = 0; w < filterWidth; ++w) {
|
|
const wIn = wBeg + w * dilationWidth;
|
|
if (wIn >= 0 && wIn < inWidth) {
|
|
const val = $x[b][hIn][wIn][d] + $filter[h][w][d];
|
|
if (val > curVal) {
|
|
curVal = val;
|
|
hInMax = hIn;
|
|
wInMax = wIn;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const dataId = cpuBackend.write(toTypedArray(gradients, x.dtype), x.shape, x.dtype);
|
|
return { dataId, shape: x.shape, dtype: x.dtype };
|
|
}
|
|
};
|
|
|
|
|
|
function draw(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { image } = inputs;
|
|
const { canvas, options } = attrs;
|
|
const { contextOptions, imageOptions } = options || {};
|
|
const alpha = (imageOptions === null || imageOptions === void 0 ? void 0 : imageOptions.alpha) || 1;
|
|
const contextType = (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextType) || '2d';
|
|
if (contextType !== '2d') {
|
|
throw new Error(`Context type ${contextOptions.contextType} is not supported by the CPU backend.`);
|
|
}
|
|
const ctx = canvas.getContext(contextType, (contextOptions === null || contextOptions === void 0 ? void 0 : contextOptions.contextAttributes) || {});
|
|
if (ctx == null) {
|
|
throw new Error(`Could not get the context with ${contextType} type.`);
|
|
}
|
|
const [height, width] = image.shape.slice(0, 2);
|
|
const depth = image.shape.length === 2 ? 1 : image.shape[2];
|
|
const data = backend.data.get(image.dataId).values;
|
|
const multiplier = image.dtype === 'float32' ? 255 : 1;
|
|
const bytes = new Uint8ClampedArray(width * height * 4);
|
|
for (let i = 0; i < height * width; ++i) {
|
|
const rgba = [0, 0, 0, 255 * alpha];
|
|
for (let d = 0; d < depth; d++) {
|
|
const value = data[i * depth + d];
|
|
if (image.dtype === 'float32') {
|
|
if (value < 0 || value > 1) {
|
|
throw new Error(`Tensor values for a float32 Tensor must be in the ` +
|
|
`range [0 - 1] but encountered ${value}.`);
|
|
}
|
|
}
|
|
else if (image.dtype === 'int32') {
|
|
if (value < 0 || value > 255) {
|
|
throw new Error(`Tensor values for a int32 Tensor must be in the ` +
|
|
`range [0 - 255] but encountered ${value}.`);
|
|
}
|
|
}
|
|
if (depth === 1) {
|
|
rgba[0] = value * multiplier;
|
|
rgba[1] = value * multiplier;
|
|
rgba[2] = value * multiplier;
|
|
}
|
|
else {
|
|
rgba[d] = value * multiplier;
|
|
}
|
|
}
|
|
const j = i * 4;
|
|
bytes[j + 0] = Math.round(rgba[0]);
|
|
bytes[j + 1] = Math.round(rgba[1]);
|
|
bytes[j + 2] = Math.round(rgba[2]);
|
|
bytes[j + 3] = Math.round(rgba[3]);
|
|
}
|
|
canvas.width = width;
|
|
canvas.height = height;
|
|
const imageData = new ImageData(bytes, width, height);
|
|
ctx.putImageData(imageData, 0, 0);
|
|
return image;
|
|
}
|
|
const drawConfig = {
|
|
kernelName: Draw,
|
|
backendName: 'cpu',
|
|
kernelFunc: draw
|
|
};
|
|
|
|
|
|
function sum(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
assertNotComplex(x, 'sum');
|
|
let $x;
|
|
if (x.dtype === 'bool') {
|
|
$x = cast$2({ inputs: { x }, backend, attrs: { dtype: 'int32' } });
|
|
}
|
|
else {
|
|
$x = identity$1({ inputs: { x }, backend });
|
|
}
|
|
const xRank = $x.shape.length;
|
|
const axes = parseAxisParam(axis, $x.shape);
|
|
const permutation = getAxesPermutation(axes, xRank);
|
|
let reductionAxes = axes;
|
|
let permutedX = $x;
|
|
if (permutation != null) {
|
|
permutedX =
|
|
transpose$1({ inputs: { x: $x }, backend, attrs: { perm: permutation } });
|
|
reductionAxes = getInnerMostAxes(reductionAxes.length, xRank);
|
|
}
|
|
assertAxesAreInnerMostDims('sum', reductionAxes, permutedX.shape.length);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes(permutedX.shape, reductionAxes);
|
|
const resultDtype = upcastType(permutedX.dtype, 'int32');
|
|
let result = zeros(backend, outShape, resultDtype);
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const vals = backend.data.get(result.dataId).values;
|
|
const aVals = backend.data.get(permutedX.dataId).values;
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let sum = 0;
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
sum += aVals[offset + j];
|
|
}
|
|
vals[i] = sum;
|
|
}
|
|
if (keepDims) {
|
|
const newShape = expandShapeToKeepDim(result.shape, axes);
|
|
const oldResult = result;
|
|
result = reshape({ inputs: { x: result }, backend, attrs: { shape: newShape } });
|
|
backend.disposeIntermediateTensorInfo(oldResult);
|
|
}
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
if (permutation != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return result;
|
|
}
|
|
const sumConfig = {
|
|
kernelName: Sum,
|
|
backendName: 'cpu',
|
|
kernelFunc: sum
|
|
};
|
|
|
|
|
|
function einsum(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { equation } = attrs;
|
|
const tensors = inputs;
|
|
const { allDims, summedDims, idDims } = decodeEinsumEquation(equation, tensors.length);
|
|
checkEinsumDimSizes(allDims.length, idDims, tensors);
|
|
const { path, steps } = getEinsumComputePath(summedDims, idDims);
|
|
const nSteps = steps.length;
|
|
let out = null;
|
|
let numDimsRemaining = allDims.length;
|
|
const tensorsToDispose = [];
|
|
for (let i = 0; i < nSteps; ++i) {
|
|
for (const idTerm of steps[i]) {
|
|
const { permutationIndices: perm, expandDims: dimsToExpand } = getEinsumPermutation(numDimsRemaining, idDims[idTerm]);
|
|
let x;
|
|
if (isIdentityPermutation(perm)) {
|
|
x = tensors[idTerm];
|
|
}
|
|
else {
|
|
x = transpose$1({ inputs: { x: tensors[idTerm] }, backend, attrs: { perm } });
|
|
tensorsToDispose.push(x);
|
|
}
|
|
const targetShape = x.shape.slice();
|
|
for (let k = 0; k < dimsToExpand.length; ++k) {
|
|
targetShape.splice(dimsToExpand[k], 0, 1);
|
|
}
|
|
if (!arraysEqual(x.shape, targetShape)) {
|
|
x = reshape({ inputs: { x }, backend, attrs: { shape: targetShape } });
|
|
tensorsToDispose.push(x);
|
|
}
|
|
if (out === null) {
|
|
out = x;
|
|
}
|
|
else {
|
|
|
|
out = multiply$1({ inputs: { a: x, b: out }, backend });
|
|
tensorsToDispose.push(out);
|
|
}
|
|
}
|
|
if (i < nSteps - 1) {
|
|
if (path[i] >= 0) {
|
|
out = sum({
|
|
inputs: { x: out },
|
|
backend,
|
|
attrs: {
|
|
axis: path[i] - (allDims.length - numDimsRemaining),
|
|
keepDims: false
|
|
}
|
|
});
|
|
tensorsToDispose.push(out);
|
|
}
|
|
numDimsRemaining--;
|
|
}
|
|
}
|
|
|
|
for (const tensorInfo of tensorsToDispose) {
|
|
if (tensorInfo === out) {
|
|
continue;
|
|
}
|
|
backend.disposeIntermediateTensorInfo(tensorInfo);
|
|
}
|
|
return out;
|
|
}
|
|
const einsumConfig = {
|
|
kernelName: Einsum,
|
|
backendName: 'cpu',
|
|
kernelFunc: einsum
|
|
};
|
|
|
|
|
|
function eluGrad(args) {
|
|
const { inputs, backend } = args;
|
|
const { dy, y } = inputs;
|
|
assertNotComplex([dy, y], 'eluGrad');
|
|
const resultValues = new Float32Array(sizeFromShape(y.shape));
|
|
const values = backend.data.get(y.dataId).values;
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
for (let i = 0; i < values.length; ++i) {
|
|
const v = values[i];
|
|
if (v >= 0) {
|
|
resultValues[i] = dyValues[i];
|
|
}
|
|
else {
|
|
resultValues[i] = dyValues[i] * (v + 1);
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(y.shape, 'float32', resultValues);
|
|
}
|
|
const eluGradConfig$1 = {
|
|
kernelName: EluGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: eluGrad
|
|
};
|
|
|
|
|
|
const p = ERF_P;
|
|
const a1 = ERF_A1;
|
|
const a2 = ERF_A2;
|
|
const a3 = ERF_A3;
|
|
const a4 = ERF_A4;
|
|
const a5 = ERF_A5;
|
|
const erf = unaryKernelFunc$1(Erf, (xi) => {
|
|
const sign = Math.sign(xi);
|
|
const v = Math.abs(xi);
|
|
const t = 1.0 / (1.0 + p * v);
|
|
return sign *
|
|
(1.0 -
|
|
(((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
|
|
Math.exp(-v * v));
|
|
});
|
|
const erfConfig = {
|
|
kernelName: Erf,
|
|
backendName: 'cpu',
|
|
kernelFunc: erf,
|
|
};
|
|
|
|
|
|
function expandDims$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { input } = inputs;
|
|
const { dim } = attrs;
|
|
const inputRank = input.shape.length;
|
|
const newShape = input.shape.slice();
|
|
let $dim = dim;
|
|
if (dim < 0) {
|
|
|
|
assert$1(-(inputRank + 1) <= dim, () => `Axis must be in the interval [${-(inputRank + 1)}, ${inputRank}]`);
|
|
$dim = inputRank + dim + 1;
|
|
}
|
|
newShape.splice($dim, 0, 1);
|
|
return reshape({ inputs: { x: input }, backend, attrs: { shape: newShape } });
|
|
}
|
|
const expandDimsConfig = {
|
|
kernelName: ExpandDims,
|
|
backendName: 'cpu',
|
|
kernelFunc: expandDims$1
|
|
};
|
|
|
|
|
|
const realDivImpl = createSimpleBinaryKernelImpl((a, b) => a / b);
|
|
const div = binaryKernelFunc$1(RealDiv, realDivImpl);
|
|
const realDivConfig = {
|
|
kernelName: RealDiv,
|
|
backendName: 'cpu',
|
|
kernelFunc: div
|
|
};
|
|
|
|
|
|
|
|
function fftBatch(input, inverse, cpuBackend) {
|
|
const inputShape = input.shape;
|
|
const batch = inputShape[0];
|
|
const innerDim = inputShape[1];
|
|
const inputVals = cpuBackend.data.get(input.dataId);
|
|
const real2D = inputVals.complexTensorInfos.real;
|
|
const imag2D = inputVals.complexTensorInfos.imag;
|
|
|
|
const resultShape = [batch, innerDim];
|
|
const resultSize = sizeFromShape(resultShape);
|
|
const resultReal = getTypedArrayFromDType('float32', resultSize);
|
|
const resultImag = getTypedArrayFromDType('float32', resultSize);
|
|
for (let b = 0; b < batch; b++) {
|
|
|
|
const r = slice$1({
|
|
inputs: { x: real2D },
|
|
backend: cpuBackend,
|
|
attrs: { begin: [b, 0], size: [1, innerDim] }
|
|
});
|
|
const i = slice$1({
|
|
inputs: { x: imag2D },
|
|
backend: cpuBackend,
|
|
attrs: { begin: [b, 0], size: [1, innerDim] }
|
|
});
|
|
const input = complex$1({ inputs: { real: r, imag: i }, backend: cpuBackend });
|
|
|
|
const { real, imag } = fftImpl(input, inverse, cpuBackend);
|
|
const res = mergeRealAndImagArrays(real, imag);
|
|
for (let d = 0; d < innerDim; d++) {
|
|
const c = getComplexWithIndex(res, d);
|
|
resultReal[b * innerDim + d] = c.real;
|
|
resultImag[b * innerDim + d] = c.imag;
|
|
}
|
|
cpuBackend.disposeIntermediateTensorInfo(r);
|
|
cpuBackend.disposeIntermediateTensorInfo(i);
|
|
cpuBackend.disposeIntermediateTensorInfo(input);
|
|
}
|
|
const $realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultReal);
|
|
const $imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', resultImag);
|
|
const result = complex$1({ inputs: { real: $realInfo, imag: $imagInfo }, backend: cpuBackend });
|
|
cpuBackend.disposeIntermediateTensorInfo($realInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($imagInfo);
|
|
return result;
|
|
}
|
|
function fftImpl(input, inverse, cpuBackend) {
|
|
const inputSize = sizeFromShape(input.shape);
|
|
const inputVals = cpuBackend.data.get(input.dataId);
|
|
const realVals = cpuBackend.data.get(inputVals.complexTensorInfos.real.dataId).values;
|
|
const imagVals = cpuBackend.data.get(inputVals.complexTensorInfos.imag.dataId).values;
|
|
if (isExponentOf2(inputSize)) {
|
|
const result = fftRadix2(realVals, imagVals, inputSize, inverse, cpuBackend);
|
|
const resultShape = [input.shape[0], input.shape[1]];
|
|
if (inverse) {
|
|
const realInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.real);
|
|
const imagInfo = cpuBackend.makeTensorInfo(resultShape, 'float32', result.imag);
|
|
const sizeInfo = cpuBackend.makeTensorInfo([], 'float32', createScalarValue(inputSize, 'float32'));
|
|
const sizeInfoCopy = identity$1({ inputs: { x: sizeInfo }, backend: cpuBackend });
|
|
const divRealInfo = realDivConfig.kernelFunc({ inputs: { a: realInfo, b: sizeInfo }, backend: cpuBackend });
|
|
const divImagInfo = realDivConfig.kernelFunc({ inputs: { a: imagInfo, b: sizeInfoCopy }, backend: cpuBackend });
|
|
const divRealVals = cpuBackend.data.get(divRealInfo.dataId).values;
|
|
const divImagVals = cpuBackend.data.get(divImagInfo.dataId).values;
|
|
cpuBackend.disposeIntermediateTensorInfo(realInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(imagInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(sizeInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(sizeInfoCopy);
|
|
cpuBackend.disposeIntermediateTensorInfo(divRealInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(divImagInfo);
|
|
return { real: divRealVals, imag: divImagVals };
|
|
}
|
|
return result;
|
|
}
|
|
else {
|
|
const data = mergeRealAndImagArrays(realVals, imagVals);
|
|
const rawOutput = fourierTransformByMatmul(data, inputSize, inverse);
|
|
return splitRealAndImagArrays(rawOutput);
|
|
}
|
|
}
|
|
function isExponentOf2(size) {
|
|
return (size & size - 1) === 0;
|
|
}
|
|
|
|
function fftRadix2(realVals, imagVals, size, inverse, cpuBackend) {
|
|
if (size === 1) {
|
|
return { real: realVals, imag: imagVals };
|
|
}
|
|
const data = mergeRealAndImagArrays(realVals, imagVals);
|
|
const half = size / 2;
|
|
const evenComplex = complexWithEvenIndex(data);
|
|
const evenRealVals = evenComplex.real;
|
|
const evenImagVals = evenComplex.imag;
|
|
const evenShape = [evenRealVals.length];
|
|
const evenRealInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenRealVals);
|
|
const evenImagInfo = cpuBackend.makeTensorInfo(evenShape, 'float32', evenImagVals);
|
|
const evenTensorInfo = complex$1({ inputs: { real: evenRealInfo, imag: evenImagInfo }, backend: cpuBackend });
|
|
const oddComplex = complexWithOddIndex(data);
|
|
const oddRealVals = oddComplex.real;
|
|
const oddImagVals = oddComplex.imag;
|
|
const oddShape = [oddRealVals.length];
|
|
const oddRealInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddRealVals);
|
|
const oddImagInfo = cpuBackend.makeTensorInfo(oddShape, 'float32', oddImagVals);
|
|
const oddTensorInfo = complex$1({ inputs: { real: oddRealInfo, imag: oddImagInfo }, backend: cpuBackend });
|
|
|
|
const $evenComplex = fftRadix2(evenRealVals, evenImagVals, half, inverse, cpuBackend);
|
|
const $evenRealVals = $evenComplex.real;
|
|
const $evenImagVals = $evenComplex.imag;
|
|
const $evenShape = [$evenRealVals.length];
|
|
const $evenRealInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenRealVals);
|
|
const $evenImagInfo = cpuBackend.makeTensorInfo($evenShape, 'float32', $evenImagVals);
|
|
const $evenTensorInfo = complex$1({
|
|
inputs: { real: $evenRealInfo, imag: $evenImagInfo },
|
|
backend: cpuBackend
|
|
});
|
|
const $oddComplex = fftRadix2(oddRealVals, oddImagVals, half, inverse, cpuBackend);
|
|
const $oddRealVals = $oddComplex.real;
|
|
const $oddImagVals = $oddComplex.imag;
|
|
const $oddShape = [$oddRealVals.length];
|
|
const $oddRealInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddRealVals);
|
|
const $oddImagInfo = cpuBackend.makeTensorInfo($oddShape, 'float32', $oddImagVals);
|
|
const $oddTensorInfo = complex$1({ inputs: { real: $oddRealInfo, imag: $oddImagInfo }, backend: cpuBackend });
|
|
const e = exponents(size, inverse);
|
|
const eShape = [e.real.length];
|
|
const eRealInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.real);
|
|
const eImagInfo = cpuBackend.makeTensorInfo(eShape, 'float32', e.imag);
|
|
const complexInfo = complex$1({ inputs: { real: eRealInfo, imag: eImagInfo }, backend: cpuBackend });
|
|
const exponentInfo = multiply$1({ inputs: { a: complexInfo, b: $oddTensorInfo }, backend: cpuBackend });
|
|
const addPart = add({
|
|
inputs: { a: $evenTensorInfo, b: exponentInfo },
|
|
backend: cpuBackend
|
|
});
|
|
const subPart = sub$1({
|
|
inputs: { a: $evenTensorInfo, b: exponentInfo },
|
|
backend: cpuBackend
|
|
});
|
|
const addPartReal = real$1({ inputs: { input: addPart }, backend: cpuBackend });
|
|
const subPartReal = real$1({ inputs: { input: subPart }, backend: cpuBackend });
|
|
const addPartImag = imag({ inputs: { input: addPart }, backend: cpuBackend });
|
|
const subPartImag = imag({ inputs: { input: subPart }, backend: cpuBackend });
|
|
const $real = concat({
|
|
inputs: [addPartReal, subPartReal],
|
|
backend: cpuBackend,
|
|
attrs: { axis: 0 }
|
|
});
|
|
const $imag = concat({
|
|
inputs: [addPartImag, subPartImag],
|
|
backend: cpuBackend,
|
|
attrs: { axis: 0 }
|
|
});
|
|
const $realVals = cpuBackend.data.get($real.dataId).values;
|
|
const $imagVals = cpuBackend.data.get($imag.dataId).values;
|
|
cpuBackend.disposeIntermediateTensorInfo(evenRealInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(evenImagInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(evenTensorInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(oddRealInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(oddImagInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(oddTensorInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($evenRealInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($evenImagInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($evenTensorInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($oddRealInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($oddImagInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo($oddTensorInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(eRealInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(eImagInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(complexInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(exponentInfo);
|
|
cpuBackend.disposeIntermediateTensorInfo(addPart);
|
|
cpuBackend.disposeIntermediateTensorInfo(subPart);
|
|
cpuBackend.disposeIntermediateTensorInfo(addPartReal);
|
|
cpuBackend.disposeIntermediateTensorInfo(addPartImag);
|
|
cpuBackend.disposeIntermediateTensorInfo(subPartReal);
|
|
cpuBackend.disposeIntermediateTensorInfo(subPartImag);
|
|
cpuBackend.disposeIntermediateTensorInfo($real);
|
|
cpuBackend.disposeIntermediateTensorInfo($imag);
|
|
return { real: $realVals, imag: $imagVals };
|
|
}
|
|
|
|
function fourierTransformByMatmul(data, size, inverse) {
|
|
const ret = new Float32Array(size * 2);
|
|
|
|
for (let r = 0; r < size; r++) {
|
|
let real = 0.0;
|
|
let imag = 0.0;
|
|
for (let c = 0; c < size; c++) {
|
|
const e = exponent(r * c, size, inverse);
|
|
const term = getComplexWithIndex(data, c);
|
|
real += term.real * e.real - term.imag * e.imag;
|
|
imag += term.real * e.imag + term.imag * e.real;
|
|
}
|
|
if (inverse) {
|
|
real /= size;
|
|
imag /= size;
|
|
}
|
|
assignToTypedArray(ret, real, imag, r);
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
|
|
function fft(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
const inputSize = sizeFromShape(input.shape);
|
|
|
|
const innerDimensionSize = input.shape[input.shape.length - 1];
|
|
const batch = inputSize / innerDimensionSize;
|
|
const input2D = reshape({
|
|
inputs: { x: input },
|
|
backend,
|
|
attrs: { shape: [batch, innerDimensionSize] }
|
|
});
|
|
const result = fftBatch(input2D, false, backend);
|
|
const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
|
|
backend.disposeIntermediateTensorInfo(input2D);
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return resultReshaped;
|
|
}
|
|
const fftConfig = {
|
|
kernelName: FFT,
|
|
backendName: 'cpu',
|
|
kernelFunc: fft
|
|
};
|
|
|
|
|
|
function fill(args) {
|
|
const { backend, attrs } = args;
|
|
const { shape, value, dtype } = attrs;
|
|
const $dtype = dtype || inferDtype(value);
|
|
const values = getArrayFromDType($dtype, sizeFromShape(shape));
|
|
fillValues(values, value, $dtype);
|
|
return backend.makeTensorInfo(shape, $dtype, values);
|
|
}
|
|
const fillConfig = {
|
|
kernelName: Fill,
|
|
backendName: 'cpu',
|
|
kernelFunc: fill
|
|
};
|
|
function fillValues(values, value, dtype) {
|
|
if (dtype === 'string') {
|
|
values.fill(value);
|
|
}
|
|
else {
|
|
values.fill(value);
|
|
}
|
|
}
|
|
|
|
|
|
const flipLeftRightConfig = {
|
|
kernelName: FlipLeftRight,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
|
const { image } = inputs;
|
|
const cpuBackend = backend;
|
|
const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
|
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
|
const imageVals = cpuBackend.data.get(image.dataId).values;
|
|
for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
|
|
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
|
|
for (let row = 0; row < imageHeight; row++) {
|
|
const rowOffset = row * (imageWidth * numChannels);
|
|
for (let col = 0; col < imageWidth; col++) {
|
|
const colOffset = col * numChannels;
|
|
for (let channel = 0; channel < numChannels; channel++) {
|
|
const coordX = Math.round(imageWidth - col - 1);
|
|
const outIdx = batchOffset + rowOffset + colOffset + channel;
|
|
let outputValue = imageVals[outIdx];
|
|
|
|
if (coordX >= 0 && coordX < imageWidth) {
|
|
|
|
const rotatedColOffset = coordX * numChannels;
|
|
const imageIdx = batchOffset + rowOffset + rotatedColOffset + channel;
|
|
outputValue = imageVals[imageIdx];
|
|
}
|
|
output[outIdx] = outputValue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const dataId = cpuBackend.write(output, image.shape, image.dtype);
|
|
return { dataId, shape: image.shape, dtype: image.dtype };
|
|
}
|
|
};
|
|
|
|
|
|
function fusedConv2D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
|
let result = conv2D({
|
|
inputs: { x, filter },
|
|
backend,
|
|
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
|
|
});
|
|
if (bias) {
|
|
const resultOld = result;
|
|
|
|
|
|
|
|
|
|
if (dataFormat === 'NCHW' && bias.shape.length === 1 &&
|
|
bias.shape[0] !== 1) {
|
|
const reshapedBias = reshape({ inputs: { x: bias }, backend, attrs: { shape: [bias.shape[0], 1, 1] } });
|
|
result =
|
|
add({ inputs: { a: result, b: reshapedBias }, backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedBias);
|
|
}
|
|
else {
|
|
|
|
|
|
result = add({ inputs: { a: result, b: bias }, backend });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(resultOld);
|
|
}
|
|
if (activation) {
|
|
const resultOld = result;
|
|
|
|
|
|
|
|
|
|
if (dataFormat === 'NCHW' && activation === 'prelu' &&
|
|
preluActivationWeights.shape.length === 1 &&
|
|
preluActivationWeights.shape[0] !== 1) {
|
|
const reshapedAlpha = reshape({
|
|
inputs: { x: preluActivationWeights },
|
|
backend,
|
|
attrs: { shape: [preluActivationWeights.shape[0], 1, 1] }
|
|
});
|
|
result = applyActivation(backend, result, activation, reshapedAlpha, leakyreluAlpha);
|
|
backend.disposeIntermediateTensorInfo(reshapedAlpha);
|
|
}
|
|
else {
|
|
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
|
|
}
|
|
backend.disposeIntermediateTensorInfo(resultOld);
|
|
}
|
|
return result;
|
|
}
|
|
const fusedConv2DConfig = {
|
|
kernelName: FusedConv2D,
|
|
backendName: 'cpu',
|
|
kernelFunc: fusedConv2D
|
|
};
|
|
|
|
|
|
function fusedDepthwiseConv2D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, filter, bias, preluActivationWeights } = inputs;
|
|
const { strides, pad, dataFormat, dilations, dimRoundingMode, activation, leakyreluAlpha } = attrs;
|
|
let result = depthwiseConv2dNative({
|
|
inputs: { x, filter },
|
|
backend,
|
|
attrs: { strides, pad, dataFormat, dilations, dimRoundingMode }
|
|
});
|
|
if (bias) {
|
|
const oldResult = result;
|
|
result = add({ inputs: { a: result, b: bias }, backend });
|
|
backend.disposeIntermediateTensorInfo(oldResult);
|
|
}
|
|
if (activation) {
|
|
const oldResult = result;
|
|
result = applyActivation(backend, result, activation, preluActivationWeights, leakyreluAlpha);
|
|
backend.disposeIntermediateTensorInfo(oldResult);
|
|
}
|
|
return result;
|
|
}
|
|
const fusedDepthwiseConv2DConfig = {
|
|
kernelName: FusedDepthwiseConv2D,
|
|
backendName: 'cpu',
|
|
kernelFunc: fusedDepthwiseConv2D
|
|
};
|
|
|
|
|
|
function gatherNd(args) {
|
|
const { inputs, backend } = args;
|
|
const { params, indices } = inputs;
|
|
const paramsSize = sizeFromShape(params.shape);
|
|
const indicesShape = indices.shape;
|
|
const sliceRank = indicesShape[indicesShape.length - 1];
|
|
const [resultShape, numSlices, sliceSize, strides] = prepareAndValidate(params, indices);
|
|
if (numSlices === 0) {
|
|
return backend.makeTensorInfo(resultShape, params.dtype, []);
|
|
}
|
|
const indicesData = backend.data.get(indices.dataId).values;
|
|
const paramsBuf = backend.bufferSync(params);
|
|
const outBuf = gatherNdImpl(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
|
|
return backend.makeTensorInfo(resultShape, params.dtype, outBuf.values);
|
|
}
|
|
const gatherNdConfig = {
|
|
kernelName: GatherNd,
|
|
backendName: 'cpu',
|
|
kernelFunc: gatherNd
|
|
};
|
|
|
|
|
|
function gatherV2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, indices } = inputs;
|
|
const { axis, batchDims } = attrs;
|
|
assertNotComplex([x, indices], 'gatherV2');
|
|
|
|
const parsedAxis = parseAxisParam(axis, x.shape)[0];
|
|
const indicesVals = backend.data.get(indices.dataId).values;
|
|
const axisDim = x.shape[parsedAxis];
|
|
for (let i = 0; i < indicesVals.length; ++i) {
|
|
const index = indicesVals[i];
|
|
assert$1(index <= axisDim - 1 && index >= 0, () => `GatherV2: the index value ${index} is not in [0, ${axisDim - 1}]`);
|
|
}
|
|
let $batchDims = batchDims;
|
|
if (batchDims == null) {
|
|
$batchDims = 0;
|
|
}
|
|
const indicesSize = sizeFromShape(indices.shape);
|
|
const shapeInfo = collectGatherOpShapeInfo(x, indices, parsedAxis, $batchDims);
|
|
const flattenX = reshape({
|
|
inputs: { x },
|
|
backend,
|
|
attrs: {
|
|
shape: [
|
|
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
|
|
shapeInfo.sliceSize
|
|
]
|
|
}
|
|
});
|
|
const flattenIndex = reshape({
|
|
inputs: { x: indices },
|
|
backend,
|
|
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
|
|
});
|
|
const flattenOutputShape = [
|
|
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
|
|
shapeInfo.sliceSize
|
|
];
|
|
const indicesBuf = backend.bufferSync(flattenIndex);
|
|
const xBuf = backend.bufferSync(flattenX);
|
|
const outBuf = gatherV2Impl(xBuf, indicesBuf, flattenOutputShape);
|
|
backend.disposeIntermediateTensorInfo(flattenX);
|
|
backend.disposeIntermediateTensorInfo(flattenIndex);
|
|
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const gatherV2Config = {
|
|
kernelName: GatherV2,
|
|
backendName: 'cpu',
|
|
kernelFunc: gatherV2
|
|
};
|
|
|
|
|
|
function ifft(args) {
|
|
const { inputs, backend } = args;
|
|
const { input } = inputs;
|
|
const inputSize = sizeFromShape(input.shape);
|
|
|
|
const innerDimensionSize = input.shape[input.shape.length - 1];
|
|
const batch = inputSize / innerDimensionSize;
|
|
const input2D = reshape({
|
|
inputs: { x: input },
|
|
backend,
|
|
attrs: { shape: [batch, innerDimensionSize] }
|
|
});
|
|
const result = fftBatch(input2D, true, backend);
|
|
const resultReshaped = reshape({ inputs: { x: result }, backend, attrs: { shape: input.shape } });
|
|
backend.disposeIntermediateTensorInfo(input2D);
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return resultReshaped;
|
|
}
|
|
const ifftConfig = {
|
|
kernelName: IFFT,
|
|
backendName: 'cpu',
|
|
kernelFunc: ifft
|
|
};
|
|
|
|
|
|
const isFinite$1 = unaryKernelFunc$1(IsFinite, (xi) => Number.isFinite(xi) ? 1 : 0, 'bool');
|
|
const isFiniteConfig = {
|
|
kernelName: IsFinite,
|
|
backendName: 'cpu',
|
|
kernelFunc: isFinite$1,
|
|
};
|
|
|
|
|
|
const isInf = unaryKernelFunc$1(IsInf, (xi) => Math.abs(xi) === Infinity ? 1 : 0, 'bool');
|
|
const isInfConfig = {
|
|
kernelName: IsInf,
|
|
backendName: 'cpu',
|
|
kernelFunc: isInf,
|
|
};
|
|
|
|
|
|
const isNaN$1 = unaryKernelFunc$1(IsNan, (xi) => Number.isNaN(xi) ? 1 : 0, 'bool');
|
|
const isNaNConfig = {
|
|
kernelName: IsNan,
|
|
backendName: 'cpu',
|
|
kernelFunc: isNaN$1,
|
|
};
|
|
|
|
|
|
function linSpace(args) {
|
|
const { backend, attrs } = args;
|
|
const { start, stop, num } = attrs;
|
|
const outVals = linSpaceImpl(start, stop, num);
|
|
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
|
|
}
|
|
const linSpaceConfig = {
|
|
kernelName: LinSpace,
|
|
backendName: 'cpu',
|
|
kernelFunc: linSpace
|
|
};
|
|
|
|
|
|
const log1p = unaryKernelFunc$1(Log1p, (xi) => Math.log1p(xi));
|
|
const log1pConfig = {
|
|
kernelName: Log1p,
|
|
backendName: 'cpu',
|
|
kernelFunc: log1p,
|
|
};
|
|
|
|
|
|
const logicalAndImpl = createSimpleBinaryKernelImpl((a, b) => a && b);
|
|
const logicalAnd = binaryKernelFunc$1(LogicalAnd, logicalAndImpl, null , 'bool');
|
|
const logicalAndConfig = {
|
|
kernelName: LogicalAnd,
|
|
backendName: 'cpu',
|
|
kernelFunc: logicalAnd
|
|
};
|
|
|
|
|
|
const logicalNot = unaryKernelFunc$1(LogicalNot, (xi) => xi ? 0 : 1, 'bool');
|
|
const logicalNotConfig = {
|
|
kernelName: LogicalNot,
|
|
backendName: 'cpu',
|
|
kernelFunc: logicalNot,
|
|
};
|
|
|
|
|
|
const logicalOrImpl = createSimpleBinaryKernelImpl((a, b) => a || b);
|
|
const logicalOr = binaryKernelFunc$1(LogicalOr, logicalOrImpl, null , 'bool');
|
|
const logicalOrConfig = {
|
|
kernelName: LogicalOr,
|
|
backendName: 'cpu',
|
|
kernelFunc: logicalOr
|
|
};
|
|
|
|
|
|
function lRN(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { depthRadius, bias, alpha, beta } = attrs;
|
|
assertNotComplex(x, 'LRN');
|
|
const channels = x.shape[3];
|
|
const maxD = channels - 1;
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const size = sizeFromShape(x.shape);
|
|
const result = new Float32Array(size);
|
|
function sumAcrossChannels(offset) {
|
|
const currentChannel = offset % channels;
|
|
let beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius);
|
|
const endSumOffset = offset - currentChannel + Math.min(currentChannel + depthRadius, maxD);
|
|
let sum = 0.0;
|
|
for (; beginSumOffset <= endSumOffset; beginSumOffset++) {
|
|
const z = xValues[beginSumOffset];
|
|
sum += z * z;
|
|
}
|
|
return sum;
|
|
}
|
|
for (let offset = 0; offset < size; offset++) {
|
|
const sum = sumAcrossChannels(offset);
|
|
const val = xValues[offset] * Math.pow(bias + alpha * sum, -beta);
|
|
result[offset] = val;
|
|
}
|
|
return backend.makeTensorInfo(x.shape, x.dtype, result);
|
|
}
|
|
|
|
const LRNConfig = {
|
|
kernelName: LRN,
|
|
backendName: 'cpu',
|
|
kernelFunc: lRN
|
|
};
|
|
|
|
|
|
function lRNGrad(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, y, dy } = inputs;
|
|
const { depthRadius, bias, alpha, beta } = attrs;
|
|
assertNotComplex(dy, 'LRNGrad');
|
|
const dySize = sizeFromShape(dy.shape);
|
|
const channels = dy.shape[3];
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const yValues = backend.data.get(y.dataId).values;
|
|
const result = new Float32Array(dySize);
|
|
const size = dySize;
|
|
for (let offset = 0; offset < size; offset++) {
|
|
const currentChannel = offset % channels;
|
|
const depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius);
|
|
const depthEnd = (offset - currentChannel) +
|
|
Math.min(channels, currentChannel + depthRadius + 1);
|
|
let norm = 0;
|
|
for (let k = depthBegin; k < depthEnd; k++) {
|
|
norm += Math.pow(xValues[k], 2);
|
|
}
|
|
norm = alpha * norm + bias;
|
|
for (let k = depthBegin; k < depthEnd; k++) {
|
|
let dyi = -2 * alpha * beta * xValues[k] * yValues[offset] / norm;
|
|
if (offset === k) {
|
|
dyi += Math.pow(norm, -beta);
|
|
}
|
|
dyi *= dyValues[offset];
|
|
result[k] += dyi;
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dy.shape, x.dtype, result);
|
|
}
|
|
|
|
const LRNGradConfig = {
|
|
kernelName: LRNGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: lRNGrad
|
|
};
|
|
|
|
|
|
function max(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { reductionIndices, keepDims } = attrs;
|
|
const cpuBackend = backend;
|
|
let xShape = x.shape;
|
|
const xRank = xShape.length;
|
|
const origAxes = parseAxisParam(reductionIndices, xShape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, xRank);
|
|
let xVals = cpuBackend.data.get(x.dataId).values;
|
|
if (permutedAxes != null) {
|
|
const newShape = new Array(xRank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = xShape[permutedAxes[i]];
|
|
}
|
|
xVals = transposeImpl$1(xVals, xShape, x.dtype, permutedAxes, newShape);
|
|
axes = getInnerMostAxes(axes.length, xRank);
|
|
xShape = newShape;
|
|
}
|
|
assertNotComplex(x, 'max');
|
|
assertAxesAreInnerMostDims('max', axes, xRank);
|
|
const [maxOutShape, reduceShape] = computeOutAndReduceShapes(xShape, axes);
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const result = maxImpl$1(xVals, reduceSize, maxOutShape, x.dtype);
|
|
const dataId = cpuBackend.write(result, maxOutShape, x.dtype);
|
|
let outShape = maxOutShape;
|
|
if (keepDims) {
|
|
|
|
const newShape = expandShapeToKeepDim(maxOutShape, origAxes);
|
|
outShape = newShape;
|
|
}
|
|
return { dataId, shape: outShape, dtype: x.dtype };
|
|
}
|
|
const maxConfig = {
|
|
kernelName: Max,
|
|
backendName: 'cpu',
|
|
kernelFunc: max
|
|
};
|
|
|
|
|
|
function maxPool(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
assertNotComplex(x, 'maxPool');
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const dilations = 1;
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, dilations), () => 'Error in maxPool: Either strides or dilations must be 1. ' +
|
|
`Got strides ${strides} and dilations '${dilations}'`);
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
let res;
|
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
|
arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
|
res = identity$1({ inputs: { x }, backend });
|
|
}
|
|
else {
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const strides = computeStrides(x.shape);
|
|
const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'max');
|
|
res = backend.makeTensorInfo(convInfo.outShape, x.dtype, buffer.values);
|
|
}
|
|
return res;
|
|
}
|
|
const maxPoolConfig = {
|
|
kernelName: MaxPool,
|
|
backendName: 'cpu',
|
|
kernelFunc: maxPool
|
|
};
|
|
|
|
|
|
function maxPool3D(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { filterSize, strides, pad, dimRoundingMode, dataFormat } = attrs;
|
|
assertNotComplex(x, 'maxPool3d');
|
|
const convInfo = computePool3DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode, dataFormat);
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const outBuf = pool3d(xValues, x.shape, x.dtype, computeStrides(x.shape), convInfo, 'max');
|
|
return backend.makeTensorInfo(outBuf.shape, 'float32', outBuf.values);
|
|
}
|
|
const maxPool3DConfig = {
|
|
kernelName: MaxPool3D,
|
|
backendName: 'cpu',
|
|
kernelFunc: maxPool3D
|
|
};
|
|
|
|
|
|
function maxPool3DGrad(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input } = inputs;
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
assertNotComplex([dy, input], 'maxPool3DGrad');
|
|
const convInfo = computePool3DInfo(input.shape, filterSize, strides, 1 , pad, dimRoundingMode);
|
|
const inputBuf = backend.bufferSync(input);
|
|
const maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
|
|
const strideDepth = convInfo.strideDepth;
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationDepth = convInfo.dilationDepth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const dx = buffer(input.shape, 'float32');
|
|
const dyBuf = backend.bufferSync(dy);
|
|
for (let batch = 0; batch < convInfo.batchSize; ++batch) {
|
|
for (let channel = 0; channel < convInfo.inChannels; ++channel) {
|
|
for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
|
|
for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
|
|
for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
|
|
|
|
const dyDepthCorner = dxDepth - padFront;
|
|
const dyRowCorner = dxRow - padTop;
|
|
const dyColCorner = dxCol - padLeft;
|
|
let dotProd = 0;
|
|
for (let wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) {
|
|
const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
|
|
if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
|
|
Math.floor(dyDepth) !== dyDepth) {
|
|
continue;
|
|
}
|
|
for (let wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) {
|
|
const dyRow = (dyRowCorner + wRow) / strideHeight;
|
|
if (dyRow < 0 || dyRow >= convInfo.outHeight ||
|
|
Math.floor(dyRow) !== dyRow) {
|
|
continue;
|
|
}
|
|
for (let wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) {
|
|
const dyCol = (dyColCorner + wCol) / strideWidth;
|
|
if (dyCol < 0 || dyCol >= convInfo.outWidth ||
|
|
Math.floor(dyCol) !== dyCol) {
|
|
continue;
|
|
}
|
|
const maxPos = effectiveFilterDepth * effectiveFilterHeight *
|
|
effectiveFilterWidth -
|
|
1 -
|
|
maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel);
|
|
const curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth +
|
|
wRow * effectiveFilterWidth + wCol;
|
|
const mask = maxPos === curPos ? 1 : 0;
|
|
if (mask === 0) {
|
|
continue;
|
|
}
|
|
const pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel);
|
|
dotProd += pixel * mask;
|
|
}
|
|
}
|
|
}
|
|
dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const maxPool3DGradConfig$1 = {
|
|
kernelName: MaxPool3DGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: maxPool3DGrad
|
|
};
|
|
|
|
|
|
function maxPoolGrad$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { dy, input, output } = inputs;
|
|
const x = input;
|
|
assertNotComplex([input, output], 'maxPoolGrad');
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, 1 , pad, dimRoundingMode);
|
|
const xValues = backend.data.get(x.dataId).values;
|
|
const maxPosBuf = buffer(convInfo.outShape, x.dtype, maxPoolPositions(xValues, x.shape, x.dtype, convInfo).values);
|
|
const strideHeight = convInfo.strideHeight;
|
|
const strideWidth = convInfo.strideWidth;
|
|
const dilationHeight = convInfo.dilationHeight;
|
|
const dilationWidth = convInfo.dilationWidth;
|
|
const effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
const effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
const dx = buffer(x.shape, 'float32');
|
|
const dyData = backend.data.get(dy.dataId).values;
|
|
const dyBuf = buffer(dy.shape, 'float32', dyData);
|
|
for (let b = 0; b < convInfo.batchSize; ++b) {
|
|
for (let d = 0; d < convInfo.inChannels; ++d) {
|
|
for (let dxR = 0; dxR < convInfo.inHeight; ++dxR) {
|
|
for (let dxC = 0; dxC < convInfo.inWidth; ++dxC) {
|
|
|
|
const dyRCorner = dxR - padTop;
|
|
const dyCCorner = dxC - padLeft;
|
|
let dotProd = 0;
|
|
for (let wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) {
|
|
const dyR = (dyRCorner + wR) / strideHeight;
|
|
if (dyR < 0 || dyR >= convInfo.outHeight ||
|
|
Math.floor(dyR) !== dyR) {
|
|
continue;
|
|
}
|
|
for (let wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) {
|
|
const dyC = (dyCCorner + wC) / strideWidth;
|
|
if (dyC < 0 || dyC >= convInfo.outWidth ||
|
|
Math.floor(dyC) !== dyC) {
|
|
continue;
|
|
}
|
|
const maxPos = effectiveFilterHeight * effectiveFilterWidth - 1 -
|
|
maxPosBuf.get(b, dyR, dyC, d);
|
|
const curPos = wR * effectiveFilterWidth + wC;
|
|
const mask = maxPos === curPos ? 1 : 0;
|
|
if (mask === 0) {
|
|
continue;
|
|
}
|
|
const pixel = dyBuf.get(b, dyR, dyC, d);
|
|
dotProd += pixel * mask;
|
|
}
|
|
}
|
|
dx.set(dotProd, b, dxR, dxC, d);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(dx.shape, dx.dtype, dx.values);
|
|
}
|
|
const maxPoolGradConfig$1 = {
|
|
kernelName: MaxPoolGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: maxPoolGrad$1
|
|
};
|
|
|
|
|
|
function maxPoolWithArgmaxImpl(xValues, xShape, dtype, includeBatchInIndex, convInfo) {
|
|
const strides = computeStrides(xShape);
|
|
const maxPools = pool(xValues, xShape, dtype, strides, convInfo, 'max');
|
|
const maxPositions = maxPoolPositions(xValues, xShape, dtype, convInfo, true, includeBatchInIndex);
|
|
return [maxPools.values, maxPositions.values];
|
|
}
|
|
|
|
|
|
const maxPoolWithArgmaxConfig = {
|
|
kernelName: MaxPoolWithArgmax,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
|
const { x } = inputs;
|
|
const { filterSize, strides, pad, includeBatchInIndex } = attrs;
|
|
const cpuBackend = backend;
|
|
assertNotComplex(x, 'MaxPoolWithArgmax');
|
|
const values = cpuBackend.data.get(x.dataId).values;
|
|
const convInfo = computePool2DInfo(x.shape, filterSize, strides, [1, 1], pad);
|
|
const [pooled, indexes] = maxPoolWithArgmaxImpl(values, x.shape, x.dtype, includeBatchInIndex, convInfo);
|
|
const pooledDataId = cpuBackend.write(pooled, convInfo.outShape, x.dtype);
|
|
const indexesDataId = cpuBackend.write(indexes, convInfo.outShape, x.dtype);
|
|
return [
|
|
{ dataId: pooledDataId, shape: convInfo.outShape, dtype: x.dtype },
|
|
{ dataId: indexesDataId, shape: convInfo.outShape, dtype: 'int32' }
|
|
];
|
|
}
|
|
};
|
|
|
|
|
|
function mean(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
const axes = parseAxisParam(axis, x.shape);
|
|
const shapes = computeOutAndReduceShapes(x.shape, axes);
|
|
const reduceShape = shapes[1];
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const toDispose = [];
|
|
const reduceSizeScalar = backend.makeTensorInfo([], 'float32', new Float32Array([reduceSize]));
|
|
toDispose.push(reduceSizeScalar);
|
|
const $x = cast$2({ inputs: { x }, backend, attrs: { dtype: 'float32' } });
|
|
toDispose.push($x);
|
|
const res = div({ inputs: { a: $x, b: reduceSizeScalar }, backend });
|
|
toDispose.push(res);
|
|
const result = sum({ inputs: { x: res }, backend, attrs: { axis, keepDims } });
|
|
toDispose.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
}
|
|
const meanConfig = {
|
|
kernelName: Mean,
|
|
backendName: 'cpu',
|
|
kernelFunc: mean
|
|
};
|
|
|
|
|
|
function min(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { axis, keepDims } = attrs;
|
|
assertNotComplex(x, 'min');
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
let axes = origAxes;
|
|
const permutedAxes = getAxesPermutation(axes, x.shape.length);
|
|
let $x = x;
|
|
if (permutedAxes != null) {
|
|
$x = transpose$1({ inputs: { x }, backend, attrs: { perm: permutedAxes } });
|
|
axes = getInnerMostAxes(axes.length, x.shape.length);
|
|
}
|
|
assertAxesAreInnerMostDims('min', axes, $x.shape.length);
|
|
const [outShape, reduceShape] = computeOutAndReduceShapes($x.shape, axes);
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const vals = makeZerosTypedArray(sizeFromShape(outShape), $x.dtype);
|
|
const aVals = backend.data.get($x.dataId).values;
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let min = aVals[offset];
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
if (Number.isNaN(value) ||
|
|
value < min) {
|
|
min = value;
|
|
}
|
|
}
|
|
vals[i] = min;
|
|
}
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
}
|
|
const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
|
|
if (keepDims) {
|
|
const expandedShape = expandShapeToKeepDim(outShape, origAxes);
|
|
const reshapedResult = reshape({ inputs: { x: result }, backend, attrs: { shape: expandedShape } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return reshapedResult;
|
|
}
|
|
return result;
|
|
}
|
|
const minConfig = {
|
|
kernelName: Min,
|
|
backendName: 'cpu',
|
|
kernelFunc: min
|
|
};
|
|
|
|
|
|
function mirrorPad(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { paddings, mode } = attrs;
|
|
assertNotComplex(x, 'mirrorPad');
|
|
const outShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] );
|
|
const start = paddings.map(p => p[0]);
|
|
const end = paddings.map((p, i) => p[0] + x.shape[i]);
|
|
const offset = mode === 'reflect' ? 0 : 1;
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const xRank = x.shape.length;
|
|
const xStrides = computeStrides(x.shape);
|
|
const resultSize = sizeFromShape(outShape);
|
|
const resultRank = outShape.length;
|
|
const resultStrides = computeStrides(outShape);
|
|
const resVals = getTypedArrayFromDType(x.dtype, resultSize);
|
|
for (let i = 0; i < resultSize; i++) {
|
|
let coords = indexToLoc(i, resultRank, resultStrides);
|
|
for (let i = 0; i < resultRank; i++) {
|
|
if (coords[i] < start[i]) {
|
|
coords[i] = start[i] * 2 - coords[i] - offset;
|
|
}
|
|
else if (coords[i] >= end[i]) {
|
|
coords[i] = (end[i] - 1) * 2 - coords[i] + offset;
|
|
}
|
|
}
|
|
coords = coords.map((c, i) => c - start[i]);
|
|
const inIndex = locToIndex(coords, xRank, xStrides);
|
|
resVals[i] = xVals[inIndex];
|
|
}
|
|
const outId = backend.write(resVals, outShape, x.dtype);
|
|
return { dataId: outId, shape: outShape, dtype: x.dtype };
|
|
}
|
|
const mirrorPadConfig = {
|
|
kernelName: MirrorPad,
|
|
backendName: 'cpu',
|
|
kernelFunc: mirrorPad
|
|
};
|
|
|
|
|
|
const modImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => {
|
|
const rem = aValue % bValue;
|
|
if ((aValue < 0 && bValue < 0) || (aValue >= 0 && bValue >= 0)) {
|
|
return rem;
|
|
}
|
|
else {
|
|
return (rem + bValue) % bValue;
|
|
}
|
|
}));
|
|
const mod = binaryKernelFunc$1(Mod, modImpl);
|
|
const modConfig = {
|
|
kernelName: Mod,
|
|
backendName: 'cpu',
|
|
kernelFunc: mod
|
|
};
|
|
|
|
|
|
function softmax(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { logits } = inputs;
|
|
const { dim } = attrs;
|
|
const logitsRank = logits.shape.length;
|
|
let $dim = dim;
|
|
if ($dim === -1) {
|
|
$dim = logitsRank - 1;
|
|
}
|
|
if ($dim !== logitsRank - 1) {
|
|
throw Error('Softmax along a non-last dimension is not yet supported. ' +
|
|
`Logits was rank ${logitsRank} and dim was ${$dim}`);
|
|
}
|
|
const axes = parseAxisParam([$dim], logits.shape);
|
|
const maxLogit = max({
|
|
inputs: { x: logits },
|
|
backend,
|
|
attrs: { reductionIndices: axes, keepDims: false }
|
|
});
|
|
const expandedShape = expandShapeToKeepDim(maxLogit.shape, axes);
|
|
const maxLogitReshaped = reshape({ inputs: { x: maxLogit }, backend, attrs: { shape: expandedShape } });
|
|
const a = sub$1({ inputs: { a: logits, b: maxLogitReshaped }, backend });
|
|
const b = exp$1({ inputs: { x: a }, backend });
|
|
const sumExp = sum({ inputs: { x: b }, backend, attrs: { axis: axes, keepDims: false } });
|
|
const sumReshaped = reshape({ inputs: { x: sumExp }, backend, attrs: { shape: expandedShape } });
|
|
const result = div({ inputs: { a: b, b: sumReshaped }, backend });
|
|
backend.disposeIntermediateTensorInfo(maxLogit);
|
|
backend.disposeIntermediateTensorInfo(maxLogitReshaped);
|
|
backend.disposeIntermediateTensorInfo(a);
|
|
backend.disposeIntermediateTensorInfo(b);
|
|
backend.disposeIntermediateTensorInfo(sumExp);
|
|
backend.disposeIntermediateTensorInfo(sumReshaped);
|
|
return result;
|
|
}
|
|
const softmaxConfig = {
|
|
kernelName: Softmax$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: softmax
|
|
};
|
|
|
|
|
|
function multinomial(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { logits } = inputs;
|
|
const { numSamples, seed, normalized } = attrs;
|
|
assertNotComplex(logits, 'multinomial');
|
|
const probabilities = normalized ?
|
|
logits :
|
|
softmax({ inputs: { logits }, backend, attrs: { dim: -1 } });
|
|
const batchSize = probabilities.shape[0];
|
|
const numEvents = probabilities.shape[1];
|
|
const probVals = backend.data.get(probabilities.dataId).values;
|
|
const resShape = [batchSize, numSamples];
|
|
const resVals = makeZerosTypedArray(sizeFromShape(resShape), 'int32');
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
const offset = b * numEvents;
|
|
|
|
|
|
const cdf = new Float32Array(numEvents - 1);
|
|
cdf[0] = probVals[offset];
|
|
for (let event = 1; event < cdf.length; ++event) {
|
|
cdf[event] = cdf[event - 1] + probVals[offset + event];
|
|
}
|
|
const random = seedrandom.alea(seed.toString());
|
|
const outOffset = b * numSamples;
|
|
for (let sampleId = 0; sampleId < numSamples; ++sampleId) {
|
|
const r = random();
|
|
|
|
resVals[outOffset + sampleId] = cdf.length;
|
|
for (let event = 0; event < cdf.length; event++) {
|
|
if (r < cdf[event]) {
|
|
resVals[outOffset + sampleId] = event;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (!normalized) {
|
|
backend.disposeIntermediateTensorInfo(probabilities);
|
|
}
|
|
return backend.makeTensorInfo(resShape, 'int32', resVals);
|
|
}
|
|
const multinomialConfig = {
|
|
kernelName: Multinomial,
|
|
backendName: 'cpu',
|
|
kernelFunc: multinomial
|
|
};
|
|
|
|
|
|
const nonMaxSuppressionV3Impl = nonMaxSuppressionV3Impl$2;
|
|
function nonMaxSuppressionV3(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { boxes, scores } = inputs;
|
|
const { maxOutputSize, iouThreshold, scoreThreshold } = attrs;
|
|
assertNotComplex(boxes, 'NonMaxSuppression');
|
|
const boxesVals = backend.data.get(boxes.dataId).values;
|
|
const scoresVals = backend.data.get(scores.dataId).values;
|
|
const { selectedIndices } = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold);
|
|
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
|
|
}
|
|
const nonMaxSuppressionV3Config = {
|
|
kernelName: NonMaxSuppressionV3,
|
|
backendName: 'cpu',
|
|
kernelFunc: nonMaxSuppressionV3
|
|
};
|
|
|
|
|
|
const nonMaxSuppressionV4Impl = nonMaxSuppressionV4Impl$2;
|
|
function nonMaxSuppressionV4(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { boxes, scores } = inputs;
|
|
const { maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize } = attrs;
|
|
assertNotComplex(boxes, 'NonMaxSuppressionPadded');
|
|
const boxesVals = backend.data.get(boxes.dataId).values;
|
|
const scoresVals = backend.data.get(scores.dataId).values;
|
|
const { selectedIndices, validOutputs } = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize);
|
|
return [
|
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
|
backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
|
|
];
|
|
}
|
|
const nonMaxSuppressionV4Config = {
|
|
kernelName: NonMaxSuppressionV4,
|
|
backendName: 'cpu',
|
|
kernelFunc: nonMaxSuppressionV4
|
|
};
|
|
|
|
|
|
const nonMaxSuppressionV5Impl = nonMaxSuppressionV5Impl$2;
|
|
function nonMaxSuppressionV5(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { boxes, scores } = inputs;
|
|
const { maxOutputSize, iouThreshold, scoreThreshold, softNmsSigma } = attrs;
|
|
assertNotComplex(boxes, 'NonMaxSuppressionWithScore');
|
|
const boxesVals = backend.data.get(boxes.dataId).values;
|
|
const scoresVals = backend.data.get(scores.dataId).values;
|
|
const maxOutputSizeVal = maxOutputSize;
|
|
const iouThresholdVal = iouThreshold;
|
|
const scoreThresholdVal = scoreThreshold;
|
|
const softNmsSigmaVal = softNmsSigma;
|
|
const { selectedIndices, selectedScores } = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal);
|
|
return [
|
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
|
backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
|
|
];
|
|
}
|
|
const nonMaxSuppressionV5Config = {
|
|
kernelName: NonMaxSuppressionV5,
|
|
backendName: 'cpu',
|
|
kernelFunc: nonMaxSuppressionV5
|
|
};
|
|
|
|
|
|
function oneHot(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { indices } = inputs;
|
|
const { dtype, depth, onValue, offValue } = attrs;
|
|
assertNotComplex(indices, 'oneHot');
|
|
const indicesSize = sizeFromShape(indices.shape);
|
|
const res = new Float32Array(indicesSize * depth);
|
|
res.fill(offValue);
|
|
const indicesVal = backend.data.get(indices.dataId).values;
|
|
for (let event = 0; event < indicesSize; ++event) {
|
|
if (indicesVal[event] >= 0 && indicesVal[event] < depth) {
|
|
res[event * depth + indicesVal[event]] = onValue;
|
|
}
|
|
}
|
|
return backend.makeTensorInfo([...indices.shape, depth], dtype, res);
|
|
}
|
|
const oneHotConfig = {
|
|
kernelName: OneHot,
|
|
backendName: 'cpu',
|
|
kernelFunc: oneHot
|
|
};
|
|
|
|
|
|
function zerosLike(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
if (x.dtype === 'string') {
|
|
throw new Error('zerosLike is not supported for string tensors');
|
|
}
|
|
else if (x.dtype === 'complex64') {
|
|
const realPart = real$1({ inputs: { input: x }, backend });
|
|
const r = zerosLike({ inputs: { x: realPart }, backend });
|
|
const imagPart = imag({ inputs: { input: x }, backend });
|
|
const i = zerosLike({ inputs: { x: imagPart }, backend });
|
|
const result = complex$1({ inputs: { real: r, imag: i }, backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(r);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
return result;
|
|
}
|
|
else {
|
|
return fill({ backend, attrs: { shape: x.shape, value: 0, dtype: x.dtype } });
|
|
}
|
|
}
|
|
const zerosLikeConfig = {
|
|
kernelName: ZerosLike,
|
|
backendName: 'cpu',
|
|
kernelFunc: zerosLike
|
|
};
|
|
|
|
|
|
function onesLike(args) {
|
|
const { inputs, backend } = args;
|
|
const { x } = inputs;
|
|
if (x.dtype === 'string') {
|
|
throw new Error('onesLike is not supported for string tensors');
|
|
}
|
|
else if (x.dtype === 'complex64') {
|
|
const realPart = real$1({ inputs: { input: x }, backend });
|
|
const r = onesLike({ inputs: { x: realPart }, backend });
|
|
const imagPart = imag({ inputs: { input: x }, backend });
|
|
const i = zerosLike({ inputs: { x: imagPart }, backend });
|
|
const result = complex$1({ inputs: { real: r, imag: i }, backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(r);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
return result;
|
|
}
|
|
else {
|
|
return fill({ backend, attrs: { shape: x.shape, value: 1, dtype: x.dtype } });
|
|
}
|
|
}
|
|
const onesLikeConfig = {
|
|
kernelName: OnesLike,
|
|
backendName: 'cpu',
|
|
kernelFunc: onesLike
|
|
};
|
|
|
|
|
|
function pack(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { axis } = attrs;
|
|
if (inputs.length === 1) {
|
|
return expandDims$1({ inputs: { input: inputs[0] }, backend, attrs: { dim: axis } });
|
|
}
|
|
const shape = inputs[0].shape;
|
|
const dtype = inputs[0].dtype;
|
|
inputs.forEach(t => {
|
|
assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
|
|
assert$1(dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes');
|
|
});
|
|
const intermediateTensorInfos = [];
|
|
const expandedTensors = inputs.map(t => {
|
|
const expandedT = expandDims$1({ inputs: { input: t }, backend, attrs: { dim: axis } });
|
|
intermediateTensorInfos.push(expandedT);
|
|
return expandedT;
|
|
});
|
|
const result = concat({ inputs: expandedTensors, backend, attrs: { axis } });
|
|
intermediateTensorInfos.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
}
|
|
const packConfig = {
|
|
kernelName: Pack,
|
|
backendName: 'cpu',
|
|
kernelFunc: pack
|
|
};
|
|
|
|
|
|
function padV2(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { paddings, constantValue } = attrs;
|
|
assertNotComplex(x, 'pad');
|
|
const outShape = paddings.map((p, i) => p[0] + x.shape[i] + p[1] );
|
|
const start = paddings.map(p => p[0]);
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const xSize = sizeFromShape(x.shape);
|
|
const xRank = x.shape.length;
|
|
const xStrides = computeStrides(x.shape);
|
|
const resultSize = sizeFromShape(outShape);
|
|
const resultRank = outShape.length;
|
|
const resultStrides = computeStrides(outShape);
|
|
const resVals = getTypedArrayFromDType(x.dtype, resultSize);
|
|
if (constantValue !== 0) {
|
|
resVals.fill(constantValue);
|
|
}
|
|
for (let i = 0; i < xSize; i++) {
|
|
const coords = indexToLoc(i, xRank, xStrides);
|
|
const outCoords = coords.map((c, i) => c + start[i]);
|
|
const outIndex = locToIndex(outCoords, resultRank, resultStrides);
|
|
resVals[outIndex] = xVals[i];
|
|
}
|
|
const outId = backend.write(resVals, outShape, x.dtype);
|
|
return { dataId: outId, shape: outShape, dtype: x.dtype };
|
|
}
|
|
const padV2Config = {
|
|
kernelName: PadV2,
|
|
backendName: 'cpu',
|
|
kernelFunc: padV2
|
|
};
|
|
|
|
|
|
const powImpl = createSimpleBinaryKernelImpl((a, b) => Math.pow(a, b));
|
|
const pow = binaryKernelFunc$1(Pow, powImpl);
|
|
const powConfig = {
|
|
kernelName: Pow,
|
|
backendName: 'cpu',
|
|
kernelFunc: pow
|
|
};
|
|
|
|
|
|
function raggedGather(args) {
|
|
const { inputs, backend} = args;
|
|
const { paramsNestedSplits, paramsDenseValues, indices } = inputs;
|
|
const $paramsNestedSplits = paramsNestedSplits.map(t => backend.data.get(t.dataId).values);
|
|
const $paramsNestedSplitsShapes = paramsNestedSplits.map(t => t.shape);
|
|
const $paramsDenseValues = backend.data.get(paramsDenseValues.dataId).values;
|
|
const $indices = backend.data.get(indices.dataId).values;
|
|
const [outputNestedSplits, outputDenseValues, outputDenseValuesShape] = raggedGatherImpl($paramsNestedSplits, $paramsNestedSplitsShapes, $paramsDenseValues, paramsDenseValues.shape, paramsDenseValues.dtype, $indices, indices.shape);
|
|
const outputNestedSplitsTensors = outputNestedSplits.map((splits) => backend.makeTensorInfo([splits.length], 'int32', splits));
|
|
const outputDenseValuesTensor = backend.makeTensorInfo(outputDenseValuesShape, paramsDenseValues.dtype, outputDenseValues);
|
|
return outputNestedSplitsTensors.concat([outputDenseValuesTensor]);
|
|
}
|
|
const raggedGatherConfig = {
|
|
kernelName: RaggedGather,
|
|
backendName: 'cpu',
|
|
kernelFunc: raggedGather,
|
|
};
|
|
|
|
|
|
function raggedRange(args) {
|
|
const { inputs, backend } = args;
|
|
const { starts, limits, deltas } = inputs;
|
|
const $starts = backend.data.get(starts.dataId).values;
|
|
const $limits = backend.data.get(limits.dataId).values;
|
|
const $deltas = backend.data.get(deltas.dataId).values;
|
|
const [rtNestedSplitsData, rtDenseValuesData] = raggedRangeImpl($starts, starts.shape, starts.dtype, $limits, limits.shape, $deltas, deltas.shape);
|
|
const rtNestedSplits = backend.makeTensorInfo([rtNestedSplitsData.length], 'int32', rtNestedSplitsData);
|
|
const rtDenseValues = backend.makeTensorInfo([rtDenseValuesData.length], starts.dtype, rtDenseValuesData);
|
|
return [rtNestedSplits, rtDenseValues];
|
|
}
|
|
const raggedRangeConfig = {
|
|
kernelName: RaggedRange,
|
|
backendName: 'cpu',
|
|
kernelFunc: raggedRange,
|
|
};
|
|
|
|
|
|
function raggedTensorToTensor(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { shape, values, defaultValue, rowPartitionTensors } = inputs;
|
|
const { rowPartitionTypes } = attrs;
|
|
const $shape = backend.data.get(shape.dataId).values;
|
|
const $values = backend.data.get(values.dataId).values;
|
|
const $defaultValue = backend.data.get(defaultValue.dataId).values;
|
|
const $rowPartitionValues = rowPartitionTensors.map(t => backend.data.get(t.dataId).values);
|
|
const rowPartitionValuesShapes = rowPartitionTensors.map(t => t.shape);
|
|
const [outputShape, output] = raggedTensorToTensorImpl($shape, shape.shape, $values, values.shape, values.dtype, $defaultValue, defaultValue.shape, $rowPartitionValues, rowPartitionValuesShapes, rowPartitionTypes);
|
|
return backend.makeTensorInfo(outputShape, values.dtype, output);
|
|
}
|
|
const raggedTensorToTensorConfig = {
|
|
kernelName: RaggedTensorToTensor,
|
|
backendName: 'cpu',
|
|
kernelFunc: raggedTensorToTensor,
|
|
};
|
|
|
|
|
|
function range$1(args) {
|
|
const { backend, attrs } = args;
|
|
const { start, stop, dtype, step } = attrs;
|
|
const values = rangeImpl(start, stop, step, dtype);
|
|
return backend.makeTensorInfo([values.length], dtype, values);
|
|
}
|
|
const rangeConfig = {
|
|
kernelName: Range,
|
|
backendName: 'cpu',
|
|
kernelFunc: range$1
|
|
};
|
|
|
|
|
|
const reciprocal = unaryKernelFunc$1(Reciprocal, (xi) => 1 / xi);
|
|
const reciprocalConfig = {
|
|
kernelName: Reciprocal,
|
|
backendName: 'cpu',
|
|
kernelFunc: reciprocal,
|
|
};
|
|
|
|
|
|
function resizeBilinear(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images } = inputs;
|
|
const { alignCorners, halfPixelCenters, size } = attrs;
|
|
assertNotComplex(images, 'resizeBilinear');
|
|
const imagesStrides = computeStrides(images.shape);
|
|
const [newHeight, newWidth] = size;
|
|
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
|
|
const xValues = backend.data.get(images.dataId).values;
|
|
const result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels]));
|
|
const effectiveInputSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
const effectiveOutputSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
let outputIdx = 0;
|
|
const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
|
|
const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
|
|
for (let b = 0; b < batch; b++) {
|
|
for (let r = 0; r < newHeight; r++) {
|
|
let sourceFracRow;
|
|
if (halfPixelCenters) {
|
|
sourceFracRow = effectiveRowSizeRatio * (r + 0.5) - 0.5;
|
|
}
|
|
else {
|
|
sourceFracRow = effectiveRowSizeRatio * r;
|
|
}
|
|
const sourceRowFloor = Math.max(0, Math.floor(sourceFracRow));
|
|
const rowFrac = sourceFracRow - sourceRowFloor;
|
|
const sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow));
|
|
const topRowOffset = b * imagesStrides[0] + sourceRowFloor * imagesStrides[1];
|
|
const botRowOffset = b * imagesStrides[0] + sourceRowCeil * imagesStrides[1];
|
|
for (let c = 0; c < newWidth; c++) {
|
|
let sourceFracCol;
|
|
if (halfPixelCenters) {
|
|
sourceFracCol = effectiveColSizeRatio * (c + 0.5) - 0.5;
|
|
}
|
|
else {
|
|
sourceFracCol = effectiveColSizeRatio * c;
|
|
}
|
|
const sourceColFloor = Math.max(0, Math.floor(sourceFracCol));
|
|
const colFrac = sourceFracCol - sourceColFloor;
|
|
const sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol));
|
|
const topLeftOffest = topRowOffset + sourceColFloor * imagesStrides[2];
|
|
const botLeftOffset = botRowOffset + sourceColFloor * imagesStrides[2];
|
|
const topRightOffset = topRowOffset + sourceColCeil * imagesStrides[2];
|
|
const botRightOffest = botRowOffset + sourceColCeil * imagesStrides[2];
|
|
for (let d = 0; d < numChannels; d++) {
|
|
|
|
|
|
const topLeft = xValues[topLeftOffest + d];
|
|
const bottomLeft = xValues[botLeftOffset + d];
|
|
const topRight = xValues[topRightOffset + d];
|
|
const bottomRight = xValues[botRightOffest + d];
|
|
const top = topLeft + (topRight - topLeft) * colFrac;
|
|
const bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac;
|
|
const newValue = top + (bottom - top) * rowFrac;
|
|
result[outputIdx++] = newValue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], 'float32', result);
|
|
}
|
|
const resizeBilinearConfig = {
|
|
kernelName: ResizeBilinear,
|
|
backendName: 'cpu',
|
|
kernelFunc: resizeBilinear
|
|
};
|
|
|
|
|
|
function resizeBilinearGrad(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images, dy } = inputs;
|
|
const { alignCorners } = attrs;
|
|
assertNotComplex([dy, images], 'resizeBilinearGrad');
|
|
const imagesStrides = computeStrides(images.shape);
|
|
const [batch, xHeight, xWidth, depth] = images.shape;
|
|
const [, yHeight, yWidth] = dy.shape;
|
|
const output = new Float32Array(batch * xHeight * xWidth * depth);
|
|
|
|
|
|
|
|
|
|
const effectiveXSize = [
|
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
|
];
|
|
const effectiveYSize = [
|
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
|
];
|
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
|
|
|
|
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
let offset = 0;
|
|
for (let b = 0; b < batch; b++) {
|
|
const bOffset = b * imagesStrides[0];
|
|
for (let r = 0; r < yHeight; r++) {
|
|
const dxR = r * heightScale;
|
|
const topDxRIndex = Math.floor(dxR);
|
|
const bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1);
|
|
const topDxROffset = bOffset + topDxRIndex * imagesStrides[1];
|
|
const bottomDxROffset = bOffset + bottomDxRIndex * imagesStrides[1];
|
|
const dxRLerp = dxR - topDxRIndex;
|
|
const inverseDxRLerp = 1.0 - dxRLerp;
|
|
for (let c = 0; c < yWidth; c++) {
|
|
const dxC = c * widthScale;
|
|
const leftDxCIndex = Math.floor(dxC);
|
|
const rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1);
|
|
const dxCLerp = dxC - leftDxCIndex;
|
|
const inverseDxCLerp = 1.0 - dxCLerp;
|
|
const topLeftRCOffset = topDxROffset + leftDxCIndex * imagesStrides[2];
|
|
const topRightRCOffset = topDxROffset + rightDxCIndex * imagesStrides[2];
|
|
const bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * imagesStrides[2];
|
|
const bottomRightRCOffset = bottomDxROffset + rightDxCIndex * imagesStrides[2];
|
|
const inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp;
|
|
const inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp;
|
|
const dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp;
|
|
const dxRLerpTimesDxCLerp = dxRLerp * dxCLerp;
|
|
for (let d = 0; d < depth; d++) {
|
|
const dyVal = dyValues[offset++];
|
|
output[topLeftRCOffset + d] +=
|
|
dyVal * inverseDxRLerpTimesInverseDxCLerp;
|
|
output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp;
|
|
output[bottomLeftRCOffset + d] += dyVal * dxRLerpTimesInverseDxCLerp;
|
|
output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo([batch, xWidth, xHeight, depth], 'float32', output);
|
|
}
|
|
const resizeBilinearGradConfig$1 = {
|
|
kernelName: ResizeBilinearGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: resizeBilinearGrad
|
|
};
|
|
|
|
|
|
function resizeNearestNeighbor(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images } = inputs;
|
|
const { alignCorners, halfPixelCenters, size } = attrs;
|
|
assertNotComplex(images, 'resizeNearestNeighbor');
|
|
const imagesStrides = computeStrides(images.shape);
|
|
const [newHeight, newWidth] = size;
|
|
const [batch, oldHeight, oldWidth, numChannels] = images.shape;
|
|
const xValues = backend.data.get(images.dataId).values;
|
|
const output = new Float32Array(batch * newHeight * newWidth * numChannels);
|
|
const effectiveInputSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
const effectiveOutputSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
const effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0];
|
|
const effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1];
|
|
let outputOffset = 0;
|
|
for (let b = 0; b < batch; b++) {
|
|
const batchOffset = b * imagesStrides[0];
|
|
for (let r = 0; r < newHeight; r++) {
|
|
const sourceFracRow = halfPixelCenters ?
|
|
effectiveRowSizeRatio * (r + 0.5) :
|
|
effectiveRowSizeRatio * r;
|
|
let sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : Math.floor(sourceFracRow));
|
|
if (halfPixelCenters) {
|
|
sourceNearestRow = Math.max(0, sourceNearestRow);
|
|
}
|
|
const rowOffset = batchOffset + sourceNearestRow * imagesStrides[1];
|
|
for (let c = 0; c < newWidth; c++) {
|
|
const sourceFracCol = halfPixelCenters ?
|
|
effectiveColSizeRatio * (c + 0.5) :
|
|
effectiveColSizeRatio * c;
|
|
let sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) :
|
|
Math.floor(sourceFracCol));
|
|
if (halfPixelCenters) {
|
|
sourceNearestCol = Math.max(0, sourceNearestCol);
|
|
}
|
|
const colOffset = rowOffset + sourceNearestCol * imagesStrides[2];
|
|
for (let d = 0; d < numChannels; d++) {
|
|
|
|
|
|
const newVal = xValues[colOffset + d];
|
|
output[outputOffset++] = newVal;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo([batch, newHeight, newWidth, numChannels], images.dtype, output);
|
|
}
|
|
const resizeNearestNeighborConfig = {
|
|
kernelName: ResizeNearestNeighbor,
|
|
backendName: 'cpu',
|
|
kernelFunc: resizeNearestNeighbor
|
|
};
|
|
|
|
|
|
function resizeNearestNeighborGrad(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { images, dy } = inputs;
|
|
const { alignCorners } = attrs;
|
|
assertNotComplex([dy, images], 'resizeNearestNeighborGrad');
|
|
const imagesStrides = computeStrides(images.shape);
|
|
const dyStrides = computeStrides(dy.shape);
|
|
const [batch, xHeight, xWidth, depth] = images.shape;
|
|
const [, yHeight, yWidth] = dy.shape;
|
|
const output = new Float32Array(batch * xHeight * xWidth * depth);
|
|
const dyValues = backend.data.get(dy.dataId).values;
|
|
|
|
|
|
const effectiveXSize = [
|
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
|
];
|
|
const effectiveYSize = [
|
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
|
];
|
|
const heightScale = effectiveXSize[0] / effectiveYSize[0];
|
|
const widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
const invHeightScale = 1 / heightScale;
|
|
const invWidthScale = 1 / widthScale;
|
|
|
|
|
|
const winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
|
const winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
|
|
|
for (let b = 0; b < batch; b++) {
|
|
const batchOffset = b * imagesStrides[0];
|
|
for (let r = 0; r < xHeight; r++) {
|
|
const rowOffset = batchOffset + r * imagesStrides[1];
|
|
|
|
const startRLerp = Math.floor(r * invHeightScale);
|
|
const startDyR = Math.floor(startRLerp - (winHeight / 2));
|
|
for (let c = 0; c < xWidth; c++) {
|
|
const colOffset = rowOffset + c * imagesStrides[2];
|
|
|
|
const startCLerp = Math.floor(c * invWidthScale);
|
|
const startDyC = Math.floor(startCLerp - (winWidth / 2));
|
|
for (let d = 0; d < depth; d++) {
|
|
let accum = 0;
|
|
|
|
for (let dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) {
|
|
const dyR = dyRIndex + startDyR;
|
|
|
|
if (dyR < 0 || dyR >= yHeight) {
|
|
continue;
|
|
}
|
|
const dyROffset = batchOffset + dyR * dyStrides[1];
|
|
const sourceFracRow = dyR * heightScale;
|
|
const sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) :
|
|
Math.floor(sourceFracRow));
|
|
if (r !== sourceNearestRow) {
|
|
continue;
|
|
}
|
|
for (let dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) {
|
|
const dyC = dyCIndex + startDyC;
|
|
|
|
if (dyC < 0 || dyC >= yWidth) {
|
|
continue;
|
|
}
|
|
const dyCOffset = dyROffset + dyC * dyStrides[2];
|
|
const sourceFracCol = dyC * widthScale;
|
|
const sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) :
|
|
Math.floor(sourceFracCol));
|
|
if (c === sourceNearestCol) {
|
|
accum += dyValues[dyCOffset + d];
|
|
}
|
|
}
|
|
}
|
|
output[colOffset + d] = accum;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(images.shape, images.dtype, output);
|
|
}
|
|
const resizeNearestNeighborGradConfig$1 = {
|
|
kernelName: ResizeNearestNeighborGrad,
|
|
backendName: 'cpu',
|
|
kernelFunc: resizeNearestNeighborGrad
|
|
};
|
|
|
|
|
|
function reverse(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { dims } = attrs;
|
|
assertNotComplex(x, 'reverse');
|
|
const xRank = x.shape.length;
|
|
const $dims = parseAxisParam(dims, x.shape);
|
|
if (xRank === 0) {
|
|
return identity$1({ inputs: { x }, backend });
|
|
}
|
|
const outBuf = new TensorBuffer(x.shape, x.dtype);
|
|
const xBuf = backend.bufferSync(x);
|
|
for (let i = 0; i < outBuf.size; i++) {
|
|
const outLoc = outBuf.indexToLoc(i);
|
|
const inLoc = outLoc.slice();
|
|
$dims.forEach(d => inLoc[d] = x.shape[d] - 1 - inLoc[d]);
|
|
outBuf.set(xBuf.get(...inLoc), ...outLoc);
|
|
}
|
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const reverseConfig = {
|
|
kernelName: Reverse,
|
|
backendName: 'cpu',
|
|
kernelFunc: reverse
|
|
};
|
|
|
|
|
|
const rotateWithOffsetConfig = {
|
|
kernelName: RotateWithOffset,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, attrs, backend }) => {
|
|
const { image } = inputs;
|
|
const { radians, fillValue, center } = attrs;
|
|
const cpuBackend = backend;
|
|
const output = getTypedArrayFromDType(image.dtype, sizeFromShape(image.shape));
|
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
|
const [centerX, centerY] = getImageCenter(center, imageHeight, imageWidth);
|
|
const fullOpacityValue = 255;
|
|
const sinFactor = Math.sin(radians);
|
|
const cosFactor = Math.cos(radians);
|
|
const imageVals = cpuBackend.data.get(image.dataId).values;
|
|
for (let batchIdx = 0; batchIdx < batch; batchIdx++) {
|
|
const batchOffset = batchIdx * imageWidth * imageHeight * numChannels;
|
|
for (let row = 0; row < imageHeight; row++) {
|
|
const rowOffset = row * (imageWidth * numChannels);
|
|
for (let col = 0; col < imageWidth; col++) {
|
|
const colOffset = col * numChannels;
|
|
for (let channel = 0; channel < numChannels; channel++) {
|
|
const coords = [batch, row, col, channel];
|
|
const x = coords[2];
|
|
const y = coords[1];
|
|
|
|
let coordX = (x - centerX) * cosFactor - (y - centerY) * sinFactor;
|
|
let coordY = (x - centerX) * sinFactor + (y - centerY) * cosFactor;
|
|
coordX = Math.round(coordX + centerX);
|
|
coordY = Math.round(coordY + centerY);
|
|
let outputValue = fillValue;
|
|
if (typeof fillValue !== 'number') {
|
|
if (channel === 3) {
|
|
outputValue = fullOpacityValue;
|
|
}
|
|
else {
|
|
outputValue = fillValue[channel];
|
|
}
|
|
}
|
|
|
|
if (coordX >= 0 && coordX < imageWidth && coordY >= 0 &&
|
|
coordY < imageHeight) {
|
|
|
|
const rotatedRowOffset = coordY * (imageWidth * numChannels);
|
|
const rotatedColOffset = coordX * numChannels;
|
|
const imageIdx = batchOffset + rotatedRowOffset + rotatedColOffset + channel;
|
|
outputValue = imageVals[imageIdx];
|
|
}
|
|
const outIdx = batchOffset + rowOffset + colOffset + channel;
|
|
output[outIdx] = outputValue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const dataId = cpuBackend.write(output, image.shape, image.dtype);
|
|
return { dataId, shape: image.shape, dtype: image.dtype };
|
|
}
|
|
};
|
|
|
|
|
|
const round = unaryKernelFunc$1(Round, (xi) => {
|
|
|
|
const base = Math.floor(xi);
|
|
if (xi - base < 0.5) {
|
|
return Math.floor(xi);
|
|
}
|
|
else if (xi - base > 0.5) {
|
|
return Math.ceil(xi);
|
|
}
|
|
else {
|
|
if (base % 2.0 === 0.0) {
|
|
return base;
|
|
}
|
|
else {
|
|
return base + 1.0;
|
|
}
|
|
}
|
|
});
|
|
const roundConfig = {
|
|
kernelName: Round,
|
|
backendName: 'cpu',
|
|
kernelFunc: round,
|
|
};
|
|
|
|
|
|
function scatterNd(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { indices, updates } = inputs;
|
|
const { shape } = attrs;
|
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, shape);
|
|
const sumDupeIndices = true;
|
|
const indicesBuf = backend.bufferSync(indices);
|
|
const updatesBuf = backend.bufferSync(updates);
|
|
const outBuf = scatterImpl(indicesBuf, updatesBuf, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, 0 , sumDupeIndices);
|
|
return backend.makeTensorInfo(shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const scatterNdConfig = {
|
|
kernelName: ScatterNd,
|
|
backendName: 'cpu',
|
|
kernelFunc: scatterNd
|
|
};
|
|
|
|
|
|
function lowerBound(array, value) {
|
|
let left = 0;
|
|
let right = array.length;
|
|
let mid = 0;
|
|
while (left < right) {
|
|
mid = Math.floor((left + right) / 2);
|
|
if (array[mid] < value) {
|
|
left = mid + 1;
|
|
}
|
|
else {
|
|
right = mid;
|
|
}
|
|
}
|
|
return right;
|
|
}
|
|
function upperBound(array, value) {
|
|
let left = 0;
|
|
let right = array.length;
|
|
let mid = 0;
|
|
while (left < right) {
|
|
mid = Math.floor((left + right) / 2);
|
|
if (array[mid] <= value) {
|
|
left = mid + 1;
|
|
}
|
|
else {
|
|
right = mid;
|
|
}
|
|
}
|
|
return right;
|
|
}
|
|
function searchSortedImpl(sortedInputs, values, batchSize, numInputs, numValues, side) {
|
|
const output = getArrayFromDType('int32', batchSize * numValues);
|
|
for (let b = 0; b < batchSize; ++b) {
|
|
const sortedInputsSlice = sortedInputs.slice(b * numInputs, (b + 1) * numInputs);
|
|
const outputOffset = b * numValues;
|
|
for (let i = 0; i < numValues; ++i) {
|
|
output[outputOffset + i] = side === 'left' ?
|
|
lowerBound(sortedInputsSlice, values[i + outputOffset]) :
|
|
upperBound(sortedInputsSlice, values[i + outputOffset]);
|
|
}
|
|
}
|
|
return output;
|
|
}
|
|
|
|
|
|
function searchSorted(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { sortedSequence, values } = inputs;
|
|
const { side } = attrs;
|
|
const $sortedSequence = backend.data.get(sortedSequence.dataId).values;
|
|
const $values = backend.data.get(values.dataId).values;
|
|
const output = searchSortedImpl($sortedSequence, $values, sortedSequence.shape[0], sortedSequence.shape[1], values.shape[1], side);
|
|
return backend.makeTensorInfo(values.shape, 'int32', output);
|
|
}
|
|
const searchSortedConfig = {
|
|
kernelName: SearchSorted,
|
|
backendName: 'cpu',
|
|
kernelFunc: searchSorted,
|
|
};
|
|
|
|
|
|
function select(args) {
|
|
const { inputs, backend } = args;
|
|
const { condition, t, e } = inputs;
|
|
assertNotComplex([condition, t, e], 'select');
|
|
const conditionRank = condition.shape.length;
|
|
const values = backend.data.get(condition.dataId).values;
|
|
const tValues = backend.data.get(t.dataId).values;
|
|
const eValues = backend.data.get(e.dataId).values;
|
|
const resultDtype = upcastType(t.dtype, e.dtype);
|
|
const newValues = makeZerosTypedArray(sizeFromShape(t.shape), resultDtype);
|
|
let index = 0;
|
|
const offset = conditionRank === 0 || conditionRank > 1 || t.shape.length === 1 ?
|
|
1 :
|
|
sizeFromShape(t.shape.slice(1));
|
|
for (let i = 0; i < values.length; i++) {
|
|
for (let j = 0; j < offset; j++) {
|
|
if (values[i] === 1) {
|
|
newValues[index++] = tValues[i];
|
|
}
|
|
else {
|
|
newValues[index++] = eValues[i];
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(t.shape, resultDtype, newValues);
|
|
}
|
|
const selectConfig = {
|
|
kernelName: Select,
|
|
backendName: 'cpu',
|
|
kernelFunc: select
|
|
};
|
|
|
|
|
|
const scaleAlpha = SELU_SCALEALPHA;
|
|
const scale = SELU_SCALE;
|
|
const selu = unaryKernelFunc$1(Selu$1, (xi) => {
|
|
if (xi >= 0) {
|
|
return scale * xi;
|
|
}
|
|
else {
|
|
return scaleAlpha * (Math.exp(xi) - 1);
|
|
}
|
|
});
|
|
const seluConfig = {
|
|
kernelName: Selu$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: selu,
|
|
};
|
|
|
|
|
|
const sign = unaryKernelFunc$1(Sign, (xi) => {
|
|
if (xi < 0) {
|
|
return -1;
|
|
}
|
|
else if (xi > 0) {
|
|
return 1;
|
|
}
|
|
else {
|
|
return 0;
|
|
}
|
|
});
|
|
const signConfig = {
|
|
kernelName: Sign,
|
|
backendName: 'cpu',
|
|
kernelFunc: sign,
|
|
};
|
|
|
|
|
|
const sin = unaryKernelFunc$1(Sin, (xi) => Math.sin(xi));
|
|
const sinConfig = {
|
|
kernelName: Sin,
|
|
backendName: 'cpu',
|
|
kernelFunc: sin,
|
|
};
|
|
|
|
|
|
const sinh = unaryKernelFunc$1(Sinh, (xi) => Math.sinh(xi));
|
|
const sinhConfig = {
|
|
kernelName: Sinh,
|
|
backendName: 'cpu',
|
|
kernelFunc: sinh,
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const epsilon$1 = 1.1920928955078125e-7;
|
|
const threshold = Math.log(epsilon$1) + 2.0;
|
|
const softplus = unaryKernelFunc$1(Softplus$1, (xi) => {
|
|
|
|
|
|
const tooLarge = xi > -threshold;
|
|
|
|
|
|
const tooSmall = xi < threshold;
|
|
const expX = Math.exp(xi);
|
|
let result;
|
|
if (tooSmall) {
|
|
result = expX;
|
|
}
|
|
else if (tooLarge) {
|
|
result = xi;
|
|
}
|
|
else {
|
|
result = Math.log(1.0 + expX);
|
|
}
|
|
return result;
|
|
});
|
|
const softplusConfig = {
|
|
kernelName: Softplus$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: softplus,
|
|
};
|
|
|
|
|
|
function spaceToBatchND(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { blockShape, paddings } = attrs;
|
|
assertNotComplex([x], 'spaceToBatchND');
|
|
const prod = sizeFromShape(blockShape);
|
|
const completePaddings = [[0, 0]];
|
|
completePaddings.push(...paddings);
|
|
for (let i = 1 + blockShape.length; i < x.shape.length; ++i) {
|
|
completePaddings.push([0, 0]);
|
|
}
|
|
const paddedX = padV2Config.kernelFunc({
|
|
inputs: { x },
|
|
backend,
|
|
attrs: { paddings: completePaddings, constantValue: 0 }
|
|
});
|
|
const reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false);
|
|
const permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false);
|
|
const flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false);
|
|
const reshapeInputs = { x: paddedX };
|
|
const reshapeAttrs = { shape: reshapedPaddedShape };
|
|
const paddedXReshaped = reshape({ inputs: reshapeInputs, backend, attrs: reshapeAttrs });
|
|
const transposeInputs = { x: paddedXReshaped };
|
|
const transposeAttrs = { perm: permutedReshapedPaddedPermutation };
|
|
const paddedXT = transpose$1({ inputs: transposeInputs, backend, attrs: transposeAttrs });
|
|
const resultReshapeInputs = { x: paddedXT };
|
|
const resultReshapeAttrs = { shape: flattenShape };
|
|
const result = reshape({ inputs: resultReshapeInputs, backend, attrs: resultReshapeAttrs });
|
|
backend.disposeIntermediateTensorInfo(paddedX);
|
|
backend.disposeIntermediateTensorInfo(paddedXReshaped);
|
|
backend.disposeIntermediateTensorInfo(paddedXT);
|
|
return result;
|
|
}
|
|
const spaceToBatchNDConfig = {
|
|
kernelName: SpaceToBatchND,
|
|
backendName: 'cpu',
|
|
kernelFunc: spaceToBatchND
|
|
};
|
|
|
|
|
|
function sparseFillEmptyRows(args) {
|
|
const { inputs, backend } = args;
|
|
const { indices, values, denseShape, defaultValue } = inputs;
|
|
if (denseShape.shape.length !== 1) {
|
|
throw new Error(`Dense shape must be a vector, saw:
|
|
${denseShape.shape}`);
|
|
}
|
|
if (indices.shape.length !== 2) {
|
|
throw new Error(`Indices must be a matrix, saw:
|
|
${indices.shape}`);
|
|
}
|
|
if (values.shape.length !== 1) {
|
|
throw new Error(`Values must be a vector, saw:
|
|
${values.shape}`);
|
|
}
|
|
if (defaultValue.shape.length !== 0) {
|
|
throw new Error(`Default value must be a scalar, saw:
|
|
${defaultValue.shape}`);
|
|
}
|
|
const $indices = backend.data.get(indices.dataId).values;
|
|
const $values = backend.data.get(values.dataId).values;
|
|
const $denseShape = backend.data.get(denseShape.dataId).values;
|
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
|
const [outputIndices, outputIndicesShape, outputValues, emptyRowIndicator, reverseIndexMap] = sparseFillEmptyRowsImpl($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue);
|
|
return [
|
|
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
|
|
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
|
|
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map((value) => Number(value)))),
|
|
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
|
|
];
|
|
}
|
|
const sparseFillEmptyRowsConfig = {
|
|
kernelName: SparseFillEmptyRows,
|
|
backendName: 'cpu',
|
|
kernelFunc: sparseFillEmptyRows,
|
|
};
|
|
|
|
|
|
function sparseReshape(args) {
|
|
const { inputs, backend } = args;
|
|
const { inputIndices, inputShape, newShape } = inputs;
|
|
if (inputIndices.shape.length !== 2) {
|
|
throw new Error(`Input indices should be a matrix but received shape
|
|
${inputIndices.shape}`);
|
|
}
|
|
if (inputShape.shape.length !== 1) {
|
|
throw new Error(`Input shape should be a vector but received shape
|
|
${inputShape.shape}`);
|
|
}
|
|
if (newShape.shape.length !== 1) {
|
|
throw new Error(`Target shape should be a vector but received shape ${newShape.shape}`);
|
|
}
|
|
const $inputShape = Array.from(backend.data.get(inputShape.dataId).values);
|
|
const $inputIndices = backend.data.get(inputIndices.dataId).values;
|
|
const targetShape = Array.from(backend.data.get(newShape.dataId).values);
|
|
const [newIndices, indicesShape, outputShape] = sparseReshapeImpl($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape);
|
|
return [
|
|
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
|
|
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
|
|
];
|
|
}
|
|
const sparseReshapeConfig = {
|
|
kernelName: SparseReshape,
|
|
backendName: 'cpu',
|
|
kernelFunc: sparseReshape,
|
|
};
|
|
|
|
|
|
function sparseSegmentMean(args) {
|
|
const { inputs, backend } = args;
|
|
const { data, indices, segmentIds } = inputs;
|
|
if (data.shape.length < 1) {
|
|
throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
|
}
|
|
if (indices.shape.length !== 1) {
|
|
throw new Error(`Indices should be a vector but received shape
|
|
${indices.shape}`);
|
|
}
|
|
if (segmentIds.shape.length !== 1) {
|
|
throw new Error(`Segment ids should be a vector but received shape
|
|
${segmentIds.shape}`);
|
|
}
|
|
if (indices.shape[0] !== segmentIds.shape[0]) {
|
|
throw new Error(`segmentIds and indices should have same size.`);
|
|
}
|
|
const $data = backend.data.get(data.dataId).values;
|
|
const $indices = backend.data.get(indices.dataId).values;
|
|
const $segmentIds = backend.data.get(segmentIds.dataId).values;
|
|
const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds, true);
|
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
|
}
|
|
const sparseSegmentMeanConfig = {
|
|
kernelName: SparseSegmentMean,
|
|
backendName: 'cpu',
|
|
kernelFunc: sparseSegmentMean,
|
|
};
|
|
|
|
|
|
function sparseSegmentSum(args) {
|
|
const { inputs, backend } = args;
|
|
const { data, indices, segmentIds } = inputs;
|
|
if (data.shape.length < 1) {
|
|
throw new Error(`Data should be at least 1 dimensional but received scalar`);
|
|
}
|
|
if (indices.shape.length !== 1) {
|
|
throw new Error(`Indices should be a vector but received shape
|
|
${indices.shape}`);
|
|
}
|
|
if (segmentIds.shape.length !== 1) {
|
|
throw new Error(`Segment ids should be a vector but received shape
|
|
${segmentIds.shape}`);
|
|
}
|
|
if (indices.shape[0] !== segmentIds.shape[0]) {
|
|
throw new Error(`segmentIds and indices should have same size.`);
|
|
}
|
|
const $data = backend.data.get(data.dataId).values;
|
|
const $indices = backend.data.get(indices.dataId).values;
|
|
const $segmentIds = backend.data.get(segmentIds.dataId).values;
|
|
const [outputData, outputDataShape] = sparseSegmentReductionImpl($data, data.shape, data.dtype, $indices, $segmentIds);
|
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
|
}
|
|
const sparseSegmentSumConfig = {
|
|
kernelName: SparseSegmentSum,
|
|
backendName: 'cpu',
|
|
kernelFunc: sparseSegmentSum,
|
|
};
|
|
|
|
|
|
function sparseToDense(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { sparseIndices, sparseValues, defaultValue } = inputs;
|
|
const { outputShape } = attrs;
|
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(sparseValues, sparseIndices, outputShape);
|
|
const sumDupeIndices = false;
|
|
const indicesBuf = backend.bufferSync(sparseIndices);
|
|
let outBuf;
|
|
switch (sparseValues.dtype) {
|
|
case 'bool': {
|
|
const updatesBuf = backend.bufferSync(sparseValues);
|
|
const $defaultValue = Boolean(backend.data.get(defaultValue.dataId).values[0]);
|
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
|
break;
|
|
}
|
|
case 'float32': {
|
|
const updatesBuf = backend.bufferSync(sparseValues);
|
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
|
break;
|
|
}
|
|
case 'int32': {
|
|
const updatesBuf = backend.bufferSync(sparseValues);
|
|
const $defaultValue = backend.data.get(defaultValue.dataId).values[0];
|
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
|
break;
|
|
}
|
|
case 'string': {
|
|
const updatesBuf = backend.bufferSync(sparseValues);
|
|
const $defaultValue = decodeString(backend.data.get(defaultValue.dataId).values[0]);
|
|
outBuf = scatterImpl(indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
|
|
break;
|
|
}
|
|
default:
|
|
throw new Error(`Unsupported type ${sparseValues.dtype}`);
|
|
}
|
|
return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const sparseToDenseConfig = {
|
|
kernelName: SparseToDense,
|
|
backendName: 'cpu',
|
|
kernelFunc: sparseToDense
|
|
};
|
|
|
|
|
|
function splitV(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { numOrSizeSplits, axis } = attrs;
|
|
const $axis = parseAxisParam(axis, x.shape)[0];
|
|
const splitSizes = prepareSplitSize(x, numOrSizeSplits, $axis);
|
|
const begin = new Array(x.shape.length).fill(0);
|
|
const size = x.shape.slice();
|
|
return splitSizes.map(s => {
|
|
const sliceSize = [...size];
|
|
sliceSize[$axis] = s;
|
|
const sliceT = slice$1({ inputs: { x }, backend, attrs: { begin, size: sliceSize } });
|
|
begin[$axis] += s;
|
|
return sliceT;
|
|
});
|
|
}
|
|
const splitVConfig = {
|
|
kernelName: SplitV,
|
|
backendName: 'cpu',
|
|
kernelFunc: splitV
|
|
};
|
|
|
|
|
|
const squareConfig = {
|
|
kernelName: Square,
|
|
backendName: 'cpu',
|
|
kernelFunc: ({ inputs, backend }) => {
|
|
const { x } = inputs;
|
|
const cpuBackend = backend;
|
|
assertNotComplex(x, 'square');
|
|
const values = cpuBackend.data.get(x.dataId).values;
|
|
const newValues = new Float32Array(values.length);
|
|
for (let i = 0; i < values.length; ++i) {
|
|
const value = values[i];
|
|
newValues[i] = value * value;
|
|
}
|
|
const dataId = cpuBackend.write(newValues, x.shape, x.dtype);
|
|
return { dataId, shape: x.shape, dtype: x.dtype };
|
|
}
|
|
};
|
|
|
|
|
|
const step = unaryKernelFunc$1(Step, (xi, attrs) => {
|
|
const stepAttrs = attrs;
|
|
if (isNaN(xi)) {
|
|
return NaN;
|
|
}
|
|
else {
|
|
return xi > 0 ? 1 : stepAttrs.alpha;
|
|
}
|
|
});
|
|
const stepConfig = {
|
|
kernelName: Step,
|
|
backendName: 'cpu',
|
|
kernelFunc: step,
|
|
};
|
|
|
|
|
|
function stridedSlice(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask } = attrs;
|
|
assertNotComplex(x, 'stridedSlice');
|
|
const { finalShapeSparse, finalShape, isIdentity, sliceDim0, isSimpleSlice, begin: $begin, end: $end, strides: $strides } = sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
|
|
let result;
|
|
|
|
|
|
if (isIdentity) {
|
|
|
|
result = reshape({ inputs: { x }, backend, attrs: { shape: finalShape } });
|
|
}
|
|
else if (sliceDim0 || isSimpleSlice) {
|
|
|
|
assert$1(x.shape.length >= 1, () => `Input must have rank at least 1, got: ${x.shape.length}`);
|
|
const size = computeOutShape$2($begin, $end, $strides);
|
|
|
|
const sliced = slice$1({ inputs: { x }, backend, attrs: { begin: $begin, size } });
|
|
result =
|
|
reshape({ inputs: { x: sliced }, backend, attrs: { shape: finalShape } });
|
|
backend.disposeIntermediateTensorInfo(sliced);
|
|
}
|
|
else {
|
|
const xBuf = backend.bufferSync(x);
|
|
const outBuf = stridedSliceImpl(finalShapeSparse, xBuf, $strides, $begin);
|
|
result = backend.makeTensorInfo(finalShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
return result;
|
|
}
|
|
const stridedSliceConfig = {
|
|
kernelName: StridedSlice,
|
|
backendName: 'cpu',
|
|
kernelFunc: stridedSlice
|
|
};
|
|
|
|
|
|
function stringNGrams(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences } = attrs;
|
|
const { data, dataSplits } = inputs;
|
|
const $data = backend.data.get(data.dataId).values;
|
|
const $dataSplits = backend.data.get(dataSplits.dataId).values;
|
|
const [nGrams, nGramsSplits] = stringNGramsImpl($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences);
|
|
return [
|
|
backend.makeTensorInfo([nGrams.length], 'string', nGrams),
|
|
backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
|
|
];
|
|
}
|
|
const stringNGramsConfig = {
|
|
kernelName: StringNGrams,
|
|
backendName: 'cpu',
|
|
kernelFunc: stringNGrams,
|
|
};
|
|
|
|
|
|
function stringSplit(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { skipEmpty } = attrs;
|
|
const { input, delimiter } = inputs;
|
|
if (input.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
if (input.shape.length !== 1) {
|
|
throw new Error(`Input must be a vector, got shape: ${input.shape}`);
|
|
}
|
|
if (delimiter.shape.length !== 0) {
|
|
throw new Error(`Delimiter must be a scalar, got shape: ${delimiter.shape}`);
|
|
}
|
|
const $input = backend.data.get(input.dataId).values;
|
|
const $delimiter = backend.data.get(delimiter.dataId).values[0];
|
|
const [indices, values, shape] = stringSplitImpl($input, $delimiter, skipEmpty);
|
|
const outputSize = values.length;
|
|
return [
|
|
backend.makeTensorInfo([outputSize, 2], 'int32', indices),
|
|
backend.makeTensorInfo([outputSize], 'string', values),
|
|
backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
|
|
];
|
|
}
|
|
const stringSplitConfig = {
|
|
kernelName: StringSplit,
|
|
backendName: 'cpu',
|
|
kernelFunc: stringSplit,
|
|
};
|
|
|
|
|
|
function stringToHashBucketFast(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { numBuckets } = attrs;
|
|
const { input } = inputs;
|
|
if (input.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
if (numBuckets <= 0) {
|
|
throw new Error(`Number of buckets must be at least 1`);
|
|
}
|
|
const $input = backend.data.get(input.dataId).values;
|
|
const output = stringToHashBucketFastImpl($input, numBuckets);
|
|
return backend.makeTensorInfo(input.shape, 'int32', output);
|
|
}
|
|
const stringToHashBucketFastConfig = {
|
|
kernelName: StringToHashBucketFast,
|
|
backendName: 'cpu',
|
|
kernelFunc: stringToHashBucketFast,
|
|
};
|
|
|
|
|
|
const tan = unaryKernelFunc$1(Tan, (xi) => Math.tan(xi));
|
|
const tanConfig = {
|
|
kernelName: Tan,
|
|
backendName: 'cpu',
|
|
kernelFunc: tan,
|
|
};
|
|
|
|
|
|
const tanh = unaryKernelFunc$1(Tanh$1, (xi) => Math.tanh(xi));
|
|
const tanhConfig = {
|
|
kernelName: Tanh$1,
|
|
backendName: 'cpu',
|
|
kernelFunc: tanh,
|
|
};
|
|
|
|
|
|
function tensorScatterUpdate(args) {
|
|
const { inputs, backend } = args;
|
|
const { tensor, indices, updates } = inputs;
|
|
const { sliceRank, numUpdates, sliceSize, strides, outputSize } = calculateShapes(updates, indices, tensor.shape);
|
|
const sumDupeIndices = false;
|
|
const indicesBuf = backend.bufferSync(indices);
|
|
const updatesBuf = backend.bufferSync(updates);
|
|
const tensorBuf = backend.bufferSync(tensor);
|
|
const outBuf = scatterImpl(indicesBuf, updatesBuf, tensor.shape, outputSize, sliceSize, numUpdates, sliceRank, strides, tensorBuf, sumDupeIndices);
|
|
return backend.makeTensorInfo(tensor.shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const tensorScatterUpdateConfig = {
|
|
kernelName: TensorScatterUpdate,
|
|
backendName: 'cpu',
|
|
kernelFunc: tensorScatterUpdate
|
|
};
|
|
|
|
|
|
function tile$1(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { reps } = attrs;
|
|
assertNotComplex(x, 'tile');
|
|
const outBuf = tileImpl(backend.bufferSync(x), reps);
|
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
const tileConfig = {
|
|
kernelName: Tile,
|
|
backendName: 'cpu',
|
|
kernelFunc: tile$1
|
|
};
|
|
|
|
|
|
function topK(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x } = inputs;
|
|
const { k, sorted } = attrs;
|
|
assertNotComplex(x, 'topk');
|
|
const xVals = backend.data.get(x.dataId).values;
|
|
const [allTopKVals, allTopKIndices] = topKImpl(xVals, x.shape, x.dtype, k, sorted);
|
|
return [
|
|
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
|
|
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
|
|
];
|
|
}
|
|
const topKConfig = {
|
|
kernelName: TopK,
|
|
backendName: 'cpu',
|
|
kernelFunc: topK
|
|
};
|
|
|
|
|
|
function transform(args) {
|
|
const { inputs, attrs, backend } = args;
|
|
const { image, transforms } = inputs;
|
|
const { interpolation, fillMode, fillValue, outputShape } = attrs;
|
|
const [batch, imageHeight, imageWidth, numChannels] = image.shape;
|
|
const [outHeight, outWidth] = outputShape != null ? outputShape : [imageHeight, imageWidth];
|
|
const outShape = [batch, outHeight, outWidth, numChannels];
|
|
const inStrides = computeStrides(image.shape);
|
|
const batchInStride = inStrides[0];
|
|
const rowInStride = inStrides[1];
|
|
const colInStride = inStrides[2];
|
|
const outStrides = computeStrides(outShape);
|
|
const batchOutStride = outStrides[0];
|
|
const rowOutStride = outStrides[1];
|
|
const colOutStride = outStrides[2];
|
|
const outVals = getTypedArrayFromDType(image.dtype, sizeFromShape(outShape));
|
|
outVals.fill(fillValue);
|
|
const imageVals = backend.data.get(image.dataId).values;
|
|
const transformVals = backend.data.get(transforms.dataId).values;
|
|
|
|
|
|
for (let b = 0; b < batch; ++b) {
|
|
const transform = transforms.shape[0] === 1 ?
|
|
transformVals :
|
|
transformVals.subarray(b * 8, b * 8 + 8);
|
|
for (let outY = 0; outY < outHeight; ++outY) {
|
|
for (let outX = 0; outX < outWidth; ++outX) {
|
|
for (let channel = 0; channel < numChannels; ++channel) {
|
|
let val;
|
|
const projection = transform[6] * outX + transform[7] * outY + 1;
|
|
if (projection === 0) {
|
|
|
|
|
|
continue;
|
|
}
|
|
const inX = (transform[0] * outX + transform[1] * outY + transform[2]) /
|
|
projection;
|
|
const inY = (transform[3] * outX + transform[4] * outY + transform[5]) /
|
|
projection;
|
|
const x = mapCoord(inX, imageWidth, fillMode);
|
|
const y = mapCoord(inY, imageHeight, fillMode);
|
|
switch (interpolation) {
|
|
case 'nearest':
|
|
val = nearestInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
|
|
break;
|
|
case 'bilinear':
|
|
val = bilinearInterpolation(imageVals, imageHeight, imageWidth, batchInStride, rowInStride, colInStride, b, y, x, channel, fillValue);
|
|
break;
|
|
default:
|
|
throw new Error(`Error in Transform: Expect 'nearest' or ` +
|
|
`'bilinear', but got ${interpolation}`);
|
|
}
|
|
const ind = b * batchOutStride + outY * rowOutStride +
|
|
outX * colOutStride + channel;
|
|
outVals[ind] = val;
|
|
}
|
|
}
|
|
}
|
|
return backend.makeTensorInfo(outShape, image.dtype, outVals);
|
|
}
|
|
const dataId = backend.write(outVals, outShape, image.dtype);
|
|
return { dataId, shape: image.shape, dtype: image.dtype };
|
|
}
|
|
const transformConfig = {
|
|
kernelName: Transform,
|
|
backendName: 'cpu',
|
|
kernelFunc: transform
|
|
};
|
|
function mapCoord(outCoord, len, mode) {
|
|
switch (mode) {
|
|
case 'reflect':
|
|
return mapCoordReflect(outCoord, len);
|
|
case 'wrap':
|
|
return mapCoordWrap(outCoord, len);
|
|
case 'nearest':
|
|
return mapCoordNearest(outCoord, len);
|
|
case 'constant':
|
|
default:
|
|
return mapCoordConstant(outCoord);
|
|
}
|
|
}
|
|
function mapCoordReflect(outCoord, len) {
|
|
|
|
let inCoord = outCoord;
|
|
if (inCoord < 0) {
|
|
if (len <= 1) {
|
|
inCoord = 0;
|
|
}
|
|
else {
|
|
const sz2 = 2 * len;
|
|
if (inCoord < sz2) {
|
|
inCoord = sz2 * Math.trunc(-inCoord / sz2) + inCoord;
|
|
}
|
|
inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1;
|
|
}
|
|
}
|
|
else if (inCoord > len - 1) {
|
|
if (len <= 1) {
|
|
inCoord = 0;
|
|
}
|
|
else {
|
|
const sz2 = 2 * len;
|
|
inCoord -= sz2 * Math.trunc(inCoord / sz2);
|
|
if (inCoord >= len) {
|
|
inCoord = sz2 - inCoord - 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
return clamp(0, inCoord, len - 1);
|
|
}
|
|
function mapCoordWrap(outCoord, len) {
|
|
|
|
let inCoord = outCoord;
|
|
if (inCoord < 0) {
|
|
if (len <= 1) {
|
|
inCoord = 0;
|
|
}
|
|
else {
|
|
const sz = len - 1;
|
|
inCoord += len * (Math.trunc(-inCoord / sz) + 1);
|
|
}
|
|
}
|
|
else if (inCoord > len - 1) {
|
|
if (len <= 1) {
|
|
inCoord = 0;
|
|
}
|
|
else {
|
|
const sz = len - 1;
|
|
inCoord -= len * Math.trunc(inCoord / sz);
|
|
}
|
|
}
|
|
|
|
|
|
return clamp(0, inCoord, len - 1);
|
|
}
|
|
function mapCoordConstant(outCoord, len) {
|
|
return outCoord;
|
|
}
|
|
function mapCoordNearest(outCoord, len) {
|
|
return clamp(0, outCoord, len - 1);
|
|
}
|
|
function readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
|
|
const ind = batch * batchStride + y * rowStride + x * colStride + channel;
|
|
if (0 <= y && y < imageHeight && 0 <= x && x < imageWidth) {
|
|
return imageVals[ind];
|
|
}
|
|
else {
|
|
return fillValue;
|
|
}
|
|
}
|
|
function nearestInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
|
|
const $y = Math.round(y);
|
|
const $x = Math.round(x);
|
|
return readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, $y, $x, channel, fillValue);
|
|
}
|
|
function bilinearInterpolation(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, y, x, channel, fillValue) {
|
|
const yFloor = Math.floor(y);
|
|
const xFloor = Math.floor(x);
|
|
const yCeil = yFloor + 1;
|
|
const xCeil = xFloor + 1;
|
|
|
|
|
|
const valueYFloor = (xCeil - x) *
|
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xFloor, channel, fillValue) +
|
|
(x - xFloor) *
|
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yFloor, xCeil, channel, fillValue);
|
|
|
|
|
|
const valueYCeil = (xCeil - x) *
|
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xFloor, channel, fillValue) +
|
|
(x - xFloor) *
|
|
readWithFillValue(imageVals, imageHeight, imageWidth, batchStride, rowStride, colStride, batch, yCeil, xCeil, channel, fillValue);
|
|
|
|
|
|
return (yCeil - y) * valueYFloor + (y - yFloor) * valueYCeil;
|
|
}
|
|
|
|
|
|
function unique$1(args) {
|
|
const { inputs, attrs, backend } = args;
|
|
const { axis } = attrs;
|
|
const { x } = inputs;
|
|
assertNotComplex(x, 'unique');
|
|
const values = backend.data.get(x.dataId).values;
|
|
const { outputValues, outputShape, indices } = uniqueImpl(values, axis, x.shape, x.dtype);
|
|
return [
|
|
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
|
|
backend.makeTensorInfo([indices.length], 'int32', indices),
|
|
];
|
|
}
|
|
const uniqueConfig = {
|
|
kernelName: Unique,
|
|
backendName: 'cpu',
|
|
kernelFunc: unique$1,
|
|
};
|
|
|
|
|
|
function unpack(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { value } = inputs;
|
|
let { axis } = attrs;
|
|
if (axis < 0) {
|
|
axis += value.shape.length;
|
|
}
|
|
const valueRank = value.shape.length;
|
|
const num = value.shape[axis];
|
|
const outShape = new Array(valueRank - 1);
|
|
let outIndex = 0;
|
|
for (let i = 0; i < valueRank; i++) {
|
|
if (i !== axis) {
|
|
outShape[outIndex++] = value.shape[i];
|
|
}
|
|
}
|
|
const begin = new Array(valueRank).fill(0);
|
|
const size = value.shape.slice();
|
|
size[axis] = 1;
|
|
const res = new Array(num);
|
|
for (let i = 0; i < res.length; i++) {
|
|
begin[axis] = i;
|
|
const tempRes = slice$1({ inputs: { x: value }, backend, attrs: { begin, size } });
|
|
res[i] = reshape({ inputs: { x: tempRes }, backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo(tempRes);
|
|
}
|
|
return res;
|
|
}
|
|
const unpackConfig = {
|
|
kernelName: Unpack,
|
|
backendName: 'cpu',
|
|
kernelFunc: unpack
|
|
};
|
|
|
|
|
|
function unsortedSegmentSum(args) {
|
|
const { inputs, backend, attrs } = args;
|
|
const { x, segmentIds } = inputs;
|
|
const { numSegments } = attrs;
|
|
assertNotComplex(x, 'unsortedSegmentSum');
|
|
const xRank = x.shape.length;
|
|
const segmentIdsRank = segmentIds.shape.length;
|
|
const res = [];
|
|
const intermediates = [];
|
|
|
|
|
|
const numIters = xRank - segmentIdsRank;
|
|
let $segmentIds = segmentIds;
|
|
for (let i = 0; i < numIters; ++i) {
|
|
const expanded = expandDims$1({ inputs: { input: $segmentIds }, backend, attrs: { dim: i + 1 } });
|
|
$segmentIds = expanded;
|
|
intermediates.push(expanded);
|
|
}
|
|
for (let i = 0; i < numSegments; ++i) {
|
|
const scalarValue = createScalarValue(i, 'int32');
|
|
const segmentId = backend.makeTensorInfo([], 'int32', scalarValue);
|
|
const mask = equal$1({ inputs: { a: segmentId, b: $segmentIds }, backend });
|
|
const maskCasted = cast$2({ inputs: { x: mask }, backend, attrs: { dtype: 'float32' } });
|
|
const mul = multiply$1({ inputs: { a: maskCasted, b: x }, backend });
|
|
const sumTensorInfo = sum({ inputs: { x: mul }, backend, attrs: { axis: 0, keepDims: false } });
|
|
res.push(sumTensorInfo);
|
|
intermediates.push(segmentId);
|
|
intermediates.push(mask);
|
|
intermediates.push(maskCasted);
|
|
intermediates.push(mul);
|
|
intermediates.push(sumTensorInfo);
|
|
}
|
|
const result = pack({ inputs: res, backend, attrs: { axis: 0 } });
|
|
intermediates.forEach(t => backend.disposeIntermediateTensorInfo(t));
|
|
return result;
|
|
}
|
|
const unsortedSegmentSumConfig = {
|
|
kernelName: UnsortedSegmentSum,
|
|
backendName: 'cpu',
|
|
kernelFunc: unsortedSegmentSum
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const kernelConfigs = [
|
|
_fusedMatMulConfig,
|
|
absConfig$1,
|
|
acosConfig,
|
|
acoshConfig,
|
|
addConfig$1,
|
|
addNConfig,
|
|
allConfig,
|
|
anyConfig,
|
|
argMaxConfig,
|
|
argMinConfig,
|
|
asinConfig,
|
|
asinhConfig,
|
|
atanConfig,
|
|
atan2Config,
|
|
atanhConfig,
|
|
avgPoolConfig,
|
|
avgPool3DConfig,
|
|
avgPool3DGradConfig$1,
|
|
avgPoolGradConfig$1,
|
|
batchMatMulConfig,
|
|
batchNormConfig,
|
|
batchToSpaceNDConfig,
|
|
bincountConfig,
|
|
bitwiseAndConfig$1,
|
|
broadcastArgsConfig,
|
|
castConfig$1,
|
|
ceilConfig$1,
|
|
clipByValueConfig,
|
|
complexConfig$1,
|
|
complexAbsConfig,
|
|
concatConfig,
|
|
conv2DConfig,
|
|
conv2DBackpropFilterConfig,
|
|
conv2DBackpropInputConfig,
|
|
conv3DConfig,
|
|
conv3DBackpropFilterV2Config,
|
|
conv3DBackpropInputV2Config,
|
|
cosConfig,
|
|
coshConfig,
|
|
cropAndResizeConfig,
|
|
cumprodConfig,
|
|
cumsumConfig,
|
|
denseBincountConfig,
|
|
depthToSpaceConfig,
|
|
depthwiseConv2dNativeConfig,
|
|
depthwiseConv2dNativeBackpropFilterConfig,
|
|
depthwiseConv2dNativeBackpropInputConfig,
|
|
diagConfig,
|
|
dilation2DConfig,
|
|
dilation2DBackpropFilterConfig,
|
|
dilation2DBackpropInputConfig,
|
|
drawConfig,
|
|
einsumConfig,
|
|
eluConfig,
|
|
eluGradConfig$1,
|
|
equalConfig$1,
|
|
erfConfig,
|
|
expConfig$1,
|
|
expandDimsConfig,
|
|
expm1Config$1,
|
|
fftConfig,
|
|
fillConfig,
|
|
flipLeftRightConfig,
|
|
floorConfig$1,
|
|
floorDivConfig$1,
|
|
fusedConv2DConfig,
|
|
fusedDepthwiseConv2DConfig,
|
|
gatherNdConfig,
|
|
gatherV2Config,
|
|
greaterConfig$1,
|
|
greaterEqualConfig$1,
|
|
identityConfig$1,
|
|
ifftConfig,
|
|
imagConfig,
|
|
isFiniteConfig,
|
|
isInfConfig,
|
|
isNaNConfig,
|
|
leakyReluConfig,
|
|
lessConfig$1,
|
|
lessEqualConfig$1,
|
|
linSpaceConfig,
|
|
logConfig$1,
|
|
log1pConfig,
|
|
logicalAndConfig,
|
|
logicalNotConfig,
|
|
logicalOrConfig,
|
|
LRNConfig,
|
|
LRNGradConfig,
|
|
maxConfig,
|
|
maximumConfig$1,
|
|
maxPoolConfig,
|
|
maxPool3DConfig,
|
|
maxPool3DGradConfig$1,
|
|
maxPoolGradConfig$1,
|
|
maxPoolWithArgmaxConfig,
|
|
meanConfig,
|
|
minConfig,
|
|
minimumConfig$1,
|
|
mirrorPadConfig,
|
|
modConfig,
|
|
multinomialConfig,
|
|
multiplyConfig$1,
|
|
negConfig$1,
|
|
nonMaxSuppressionV3Config,
|
|
nonMaxSuppressionV4Config,
|
|
nonMaxSuppressionV5Config,
|
|
notEqualConfig$1,
|
|
oneHotConfig,
|
|
onesLikeConfig,
|
|
packConfig,
|
|
padV2Config,
|
|
powConfig,
|
|
preluConfig,
|
|
prodConfig$1,
|
|
raggedGatherConfig,
|
|
raggedRangeConfig,
|
|
raggedTensorToTensorConfig,
|
|
rangeConfig,
|
|
realConfig$1,
|
|
realDivConfig,
|
|
reciprocalConfig,
|
|
reluConfig,
|
|
relu6Config,
|
|
reshapeConfig,
|
|
resizeBilinearConfig,
|
|
resizeBilinearGradConfig$1,
|
|
resizeNearestNeighborConfig,
|
|
resizeNearestNeighborGradConfig$1,
|
|
reverseConfig,
|
|
rotateWithOffsetConfig,
|
|
roundConfig,
|
|
rsqrtConfig$1,
|
|
scatterNdConfig,
|
|
searchSortedConfig,
|
|
selectConfig,
|
|
seluConfig,
|
|
sigmoidConfig$1,
|
|
signConfig,
|
|
sinConfig,
|
|
sinhConfig,
|
|
sliceConfig$1,
|
|
softmaxConfig,
|
|
softplusConfig,
|
|
spaceToBatchNDConfig,
|
|
sparseFillEmptyRowsConfig,
|
|
sparseReshapeConfig,
|
|
sparseSegmentMeanConfig,
|
|
sparseSegmentSumConfig,
|
|
sparseToDenseConfig,
|
|
splitVConfig,
|
|
sqrtConfig$1,
|
|
squareConfig,
|
|
squaredDifferenceConfig$1,
|
|
staticRegexReplaceConfig$1,
|
|
stepConfig,
|
|
stridedSliceConfig,
|
|
stringNGramsConfig,
|
|
stringSplitConfig,
|
|
stringToHashBucketFastConfig,
|
|
subConfig$1,
|
|
sumConfig,
|
|
tanConfig,
|
|
tanhConfig,
|
|
tensorScatterUpdateConfig,
|
|
tileConfig,
|
|
topKConfig,
|
|
transformConfig,
|
|
transposeConfig$1,
|
|
uniqueConfig,
|
|
unpackConfig,
|
|
unsortedSegmentSumConfig,
|
|
zerosLikeConfig
|
|
];
|
|
for (const kernelConfig of kernelConfigs) {
|
|
registerKernel(kernelConfig);
|
|
}
|
|
|
|
|
|
const absGradConfig = {
|
|
kernelName: Abs,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(dy, step$2(cast$3(x, 'float32'), -1)) };
|
|
}
|
|
};
|
|
|
|
|
|
const acosGradConfig = {
|
|
kernelName: Acos,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return {
|
|
x: () => {
|
|
const a = square$2(cast$3(x, 'float32'));
|
|
const b = sqrt$2(sub$2(scalar(1), a));
|
|
return neg$2(div$1(dy, b));
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const acoshGradConfig = {
|
|
kernelName: Acosh,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return {
|
|
x: () => {
|
|
const a = sqrt$2(sub$2(square$2(cast$3(x, 'float32')), 1));
|
|
return div$1(dy, a);
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const addGradConfig = {
|
|
kernelName: Add,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
let res = dy;
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, a.shape);
|
|
};
|
|
const derB = () => {
|
|
let res = dy;
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, b.shape);
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const addNGradConfig = {
|
|
kernelName: AddN,
|
|
saveAllInputs: true,
|
|
gradFunc: (dy, saved) => {
|
|
const ders = {};
|
|
saved.forEach((_, i) => {
|
|
ders[i] = () => dy.clone();
|
|
});
|
|
return ders;
|
|
}
|
|
};
|
|
|
|
|
|
const argMaxGradConfig = {
|
|
kernelName: ArgMax,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => zerosLike$2(x) };
|
|
}
|
|
};
|
|
|
|
|
|
const argMinGradConfig = {
|
|
kernelName: ArgMin,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => zerosLike$2(x) };
|
|
}
|
|
};
|
|
|
|
|
|
const asinGradConfig = {
|
|
kernelName: Asin,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, sqrt$2(sub$2(scalar(1), square$2(cast$3(x, 'float32'))))) };
|
|
}
|
|
};
|
|
|
|
|
|
const asinhGradConfig = {
|
|
kernelName: Asinh,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return {
|
|
x: () => {
|
|
const a = sqrt$2(add$1(scalar(1), square$2(cast$3(x, 'float32'))));
|
|
return div$1(dy, a);
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const atan2GradConfig = {
|
|
kernelName: Atan2,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
const d = add$1(square$2(a), square$2(b));
|
|
let res = mul(dy, div$1(b, d));
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, a.shape);
|
|
};
|
|
const derB = () => {
|
|
const d = add$1(square$2(a), square$2(b));
|
|
let res = neg$2(mul(dy, div$1(a, d)));
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, b.shape);
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const atanGradConfig = {
|
|
kernelName: Atan,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, add$1(square$2(cast$3(x, 'float32')), 1)) };
|
|
}
|
|
};
|
|
|
|
|
|
const atanhGradConfig = {
|
|
kernelName: Atanh,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, sub$2(scalar(1), square$2(cast$3(x, 'float32')))) };
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function avgPool3dGrad_(dy, input, filterSize, strides, pad, dimRoundingMode) {
|
|
const $dy = convertToTensor(dy, 'dy', 'avgPool3dGrad');
|
|
const $input = convertToTensor(input, 'input', 'avgPool3dGrad');
|
|
let dy5D = $dy;
|
|
let input5D = $input;
|
|
let reshapedTo5D = false;
|
|
if ($input.rank === 4) {
|
|
reshapedTo5D = true;
|
|
dy5D = reshape$2($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
|
|
input5D = reshape$2($input, [
|
|
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
|
|
]);
|
|
}
|
|
assert$1(dy5D.rank === 5, () => `Error in avgPool3dGrad: dy must be rank 5 but got rank ` +
|
|
`${dy5D.rank}.`);
|
|
assert$1(input5D.rank === 5, () => `Error in avgPool3dGrad: input must be rank 5 but got rank ` +
|
|
`${input5D.rank}.`);
|
|
checkPadOnDimRoundingMode('avgPool3dGrad', pad, dimRoundingMode);
|
|
const inputs = { dy: dy5D, input: input5D };
|
|
const attrs = { filterSize, strides, pad, dimRoundingMode };
|
|
|
|
const res = ENGINE.runKernel(AvgPool3DGrad, inputs, attrs);
|
|
if (reshapedTo5D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
|
}
|
|
return res;
|
|
}
|
|
const avgPool3dGrad = op({ avgPool3dGrad_ });
|
|
|
|
|
|
const avgPool3DGradConfig = {
|
|
kernelName: AvgPool3D,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
return {
|
|
x: () => avgPool3dGrad(dy, x, filterSize, strides, pad, dimRoundingMode)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function avgPoolGrad_(dy, input, filterSize, strides, pad) {
|
|
const $dy = convertToTensor(dy, 'dy', 'avgPoolGrad');
|
|
const $input = convertToTensor(input, 'input', 'avgPoolGrad');
|
|
assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy (${$dy.rank})`);
|
|
let input4D = $input;
|
|
let dy4D = $dy;
|
|
let reshapedTo4D = false;
|
|
if ($input.rank === 3) {
|
|
reshapedTo4D = true;
|
|
input4D =
|
|
reshape$2($input, [1, $input.shape[0], $input.shape[1], $input.shape[2]]);
|
|
dy4D = reshape$2($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2]]);
|
|
}
|
|
assert$1(dy4D.rank === 4, () => `Error in avgPoolGrad: dy must be rank 4 but got rank ` +
|
|
`${dy4D.rank}.`);
|
|
assert$1(input4D.rank === 4, () => `Error in avgPoolGrad: input must be rank 4 but got rank ` +
|
|
`${input4D.rank}.`);
|
|
const inputs = { dy: dy4D, input: input4D };
|
|
const attrs = { filterSize, strides, pad };
|
|
|
|
const res = ENGINE.runKernel(AvgPoolGrad, inputs, attrs);
|
|
if (reshapedTo4D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3]]);
|
|
}
|
|
return res;
|
|
}
|
|
const avgPoolGrad = op({ avgPoolGrad_ });
|
|
|
|
|
|
const avgPoolGradConfig = {
|
|
kernelName: AvgPool,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { filterSize, strides, pad } = attrs;
|
|
return { x: () => avgPoolGrad(dy, x, filterSize, strides, pad) };
|
|
}
|
|
};
|
|
|
|
|
|
const batchMatMulGradConfig = {
|
|
kernelName: BatchMatMul,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [a, b] = saved;
|
|
const { transposeA, transposeB } = attrs;
|
|
if (!transposeA && !transposeB) {
|
|
return {
|
|
a: () => matMul$1(dy, b, false, true),
|
|
b: () => matMul$1(a, dy, true, false)
|
|
};
|
|
}
|
|
else if (!transposeA && transposeB) {
|
|
return {
|
|
a: () => matMul$1(dy, b, false, false),
|
|
b: () => matMul$1(dy, a, true, false)
|
|
};
|
|
}
|
|
else if (transposeA && !transposeB) {
|
|
return {
|
|
a: () => matMul$1(b, dy, false, true),
|
|
b: () => matMul$1(a, dy, false, false)
|
|
};
|
|
}
|
|
else {
|
|
return {
|
|
a: () => matMul$1(b, dy, true, true),
|
|
b: () => matMul$1(dy, a, true, true)
|
|
};
|
|
}
|
|
}
|
|
};
|
|
|
|
|
|
const batchToSpaceNDGradConfig = {
|
|
kernelName: BatchToSpaceND,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { blockShape, crops } = attrs;
|
|
return { x: () => spaceToBatchND$2(dy, blockShape, crops) };
|
|
}
|
|
};
|
|
|
|
|
|
const broadcastToGradConfig = {
|
|
kernelName: BroadcastTo,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const broadCastToAttrs = attrs;
|
|
const inputShape = broadCastToAttrs.inputShape;
|
|
const outputShape = broadCastToAttrs.shape;
|
|
const reps = Array.from(outputShape);
|
|
for (let i = inputShape.length - 1; i >= 0; i--) {
|
|
if (inputShape[i] === outputShape[i]) {
|
|
reps[i] = 1;
|
|
}
|
|
else if (inputShape[i] !== 1) {
|
|
throw new Error(`broadcastTo(): [${inputShape}] cannot be broadcast to [${outputShape}].`);
|
|
}
|
|
}
|
|
const axes = [];
|
|
for (let i = 0; i < reps.length; i++) {
|
|
if (reps[i] > 1) {
|
|
axes.push(i);
|
|
}
|
|
}
|
|
return { x: () => sum$2(dy, axes, true ) };
|
|
}
|
|
};
|
|
|
|
|
|
const castGradConfig = {
|
|
kernelName: Cast,
|
|
gradFunc: (dy) => {
|
|
return { x: () => dy.clone() };
|
|
}
|
|
};
|
|
|
|
|
|
const ceilGradConfig = {
|
|
kernelName: Ceil,
|
|
gradFunc: (dy) => {
|
|
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const clipByValueGradConfig = {
|
|
kernelName: ClipByValue,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { clipValueMin, clipValueMax } = attrs;
|
|
return {
|
|
x: () => where(logicalAnd$2(greaterEqual$2(x, clipValueMin), lessEqual$2(x, clipValueMax)), dy, zerosLike$2(dy)),
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const complexAbsGradConfig = {
|
|
kernelName: ComplexAbs,
|
|
inputsToSave: ['x'],
|
|
gradFunc: absGradConfig.gradFunc,
|
|
};
|
|
|
|
|
|
const concatGradConfig = {
|
|
kernelName: Concat,
|
|
saveAllInputs: true,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const shapes = saved.map(t => t.shape);
|
|
const { axis } = attrs;
|
|
const $axis = parseAxisParam(axis, saved[0].shape)[0];
|
|
const sizeSplits = shapes.map(s => s[$axis]);
|
|
const derTensors = split$1(dy, sizeSplits, $axis);
|
|
return derTensors.map(t => () => t);
|
|
}
|
|
};
|
|
|
|
|
|
const conv2DGradConfig = {
|
|
kernelName: Conv2D,
|
|
inputsToSave: ['x', 'filter'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x4D, $filter] = saved;
|
|
const { dilations, strides, pad, dataFormat } = attrs;
|
|
assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv2D: dilation rates greater than 1 ' +
|
|
`are not yet supported in gradients. Got dilations '${dilations}'`);
|
|
return {
|
|
x: () => conv2DBackpropInput$2(x4D.shape, dy, $filter, strides, pad, dataFormat),
|
|
filter: () => conv2DBackpropFilter$2(x4D, dy, $filter.shape, strides, pad, dataFormat)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const conv2DBackpropInputGradConfig = {
|
|
kernelName: Conv2DBackpropInput,
|
|
inputsToSave: ['dy', 'filter'],
|
|
gradFunc: (ddx, saved, attrs) => {
|
|
const [dy, filter] = saved;
|
|
const { strides, pad, dataFormat, dimRoundingMode } = attrs;
|
|
return {
|
|
dy: () => conv2d$1(ddx, filter, strides, pad, dataFormat, 1 , dimRoundingMode),
|
|
filter: () => conv2DBackpropFilter$2(ddx, dy, filter.shape, strides, pad, dataFormat, dimRoundingMode)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function conv3DBackpropFilter_(x, dy, filterShape, strides, pad) {
|
|
let x5D = x;
|
|
if (x.rank === 4) {
|
|
x5D = reshape$2(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
|
|
}
|
|
let dy5D = dy;
|
|
if (dy5D.rank === 4) {
|
|
dy5D = reshape$2(dy, [1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]]);
|
|
}
|
|
assert$1(x5D.rank === 5, () => `Error in conv3dDerFilter: input must be rank 5, but got shape ` +
|
|
`${x5D.shape}.`);
|
|
assert$1(dy5D.rank === 5, () => `Error in conv3dDerFilter: dy must be rank 5, but got shape ` +
|
|
`${dy5D.shape}.`);
|
|
assert$1(filterShape.length === 5, () => `Error in conv3dDerFilter: filterShape must be length 5, but got ` +
|
|
`${filterShape}.`);
|
|
assert$1(x5D.shape[4] === filterShape[3], () => `Error in conv3dDerFilter: depth of input ${x5D.shape[4]}) must ` +
|
|
`match input depth in filter (${filterShape[3]}.`);
|
|
assert$1(dy5D.shape[4] === filterShape[4], () => `Error in conv3dDerFilter: depth of dy (${dy5D.shape[4]}) must ` +
|
|
`match output depth for filter (${filterShape[4]}).`);
|
|
const inputs = { x: x5D, dy: dy5D };
|
|
const attrs = { strides, pad, filterShape };
|
|
|
|
return ENGINE.runKernel(Conv3DBackpropFilterV2, inputs, attrs);
|
|
}
|
|
const conv3DBackpropFilter = op({ conv3DBackpropFilter_ });
|
|
|
|
|
|
const conv3DGradConfig = {
|
|
kernelName: Conv3D,
|
|
inputsToSave: ['x', 'filter'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { dilations, strides, pad } = attrs;
|
|
assert$1(tupleValuesAreOne(dilations), () => 'Error in gradient of conv3D: dilation rates greater than 1 are ' +
|
|
`not yet supported in gradients. Got dilations '${dilations}'`);
|
|
const [x5D, $filter] = saved;
|
|
return {
|
|
x: () => conv3DBackpropInput$1(x5D.shape, dy, $filter, strides, pad),
|
|
filter: () => conv3DBackpropFilter(x5D, dy, $filter.shape, strides, pad)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const cosGradConfig = {
|
|
kernelName: Cos,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(neg$2(sin$2(cast$3(x, 'float32'))), dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const coshGradConfig = {
|
|
kernelName: Cosh,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(sinh$2(cast$3(x, 'float32')), dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const cumsumGradConfig = {
|
|
kernelName: Cumsum,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { axis, exclusive, reverse } = attrs;
|
|
return {
|
|
x: () => {
|
|
const permutation = getAxesPermutation([axis], x.rank);
|
|
let out = cumsum$2(dy, axis, exclusive, !reverse);
|
|
if (permutation != null) {
|
|
out = transpose$2(out, permutation);
|
|
}
|
|
return out;
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const depthwiseConv2dNativeGradConfig = {
|
|
kernelName: DepthwiseConv2dNative,
|
|
inputsToSave: ['x', 'filter'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { dilations, strides, pad, dimRoundingMode } = attrs;
|
|
const $dilations = dilations == null ? [1, 1] : dilations;
|
|
assert$1(tupleValuesAreOne($dilations), () => 'Error in gradient of depthwiseConv2dNative: dilation rates ' +
|
|
`greater than 1 are not yet supported. Got dilations ` +
|
|
`'${$dilations}'`);
|
|
const [x, filter] = saved;
|
|
assert$1(x.rank === 4, () => `Error in gradient of depthwiseConv2dNative: input must be ` +
|
|
`rank 4, but got rank ${x.rank}.`);
|
|
assert$1(filter.rank === 4, () => `Error in gradient of depthwiseConv2dNative: filter must be ` +
|
|
`rank 4, but got rank ${filter.rank}.`);
|
|
assert$1(x.shape[3] === filter.shape[2], () => `Error in gradient of depthwiseConv2d: number of input ` +
|
|
`channels (${x.shape[3]}) must match the inChannels dimension ` +
|
|
`in filter ${filter.shape[2]}.`);
|
|
assert$1(eitherStridesOrDilationsAreOne(strides, $dilations), () => 'Error in gradient of depthwiseConv2d: Either strides or ' +
|
|
`dilations must be 1. Got strides ${strides} and dilations ` +
|
|
`'${$dilations}'.`);
|
|
checkPadOnDimRoundingMode('depthwiseConv2d', pad, dimRoundingMode);
|
|
return {
|
|
x: () => depthwiseConv2dNativeBackpropInput$2(x.shape, dy, filter, strides, pad, $dilations, dimRoundingMode),
|
|
filter: () => depthwiseConv2dNativeBackpropFilter$2(x, dy, filter.shape, strides, pad, $dilations, dimRoundingMode),
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const dilation2dGradConfig = {
|
|
kernelName: Dilation2D,
|
|
inputsToSave: ['x', 'filter'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x, filter] = saved;
|
|
const inputInputs = { x, filter, dy };
|
|
const filterInputs = { x, filter, dy };
|
|
return {
|
|
x: () => ENGINE.runKernel(Dilation2DBackpropInput, inputInputs, attrs),
|
|
filter: () => ENGINE.runKernel(Dilation2DBackpropFilter, filterInputs, attrs)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const eluGradConfig = {
|
|
kernelName: Elu$1,
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved) => {
|
|
const [y] = saved;
|
|
const inputs = { dy, y };
|
|
return { x: () => ENGINE.runKernel(EluGrad, inputs) };
|
|
}
|
|
};
|
|
|
|
|
|
const erfGradConfig = {
|
|
kernelName: Erf,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
const a = mul(exp$2(neg$2(square$2(x))), 2 / Math.sqrt(Math.PI));
|
|
return { x: () => mul(dy, a) };
|
|
}
|
|
};
|
|
|
|
|
|
const expGradConfig = {
|
|
kernelName: Exp,
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved) => {
|
|
const [y] = saved;
|
|
return { x: () => mul(dy, y) };
|
|
}
|
|
};
|
|
|
|
|
|
const expandDimsGradConfig = {
|
|
kernelName: ExpandDims,
|
|
inputsToSave: ['input'],
|
|
gradFunc: (dy, saved) => {
|
|
const [input] = saved;
|
|
return { input: () => reshape$2(dy, input.shape) };
|
|
}
|
|
};
|
|
|
|
|
|
const expm1GradConfig = {
|
|
kernelName: Expm1,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(dy, exp$2(x)) };
|
|
}
|
|
};
|
|
|
|
|
|
const floorGradConfig = {
|
|
kernelName: Floor,
|
|
gradFunc: (dy) => {
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const floorDivGradConfig = {
|
|
kernelName: FloorDiv,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
const res = div$1(dy, cast$3(b, 'float32'));
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
return reshape$2(sum$2(res, reduceAxes), a.shape);
|
|
}
|
|
return res;
|
|
};
|
|
const derB = () => {
|
|
let res = mul(dy, cast$3(a, 'float32'));
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = reshape$2(sum$2(res, reduceAxes), b.shape);
|
|
}
|
|
const tmp = square$2(b);
|
|
return neg$2(div$1(res, cast$3(tmp, 'float32')));
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const fusedBatchNormGradConfig = {
|
|
kernelName: FusedBatchNorm,
|
|
inputsToSave: ['x', 'mean', 'variance', 'scale'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { varianceEpsilon } = attrs;
|
|
const [x, mean, variance, scale] = saved;
|
|
const scaleValue = scale == null ? scalar(1) : scale;
|
|
const reductionAxes = getReductionAxes(mean.shape, x.shape);
|
|
const tileShape = [];
|
|
if (mean.rank === 1) {
|
|
for (let i = 0; i < x.shape.length - 1; ++i) {
|
|
tileShape.push(x.shape[i]);
|
|
}
|
|
tileShape.push(1);
|
|
}
|
|
const xMinusMean = sub$2(x, mean);
|
|
const dyTimesScaleValue = mul(dy, scaleValue);
|
|
const oneOverSqrtVariance = rsqrt$2(add$1(variance, scalar(varianceEpsilon)));
|
|
const minusHalfRCube = mul(mul(mul(oneOverSqrtVariance, oneOverSqrtVariance), oneOverSqrtVariance), scalar(-0.5));
|
|
const derX = () => {
|
|
if (mean.rank === 1) {
|
|
return reshape$2(mul(mul(dy, tile$3(reshape$2(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]), tileShape)), scaleValue), x.shape);
|
|
}
|
|
else {
|
|
return reshape$2(mul(mul(dy, oneOverSqrtVariance), scaleValue), x.shape);
|
|
}
|
|
};
|
|
const derMean = () => {
|
|
let meanDer = mul(mul(oneOverSqrtVariance, scalar(-1)), dyTimesScaleValue);
|
|
if (mean.rank === 1) {
|
|
meanDer = sum$2(meanDer, reductionAxes);
|
|
}
|
|
return reshape$2(meanDer, mean.shape);
|
|
};
|
|
const derVariance = () => {
|
|
let varianceDer = mul(mul(minusHalfRCube, xMinusMean), dyTimesScaleValue);
|
|
if (mean.rank === 1) {
|
|
varianceDer = sum$2(varianceDer, reductionAxes);
|
|
}
|
|
return reshape$2(varianceDer, mean.shape);
|
|
};
|
|
const derScale = () => {
|
|
const xMinusMean2TimesRsqrt = mul(xMinusMean, oneOverSqrtVariance);
|
|
let scaleDer = mul(dy, xMinusMean2TimesRsqrt);
|
|
if (mean.rank === 1) {
|
|
scaleDer = sum$2(scaleDer, reductionAxes);
|
|
}
|
|
return reshape$2(scaleDer, mean.shape);
|
|
};
|
|
const derOffset = () => {
|
|
let offsetDer = dy;
|
|
if (mean.rank === 1) {
|
|
offsetDer = sum$2(offsetDer, reductionAxes);
|
|
}
|
|
return reshape$2(offsetDer, mean.shape);
|
|
};
|
|
return {
|
|
x: derX,
|
|
mean: derMean,
|
|
variance: derVariance,
|
|
scale: derScale,
|
|
offset: derOffset
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const gatherGradConfig = {
|
|
kernelName: GatherV2,
|
|
inputsToSave: ['x', 'indices'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x, indices] = saved;
|
|
const { axis, batchDims } = attrs;
|
|
const parsedAxis = parseAxisParam(axis, x.shape)[0];
|
|
const derXBatch = (x, indices, dy) => {
|
|
return () => {
|
|
const paramsShape = x.shape;
|
|
const indicesSize = indices.size;
|
|
const outerShape = paramsShape.slice(0, parsedAxis);
|
|
const outerDims = outerShape.length;
|
|
const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
|
|
const innerDims = innerShape.length;
|
|
const outerAxesIndices = arrayRange(0, outerDims);
|
|
const innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims);
|
|
const valuesShape = arrayConcat([outerShape, [indicesSize],
|
|
innerShape]);
|
|
const values = reshape$2(dy, valuesShape);
|
|
const reshapedIndices = reshape$2(indices, [indicesSize]);
|
|
const transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
|
|
const valuesTranspose = transpose$2(values, transposeDims);
|
|
let paramsGrad = unsortedSegmentSum$2(valuesTranspose, reshapedIndices, x.shape[parsedAxis]);
|
|
const invertTransposeDims = getUndoAxesPermutation(transposeDims);
|
|
paramsGrad = transpose$2(paramsGrad, invertTransposeDims);
|
|
return paramsGrad;
|
|
};
|
|
};
|
|
if (batchDims === 1) {
|
|
const batchSize = x.shape[0];
|
|
const xBatch = x.split(batchSize, 0);
|
|
const derXBatched = () => {
|
|
const stacked = stack(xBatch.map((x, i) => {
|
|
return derXBatch(x, indices.slice(i, 1), dy.slice(i, 1))();
|
|
}));
|
|
return stacked.reshape(x.shape);
|
|
};
|
|
return { x: derXBatched, indices: () => indices };
|
|
}
|
|
else {
|
|
return { x: derXBatch(x, indices, dy), indices: () => indices };
|
|
}
|
|
}
|
|
};
|
|
function arrayRange(start, stop) {
|
|
const result = [];
|
|
for (let i = start; i < stop; ++i) {
|
|
result.push(i);
|
|
}
|
|
return result;
|
|
}
|
|
function arrayConcat(arrays) {
|
|
const result = [];
|
|
for (let i = 0; i < arrays.length; ++i) {
|
|
for (let j = 0; j < arrays[i].length; ++j) {
|
|
result.push(arrays[i][j]);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
|
|
const greaterEqualGradConfig = {
|
|
kernelName: GreaterEqual,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
return { a: () => zerosLike$2(a), b: () => zerosLike$2(b) };
|
|
}
|
|
};
|
|
|
|
|
|
const identityGradConfig = {
|
|
kernelName: Identity$1,
|
|
gradFunc: (dy) => {
|
|
return { x: () => cast$3(dy, 'float32') };
|
|
}
|
|
};
|
|
|
|
|
|
const isFiniteGradConfig = {
|
|
kernelName: IsFinite,
|
|
gradFunc: (dy) => {
|
|
|
|
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const isInfGradConfig = {
|
|
kernelName: IsInf,
|
|
gradFunc: (dy) => {
|
|
|
|
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const isNanGradConfig = {
|
|
kernelName: IsNan,
|
|
gradFunc: (dy) => {
|
|
|
|
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const leakyReluGradConfig = {
|
|
kernelName: LeakyRelu,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { alpha } = attrs;
|
|
const mask = greater$2(x, 0);
|
|
|
|
|
|
return { x: () => where(mask, dy, mul(dy, alpha)) };
|
|
}
|
|
};
|
|
|
|
|
|
const log1pGradConfig = {
|
|
kernelName: Log1p,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, add$1(x, 1)) };
|
|
}
|
|
};
|
|
|
|
|
|
const logGradConfig = {
|
|
kernelName: Log,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, cast$3(x, 'float32')) };
|
|
}
|
|
};
|
|
|
|
|
|
const logSoftmaxGradConfig = {
|
|
kernelName: LogSoftmax$1,
|
|
inputsToSave: [],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [value] = saved;
|
|
const { axis } = attrs;
|
|
return {
|
|
logits: () => {
|
|
const keepDims = true;
|
|
const softmax = exp$2(value);
|
|
return sub$2(dy, mul(sum$2(dy, axis, keepDims), softmax));
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
function localResponseNormalizationBackprop_(x, y, dy, depthRadius = 5, bias = 1, alpha = 1, beta = 0.5) {
|
|
const inputs = { x, y, dy };
|
|
const attrs = { depthRadius, bias, alpha, beta };
|
|
return ENGINE.runKernel(LRNGrad, inputs, attrs);
|
|
}
|
|
const localResponseNormalizationBackprop = op({ localResponseNormalizationBackprop_ });
|
|
|
|
|
|
const lrnGradConfig = {
|
|
kernelName: LRN,
|
|
inputsToSave: ['x'],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x, y] = saved;
|
|
const { depthRadius, bias, alpha, beta } = attrs;
|
|
return {
|
|
x: () => localResponseNormalizationBackprop(x, y, dy, depthRadius, bias, alpha, beta)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function gradForMinAndMax(dy, y, xOrig, origAxes) {
|
|
if (y.rank < xOrig.rank) {
|
|
y = reshape$2(y, expandShapeToKeepDim(y.shape, origAxes));
|
|
}
|
|
if (dy.rank < xOrig.rank) {
|
|
dy = reshape$2(dy, expandShapeToKeepDim(dy.shape, origAxes));
|
|
}
|
|
return {
|
|
x: () => {
|
|
const dx = mul(dy, cast$3(equal$2(xOrig, y), dy.dtype));
|
|
return dx;
|
|
}
|
|
};
|
|
}
|
|
|
|
|
|
const maxGradConfig = {
|
|
kernelName: Max,
|
|
inputsToSave: ['x'],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const maxAttrs = attrs;
|
|
const { reductionIndices } = maxAttrs;
|
|
const x = saved[0];
|
|
const y = saved[1];
|
|
const origAxes = parseAxisParam(reductionIndices, x.shape);
|
|
const maxGrad = gradForMinAndMax(dy, y, x, origAxes);
|
|
return {
|
|
x: () => {
|
|
return maxGrad['x']();
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const maximumGradConfig = {
|
|
kernelName: Maximum,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const derA = () => mul(dy, cast$3(greaterEqual$2(a, b), 'float32'));
|
|
const derB = () => mul(dy, cast$3(less$2(a, b), 'float32'));
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function maxPool3dGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
|
|
const $dy = convertToTensor(dy, 'dy', 'maxPool3dGrad');
|
|
const $input = convertToTensor(input, 'input', 'maxPool3dGrad');
|
|
const $output = convertToTensor(output, 'output', 'maxPool3dGrad');
|
|
let dy5D = $dy;
|
|
let input5D = $input;
|
|
let output5D = $output;
|
|
let reshapedTo5D = false;
|
|
if ($input.rank === 4) {
|
|
reshapedTo5D = true;
|
|
dy5D = reshape$2($dy, [1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]]);
|
|
input5D = reshape$2($input, [
|
|
1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]
|
|
]);
|
|
output5D = reshape$2($output, [
|
|
1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]
|
|
]);
|
|
}
|
|
assert$1(dy5D.rank === 5, () => `Error in maxPool3dGrad: dy must be rank 5 but got rank ` +
|
|
`${dy5D.rank}.`);
|
|
assert$1(input5D.rank === 5, () => `Error in maxPool3dGrad: input must be rank 5 but got rank ` +
|
|
`${input5D.rank}.`);
|
|
assert$1(output5D.rank === 5, () => `Error in maxPool3dGrad: output must be rank 5 but got rank ` +
|
|
`${output5D.rank}.`);
|
|
checkPadOnDimRoundingMode('maxPool3dGrad', pad, dimRoundingMode);
|
|
const inputs = { dy: dy5D, input: input5D, output: output5D };
|
|
const attrs = { filterSize, strides, pad, dimRoundingMode };
|
|
|
|
const res = ENGINE.runKernel(MaxPool3DGrad, inputs, attrs);
|
|
if (reshapedTo5D) {
|
|
return reshape$2(res, [res.shape[1], res.shape[2], res.shape[3], res.shape[4]]);
|
|
}
|
|
return res;
|
|
}
|
|
const maxPool3dGrad = op({ maxPool3dGrad_ });
|
|
|
|
|
|
const maxPool3DGradConfig = {
|
|
kernelName: MaxPool3D,
|
|
inputsToSave: ['x'],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x, y] = saved;
|
|
const { filterSize, strides, pad, dimRoundingMode } = attrs;
|
|
return {
|
|
x: () => maxPool3dGrad(dy, x, y, filterSize, strides, pad, dimRoundingMode)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function maxPoolGrad_(dy, input, output, filterSize, strides, pad, dimRoundingMode) {
|
|
const $dy = convertToTensor(dy, 'dy', 'maxPoolGrad');
|
|
const $input = convertToTensor(input, 'input', 'maxPoolGrad');
|
|
const $output = convertToTensor(output, 'output', 'maxPoolGrad');
|
|
assert$1($input.rank === $dy.rank, () => `Rank of input (${$input.rank}) does not match rank of dy ` +
|
|
`(${$dy.rank})`);
|
|
assert$1($dy.rank === 4, () => `Error in maxPoolGrad: dy must be rank 4 but got rank ` +
|
|
`${$dy.rank}.`);
|
|
assert$1($input.rank === 4, () => `Error in maxPoolGrad: input must be rank 4 but got rank ` +
|
|
`${$input.rank}.`);
|
|
checkPadOnDimRoundingMode('maxPoolGrad', pad, dimRoundingMode);
|
|
const inputs = { dy: $dy, input: $input, output: $output };
|
|
const attrs = { filterSize, strides, pad, dimRoundingMode };
|
|
|
|
return ENGINE.runKernel(MaxPoolGrad, inputs, attrs);
|
|
}
|
|
const maxPoolGrad = op({ maxPoolGrad_ });
|
|
|
|
|
|
const maxPoolGradConfig = {
|
|
kernelName: MaxPool,
|
|
inputsToSave: ['x'],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x, y] = saved;
|
|
const { filterSize, strides, pad } = attrs;
|
|
return {
|
|
x: () => maxPoolGrad(dy, x, y, filterSize, strides, pad)
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const meanGradConfig = {
|
|
kernelName: Mean,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { axis } = attrs;
|
|
const axes = parseAxisParam(axis, x.shape);
|
|
const shapes = computeOutAndReduceShapes(x.shape, axes);
|
|
const reduceShape = shapes[1];
|
|
const reduceSize = sizeFromShape(reduceShape);
|
|
const derX = () => {
|
|
const expandedDyShape = x.shape.slice();
|
|
axes.forEach(axis => {
|
|
expandedDyShape[axis] = 1;
|
|
});
|
|
const expandedDy = reshape$2(dy, expandedDyShape);
|
|
const res = div$1(mul(expandedDy, ones(x.shape, 'float32')), reduceSize);
|
|
return res;
|
|
};
|
|
return { x: derX };
|
|
}
|
|
};
|
|
|
|
|
|
const minGradConfig = {
|
|
kernelName: Min,
|
|
inputsToSave: ['x'],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const minAttrs = attrs;
|
|
const { axis } = minAttrs;
|
|
const [x, y] = saved;
|
|
const origAxes = parseAxisParam(axis, x.shape);
|
|
const minGrad = gradForMinAndMax(dy, y, x, origAxes);
|
|
return {
|
|
x: () => {
|
|
return minGrad['x']();
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const minimumGradConfig = {
|
|
kernelName: Minimum,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const derA = () => mul(dy, cast$3(lessEqual$2(a, b), 'float32'));
|
|
const derB = () => mul(dy, cast$3(greater$2(a, b), 'float32'));
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const mirrorPadGradConfig = {
|
|
kernelName: MirrorPad,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
|
|
|
|
const x = saved[0];
|
|
const { paddings } = attrs;
|
|
const begin = paddings.map(p => p[0]);
|
|
return { x: () => slice$2(dy, begin, x.shape) };
|
|
}
|
|
};
|
|
|
|
|
|
const modGradConfig = {
|
|
kernelName: Mod,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
return reshape$2(sum$2(dy, reduceAxes), a.shape);
|
|
}
|
|
return dy;
|
|
};
|
|
const derB = () => {
|
|
const res = mul(dy, neg$2(floor$2(div$1(a, b))));
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
return reshape$2(sum$2(res, reduceAxes), b.shape);
|
|
}
|
|
return res;
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const multiplyGradConfig = {
|
|
kernelName: Multiply,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
const res = mul(dy, cast$3(b, 'float32'));
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
return reshape$2(sum$2(res, reduceAxes), a.shape);
|
|
}
|
|
return res;
|
|
};
|
|
const derB = () => {
|
|
const res = mul(dy, cast$3(a, 'float32'));
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
return reshape$2(sum$2(res, reduceAxes), b.shape);
|
|
}
|
|
return res;
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const negGradConfig = {
|
|
kernelName: Neg,
|
|
gradFunc: (dy) => {
|
|
return { x: () => neg$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const oneHotGradConfig = {
|
|
kernelName: OneHot,
|
|
inputsToSave: ['indices'],
|
|
gradFunc: (dy, saved) => {
|
|
const indices = saved[0];
|
|
return { indices: () => zeros$1(indices.shape, 'float32') };
|
|
}
|
|
};
|
|
|
|
|
|
const onesLikeGradConfig = {
|
|
kernelName: OnesLike,
|
|
gradFunc: (dy) => {
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const packGradConfig = {
|
|
kernelName: Pack,
|
|
saveAllInputs: true,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { axis } = attrs;
|
|
const derTensors = unstack(dy, axis);
|
|
return derTensors.map(t => () => t);
|
|
}
|
|
};
|
|
|
|
|
|
const padV2GradConfig = {
|
|
kernelName: PadV2,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
|
|
|
|
const x = saved[0];
|
|
const { paddings } = attrs;
|
|
const begin = paddings.map(p => p[0]);
|
|
return { x: () => slice$2(dy, begin, x.shape) };
|
|
}
|
|
};
|
|
|
|
|
|
const powGradConfig = {
|
|
kernelName: Pow,
|
|
inputsToSave: ['a', 'b'],
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b, y] = saved;
|
|
const base = a;
|
|
const exp = b;
|
|
const outShape = assertAndGetBroadcastShape(base.shape, exp.shape);
|
|
const derBase = () => {
|
|
const expFloat = cast$3(exp, 'float32');
|
|
let res = mul(dy, mul(expFloat, pow$2(base, sub$2(expFloat, scalar(1)))));
|
|
const reduceAxes = getReductionAxes(base.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, base.shape);
|
|
};
|
|
const derExp = () => {
|
|
const condition = greater$2(base, 0);
|
|
const logBase = where(condition, log$2(base), zerosLike$2(base));
|
|
let res = mul(dy, mul(y, logBase));
|
|
const reduceAxes = getReductionAxes(exp.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, exp.shape);
|
|
};
|
|
return { a: derBase, b: derExp };
|
|
}
|
|
};
|
|
|
|
|
|
const preluGradConfig = {
|
|
kernelName: Prelu,
|
|
inputsToSave: ['x', 'alpha'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x, alpha] = saved;
|
|
const mask = greater$2(x, 0);
|
|
return {
|
|
x: () => where(mask, dy, mul(dy, alpha)),
|
|
alpha: () => {
|
|
let res = where(mask, zerosLike$2(dy), mul(dy, x));
|
|
const reduceAxes = getReductionAxes(alpha.shape, dy.shape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, alpha.shape);
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
|
|
function prodGradFn_(x, dy, axis) {
|
|
|
|
|
|
|
|
const expandedYShape = x.shape.slice();
|
|
expandedYShape[axis] = 1;
|
|
|
|
const expandedDy = reshape$2(dy, expandedYShape);
|
|
const xCumProd = cumprod$2(x, axis, true, false);
|
|
const xCumRevProd = cumprod$2(x, axis, true, true);
|
|
const dx = mul(xCumProd, xCumRevProd);
|
|
return mul(expandedDy, dx);
|
|
}
|
|
|
|
|
|
|
|
function prodsGradFn_(x, dy, axis) {
|
|
|
|
const xRank = x.shape.length;
|
|
const finalProdAxis = xRank - axis.length;
|
|
const xPermutation = getAxesPermutation(axis, xRank);
|
|
let permutedX = x;
|
|
if (xPermutation != null) {
|
|
permutedX = transpose$2(x, xPermutation);
|
|
}
|
|
|
|
|
|
const newShape = permutedX.shape.slice();
|
|
const removedShape = newShape.splice(xRank - axis.length, axis.length);
|
|
const endPartShape = removedShape.reduce((p, c) => p * c, 1);
|
|
newShape.push(endPartShape);
|
|
const reshapedPermutedX = permutedX.reshape(newShape);
|
|
let prodGrad = prodGradFn_(reshapedPermutedX, dy, finalProdAxis);
|
|
|
|
|
|
prodGrad = prodGrad.reshape(permutedX.shape);
|
|
if (xPermutation != null) {
|
|
const undoPermutation = getUndoAxesPermutation(xPermutation);
|
|
prodGrad = transpose$2(prodGrad, undoPermutation);
|
|
}
|
|
return prodGrad;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const prodGradConfig = {
|
|
kernelName: Prod,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { axis } = attrs;
|
|
let axisArr = [];
|
|
if (axis === undefined || axis === null) {
|
|
axisArr = x.shape.map((_, i) => i);
|
|
}
|
|
else if (typeof axis === 'number') {
|
|
axisArr = [axis];
|
|
}
|
|
else {
|
|
axisArr = axis;
|
|
}
|
|
return { x: () => prodsGradFn_(x, dy, axisArr) };
|
|
}
|
|
};
|
|
|
|
|
|
const divGradConfig = {
|
|
kernelName: RealDiv,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
const res = div$1(dy, cast$3(b, 'float32'));
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
return reshape$2(sum$2(res, reduceAxes), a.shape);
|
|
}
|
|
return res;
|
|
};
|
|
const derB = () => {
|
|
let res = mul(dy, cast$3(a, 'float32'));
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = reshape$2(sum$2(res, reduceAxes), b.shape);
|
|
}
|
|
const tmp = square$2(b);
|
|
return neg$2(div$1(res, cast$3(tmp, 'float32')));
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const reciprocalGradConfig = {
|
|
kernelName: Reciprocal,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, neg$2(square$2(x))) };
|
|
}
|
|
};
|
|
|
|
|
|
const relu6GradConfig = {
|
|
kernelName: Relu6$1,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
const mask = mul(lessEqual$2(x, 6), step$2(x));
|
|
return { x: () => mul(dy, cast$3(mask, 'float32')) };
|
|
}
|
|
};
|
|
|
|
|
|
const reluGradConfig = {
|
|
kernelName: Relu$1,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(dy, cast$3(step$2(x), 'float32')) };
|
|
}
|
|
};
|
|
|
|
|
|
const reshapeGradConfig = {
|
|
kernelName: Reshape$1,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => reshape$2(dy, x.shape) };
|
|
}
|
|
};
|
|
|
|
|
|
const resizeBilinearGradConfig = {
|
|
kernelName: ResizeBilinear,
|
|
inputsToSave: ['images'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [images] = saved;
|
|
const inputs = { dy, images };
|
|
const imagesDer = () =>
|
|
|
|
ENGINE.runKernel(ResizeBilinearGrad, inputs, attrs);
|
|
return { images: imagesDer };
|
|
}
|
|
};
|
|
|
|
|
|
const resizeNearestNeighborGradConfig = {
|
|
kernelName: ResizeNearestNeighbor,
|
|
inputsToSave: ['images'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [images] = saved;
|
|
const inputs = { dy, images };
|
|
const imagesDer = () =>
|
|
|
|
ENGINE.runKernel(ResizeNearestNeighborGrad, inputs, attrs);
|
|
return { images: imagesDer };
|
|
}
|
|
};
|
|
|
|
|
|
const reverseGradConfig = {
|
|
kernelName: Reverse,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { dims } = attrs;
|
|
const axes = parseAxisParam(dims, dy.shape);
|
|
return { x: () => reverse$2(dy, axes) };
|
|
}
|
|
};
|
|
|
|
|
|
const roundGradConfig = {
|
|
kernelName: Round,
|
|
gradFunc: (dy) => {
|
|
|
|
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const rsqrtGradConfig = {
|
|
kernelName: Rsqrt,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => neg$2(div$1(dy, mul(pow$2(x, 1.5), 2))) };
|
|
}
|
|
};
|
|
|
|
|
|
const selectGradConfig = {
|
|
kernelName: Select,
|
|
inputsToSave: ['condition'],
|
|
gradFunc: (dy, saved) => {
|
|
const [condition] = saved;
|
|
return {
|
|
|
|
|
|
condition: () => cast$3(zerosLike$2(condition), 'float32'),
|
|
t: () => mul(dy, cast$3(condition, dy.dtype)),
|
|
e: () => mul(dy, cast$3(logicalNot$2(condition), dy.dtype))
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const seluGradConfig = {
|
|
kernelName: Selu$1,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return {
|
|
x: () => {
|
|
const mask = greater$2(x, scalar(0));
|
|
const scaleAlpha = scalar(SELU_SCALEALPHA);
|
|
const scale = scalar(SELU_SCALE);
|
|
const greaterThanZeroDer = mul(dy, scale);
|
|
const lessEqualZeroDer = mul(mul(dy, scaleAlpha), exp$2(cast$3(x, 'float32')));
|
|
return where(mask, greaterThanZeroDer, lessEqualZeroDer);
|
|
}
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const sigmoidGradConfig = {
|
|
kernelName: Sigmoid$1,
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved) => {
|
|
const [y] = saved;
|
|
return { x: () => mul(dy, mul(y, sub$2(scalar(1), y))) };
|
|
}
|
|
};
|
|
|
|
|
|
const signGradConfig = {
|
|
kernelName: Sign,
|
|
gradFunc: (dy) => {
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const sinGradConfig = {
|
|
kernelName: Sin,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(cos$2(cast$3(x, 'float32')), dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const sinhGradConfig = {
|
|
kernelName: Sinh,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(cosh$2(cast$3(x, 'float32')), dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const sliceGradConfig = {
|
|
kernelName: Slice,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { begin, size } = attrs;
|
|
const inputShape = x.shape;
|
|
const [begin_, size_] = parseSliceParams(x, begin, size);
|
|
|
|
|
|
|
|
|
|
|
|
const paddings = [];
|
|
for (let i = 0; i < dy.rank; i++) {
|
|
paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]);
|
|
}
|
|
return { x: () => pad(dy, paddings) };
|
|
}
|
|
};
|
|
|
|
|
|
const softmaxGradConfig = {
|
|
kernelName: Softmax$1,
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [y] = saved;
|
|
const { dim } = attrs;
|
|
const keepDims = true;
|
|
const dyTimesY = mul(dy, y);
|
|
return {
|
|
logits: () => sub$2(dyTimesY, mul(sum$2(dyTimesY, [dim], keepDims), y))
|
|
};
|
|
}
|
|
};
|
|
|
|
|
|
const softplusGradConfig = {
|
|
kernelName: Softplus$1,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(dy, sigmoid$2(x)) };
|
|
}
|
|
};
|
|
|
|
|
|
const spaceToBatchNDGradConfig = {
|
|
kernelName: SpaceToBatchND,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { blockShape, paddings } = attrs;
|
|
return { x: () => batchToSpaceND$2(dy, blockShape, paddings) };
|
|
}
|
|
};
|
|
|
|
|
|
const splitVGradConfig = {
|
|
kernelName: SplitV,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const { axis } = attrs;
|
|
return { x: () => concat$2(dy, axis) };
|
|
}
|
|
};
|
|
|
|
|
|
const sqrtGradConfig = {
|
|
kernelName: Sqrt,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, mul(sqrt$2(cast$3(x, 'float32')), 2)) };
|
|
}
|
|
};
|
|
|
|
|
|
const squareGradConfig = {
|
|
kernelName: Square,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => mul(dy, mul(cast$3(x, 'float32'), 2)) };
|
|
}
|
|
};
|
|
|
|
|
|
const squaredDifferenceGradConfig = {
|
|
kernelName: SquaredDifference,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const two = scalar(2);
|
|
const derA = () => mul(dy, mul(two, sub$2(a, b)));
|
|
const derB = () => mul(dy, mul(two, sub$2(b, a)));
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const stepGradConfig = {
|
|
kernelName: Step,
|
|
gradFunc: (dy) => {
|
|
|
|
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const subGradConfig = {
|
|
kernelName: Sub,
|
|
inputsToSave: ['a', 'b'],
|
|
gradFunc: (dy, saved) => {
|
|
const [a, b] = saved;
|
|
const outShape = assertAndGetBroadcastShape(a.shape, b.shape);
|
|
const derA = () => {
|
|
let res = dy;
|
|
const reduceAxes = getReductionAxes(a.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(res, a.shape);
|
|
};
|
|
const derB = () => {
|
|
let res = dy;
|
|
const reduceAxes = getReductionAxes(b.shape, outShape);
|
|
if (reduceAxes.length > 0) {
|
|
res = sum$2(res, reduceAxes);
|
|
}
|
|
return reshape$2(neg$2(res), b.shape);
|
|
};
|
|
return { a: derA, b: derB };
|
|
}
|
|
};
|
|
|
|
|
|
const sumGradConfig = {
|
|
kernelName: Sum,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const expandedDyShape = x.shape.slice();
|
|
const { axis } = attrs;
|
|
const axes = parseAxisParam(axis, x.shape);
|
|
axes.forEach(axis => {
|
|
expandedDyShape[axis] = 1;
|
|
});
|
|
const expandedDy = reshape$2(dy, expandedDyShape);
|
|
const derX = mul(expandedDy, ones(x.shape, 'float32'));
|
|
return { x: () => derX };
|
|
}
|
|
};
|
|
|
|
|
|
const tanGradConfig = {
|
|
kernelName: Tan,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved) => {
|
|
const [x] = saved;
|
|
return { x: () => div$1(dy, square$2(cos$2(x))) };
|
|
}
|
|
};
|
|
|
|
|
|
const tanhGradConfig = {
|
|
kernelName: Tanh$1,
|
|
outputsToSave: [true],
|
|
gradFunc: (dy, saved) => {
|
|
const [y] = saved;
|
|
return { x: () => mul(sub$2(scalar(1), square$2(y)), dy) };
|
|
}
|
|
};
|
|
|
|
|
|
const tileGradConfig = {
|
|
kernelName: Tile,
|
|
inputsToSave: ['x'],
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const [x] = saved;
|
|
const { reps } = attrs;
|
|
const derX = () => {
|
|
let xGrad = zerosLike$2(x);
|
|
|
|
|
|
if (x.rank === 1) {
|
|
for (let i = 0; i < reps[0]; ++i) {
|
|
xGrad = add$1(xGrad, slice$2(dy, [i * x.shape[0]], [x.shape[0]]));
|
|
}
|
|
}
|
|
else if (x.rank === 2) {
|
|
for (let i = 0; i < reps[0]; ++i) {
|
|
for (let j = 0; j < reps[1]; ++j) {
|
|
xGrad = add$1(xGrad, slice$2(dy, [i * x.shape[0], j * x.shape[1]], [
|
|
x.shape[0], x.shape[1]
|
|
]));
|
|
}
|
|
}
|
|
}
|
|
else if (x.rank === 3) {
|
|
for (let i = 0; i < reps[0]; ++i) {
|
|
for (let j = 0; j < reps[1]; ++j) {
|
|
for (let k = 0; k < reps[2]; ++k) {
|
|
xGrad =
|
|
add$1(xGrad, slice$2(dy, [i * x.shape[0], j * x.shape[1], k * x.shape[2]], [x.shape[0], x.shape[1], x.shape[2]]));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else if (x.rank === 4) {
|
|
for (let i = 0; i < reps[0]; ++i) {
|
|
for (let j = 0; j < reps[1]; ++j) {
|
|
for (let k = 0; k < reps[2]; ++k) {
|
|
for (let l = 0; l < reps[3]; ++l) {
|
|
xGrad =
|
|
add$1(xGrad, slice$2(dy, [
|
|
i * x.shape[0], j * x.shape[1], k * x.shape[2],
|
|
l * x.shape[3]
|
|
], [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
throw new Error(`Gradient for tile operation is not implemented for rank-` +
|
|
`${x.rank} tensors yet.`);
|
|
}
|
|
return xGrad;
|
|
};
|
|
return { x: derX };
|
|
},
|
|
};
|
|
|
|
|
|
const transposeGradConfig = {
|
|
kernelName: Transpose,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const transposeAttrs = attrs;
|
|
const { perm } = transposeAttrs;
|
|
const undoPerm = getUndoAxesPermutation(perm);
|
|
return { x: () => transpose$2(dy, undoPerm) };
|
|
}
|
|
};
|
|
|
|
|
|
const unpackGradConfig = {
|
|
kernelName: Unpack,
|
|
gradFunc: (dy, saved, attrs) => {
|
|
const unpackAttrs = attrs;
|
|
const { axis } = unpackAttrs;
|
|
return { value: () => stack(dy, axis) };
|
|
}
|
|
};
|
|
|
|
|
|
const unsortedSegmentSumGradConfig = {
|
|
kernelName: UnsortedSegmentSum,
|
|
inputsToSave: ['segmentIds'],
|
|
gradFunc: (dy, saved) => {
|
|
const [segmentIds] = saved;
|
|
const derX = () => {
|
|
return gatherDropNegatives(dy, segmentIds);
|
|
};
|
|
return { x: derX };
|
|
}
|
|
};
|
|
function gatherDropNegatives(x, indices) {
|
|
|
|
|
|
|
|
const zeroClippedIndices = maximum$2(indices, zerosLike$2(indices));
|
|
const gathered = gather$1(x, zeroClippedIndices);
|
|
let isPositive = greaterEqual$2(indices, scalar(0, 'int32'));
|
|
const numIters = gathered.rank - isPositive.rank;
|
|
for (let i = 0; i < numIters; ++i) {
|
|
isPositive = expandDims$3(isPositive, i + 1);
|
|
}
|
|
isPositive = logicalAnd$2(isPositive, ones(gathered.shape, 'bool'));
|
|
const zeroSlice = zerosLike$2(gathered);
|
|
return where(isPositive, gathered, zeroSlice);
|
|
}
|
|
|
|
|
|
const zerosLikeGradConfig = {
|
|
kernelName: ZerosLike,
|
|
gradFunc: (dy) => {
|
|
return { x: () => zerosLike$2(dy) };
|
|
}
|
|
};
|
|
|
|
|
|
|
|
const gradConfigs = [
|
|
absGradConfig,
|
|
acosGradConfig,
|
|
acoshGradConfig,
|
|
addGradConfig,
|
|
addNGradConfig,
|
|
argMaxGradConfig,
|
|
argMinGradConfig,
|
|
asinGradConfig,
|
|
asinhGradConfig,
|
|
atan2GradConfig,
|
|
atanGradConfig,
|
|
atanhGradConfig,
|
|
avgPool3DGradConfig,
|
|
avgPoolGradConfig,
|
|
batchMatMulGradConfig,
|
|
batchToSpaceNDGradConfig,
|
|
broadcastToGradConfig,
|
|
castGradConfig,
|
|
ceilGradConfig,
|
|
clipByValueGradConfig,
|
|
complexAbsGradConfig,
|
|
concatGradConfig,
|
|
conv2DBackpropInputGradConfig,
|
|
conv2DGradConfig,
|
|
conv3DGradConfig,
|
|
cosGradConfig,
|
|
coshGradConfig,
|
|
cumsumGradConfig,
|
|
depthwiseConv2dNativeGradConfig,
|
|
dilation2dGradConfig,
|
|
divGradConfig,
|
|
eluGradConfig,
|
|
erfGradConfig,
|
|
expGradConfig,
|
|
expandDimsGradConfig,
|
|
expm1GradConfig,
|
|
floorDivGradConfig,
|
|
floorGradConfig,
|
|
fusedBatchNormGradConfig,
|
|
gatherGradConfig,
|
|
greaterEqualGradConfig,
|
|
identityGradConfig,
|
|
isFiniteGradConfig,
|
|
isInfGradConfig,
|
|
isNanGradConfig,
|
|
leakyReluGradConfig,
|
|
log1pGradConfig,
|
|
logGradConfig,
|
|
logSoftmaxGradConfig,
|
|
lrnGradConfig,
|
|
maxGradConfig,
|
|
maxGradConfig,
|
|
maximumGradConfig,
|
|
maxPool3DGradConfig,
|
|
maxPoolGradConfig,
|
|
meanGradConfig,
|
|
minGradConfig,
|
|
minimumGradConfig,
|
|
mirrorPadGradConfig,
|
|
modGradConfig,
|
|
multiplyGradConfig,
|
|
negGradConfig,
|
|
oneHotGradConfig,
|
|
onesLikeGradConfig,
|
|
packGradConfig,
|
|
padV2GradConfig,
|
|
padV2GradConfig,
|
|
powGradConfig,
|
|
preluGradConfig,
|
|
prodGradConfig,
|
|
reciprocalGradConfig,
|
|
relu6GradConfig,
|
|
reluGradConfig,
|
|
reshapeGradConfig,
|
|
resizeBilinearGradConfig,
|
|
resizeNearestNeighborGradConfig,
|
|
reverseGradConfig,
|
|
roundGradConfig,
|
|
rsqrtGradConfig,
|
|
selectGradConfig,
|
|
seluGradConfig,
|
|
sigmoidGradConfig,
|
|
signGradConfig,
|
|
sinGradConfig,
|
|
sinhGradConfig,
|
|
sliceGradConfig,
|
|
softmaxGradConfig,
|
|
softplusGradConfig,
|
|
spaceToBatchNDGradConfig,
|
|
spaceToBatchNDGradConfig,
|
|
splitVGradConfig,
|
|
splitVGradConfig,
|
|
sqrtGradConfig,
|
|
squaredDifferenceGradConfig,
|
|
squareGradConfig,
|
|
stepGradConfig,
|
|
subGradConfig,
|
|
sumGradConfig,
|
|
tanGradConfig,
|
|
tanhGradConfig,
|
|
tileGradConfig,
|
|
transposeGradConfig,
|
|
unpackGradConfig,
|
|
unsortedSegmentSumGradConfig,
|
|
zerosLikeGradConfig
|
|
];
|
|
for (const gradientConfig of gradConfigs) {
|
|
registerGradient(gradientConfig);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class AttributeError extends Error {
|
|
constructor(message) {
|
|
super(message);
|
|
|
|
Object.setPrototypeOf(this, AttributeError.prototype);
|
|
}
|
|
}
|
|
|
|
class RuntimeError extends Error {
|
|
constructor(message) {
|
|
super(message);
|
|
|
|
Object.setPrototypeOf(this, RuntimeError.prototype);
|
|
}
|
|
}
|
|
|
|
class ValueError extends Error {
|
|
constructor(message) {
|
|
super(message);
|
|
|
|
Object.setPrototypeOf(this, ValueError.prototype);
|
|
}
|
|
}
|
|
|
|
class NotImplementedError extends Error {
|
|
constructor(message) {
|
|
super(message);
|
|
|
|
Object.setPrototypeOf(this, NotImplementedError.prototype);
|
|
}
|
|
}
|
|
|
|
class AssertionError extends Error {
|
|
constructor(message) {
|
|
super(message);
|
|
|
|
Object.setPrototypeOf(this, AssertionError.prototype);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class LruCache {
|
|
constructor(maxEntries) {
|
|
this.maxEntries = maxEntries || 100;
|
|
this.cache = new Map();
|
|
}
|
|
|
|
get(key) {
|
|
let entry;
|
|
if (this.cache.has(key)) {
|
|
entry = this.cache.get(key);
|
|
this.cache.delete(key);
|
|
this.cache.set(key, entry);
|
|
}
|
|
return entry;
|
|
}
|
|
|
|
put(key, value) {
|
|
if (this.cache.has(key)) {
|
|
this.cache.delete(key);
|
|
}
|
|
else if (this.cache.size >= this.maxEntries) {
|
|
const keyToDelete = this.cache.keys().next().value;
|
|
this.cache.delete(keyToDelete);
|
|
}
|
|
this.cache.set(key, value);
|
|
}
|
|
|
|
getMaxEntries() {
|
|
return this.maxEntries;
|
|
}
|
|
|
|
setMaxEntries(maxEntries) {
|
|
if (maxEntries < 0) {
|
|
throw new Error(`The maxEntries of LRU caches must be at least 0, but got ${maxEntries}.`);
|
|
}
|
|
if (this.maxEntries > maxEntries) {
|
|
for (let i = 0; i < this.maxEntries - maxEntries; i++) {
|
|
const keyToDelete = this.cache.keys().next().value;
|
|
this.cache.delete(keyToDelete);
|
|
}
|
|
}
|
|
this.maxEntries = maxEntries;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function pyListRepeat(value, numValues) {
|
|
if (Array.isArray(value)) {
|
|
|
|
let newArray = [];
|
|
for (let i = 0; i < numValues; i++) {
|
|
newArray = newArray.concat(value);
|
|
}
|
|
return newArray;
|
|
}
|
|
else {
|
|
const newArray = new Array(numValues);
|
|
newArray.fill(value);
|
|
return newArray;
|
|
}
|
|
}
|
|
function assert(val, message) {
|
|
if (!val) {
|
|
throw new AssertionError(message);
|
|
}
|
|
}
|
|
|
|
function count(array, refernce) {
|
|
let counter = 0;
|
|
for (const item of array) {
|
|
if (item === refernce) {
|
|
counter++;
|
|
}
|
|
}
|
|
return counter;
|
|
}
|
|
|
|
function singletonOrArray(xs) {
|
|
if (xs.length === 1) {
|
|
return xs[0];
|
|
}
|
|
return xs;
|
|
}
|
|
|
|
|
|
function toList(x) {
|
|
if (Array.isArray(x)) {
|
|
return x;
|
|
}
|
|
return [x];
|
|
}
|
|
|
|
function toSnakeCase(name) {
|
|
const intermediate = name.replace(/(.)([A-Z][a-z0-9]+)/g, '$1_$2');
|
|
const insecure = intermediate.replace(/([a-z])([A-Z])/g, '$1_$2').toLowerCase();
|
|
|
|
if (insecure[0] !== '_') {
|
|
return insecure;
|
|
}
|
|
return 'private' + insecure;
|
|
}
|
|
function toCamelCase(identifier) {
|
|
|
|
if (identifier.length <= 1) {
|
|
return identifier;
|
|
}
|
|
|
|
if (identifier.indexOf('_') === -1) {
|
|
return identifier;
|
|
}
|
|
return identifier.replace(/[_]+(\w|$)/g, (m, p1) => p1.toUpperCase());
|
|
}
|
|
|
|
let _GLOBAL_CUSTOM_OBJECTS = {};
|
|
function serializeKerasObject(instance) {
|
|
if (instance === null || instance === undefined) {
|
|
return null;
|
|
}
|
|
const dict = {};
|
|
dict['className'] = instance.getClassName();
|
|
dict['config'] = instance.getConfig();
|
|
return dict;
|
|
}
|
|
|
|
function convertNDArrayScalarsInConfig(config) {
|
|
if (config == null || typeof config !== 'object') {
|
|
return;
|
|
}
|
|
else if (Array.isArray(config)) {
|
|
config.forEach(configItem => convertNDArrayScalarsInConfig(configItem));
|
|
}
|
|
else {
|
|
const fields = Object.keys(config);
|
|
for (const field of fields) {
|
|
const value = config[field];
|
|
if (value != null && typeof value === 'object') {
|
|
if (!Array.isArray(value) && value['type'] === 'ndarray' &&
|
|
typeof value['value'] === 'number') {
|
|
config[field] = value['value'];
|
|
}
|
|
else {
|
|
convertNDArrayScalarsInConfig(value);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
function deserializeKerasObject(identifier, moduleObjects = {}, customObjects = {}, printableModuleName = 'object', fastWeightInit = false) {
|
|
|
|
if (typeof identifier === 'string') {
|
|
const functionName = identifier;
|
|
let fn;
|
|
if (functionName in customObjects) {
|
|
fn = customObjects[functionName];
|
|
}
|
|
else if (functionName in _GLOBAL_CUSTOM_OBJECTS) {
|
|
fn = _GLOBAL_CUSTOM_OBJECTS[functionName];
|
|
}
|
|
else {
|
|
fn = moduleObjects[functionName];
|
|
if (fn == null) {
|
|
throw new ValueError(`Unknown ${printableModuleName}: ${identifier}. ` +
|
|
`This may be due to one of the following reasons:\n` +
|
|
`1. The ${printableModuleName} is defined in Python, in which ` +
|
|
`case it needs to be ported to TensorFlow.js or your JavaScript ` +
|
|
`code.\n` +
|
|
`2. The custom ${printableModuleName} is defined in JavaScript, ` +
|
|
`but is not registered properly with ` +
|
|
`tf.serialization.registerClass().`);
|
|
|
|
}
|
|
}
|
|
return fn;
|
|
}
|
|
else {
|
|
|
|
const config = identifier;
|
|
if (config['className'] == null || config['config'] == null) {
|
|
throw new ValueError(`${printableModuleName}: Improper config format: ` +
|
|
`${JSON.stringify(config)}.\n` +
|
|
`'className' and 'config' must set.`);
|
|
}
|
|
const className = config['className'];
|
|
let cls, fromConfig;
|
|
if (className in customObjects) {
|
|
[cls, fromConfig] = customObjects[className];
|
|
}
|
|
else if (className in _GLOBAL_CUSTOM_OBJECTS) {
|
|
[cls, fromConfig] = _GLOBAL_CUSTOM_OBJECTS['className'];
|
|
}
|
|
else if (className in moduleObjects) {
|
|
[cls, fromConfig] = moduleObjects[className];
|
|
}
|
|
if (cls == null) {
|
|
throw new ValueError(`Unknown ${printableModuleName}: ${className}. ` +
|
|
`This may be due to one of the following reasons:\n` +
|
|
`1. The ${printableModuleName} is defined in Python, in which ` +
|
|
`case it needs to be ported to TensorFlow.js or your JavaScript ` +
|
|
`code.\n` +
|
|
`2. The custom ${printableModuleName} is defined in JavaScript, ` +
|
|
`but is not registered properly with ` +
|
|
`tf.serialization.registerClass().`);
|
|
|
|
}
|
|
if (fromConfig != null) {
|
|
|
|
|
|
|
|
|
|
|
|
const customObjectsCombined = {};
|
|
for (const key of Object.keys(_GLOBAL_CUSTOM_OBJECTS)) {
|
|
customObjectsCombined[key] = _GLOBAL_CUSTOM_OBJECTS[key];
|
|
}
|
|
for (const key of Object.keys(customObjects)) {
|
|
customObjectsCombined[key] = customObjects[key];
|
|
}
|
|
|
|
const nestedConfig = config['config'];
|
|
nestedConfig['customObjects'] = customObjectsCombined;
|
|
const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
|
|
for (const key of Object.keys(customObjects)) {
|
|
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
|
|
}
|
|
convertNDArrayScalarsInConfig(config['config']);
|
|
const returnObj = fromConfig(cls, config['config'], customObjects, fastWeightInit);
|
|
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
|
|
return returnObj;
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
const backupCustomObjects = Object.assign({}, _GLOBAL_CUSTOM_OBJECTS);
|
|
for (const key of Object.keys(customObjects)) {
|
|
_GLOBAL_CUSTOM_OBJECTS[key] = customObjects[key];
|
|
}
|
|
|
|
|
|
|
|
const returnObj = new cls(config['config']);
|
|
_GLOBAL_CUSTOM_OBJECTS = Object.assign({}, backupCustomObjects);
|
|
return returnObj;
|
|
}
|
|
}
|
|
}
|
|
|
|
function numberCompare(a, b) {
|
|
return (a < b) ? -1 : ((a > b) ? 1 : 0);
|
|
}
|
|
|
|
function reverseNumberCompare(a, b) {
|
|
return -1 * numberCompare(a, b);
|
|
}
|
|
|
|
function unique(xs) {
|
|
if (xs == null) {
|
|
return xs;
|
|
}
|
|
const out = [];
|
|
|
|
for (const x of xs) {
|
|
if (out.indexOf(x) === -1) {
|
|
out.push(x);
|
|
}
|
|
}
|
|
return out;
|
|
}
|
|
|
|
function isObjectEmpty(obj) {
|
|
if (obj == null) {
|
|
throw new ValueError(`Invalid value in obj: ${JSON.stringify(obj)}`);
|
|
}
|
|
for (const key in obj) {
|
|
if (obj.hasOwnProperty(key)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
function checkStringTypeUnionValue(values, label, value) {
|
|
if (value == null) {
|
|
return;
|
|
}
|
|
if (values.indexOf(value) < 0) {
|
|
throw new ValueError(`${value} is not a valid ${label}. Valid values are ${values} or null/undefined.`);
|
|
}
|
|
}
|
|
|
|
|
|
function assertPositiveInteger(value, name) {
|
|
if (Array.isArray(value)) {
|
|
assert$1(value.length > 0, () => `${name} is unexpectedly an empty array.`);
|
|
value.forEach((v, i) => assertPositiveInteger(v, `element ${i + 1} of ${name}`));
|
|
}
|
|
else {
|
|
assert$1(Number.isInteger(value) && value > 0, () => `Expected ${name} to be a positive integer, but got ` +
|
|
`${formatAsFriendlyString(value)}.`);
|
|
}
|
|
}
|
|
|
|
|
|
function formatAsFriendlyString(value) {
|
|
if (value === null) {
|
|
return 'null';
|
|
}
|
|
else if (Array.isArray(value)) {
|
|
return '[' + value.map(v => formatAsFriendlyString(v)).join(',') + ']';
|
|
}
|
|
else if (typeof value === 'string') {
|
|
return `"${value}"`;
|
|
}
|
|
else {
|
|
return `${value}`;
|
|
}
|
|
}
|
|
|
|
function debounce(f, waitMs, nowFunc) {
|
|
let lastTime = nowFunc != null ? nowFunc() : now();
|
|
let lastResult;
|
|
const f2 = (...args) => {
|
|
const now$1 = nowFunc != null ? nowFunc() : now();
|
|
if (now$1 - lastTime < waitMs) {
|
|
return lastResult;
|
|
}
|
|
lastTime = now$1;
|
|
lastResult = f(...args);
|
|
return lastResult;
|
|
};
|
|
return f2;
|
|
}
|
|
|
|
function mapActivationToFusedKernel(activationName) {
|
|
if (activationName === 'relu') {
|
|
return 'relu';
|
|
}
|
|
if (activationName === 'linear') {
|
|
return 'linear';
|
|
}
|
|
if (activationName === 'elu') {
|
|
return 'elu';
|
|
}
|
|
return null;
|
|
}
|
|
|
|
|
|
|
|
|
|
let _nextUniqueTensorId = 0;
|
|
function getNextUniqueTensorId() {
|
|
return _nextUniqueTensorId++;
|
|
}
|
|
const _uidPrefixes = {};
|
|
|
|
function getUid(prefix = '') {
|
|
if (!(prefix in _uidPrefixes)) {
|
|
_uidPrefixes[prefix] = 0;
|
|
}
|
|
_uidPrefixes[prefix] += 1;
|
|
return prefix + _uidPrefixes[prefix].toString();
|
|
}
|
|
|
|
|
|
const VALID_DATA_FORMAT_VALUES = ['channelsFirst', 'channelsLast'];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const nameMap = new Map();
|
|
function checkDataFormat(value) {
|
|
checkStringTypeUnionValue(VALID_DATA_FORMAT_VALUES, 'DataFormat', value);
|
|
}
|
|
const _nameScopeStack = [];
|
|
const _nameScopeDivider = '/';
|
|
|
|
function nameScope(name, fn) {
|
|
_nameScopeStack.push(name);
|
|
try {
|
|
const val = fn();
|
|
_nameScopeStack.pop();
|
|
return val;
|
|
}
|
|
catch (e) {
|
|
_nameScopeStack.pop();
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
function currentNameScopePrefix() {
|
|
if (_nameScopeStack.length === 0) {
|
|
return '';
|
|
}
|
|
else {
|
|
return _nameScopeStack.join(_nameScopeDivider) + _nameScopeDivider;
|
|
}
|
|
}
|
|
|
|
function getScopedTensorName(tensorName) {
|
|
if (!isValidTensorName(tensorName)) {
|
|
throw new Error('Not a valid tensor name: \'' + tensorName + '\'');
|
|
}
|
|
return currentNameScopePrefix() + tensorName;
|
|
}
|
|
|
|
function getUniqueTensorName(scopedName) {
|
|
if (!isValidTensorName(scopedName)) {
|
|
throw new Error('Not a valid tensor name: \'' + scopedName + '\'');
|
|
}
|
|
if (!nameMap.has(scopedName)) {
|
|
nameMap.set(scopedName, 0);
|
|
}
|
|
const index = nameMap.get(scopedName);
|
|
nameMap.set(scopedName, nameMap.get(scopedName) + 1);
|
|
if (index > 0) {
|
|
const result = `${scopedName}_${index}`;
|
|
|
|
|
|
nameMap.set(result, 1);
|
|
return result;
|
|
}
|
|
else {
|
|
return scopedName;
|
|
}
|
|
}
|
|
const tensorNameRegex = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
|
|
|
|
function isValidTensorName(name) {
|
|
return !!name.match(tensorNameRegex);
|
|
}
|
|
|
|
|
|
|
|
|
|
function arrayProd(array, begin, end) {
|
|
if (begin == null) {
|
|
begin = 0;
|
|
}
|
|
if (end == null) {
|
|
end = array.length;
|
|
}
|
|
let prod = 1;
|
|
for (let i = begin; i < end; ++i) {
|
|
prod *= array[i];
|
|
}
|
|
return prod;
|
|
}
|
|
|
|
function range(begin, end) {
|
|
if (end < begin) {
|
|
throw new ValueError(`end (${end}) < begin (${begin}) is forbidden.`);
|
|
}
|
|
const out = [];
|
|
for (let i = begin; i < end; ++i) {
|
|
out.push(i);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
|
|
let _epsilon;
|
|
|
|
function epsilon() {
|
|
if (_epsilon == null) {
|
|
_epsilon = backend().epsilon();
|
|
}
|
|
return _epsilon;
|
|
}
|
|
|
|
function imageDataFormat() {
|
|
return 'channelsLast';
|
|
}
|
|
|
|
|
|
|
|
|
|
function cast(x, dtype) {
|
|
return cast$3(x, dtype);
|
|
}
|
|
|
|
function expandDims(x, axis = -1) {
|
|
const outShape = x.shape.slice();
|
|
if (axis < 0) {
|
|
axis = outShape.length + axis + 1;
|
|
}
|
|
outShape.splice(axis, 0, 1);
|
|
return reshape$2(x, outShape);
|
|
}
|
|
|
|
function repeat(x, n) {
|
|
return tidy(() => {
|
|
if (x.shape.length !== 2) {
|
|
throw new ValueError(`repeat() expects a rank-2 tensor, but received a ` +
|
|
`rank-${x.shape.length} tensor.`);
|
|
}
|
|
const y = expandDims(x, 1);
|
|
return tile(y, [1, n, 1]);
|
|
});
|
|
}
|
|
|
|
function flatten(x) {
|
|
const newShape = [arrayProd(x.shape)];
|
|
return reshape$2(x, newShape);
|
|
}
|
|
|
|
function batchFlatten(x) {
|
|
if (x.rank <= 1) {
|
|
throw new ValueError(`batchFlatten requires a minimum rank of 2. Got rank: ${x.rank}.`);
|
|
}
|
|
const newShape = [x.shape[0], arrayProd(x.shape, 1)];
|
|
return reshape$2(x, newShape);
|
|
}
|
|
|
|
function sliceAlongFirstAxis(array, start, size) {
|
|
return tidy(() => {
|
|
switch (array.rank) {
|
|
case 1:
|
|
return slice1d(array, start, size);
|
|
case 2:
|
|
return slice2d(array, [start, 0], [size, array.shape[1]]);
|
|
case 3:
|
|
return slice3d(array, [start, 0, 0], [size, array.shape[1], array.shape[2]]);
|
|
case 4:
|
|
return slice4d(array, [start, 0, 0, 0], [size, array.shape[1], array.shape[2], array.shape[3]]);
|
|
case 5:
|
|
return slice$2(array, [start, 0, 0, 0, 0], [
|
|
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4]
|
|
]);
|
|
case 6:
|
|
return slice$2(array, [start, 0, 0, 0, 0, 0], [
|
|
size, array.shape[1], array.shape[2], array.shape[3], array.shape[4],
|
|
array.shape[5]
|
|
]);
|
|
default:
|
|
throw new ValueError(`sliceAlongFirstAxis() received an unsupported tensor rank: ` +
|
|
`${array.rank}`);
|
|
}
|
|
});
|
|
}
|
|
|
|
function tile(x, n) {
|
|
if (!Array.isArray(n)) {
|
|
n = [n];
|
|
}
|
|
if (x.rank !== n.length) {
|
|
throw new ValueError(`The length of input n (${n.length}) does not match ` +
|
|
`the number of dimensions in input x (${x.rank})`);
|
|
}
|
|
return tile$3(x, n);
|
|
}
|
|
|
|
|
|
function randomNormal(shape, mean = 0.0, stddev = 1.0, dtype, seed) {
|
|
return randomNormal$1(shape, mean, stddev, dtype, seed);
|
|
}
|
|
|
|
|
|
function dot(a, b, activation, bias) {
|
|
if ((a.rank < 2) || (b.rank < 2)) {
|
|
throw new NotImplementedError(`dot requires both inputs to be rank >= 2` +
|
|
` but got x shape = ${a.shape} and y shape = ${b.shape}`);
|
|
}
|
|
if (b.rank >= 3) {
|
|
const xLastDim = a.shape.slice(-1)[0];
|
|
const ySecondLastDim = b.shape.slice(-2)[0];
|
|
if (xLastDim !== ySecondLastDim) {
|
|
throw new NotImplementedError(`If rank y >= 3, then the second last dim` +
|
|
` of y must equal the last dim of x but got x shape = ${a.shape} and ` +
|
|
` y shape = ${b.shape}`);
|
|
}
|
|
}
|
|
|
|
if ((a.rank === 2) && (b.rank === 2)) {
|
|
const transposeA = false;
|
|
const transposeB = false;
|
|
|
|
|
|
|
|
return matMul({
|
|
a,
|
|
b: b,
|
|
transposeA,
|
|
transposeB,
|
|
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
|
|
activation
|
|
});
|
|
}
|
|
else {
|
|
|
|
const aFirstDims = a.shape.slice();
|
|
const aLastDim = aFirstDims.pop();
|
|
a = reshape$2(a, [-1, aLastDim]);
|
|
|
|
|
|
const bShape = b.shape.slice();
|
|
const bLastDim = bShape.pop();
|
|
const ySecondLastDim = bShape.pop();
|
|
const yOtherDims = [...bShape, bLastDim];
|
|
|
|
|
|
const perm = Array.from({ length: b.rank }, (_, i) => {
|
|
if (i === 0) {
|
|
return b.rank - 2;
|
|
}
|
|
else if (i <= b.rank - 2) {
|
|
return i - 1;
|
|
}
|
|
return i;
|
|
});
|
|
b = reshape$2(transpose$2(b, perm), [ySecondLastDim, -1]);
|
|
|
|
const outputShape = [...aFirstDims, ...yOtherDims];
|
|
const transposeA = false;
|
|
const transposeB = false;
|
|
return reshape$2(matMul({
|
|
a,
|
|
b,
|
|
transposeA,
|
|
transposeB,
|
|
bias: bias ? reshapeBias(a.rank, bias, imageDataFormat()) : null,
|
|
activation
|
|
}), outputShape);
|
|
}
|
|
}
|
|
|
|
|
|
function gather(reference, indices, axis) {
|
|
return tidy(() => {
|
|
if (Array.isArray(indices)) {
|
|
indices = tensor1d(indices, 'int32');
|
|
}
|
|
else {
|
|
indices = cast$3(indices, 'int32');
|
|
}
|
|
return gather$1(reference, indices, axis);
|
|
});
|
|
}
|
|
|
|
function square(x) {
|
|
return mul(x, x);
|
|
}
|
|
|
|
function reshapeBias(xRank, bias, dataFormat) {
|
|
const biasShape = bias.shape;
|
|
if (bias.rank !== 1 && bias.rank !== xRank) {
|
|
throw new ValueError(`Unexpected bias dimensions: ${bias.rank}` +
|
|
`; expected it to be 1 or ${xRank}`);
|
|
}
|
|
if (xRank === 5) {
|
|
if (dataFormat === 'channelsFirst') {
|
|
if (biasShape.length === 1) {
|
|
return reshape$2(bias, [1, biasShape[0], 1, 1, 1]);
|
|
}
|
|
else {
|
|
return reshape$2(bias, [1, biasShape[3], biasShape[0], biasShape[1], biasShape[2]]);
|
|
}
|
|
}
|
|
else if (dataFormat === 'channelsLast') {
|
|
if (biasShape.length === 1) {
|
|
return reshape$2(bias, [1, 1, 1, 1, biasShape[0]]);
|
|
}
|
|
else {
|
|
return reshape$2(bias, [1].concat(biasShape));
|
|
}
|
|
}
|
|
}
|
|
else if (xRank === 4) {
|
|
if (dataFormat === 'channelsFirst') {
|
|
if (biasShape.length === 1) {
|
|
return reshape$2(bias, [1, biasShape[0], 1, 1]);
|
|
}
|
|
else {
|
|
return reshape$2(bias, [1, biasShape[2], biasShape[0], biasShape[1]]);
|
|
}
|
|
}
|
|
else if (dataFormat === 'channelsLast') {
|
|
if (biasShape.length === 1) {
|
|
return reshape$2(bias, [1, 1, 1, biasShape[0]]);
|
|
}
|
|
else {
|
|
return reshape$2(bias, [1].concat(biasShape));
|
|
}
|
|
}
|
|
}
|
|
else if (xRank === 3) {
|
|
if (dataFormat === 'channelsFirst') {
|
|
if (biasShape.length === 1) {
|
|
return reshape$2(bias, [1, biasShape[0], 1]);
|
|
}
|
|
else {
|
|
return reshape$2(bias, [1, biasShape[1], biasShape[0]]);
|
|
}
|
|
}
|
|
else if (dataFormat === 'channelsLast') {
|
|
if (biasShape.length === 1) {
|
|
return reshape$2(bias, [1, 1, biasShape[0]]);
|
|
}
|
|
else {
|
|
return reshape$2(bias, [1].concat(biasShape));
|
|
}
|
|
}
|
|
}
|
|
else if (xRank < 3) {
|
|
return bias;
|
|
}
|
|
throw new ValueError(`Unsupported input rank by biasAdd: ${bias.rank}`);
|
|
}
|
|
|
|
|
|
function biasAdd(x, bias, dataFormat) {
|
|
return tidy(() => {
|
|
if (dataFormat == null) {
|
|
dataFormat = imageDataFormat();
|
|
}
|
|
checkDataFormat(dataFormat);
|
|
return add$1(x, reshapeBias(x.rank, bias, dataFormat));
|
|
});
|
|
}
|
|
|
|
function elu(x, alpha = 1) {
|
|
|
|
if (alpha !== 1) {
|
|
throw new NotImplementedError(`Support for alpha values other than 1 (${alpha}) is not implemented ` +
|
|
`yet.`);
|
|
}
|
|
return elu$3(x);
|
|
}
|
|
|
|
function softsign(x) {
|
|
return tidy(() => div$1(x, add$1(abs$2(x), 1)));
|
|
}
|
|
|
|
function dropout$1(x, level, noiseShape, seed) {
|
|
return tidy(() => dropout$2(x, level, noiseShape, seed));
|
|
}
|
|
|
|
function hardSigmoid(x) {
|
|
return tidy(() => {
|
|
const y = add$1(.5, mul(.2, x));
|
|
return clipByValue$2(y, 0, 1);
|
|
});
|
|
}
|
|
|
|
function inTrainPhase(x, alt, training = false) {
|
|
return training ? x() : alt();
|
|
}
|
|
|
|
|
|
const VALID_FAN_MODE_VALUES = ['fanIn', 'fanOut', 'fanAvg'];
|
|
const VALID_DISTRIBUTION_VALUES = ['normal', 'uniform', 'truncatedNormal'];
|
|
|
|
|
|
function checkFanMode(value) {
|
|
checkStringTypeUnionValue(VALID_FAN_MODE_VALUES, 'FanMode', value);
|
|
}
|
|
function checkDistribution(value) {
|
|
checkStringTypeUnionValue(VALID_DISTRIBUTION_VALUES, 'Distribution', value);
|
|
}
|
|
|
|
class Initializer extends Serializable {
|
|
fromConfigUsesCustomObjects() {
|
|
return false;
|
|
}
|
|
getConfig() {
|
|
return {};
|
|
}
|
|
}
|
|
class Zeros extends Initializer {
|
|
apply(shape, dtype) {
|
|
return zeros$1(shape, dtype);
|
|
}
|
|
}
|
|
|
|
Zeros.className = 'Zeros';
|
|
registerClass(Zeros);
|
|
class Ones extends Initializer {
|
|
apply(shape, dtype) {
|
|
return ones(shape, dtype);
|
|
}
|
|
}
|
|
|
|
Ones.className = 'Ones';
|
|
registerClass(Ones);
|
|
class Constant extends Initializer {
|
|
constructor(args) {
|
|
super();
|
|
if (typeof args !== 'object') {
|
|
throw new ValueError(`Expected argument of type ConstantConfig but got ${args}`);
|
|
}
|
|
if (args.value === undefined) {
|
|
throw new ValueError(`config must have value set but got ${args}`);
|
|
}
|
|
this.value = args.value;
|
|
}
|
|
apply(shape, dtype) {
|
|
return tidy(() => mul(scalar(this.value), ones(shape, dtype)));
|
|
}
|
|
getConfig() {
|
|
return {
|
|
value: this.value,
|
|
};
|
|
}
|
|
}
|
|
|
|
Constant.className = 'Constant';
|
|
registerClass(Constant);
|
|
class RandomUniform extends Initializer {
|
|
constructor(args) {
|
|
super();
|
|
this.DEFAULT_MINVAL = -0.05;
|
|
this.DEFAULT_MAXVAL = 0.05;
|
|
this.minval = args.minval || this.DEFAULT_MINVAL;
|
|
this.maxval = args.maxval || this.DEFAULT_MAXVAL;
|
|
this.seed = args.seed;
|
|
}
|
|
apply(shape, dtype) {
|
|
return randomUniform(shape, this.minval, this.maxval, dtype, this.seed);
|
|
}
|
|
getConfig() {
|
|
return { minval: this.minval, maxval: this.maxval, seed: this.seed };
|
|
}
|
|
}
|
|
|
|
RandomUniform.className = 'RandomUniform';
|
|
registerClass(RandomUniform);
|
|
class RandomNormal extends Initializer {
|
|
constructor(args) {
|
|
super();
|
|
this.DEFAULT_MEAN = 0.;
|
|
this.DEFAULT_STDDEV = 0.05;
|
|
this.mean = args.mean || this.DEFAULT_MEAN;
|
|
this.stddev = args.stddev || this.DEFAULT_STDDEV;
|
|
this.seed = args.seed;
|
|
}
|
|
apply(shape, dtype) {
|
|
dtype = dtype || 'float32';
|
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
|
throw new NotImplementedError(`randomNormal does not support dType ${dtype}.`);
|
|
}
|
|
return randomNormal(shape, this.mean, this.stddev, dtype, this.seed);
|
|
}
|
|
getConfig() {
|
|
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
|
|
}
|
|
}
|
|
|
|
RandomNormal.className = 'RandomNormal';
|
|
registerClass(RandomNormal);
|
|
class TruncatedNormal extends Initializer {
|
|
constructor(args) {
|
|
super();
|
|
this.DEFAULT_MEAN = 0.;
|
|
this.DEFAULT_STDDEV = 0.05;
|
|
this.mean = args.mean || this.DEFAULT_MEAN;
|
|
this.stddev = args.stddev || this.DEFAULT_STDDEV;
|
|
this.seed = args.seed;
|
|
}
|
|
apply(shape, dtype) {
|
|
dtype = dtype || 'float32';
|
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
|
throw new NotImplementedError(`truncatedNormal does not support dType ${dtype}.`);
|
|
}
|
|
return truncatedNormal(shape, this.mean, this.stddev, dtype, this.seed);
|
|
}
|
|
getConfig() {
|
|
return { mean: this.mean, stddev: this.stddev, seed: this.seed };
|
|
}
|
|
}
|
|
|
|
TruncatedNormal.className = 'TruncatedNormal';
|
|
registerClass(TruncatedNormal);
|
|
class Identity extends Initializer {
|
|
constructor(args) {
|
|
super();
|
|
this.gain = args.gain != null ? args.gain : 1.0;
|
|
}
|
|
apply(shape, dtype) {
|
|
return tidy(() => {
|
|
if (shape.length !== 2 || shape[0] !== shape[1]) {
|
|
throw new ValueError('Identity matrix initializer can only be used for' +
|
|
' 2D square matrices.');
|
|
}
|
|
else {
|
|
return mul(this.gain, eye(shape[0]));
|
|
}
|
|
});
|
|
}
|
|
getConfig() {
|
|
return { gain: this.gain };
|
|
}
|
|
}
|
|
|
|
Identity.className = 'Identity';
|
|
registerClass(Identity);
|
|
|
|
function computeFans(shape, dataFormat = 'channelsLast') {
|
|
let fanIn;
|
|
let fanOut;
|
|
checkDataFormat(dataFormat);
|
|
if (shape.length === 2) {
|
|
fanIn = shape[0];
|
|
fanOut = shape[1];
|
|
}
|
|
else if ([3, 4, 5].indexOf(shape.length) !== -1) {
|
|
if (dataFormat === 'channelsFirst') {
|
|
const receptiveFieldSize = arrayProd(shape, 2);
|
|
fanIn = shape[1] * receptiveFieldSize;
|
|
fanOut = shape[0] * receptiveFieldSize;
|
|
}
|
|
else if (dataFormat === 'channelsLast') {
|
|
const receptiveFieldSize = arrayProd(shape, 0, shape.length - 2);
|
|
fanIn = shape[shape.length - 2] * receptiveFieldSize;
|
|
fanOut = shape[shape.length - 1] * receptiveFieldSize;
|
|
}
|
|
}
|
|
else {
|
|
const shapeProd = arrayProd(shape);
|
|
fanIn = Math.sqrt(shapeProd);
|
|
fanOut = Math.sqrt(shapeProd);
|
|
}
|
|
return [fanIn, fanOut];
|
|
}
|
|
class VarianceScaling extends Initializer {
|
|
|
|
constructor(args) {
|
|
super();
|
|
if (args.scale < 0.0) {
|
|
throw new ValueError(`scale must be a positive float. Got: ${args.scale}`);
|
|
}
|
|
this.scale = args.scale == null ? 1.0 : args.scale;
|
|
this.mode = args.mode == null ? 'fanIn' : args.mode;
|
|
checkFanMode(this.mode);
|
|
this.distribution =
|
|
args.distribution == null ? 'normal' : args.distribution;
|
|
checkDistribution(this.distribution);
|
|
this.seed = args.seed;
|
|
}
|
|
apply(shape, dtype) {
|
|
const fans = computeFans(shape);
|
|
const fanIn = fans[0];
|
|
const fanOut = fans[1];
|
|
let scale = this.scale;
|
|
if (this.mode === 'fanIn') {
|
|
scale /= Math.max(1, fanIn);
|
|
}
|
|
else if (this.mode === 'fanOut') {
|
|
scale /= Math.max(1, fanOut);
|
|
}
|
|
else {
|
|
scale /= Math.max(1, (fanIn + fanOut) / 2);
|
|
}
|
|
if (this.distribution === 'normal') {
|
|
const stddev = Math.sqrt(scale);
|
|
dtype = dtype || 'float32';
|
|
if (dtype !== 'float32' && dtype !== 'int32') {
|
|
throw new NotImplementedError(`${this.getClassName()} does not support dType ${dtype}.`);
|
|
}
|
|
return truncatedNormal(shape, 0, stddev, dtype, this.seed);
|
|
}
|
|
else {
|
|
const limit = Math.sqrt(3 * scale);
|
|
return randomUniform(shape, -limit, limit, dtype, this.seed);
|
|
}
|
|
}
|
|
getConfig() {
|
|
return {
|
|
scale: this.scale,
|
|
mode: this.mode,
|
|
distribution: this.distribution,
|
|
seed: this.seed
|
|
};
|
|
}
|
|
}
|
|
|
|
VarianceScaling.className = 'VarianceScaling';
|
|
registerClass(VarianceScaling);
|
|
class GlorotUniform extends VarianceScaling {
|
|
|
|
constructor(args) {
|
|
super({
|
|
scale: 1.0,
|
|
mode: 'fanAvg',
|
|
distribution: 'uniform',
|
|
seed: args == null ? null : args.seed
|
|
});
|
|
}
|
|
getClassName() {
|
|
|
|
|
|
|
|
return VarianceScaling.className;
|
|
}
|
|
}
|
|
|
|
GlorotUniform.className = 'GlorotUniform';
|
|
registerClass(GlorotUniform);
|
|
class GlorotNormal extends VarianceScaling {
|
|
|
|
constructor(args) {
|
|
super({
|
|
scale: 1.0,
|
|
mode: 'fanAvg',
|
|
distribution: 'normal',
|
|
seed: args == null ? null : args.seed
|
|
});
|
|
}
|
|
getClassName() {
|
|
|
|
|
|
|
|
return VarianceScaling.className;
|
|
}
|
|
}
|
|
|
|
GlorotNormal.className = 'GlorotNormal';
|
|
registerClass(GlorotNormal);
|
|
class HeNormal extends VarianceScaling {
|
|
constructor(args) {
|
|
super({
|
|
scale: 2.0,
|
|
mode: 'fanIn',
|
|
distribution: 'normal',
|
|
seed: args == null ? null : args.seed
|
|
});
|
|
}
|
|
getClassName() {
|
|
|
|
|
|
|
|
return VarianceScaling.className;
|
|
}
|
|
}
|
|
|
|
HeNormal.className = 'HeNormal';
|
|
registerClass(HeNormal);
|
|
class HeUniform extends VarianceScaling {
|
|
constructor(args) {
|
|
super({
|
|
scale: 2.0,
|
|
mode: 'fanIn',
|
|
distribution: 'uniform',
|
|
seed: args == null ? null : args.seed
|
|
});
|
|
}
|
|
getClassName() {
|
|
|
|
|
|
|
|
return VarianceScaling.className;
|
|
}
|
|
}
|
|
|
|
HeUniform.className = 'HeUniform';
|
|
registerClass(HeUniform);
|
|
class LeCunNormal extends VarianceScaling {
|
|
constructor(args) {
|
|
super({
|
|
scale: 1.0,
|
|
mode: 'fanIn',
|
|
distribution: 'normal',
|
|
seed: args == null ? null : args.seed
|
|
});
|
|
}
|
|
getClassName() {
|
|
|
|
|
|
|
|
return VarianceScaling.className;
|
|
}
|
|
}
|
|
|
|
LeCunNormal.className = 'LeCunNormal';
|
|
registerClass(LeCunNormal);
|
|
class LeCunUniform extends VarianceScaling {
|
|
constructor(args) {
|
|
super({
|
|
scale: 1.0,
|
|
mode: 'fanIn',
|
|
distribution: 'uniform',
|
|
seed: args == null ? null : args.seed
|
|
});
|
|
}
|
|
getClassName() {
|
|
|
|
|
|
|
|
return VarianceScaling.className;
|
|
}
|
|
}
|
|
|
|
LeCunUniform.className = 'LeCunUniform';
|
|
registerClass(LeCunUniform);
|
|
class Orthogonal extends Initializer {
|
|
constructor(args) {
|
|
super();
|
|
this.DEFAULT_GAIN = 1;
|
|
this.ELEMENTS_WARN_SLOW = 2000;
|
|
this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
|
|
this.seed = args.seed;
|
|
}
|
|
apply(shape, dtype) {
|
|
return tidy(() => {
|
|
if (shape.length < 2) {
|
|
throw new NotImplementedError('Shape must be at least 2D.');
|
|
}
|
|
if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
|
|
throw new TypeError(`Unsupported data type ${dtype}.`);
|
|
}
|
|
dtype = dtype;
|
|
|
|
|
|
const numRows = sizeFromShape(shape.slice(0, -1));
|
|
const numCols = shape[shape.length - 1];
|
|
const numElements = numRows * numCols;
|
|
if (numElements > this.ELEMENTS_WARN_SLOW) {
|
|
console.warn(`Orthogonal initializer is being called on a matrix with more ` +
|
|
`than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
|
|
`Slowness may result.`);
|
|
}
|
|
const flatShape = [Math.max(numCols, numRows), Math.min(numCols, numRows)];
|
|
|
|
const randNormalMat = randomNormal(flatShape, 0, 1, dtype, this.seed);
|
|
|
|
const qr = linalg.qr(randNormalMat, false);
|
|
let qMat = qr[0];
|
|
const rMat = qr[1];
|
|
|
|
const diag = rMat.flatten().stridedSlice([0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], [Math.min(numCols, numRows) + 1]);
|
|
qMat = mul(qMat, diag.sign());
|
|
if (numRows < numCols) {
|
|
qMat = qMat.transpose();
|
|
}
|
|
return mul(scalar(this.gain), qMat.reshape(shape));
|
|
});
|
|
}
|
|
getConfig() {
|
|
return {
|
|
gain: this.gain,
|
|
seed: this.seed,
|
|
};
|
|
}
|
|
}
|
|
|
|
Orthogonal.className = 'Orthogonal';
|
|
registerClass(Orthogonal);
|
|
|
|
|
|
const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
|
|
'constant': 'Constant',
|
|
'glorotNormal': 'GlorotNormal',
|
|
'glorotUniform': 'GlorotUniform',
|
|
'heNormal': 'HeNormal',
|
|
'heUniform': 'HeUniform',
|
|
'identity': 'Identity',
|
|
'leCunNormal': 'LeCunNormal',
|
|
'leCunUniform': 'LeCunUniform',
|
|
'ones': 'Ones',
|
|
'orthogonal': 'Orthogonal',
|
|
'randomNormal': 'RandomNormal',
|
|
'randomUniform': 'RandomUniform',
|
|
'truncatedNormal': 'TruncatedNormal',
|
|
'varianceScaling': 'VarianceScaling',
|
|
'zeros': 'Zeros'
|
|
};
|
|
function deserializeInitializer(config, customObjects = {}) {
|
|
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'initializer');
|
|
}
|
|
function serializeInitializer(initializer) {
|
|
return serializeKerasObject(initializer);
|
|
}
|
|
function getInitializer(identifier) {
|
|
if (typeof identifier === 'string') {
|
|
const className = identifier in INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
|
|
INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
|
|
identifier;
|
|
|
|
if (className === 'GlorotNormal') {
|
|
return new GlorotNormal();
|
|
}
|
|
else if (className === 'GlorotUniform') {
|
|
return new GlorotUniform();
|
|
}
|
|
else if (className === 'HeNormal') {
|
|
return new HeNormal();
|
|
}
|
|
else if (className === 'HeUniform') {
|
|
return new HeUniform();
|
|
}
|
|
else if (className === 'LeCunNormal') {
|
|
return new LeCunNormal();
|
|
}
|
|
else if (className === 'LeCunUniform') {
|
|
return new LeCunUniform();
|
|
}
|
|
else {
|
|
const config = {};
|
|
config['className'] = className;
|
|
config['config'] = {};
|
|
return deserializeInitializer(config);
|
|
}
|
|
}
|
|
else if (identifier instanceof Initializer) {
|
|
return identifier;
|
|
}
|
|
else {
|
|
return deserializeInitializer(identifier);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function normalizeShapeList(x) {
|
|
if (x.length === 0) {
|
|
return [];
|
|
}
|
|
if (!Array.isArray(x[0])) {
|
|
return [x];
|
|
}
|
|
return x;
|
|
}
|
|
|
|
function getExactlyOneTensor(xs) {
|
|
let x;
|
|
if (Array.isArray(xs)) {
|
|
if (xs.length !== 1) {
|
|
throw new ValueError(`Expected Tensor length to be 1; got ${xs.length}`);
|
|
}
|
|
x = xs[0];
|
|
}
|
|
else {
|
|
x = xs;
|
|
}
|
|
return x;
|
|
}
|
|
|
|
function getExactlyOneShape(shapes) {
|
|
if (Array.isArray(shapes) && Array.isArray(shapes[0])) {
|
|
if (shapes.length === 1) {
|
|
shapes = shapes;
|
|
return shapes[0];
|
|
}
|
|
else {
|
|
throw new ValueError(`Expected exactly 1 Shape; got ${shapes.length}`);
|
|
}
|
|
}
|
|
else {
|
|
return shapes;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function countParamsInWeights(weights) {
|
|
let count = 0;
|
|
for (const weight of weights) {
|
|
if (weight.shape.length === 0) {
|
|
count += 1;
|
|
}
|
|
else {
|
|
count += weight.shape.reduce((a, b) => a * b);
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
|
|
const DEFAULT_VARIABLE_NAME_PREFIX = 'Variable';
|
|
|
|
class LayerVariable {
|
|
|
|
constructor(val, dtype = 'float32', name = DEFAULT_VARIABLE_NAME_PREFIX, trainable = true, constraint = null) {
|
|
this.dtype = dtype == null ? 'float32' : dtype;
|
|
this.shape = val.shape;
|
|
this.id = getNextUniqueTensorId();
|
|
name = name == null ? DEFAULT_VARIABLE_NAME_PREFIX : name;
|
|
this.originalName = getScopedTensorName(name);
|
|
this.name = getUniqueTensorName(this.originalName);
|
|
this.trainable_ = trainable;
|
|
this.constraint = constraint;
|
|
this.val = variable(val, this.trainable_, this.name, this.dtype);
|
|
}
|
|
|
|
read() {
|
|
this.assertNotDisposed();
|
|
return this.val;
|
|
}
|
|
|
|
write(newVal) {
|
|
|
|
this.assertNotDisposed();
|
|
checkShapesMatch(this.val, newVal);
|
|
|
|
if (this.val.id !== newVal.id) {
|
|
this.val.assign(newVal);
|
|
if (this.constraint != null) {
|
|
this.val.assign(this.constraint.apply(this.val));
|
|
}
|
|
}
|
|
return this;
|
|
}
|
|
|
|
dispose() {
|
|
this.assertNotDisposed();
|
|
this.val.dispose();
|
|
}
|
|
assertNotDisposed() {
|
|
if (this.val.isDisposed) {
|
|
throw new Error(`LayersVariable ${this.name} is already disposed.`);
|
|
}
|
|
}
|
|
get trainable() {
|
|
return this.trainable_;
|
|
}
|
|
set trainable(trainable) {
|
|
this.trainable_ = trainable;
|
|
this.val.trainable = trainable;
|
|
}
|
|
}
|
|
function checkShapesMatch(x, y) {
|
|
if (x.shape.toString() !== y.shape.toString()) {
|
|
throw new Error('Shape mismatch: ' + JSON.stringify(x.shape) + ' vs. ' +
|
|
JSON.stringify(y.shape));
|
|
}
|
|
}
|
|
|
|
function batchGetValue(xs) {
|
|
return xs.map(x => x.read());
|
|
}
|
|
|
|
function batchSetValue(variablesAndValues) {
|
|
variablesAndValues.forEach(variableAndValue => {
|
|
const variable = variableAndValue[0];
|
|
variable.write(variableAndValue[1]);
|
|
});
|
|
}
|
|
|
|
|
|
|
|
|
|
class InputSpec {
|
|
constructor(args) {
|
|
this.dtype = args.dtype;
|
|
this.shape = args.shape;
|
|
|
|
if (args.shape != null) {
|
|
this.ndim = args.shape.length;
|
|
}
|
|
else {
|
|
this.ndim = args.ndim;
|
|
}
|
|
this.maxNDim = args.maxNDim;
|
|
this.minNDim = args.minNDim;
|
|
this.axes = args.axes || {};
|
|
}
|
|
}
|
|
|
|
class SymbolicTensor {
|
|
|
|
constructor(dtype, shape, sourceLayer, inputs, callArgs, name, outputTensorIndex) {
|
|
this.dtype = dtype;
|
|
this.shape = shape;
|
|
this.sourceLayer = sourceLayer;
|
|
this.inputs = inputs;
|
|
this.callArgs = callArgs;
|
|
this.outputTensorIndex = outputTensorIndex;
|
|
this.id = getNextUniqueTensorId();
|
|
if (name != null) {
|
|
this.originalName = getScopedTensorName(name);
|
|
this.name = getUniqueTensorName(this.originalName);
|
|
}
|
|
this.rank = shape.length;
|
|
}
|
|
}
|
|
let _nextNodeID = 0;
|
|
|
|
class Node {
|
|
constructor(args,
|
|
|
|
callArgs) {
|
|
this.callArgs = callArgs;
|
|
this.id = _nextNodeID++;
|
|
|
|
this.outboundLayer = args.outboundLayer;
|
|
|
|
|
|
this.inboundLayers = args.inboundLayers;
|
|
|
|
this.nodeIndices = args.nodeIndices;
|
|
|
|
this.tensorIndices = args.tensorIndices;
|
|
|
|
|
|
this.inputTensors = args.inputTensors;
|
|
|
|
this.outputTensors = args.outputTensors;
|
|
|
|
this.inputMasks = args.inputMasks;
|
|
|
|
this.outputMasks = args.outputMasks;
|
|
|
|
|
|
this.inputShapes = args.inputShapes;
|
|
|
|
this.outputShapes = args.outputShapes;
|
|
|
|
for (const layer of args.inboundLayers) {
|
|
if (layer != null) {
|
|
layer.outboundNodes.push(this);
|
|
}
|
|
}
|
|
args.outboundLayer.inboundNodes.push(this);
|
|
}
|
|
getConfig() {
|
|
const inboundNames = [];
|
|
for (const layer of this.inboundLayers) {
|
|
if (layer != null) {
|
|
inboundNames.push(layer.name);
|
|
}
|
|
else {
|
|
inboundNames.push(null);
|
|
}
|
|
}
|
|
return {
|
|
outboundLayer: this.outboundLayer ? this.outboundLayer.name : null,
|
|
inboundLayers: inboundNames,
|
|
nodeIndices: this.nodeIndices,
|
|
tensorIndices: this.tensorIndices
|
|
};
|
|
}
|
|
}
|
|
let _nextLayerID = 0;
|
|
|
|
class Layer extends Serializable {
|
|
constructor(args = {}) {
|
|
super();
|
|
this._callHook = null;
|
|
this._addedWeightNames = [];
|
|
|
|
|
|
|
|
|
|
this._stateful = false;
|
|
this.id = _nextLayerID++;
|
|
this.activityRegularizer = null;
|
|
this.inputSpec = null;
|
|
this.supportsMasking = false;
|
|
|
|
this._trainableWeights = [];
|
|
this._nonTrainableWeights = [];
|
|
this._losses = [];
|
|
this._updates = [];
|
|
this._built = false;
|
|
|
|
this.inboundNodes = [];
|
|
this.outboundNodes = [];
|
|
let name = args.name;
|
|
if (!name) {
|
|
const prefix = this.getClassName();
|
|
name = toSnakeCase(prefix) + '_' + getUid(prefix);
|
|
}
|
|
this.name = name;
|
|
this.trainable_ = args.trainable == null ? true : args.trainable;
|
|
if (args.inputShape != null || args.batchInputShape != null) {
|
|
|
|
let batchInputShape;
|
|
if (args.batchInputShape != null) {
|
|
batchInputShape = args.batchInputShape;
|
|
}
|
|
else if (args.inputShape != null) {
|
|
let batchSize = null;
|
|
if (args.batchSize != null) {
|
|
batchSize = args.batchSize;
|
|
}
|
|
batchInputShape = [batchSize].concat(args.inputShape);
|
|
}
|
|
this.batchInputShape = batchInputShape;
|
|
|
|
let dtype = args.dtype;
|
|
if (dtype == null) {
|
|
dtype = args.inputDType;
|
|
}
|
|
if (dtype == null) {
|
|
dtype = 'float32';
|
|
}
|
|
this.dtype = dtype;
|
|
}
|
|
if (args.weights != null) {
|
|
this.initialWeights = args.weights;
|
|
}
|
|
else {
|
|
this.initialWeights = null;
|
|
}
|
|
|
|
|
|
this._refCount = null;
|
|
this.fastWeightInitDuringBuild = false;
|
|
}
|
|
|
|
static nodeKey(layer, nodeIndex) {
|
|
return layer.name + '_ib-' + nodeIndex.toString();
|
|
}
|
|
|
|
getNodeAtIndex(nodeIndex, attrName) {
|
|
if (this.inboundNodes.length === 0) {
|
|
throw new RuntimeError('The layer has never been called ' +
|
|
`and thus has no defined ${attrName}.`);
|
|
}
|
|
if (this.inboundNodes.length <= nodeIndex) {
|
|
throw new ValueError(`Asked to get ${attrName} at node ${nodeIndex}, ` +
|
|
`but the layer has only ${this.inboundNodes.length} inbound nodes.`);
|
|
}
|
|
return this.inboundNodes[nodeIndex];
|
|
}
|
|
|
|
getInputAt(nodeIndex) {
|
|
return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'input').inputTensors);
|
|
}
|
|
|
|
getOutputAt(nodeIndex) {
|
|
return singletonOrArray(this.getNodeAtIndex(nodeIndex, 'output').outputTensors);
|
|
}
|
|
|
|
|
|
get input() {
|
|
if (this.inboundNodes.length > 1) {
|
|
throw new AttributeError(`Layer ${this.name}` +
|
|
' has multiple inbound nodes, ' +
|
|
'hence the notion of "layer input" ' +
|
|
'is ill-defined. ' +
|
|
'Use `getInputAt(nodeIndex)` instead.');
|
|
}
|
|
else if (this.inboundNodes.length === 0) {
|
|
throw new AttributeError(`Layer ${this.name}` +
|
|
' is not connected, no input to return.');
|
|
}
|
|
return singletonOrArray(this.getNodeAtIndex(0, 'input').inputTensors);
|
|
}
|
|
|
|
get output() {
|
|
if (this.inboundNodes.length === 0) {
|
|
throw new AttributeError(`Layer ${this.name}` +
|
|
' has no inbound nodes.');
|
|
}
|
|
if (this.inboundNodes.length > 1) {
|
|
throw new AttributeError(`Layer ${this.name}` +
|
|
' has multiple inbound nodes, ' +
|
|
'hence the notion of "layer output" ' +
|
|
'is ill-defined. ' +
|
|
'Use `getOutputAt(nodeIndex)` instead.');
|
|
}
|
|
return singletonOrArray(this.getNodeAtIndex(0, 'output').outputTensors);
|
|
}
|
|
get losses() {
|
|
return this._losses;
|
|
}
|
|
|
|
calculateLosses() {
|
|
|
|
|
|
|
|
|
|
return this.losses.map(lossFn => lossFn());
|
|
}
|
|
get updates() {
|
|
return this._updates;
|
|
}
|
|
get built() {
|
|
return this._built;
|
|
}
|
|
set built(built) {
|
|
this._built = built;
|
|
}
|
|
get trainable() {
|
|
return this.trainable_;
|
|
}
|
|
set trainable(trainable) {
|
|
this._trainableWeights.forEach(w => w.trainable = trainable);
|
|
this.trainable_ = trainable;
|
|
}
|
|
get trainableWeights() {
|
|
if (this.trainable_) {
|
|
return this._trainableWeights.filter(w => w.trainable);
|
|
}
|
|
else {
|
|
return [];
|
|
}
|
|
}
|
|
set trainableWeights(weights) {
|
|
this._trainableWeights = weights;
|
|
}
|
|
get nonTrainableWeights() {
|
|
if (this.trainable) {
|
|
return this._trainableWeights.filter(w => !w.trainable)
|
|
.concat(this._nonTrainableWeights);
|
|
}
|
|
else {
|
|
return this._trainableWeights.concat(this._nonTrainableWeights);
|
|
}
|
|
}
|
|
set nonTrainableWeights(weights) {
|
|
this._nonTrainableWeights = weights;
|
|
}
|
|
|
|
get weights() {
|
|
return this.trainableWeights.concat(this.nonTrainableWeights);
|
|
}
|
|
get stateful() {
|
|
return this._stateful;
|
|
}
|
|
|
|
resetStates() {
|
|
if (!this.stateful) {
|
|
throw new Error('Cannot call the resetStates() method of a non-stateful Layer ' +
|
|
'object.');
|
|
}
|
|
}
|
|
|
|
assertInputCompatibility(inputs) {
|
|
const inputsList = toList(inputs);
|
|
if (this.inputSpec == null || this.inputSpec.length === 0) {
|
|
return;
|
|
}
|
|
const inputSpec = toList(this.inputSpec);
|
|
if (inputsList.length !== inputSpec.length) {
|
|
throw new ValueError(`Layer ${this.name} expects ${inputSpec.length} inputs, ` +
|
|
`but it received ${inputsList.length} input tensors. ` +
|
|
`Input received: ${inputs}`);
|
|
}
|
|
for (let inputIndex = 0; inputIndex < inputsList.length; inputIndex++) {
|
|
const x = inputsList[inputIndex];
|
|
const spec = inputSpec[inputIndex];
|
|
if (spec == null) {
|
|
continue;
|
|
}
|
|
|
|
const ndim = x.rank;
|
|
if (spec.ndim != null) {
|
|
if (ndim !== spec.ndim) {
|
|
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}: ` +
|
|
`expected ndim=${spec.ndim}, found ndim=${ndim}`);
|
|
}
|
|
}
|
|
if (spec.maxNDim != null) {
|
|
if (ndim > spec.maxNDim) {
|
|
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
|
|
`: expected max_ndim=${spec.maxNDim}, found ndim=${ndim}`);
|
|
}
|
|
}
|
|
if (spec.minNDim != null) {
|
|
if (ndim < spec.minNDim) {
|
|
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name}` +
|
|
`: expected min_ndim=${spec.minNDim}, found ndim=${ndim}.`);
|
|
}
|
|
}
|
|
|
|
if (spec.dtype != null) {
|
|
if (x.dtype !== spec.dtype) {
|
|
throw new ValueError(`Input ${inputIndex} is incompatible with layer ${this.name} ` +
|
|
`: expected dtype=${spec.dtype}, found dtype=${x.dtype}.`);
|
|
}
|
|
}
|
|
|
|
if (spec.axes) {
|
|
const xShape = x.shape;
|
|
for (const key in spec.axes) {
|
|
const axis = Number(key);
|
|
const value = spec.axes[key];
|
|
|
|
|
|
|
|
const xShapeAtAxis = axis >= 0 ? xShape[axis] : xShape[xShape.length + axis];
|
|
if (value != null && [value, null].indexOf(xShapeAtAxis) === -1) {
|
|
throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
|
|
`${this.name}: expected axis ${axis} of input shape to ` +
|
|
`have value ${value} but got shape ${xShape}.`);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (spec.shape != null) {
|
|
for (let i = 0; i < spec.shape.length; ++i) {
|
|
const specDim = spec.shape[i];
|
|
const dim = x.shape[i];
|
|
if (specDim != null && dim != null) {
|
|
if (specDim !== dim) {
|
|
throw new ValueError(`Input ${inputIndex} is incompatible with layer ` +
|
|
`${this.name}: expected shape=${spec.shape}, ` +
|
|
`found shape=${x.shape}.`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
call(inputs, kwargs) {
|
|
return inputs;
|
|
}
|
|
invokeCallHook(inputs, kwargs) {
|
|
if (this._callHook != null) {
|
|
this._callHook(inputs, kwargs);
|
|
}
|
|
}
|
|
|
|
setCallHook(callHook) {
|
|
this._callHook = callHook;
|
|
}
|
|
|
|
clearCallHook() {
|
|
this._callHook = null;
|
|
}
|
|
|
|
|
|
apply(inputs, kwargs) {
|
|
kwargs = kwargs || {};
|
|
this.assertNotDisposed();
|
|
|
|
const inputsList = toList(inputs);
|
|
const allAreSymbolic = checkAllSymbolic(inputs);
|
|
const noneAreSymbolic = checkNoneSymbolic(inputs);
|
|
if (allAreSymbolic === noneAreSymbolic) {
|
|
throw new ValueError('Arguments to apply() must be all ' +
|
|
'SymbolicTensors or all Tensors');
|
|
}
|
|
|
|
return nameScope(this.name, () => {
|
|
|
|
if (!this.built) {
|
|
|
|
this.assertInputCompatibility(inputs);
|
|
|
|
const inputShapes = [];
|
|
for (const xElem of toList(inputs)) {
|
|
inputShapes.push(xElem.shape);
|
|
}
|
|
this.build(singletonOrArray(inputShapes));
|
|
this.built = true;
|
|
|
|
if (this.initialWeights) {
|
|
this.setWeights(this.initialWeights);
|
|
}
|
|
if (this._refCount === null && noneAreSymbolic) {
|
|
|
|
|
|
|
|
this._refCount = 1;
|
|
}
|
|
}
|
|
|
|
this.assertInputCompatibility(inputs);
|
|
|
|
|
|
|
|
if (noneAreSymbolic) {
|
|
let output = this.call(inputs, kwargs);
|
|
|
|
if (this.supportsMasking) {
|
|
|
|
this.setMaskMetadata(inputs, output);
|
|
}
|
|
|
|
|
|
const outputList = toList(output);
|
|
const outputListCopy = [];
|
|
|
|
|
|
for (let x of outputList) {
|
|
if (inputsList.indexOf(x) !== -1) {
|
|
x = x.clone();
|
|
}
|
|
outputListCopy.push(x);
|
|
}
|
|
output = singletonOrArray(outputListCopy);
|
|
if (this.activityRegularizer != null) {
|
|
throw new NotImplementedError('Layer invocation in the presence of activity ' +
|
|
'regularizer(s) is not supported yet.');
|
|
}
|
|
|
|
return output;
|
|
}
|
|
else {
|
|
const inputShape = collectInputShape(inputs);
|
|
const outputShape = this.computeOutputShape(inputShape);
|
|
let output;
|
|
const outputDType = guessOutputDType();
|
|
this.warnOnIncompatibleInputShape(Array.isArray(inputs) ? inputShape[0] :
|
|
inputShape);
|
|
if (outputShape != null && outputShape.length > 0 &&
|
|
Array.isArray(outputShape[0])) {
|
|
|
|
output = outputShape
|
|
.map((shape, index) => new SymbolicTensor(outputDType, shape, this, toList(inputs), kwargs, this.name, index));
|
|
}
|
|
else {
|
|
output = new SymbolicTensor(outputDType, outputShape, this, toList(inputs), kwargs, this.name);
|
|
}
|
|
|
|
this.addInboundNode(inputs, output, null, null, inputShape, outputShape, kwargs);
|
|
this._refCount++;
|
|
if (this.activityRegularizer != null) {
|
|
throw new NotImplementedError('Layer invocation in the presence of activity ' +
|
|
'regularizer(s) is not supported yet.');
|
|
}
|
|
return output;
|
|
}
|
|
});
|
|
}
|
|
|
|
warnOnIncompatibleInputShape(inputShape) {
|
|
if (this.batchInputShape == null) {
|
|
return;
|
|
}
|
|
else if (inputShape.length !== this.batchInputShape.length) {
|
|
console.warn(`The rank of the input tensor provided (shape: ` +
|
|
`${JSON.stringify(inputShape)}) does not match that of the ` +
|
|
`batchInputShape (${JSON.stringify(this.batchInputShape)}) ` +
|
|
`of the layer ${this.name}`);
|
|
}
|
|
else {
|
|
let dimMismatch = false;
|
|
this.batchInputShape.forEach((dimension, i) => {
|
|
if (dimension != null && inputShape[i] != null &&
|
|
inputShape[i] !== dimension) {
|
|
dimMismatch = true;
|
|
}
|
|
});
|
|
if (dimMismatch) {
|
|
console.warn(`The shape of the input tensor ` +
|
|
`(${JSON.stringify(inputShape)}) does not ` +
|
|
`match the expectation of layer ${this.name}: ` +
|
|
`${JSON.stringify(this.batchInputShape)}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
get outputShape() {
|
|
if (this.inboundNodes == null || this.inboundNodes.length === 0) {
|
|
throw new AttributeError(`The layer ${this.name} has never been called and thus has no ` +
|
|
`defined output shape.`);
|
|
}
|
|
const allOutputShapes = [];
|
|
for (const node of this.inboundNodes) {
|
|
const shapeString = JSON.stringify(node.outputShapes);
|
|
if (allOutputShapes.indexOf(shapeString) === -1) {
|
|
allOutputShapes.push(shapeString);
|
|
}
|
|
}
|
|
if (allOutputShapes.length === 1) {
|
|
const outputShapes = this.inboundNodes[0].outputShapes;
|
|
if (Array.isArray(outputShapes) && Array.isArray(outputShapes[0]) &&
|
|
outputShapes.length === 1) {
|
|
return outputShapes[0];
|
|
}
|
|
else {
|
|
return outputShapes;
|
|
}
|
|
}
|
|
else {
|
|
throw new AttributeError(`The layer ${this.name} has multiple inbound nodes with different ` +
|
|
`output shapes. Hence the notion of "output shape" is ill-defined ` +
|
|
`for the layer.`);
|
|
|
|
}
|
|
}
|
|
|
|
countParams() {
|
|
if (!this.built) {
|
|
throw new RuntimeError(`You tried to call countParams() on ${this.name}, ` +
|
|
`but the layer is not built yet. Build it first by calling ` +
|
|
`build(batchInputShape).`);
|
|
}
|
|
return countParamsInWeights(this.weights);
|
|
}
|
|
|
|
build(inputShape) {
|
|
this.built = true;
|
|
}
|
|
|
|
getWeights(trainableOnly = false) {
|
|
return batchGetValue(trainableOnly ? this.trainableWeights : this.weights);
|
|
}
|
|
|
|
setWeights(weights) {
|
|
tidy(() => {
|
|
const params = this.weights;
|
|
if (params.length !== weights.length) {
|
|
|
|
|
|
|
|
|
|
throw new ValueError(`You called setWeights(weights) on layer "${this.name}" ` +
|
|
`with a weight list of length ${weights.length}, ` +
|
|
`but the layer was expecting ${params.length} weights. ` +
|
|
`Provided weights: ${weights}...`);
|
|
}
|
|
if (params.length === 0) {
|
|
return;
|
|
}
|
|
const weightValueTuples = [];
|
|
const paramValues = batchGetValue(params);
|
|
for (let i = 0; i < paramValues.length; ++i) {
|
|
const pv = paramValues[i];
|
|
const p = params[i];
|
|
const w = weights[i];
|
|
if (!arraysEqual(pv.shape, w.shape)) {
|
|
throw new ValueError(`Layer weight shape ${pv.shape} ` +
|
|
`not compatible with provided weight shape ${w.shape}`);
|
|
}
|
|
weightValueTuples.push([p, w]);
|
|
}
|
|
batchSetValue(weightValueTuples);
|
|
});
|
|
}
|
|
|
|
addWeight(name, shape, dtype, initializer, regularizer, trainable, constraint, getInitializerFunc) {
|
|
|
|
if (this._addedWeightNames.indexOf(name) !== -1) {
|
|
throw new ValueError(`Duplicate weight name ${name} for layer ${this.name}`);
|
|
}
|
|
this._addedWeightNames.push(name);
|
|
if (dtype == null) {
|
|
dtype = 'float32';
|
|
}
|
|
if (this.fastWeightInitDuringBuild) {
|
|
initializer = getInitializerFunc != null ? getInitializerFunc() :
|
|
getInitializer('zeros');
|
|
}
|
|
const initValue = initializer.apply(shape, dtype);
|
|
const weight = new LayerVariable(initValue, dtype, name, trainable, constraint);
|
|
initValue.dispose();
|
|
|
|
if (regularizer != null) {
|
|
this.addLoss(() => regularizer.apply(weight.read()));
|
|
}
|
|
if (trainable == null) {
|
|
trainable = true;
|
|
}
|
|
if (trainable) {
|
|
this._trainableWeights.push(weight);
|
|
}
|
|
else {
|
|
this._nonTrainableWeights.push(weight);
|
|
}
|
|
return weight;
|
|
}
|
|
|
|
setFastWeightInitDuringBuild(value) {
|
|
this.fastWeightInitDuringBuild = value;
|
|
}
|
|
|
|
addLoss(losses) {
|
|
if (losses == null || Array.isArray(losses) && losses.length === 0) {
|
|
return;
|
|
}
|
|
|
|
losses = toList(losses);
|
|
if (this._losses !== undefined && this._losses !== null) {
|
|
this.losses.push(...losses);
|
|
}
|
|
}
|
|
|
|
computeOutputShape(inputShape) {
|
|
return inputShape;
|
|
}
|
|
|
|
computeMask(inputs, mask) {
|
|
if (!this.supportsMasking) {
|
|
if (mask != null) {
|
|
if (Array.isArray(mask)) {
|
|
mask.forEach(maskElement => {
|
|
if (maskElement != null) {
|
|
throw new TypeError(`Layer ${this.name} does not support masking, ` +
|
|
'but was passed an inputMask.');
|
|
}
|
|
});
|
|
}
|
|
else {
|
|
throw new TypeError(`Layer ${this.name} does not support masking, ` +
|
|
'but was passed an inputMask.');
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
|
|
return mask;
|
|
}
|
|
setMaskMetadata(inputs, outputs, previousMask) {
|
|
if (!this.supportsMasking) {
|
|
return;
|
|
}
|
|
const outputMasks = this.computeMask(inputs, previousMask);
|
|
const outputsList = toList(outputs);
|
|
const outputMasksList = toList(outputMasks);
|
|
if (outputsList.length !== outputMasksList.length) {
|
|
throw new Error(`${this.name} outputs ${outputsList.length} tensors ` +
|
|
`but ${outputsList.length} masks for those tensors`);
|
|
}
|
|
for (let i = 0; i < outputsList.length; i++) {
|
|
outputsList[i].kerasMask = outputMasksList[i];
|
|
}
|
|
}
|
|
|
|
addInboundNode(inputTensors, outputTensors, inputMasks, outputMasks, inputShapes, outputShapes, kwargs = null) {
|
|
const inputTensorList = toList(inputTensors);
|
|
outputTensors = toList(outputTensors);
|
|
inputMasks = toList(inputMasks);
|
|
outputMasks = toList(outputMasks);
|
|
inputShapes = normalizeShapeList(inputShapes);
|
|
outputShapes = normalizeShapeList(outputShapes);
|
|
|
|
const inboundLayers = [];
|
|
const nodeIndices = [];
|
|
const tensorIndices = [];
|
|
for (const x of inputTensorList) {
|
|
|
|
inboundLayers.push(x.sourceLayer);
|
|
nodeIndices.push(x.nodeIndex);
|
|
tensorIndices.push(x.tensorIndex);
|
|
}
|
|
|
|
|
|
|
|
new Node({
|
|
outboundLayer: this,
|
|
inboundLayers,
|
|
nodeIndices,
|
|
tensorIndices,
|
|
inputTensors: inputTensorList,
|
|
outputTensors,
|
|
inputMasks,
|
|
outputMasks,
|
|
inputShapes,
|
|
outputShapes
|
|
}, kwargs);
|
|
|
|
for (let i = 0; i < outputTensors.length; i++) {
|
|
|
|
outputTensors[i].sourceLayer = this;
|
|
outputTensors[i].nodeIndex = this.inboundNodes.length - 1;
|
|
outputTensors[i].tensorIndex = i;
|
|
}
|
|
}
|
|
|
|
getConfig() {
|
|
const config = { name: this.name, trainable: this.trainable };
|
|
if (this.batchInputShape != null) {
|
|
config['batchInputShape'] = this.batchInputShape;
|
|
}
|
|
if (this.dtype != null) {
|
|
config['dtype'] = this.dtype;
|
|
}
|
|
return config;
|
|
}
|
|
|
|
disposeWeights() {
|
|
this.weights.forEach(weight => weight.dispose());
|
|
return this.weights.length;
|
|
}
|
|
assertNotDisposed() {
|
|
if (this._refCount === 0) {
|
|
throw new Error(`Layer '${this.name}' is already disposed.`);
|
|
}
|
|
}
|
|
|
|
dispose() {
|
|
if (!this.built) {
|
|
throw new Error(`Cannot dispose Layer ${this.name} because it has not been ` +
|
|
`built yet.`);
|
|
}
|
|
if (this._refCount === null) {
|
|
throw new Error(`Cannot dispose Layer ${this.name} because it has not been used ` +
|
|
`yet.`);
|
|
}
|
|
this.assertNotDisposed();
|
|
let numDisposedVariables = 0;
|
|
if (--this._refCount === 0) {
|
|
numDisposedVariables = this.disposeWeights();
|
|
}
|
|
return { refCountAfterDispose: this._refCount, numDisposedVariables };
|
|
}
|
|
}
|
|
|
|
function collectInputShape(inputTensors) {
|
|
inputTensors =
|
|
toList(inputTensors);
|
|
const shapes = [];
|
|
for (const x of inputTensors) {
|
|
shapes.push(x.shape);
|
|
}
|
|
return singletonOrArray(shapes);
|
|
}
|
|
|
|
function guessOutputDType(inputTensors) {
|
|
return 'float32';
|
|
}
|
|
|
|
function getSourceInputs(tensor, layer, nodeIndex) {
|
|
if (layer == null || (nodeIndex != null && nodeIndex > 0)) {
|
|
layer = tensor.sourceLayer;
|
|
nodeIndex = tensor.nodeIndex;
|
|
}
|
|
if (layer.inboundNodes.length === 0) {
|
|
return [tensor];
|
|
}
|
|
else {
|
|
const node = layer.inboundNodes[nodeIndex];
|
|
if (node.inboundLayers.length === 0) {
|
|
return node.inputTensors;
|
|
}
|
|
else {
|
|
const sourceTensors = [];
|
|
for (let i = 0; i < node.inboundLayers.length; i++) {
|
|
const x = node.inputTensors[i];
|
|
const layer = node.inboundLayers[i];
|
|
const nodeIndex = node.nodeIndices[i];
|
|
const previousSources = getSourceInputs(x, layer, nodeIndex);
|
|
|
|
for (const x of previousSources) {
|
|
if (sourceTensors.indexOf(x) === -1) {
|
|
sourceTensors.push(x);
|
|
}
|
|
}
|
|
}
|
|
return sourceTensors;
|
|
}
|
|
}
|
|
}
|
|
function checkAllSymbolic(tensors) {
|
|
let allAreSymbolic = true;
|
|
for (const tensor of toList(tensors)) {
|
|
if (!(tensor instanceof SymbolicTensor)) {
|
|
allAreSymbolic = false;
|
|
break;
|
|
}
|
|
}
|
|
return allAreSymbolic;
|
|
}
|
|
function checkNoneSymbolic(tensors) {
|
|
let noneAreSymbolic = true;
|
|
for (const tensor of toList(tensors)) {
|
|
if (tensor instanceof SymbolicTensor) {
|
|
noneAreSymbolic = false;
|
|
break;
|
|
}
|
|
}
|
|
return noneAreSymbolic;
|
|
}
|
|
|
|
|
|
class InputLayer extends Layer {
|
|
constructor(args) {
|
|
super({
|
|
dtype: args.dtype,
|
|
name: args.name != null ? args.name : getUid('input').toString()
|
|
});
|
|
|
|
if (args.batchSize == null) {
|
|
args.batchSize = null;
|
|
}
|
|
if (args.sparse == null) {
|
|
args.sparse = false;
|
|
}
|
|
this.trainable = false;
|
|
this.built = true;
|
|
this.sparse = args.sparse;
|
|
if (args.inputShape != null && args.batchInputShape != null) {
|
|
throw new ValueError('Only provide the inputShape OR ' +
|
|
'batchInputShape argument to inputLayer, not both at the same time.');
|
|
}
|
|
let batchInputShape = args.batchInputShape;
|
|
if (batchInputShape == null) {
|
|
if (args.inputShape == null) {
|
|
throw new ValueError('An InputLayer should be passed either a ' +
|
|
'`batchInputShape` or an `inputShape`.');
|
|
}
|
|
else {
|
|
batchInputShape = [args.batchSize].concat(args.inputShape);
|
|
}
|
|
}
|
|
else {
|
|
|
|
if (args.batchSize != null) {
|
|
throw new ValueError('Cannot specify batchSize if batchInputShape is ' +
|
|
'specified when creating an InputLayer.');
|
|
}
|
|
}
|
|
const dtype = args.dtype || 'float32';
|
|
this.batchInputShape = batchInputShape;
|
|
this.dtype = dtype;
|
|
|
|
this.inputSpec = [{ shape: batchInputShape }];
|
|
const inputTensor = new SymbolicTensor(this.dtype, this.batchInputShape, this, [], {}, this.name);
|
|
inputTensor.nodeIndex = 0;
|
|
inputTensor.tensorIndex = 0;
|
|
|
|
|
|
|
|
new Node({
|
|
outboundLayer: this,
|
|
inboundLayers: [],
|
|
nodeIndices: [],
|
|
tensorIndices: [],
|
|
inputTensors: [inputTensor],
|
|
outputTensors: [inputTensor],
|
|
inputMasks: [null],
|
|
outputMasks: [null],
|
|
inputShapes: [batchInputShape],
|
|
outputShapes: [batchInputShape]
|
|
});
|
|
}
|
|
apply(inputs, kwargs) {
|
|
throw new ValueError('Cannot pass any input to an ' +
|
|
`InputLayer's apply() method. InputLayer name: ${this.name}`);
|
|
}
|
|
dispose() {
|
|
|
|
return { refCountAfterDispose: this._refCount, numDisposedVariables: 0 };
|
|
}
|
|
getConfig() {
|
|
return {
|
|
batchInputShape: this.batchInputShape,
|
|
dtype: this.dtype,
|
|
sparse: this.sparse,
|
|
name: this.name
|
|
};
|
|
}
|
|
}
|
|
|
|
InputLayer.className = 'InputLayer';
|
|
registerClass(InputLayer);
|
|
function Input(config) {
|
|
if (config.batchShape == null && config.shape == null) {
|
|
throw new Error('Please provide to Input either a `shape`' +
|
|
' or a `batchShape` argument. Note that ' +
|
|
'`shape` does not include the batch ' +
|
|
'dimension.');
|
|
}
|
|
if (config.batchShape != null && config.shape != null) {
|
|
|
|
throw new ValueError('Please provide either a `shape` or `batchShape` ' +
|
|
'argument to Input, but not both.');
|
|
}
|
|
let batchShape = config.batchShape;
|
|
if (config.shape != null && batchShape == null) {
|
|
batchShape = [null].concat(config.shape);
|
|
}
|
|
let dtype = config.dtype;
|
|
if (dtype == null) {
|
|
dtype = 'float32';
|
|
}
|
|
const inputLayer = new InputLayer({
|
|
batchInputShape: batchShape,
|
|
name: config.name,
|
|
dtype,
|
|
sparse: config.sparse
|
|
});
|
|
const outputs = inputLayer.inboundNodes[0].outputTensors;
|
|
return outputs[0];
|
|
}
|
|
|
|
|
|
|
|
|
|
function assertFeedCompatibility(key, val) {
|
|
|
|
if (key.dtype == null || key.dtype === val.dtype) {
|
|
|
|
return val;
|
|
}
|
|
try {
|
|
|
|
return cast$3(val, key.dtype);
|
|
}
|
|
catch (err) {
|
|
|
|
throw new ValueError(`The dtype of the feed (${val.dtype}) can not be cast to the dtype ` +
|
|
`of the key '${key.name}' (${key.dtype}).`);
|
|
}
|
|
}
|
|
|
|
class FeedDict {
|
|
|
|
constructor(feeds) {
|
|
this.id2Value = {};
|
|
this.id2Mask = {};
|
|
this.name2Id = {};
|
|
if (feeds instanceof FeedDict) {
|
|
for (const id in feeds.id2Value) {
|
|
this.id2Value[id] = feeds.id2Value[id];
|
|
if (id in feeds.id2Mask) {
|
|
this.id2Mask[id] = feeds.id2Mask[id];
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
if (feeds == null) {
|
|
return;
|
|
}
|
|
for (const feed of feeds) {
|
|
this.add(feed.key, feed.value);
|
|
}
|
|
}
|
|
}
|
|
|
|
add(key, value, mask) {
|
|
if (this.id2Value[key.id] == null) {
|
|
this.id2Value[key.id] = assertFeedCompatibility(key, value);
|
|
this.name2Id[key.name] = key.id;
|
|
if (mask != null) {
|
|
this.id2Mask[key.id] = mask;
|
|
}
|
|
}
|
|
else {
|
|
throw new ValueError(`Duplicate key: name=${key.name}, id=${key.id}`);
|
|
}
|
|
return this;
|
|
}
|
|
|
|
addFeed(feed) {
|
|
this.add(feed.key, feed.value);
|
|
}
|
|
|
|
hasKey(key) {
|
|
return this.id2Value[key.id] != null;
|
|
}
|
|
|
|
names() {
|
|
return Object.keys(this.name2Id);
|
|
}
|
|
|
|
getValue(key) {
|
|
if (key instanceof SymbolicTensor) {
|
|
if (this.id2Value[key.id] == null) {
|
|
throw new ValueError(`Nonexistent key: ${key.name}`);
|
|
}
|
|
else {
|
|
return this.id2Value[key.id];
|
|
}
|
|
}
|
|
else {
|
|
const id = this.name2Id[key];
|
|
if (id == null) {
|
|
throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
|
|
}
|
|
return this.id2Value[id];
|
|
}
|
|
}
|
|
|
|
getMask(key) {
|
|
if (key instanceof SymbolicTensor) {
|
|
if (this.id2Value[key.id] == null) {
|
|
throw new ValueError(`Nonexistent key: ${key.name}`);
|
|
}
|
|
else {
|
|
return this.id2Mask[key.id];
|
|
}
|
|
}
|
|
else {
|
|
const id = this.name2Id[key];
|
|
if (id == null) {
|
|
throw new ValueError(`Feed dict has no SymbolicTensor name: ${key}`);
|
|
}
|
|
return this.id2Mask[id];
|
|
}
|
|
}
|
|
|
|
disposeMasks() {
|
|
if (this.id2Mask != null) {
|
|
dispose(this.id2Mask);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
const cachedSorted = new LruCache();
|
|
|
|
const cachedRecipientCounts = new LruCache();
|
|
|
|
function execute(fetches, feedDict, kwargs, probe) {
|
|
const training = kwargs == null ? false : kwargs['training'];
|
|
const arrayFetches = Array.isArray(fetches);
|
|
const fetchArray = arrayFetches ? fetches : [fetches];
|
|
const outputNames = fetchArray.map(t => t.name);
|
|
const finalOutputs = [];
|
|
const feedNames = feedDict.names();
|
|
for (const outputName of outputNames) {
|
|
if (feedNames.indexOf(outputName) !== -1) {
|
|
finalOutputs.push(feedDict.getValue(outputName));
|
|
}
|
|
else {
|
|
finalOutputs.push(null);
|
|
}
|
|
}
|
|
|
|
const fetchAndFeedKey = outputNames.join(',') + '|' + feedDict.names().sort().join(',');
|
|
let sorted = cachedSorted.get(fetchAndFeedKey);
|
|
let recipientCounts;
|
|
if (sorted == null) {
|
|
|
|
|
|
const out = getTopologicalSortAndRecipientCounts(fetchArray, feedDict);
|
|
sorted = out.sorted;
|
|
recipientCounts = out.recipientCounts;
|
|
|
|
cachedSorted.put(fetchAndFeedKey, sorted);
|
|
cachedRecipientCounts.put(fetchAndFeedKey, recipientCounts);
|
|
}
|
|
recipientCounts = {};
|
|
if (!training) {
|
|
Object.assign(recipientCounts, cachedRecipientCounts.get(fetchAndFeedKey));
|
|
}
|
|
const internalFeedDict = new FeedDict(feedDict);
|
|
|
|
for (let i = 0; i < sorted.length; ++i) {
|
|
const symbolic = sorted[i];
|
|
const srcLayer = symbolic.sourceLayer;
|
|
if (srcLayer instanceof InputLayer) {
|
|
continue;
|
|
}
|
|
const inputValues = [];
|
|
const inputMasks = [];
|
|
const tensorsToDispose = [];
|
|
let maskExists = false;
|
|
for (const input of symbolic.inputs) {
|
|
const value = internalFeedDict.getValue(input);
|
|
const mask = internalFeedDict.getMask(input);
|
|
inputValues.push(value);
|
|
inputMasks.push(mask);
|
|
if (mask != null) {
|
|
maskExists = true;
|
|
}
|
|
if (!training) {
|
|
recipientCounts[input.name]--;
|
|
if (recipientCounts[input.name] === 0 && !feedDict.hasKey(input) &&
|
|
outputNames.indexOf(input.name) === -1 && !value.isDisposed &&
|
|
input.sourceLayer.stateful !== true) {
|
|
tensorsToDispose.push(value);
|
|
}
|
|
}
|
|
}
|
|
if (maskExists) {
|
|
kwargs = kwargs || {};
|
|
kwargs['mask'] = inputMasks[0];
|
|
}
|
|
const outputTensors = toList(srcLayer.apply(inputValues, kwargs));
|
|
let outputMask = null;
|
|
if (srcLayer.supportsMasking) {
|
|
outputMask = srcLayer.computeMask(inputValues, inputMasks);
|
|
}
|
|
const layerOutputs = getNodeOutputs(symbolic);
|
|
const outputSymbolicTensors = Array.isArray(layerOutputs) ? layerOutputs : [layerOutputs];
|
|
for (let i = 0; i < outputSymbolicTensors.length; ++i) {
|
|
if (!internalFeedDict.hasKey(outputSymbolicTensors[i])) {
|
|
internalFeedDict.add(outputSymbolicTensors[i], outputTensors[i], Array.isArray(outputMask) ? outputMask[0] : outputMask);
|
|
}
|
|
const index = outputNames.indexOf(outputSymbolicTensors[i].name);
|
|
if (index !== -1) {
|
|
finalOutputs[index] = outputTensors[i];
|
|
}
|
|
}
|
|
if (!training) {
|
|
|
|
dispose(tensorsToDispose);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
internalFeedDict.disposeMasks();
|
|
return arrayFetches ? finalOutputs : finalOutputs[0];
|
|
}
|
|
|
|
function getTopologicalSortAndRecipientCounts(fetches, feedDict) {
|
|
assert$1(fetches != null && fetches.length > 0, () => `Expected at least one fetch, got none`);
|
|
let finalSorted = [];
|
|
let finalRecipientMap = {};
|
|
if (fetches.length === 1) {
|
|
|
|
const out = getTopologicalSortAndRecipientCountsForOneFetch(fetches[0], feedDict);
|
|
finalSorted = out.sorted;
|
|
finalRecipientMap = out.recipientMap;
|
|
}
|
|
else {
|
|
const visited = new Set();
|
|
for (const fetch of fetches) {
|
|
const { sorted, recipientMap } = getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict);
|
|
|
|
for (const symbolicTensor of sorted) {
|
|
if (!visited.has(symbolicTensor.name)) {
|
|
finalSorted.push(symbolicTensor);
|
|
visited.add(symbolicTensor.name);
|
|
}
|
|
}
|
|
|
|
for (const name in recipientMap) {
|
|
if (finalRecipientMap[name] == null) {
|
|
finalRecipientMap[name] = new Set();
|
|
}
|
|
recipientMap[name].forEach(recipient => finalRecipientMap[name].add(recipient));
|
|
}
|
|
}
|
|
}
|
|
return {
|
|
sorted: finalSorted,
|
|
recipientCounts: recipientMap2Counts(finalRecipientMap)
|
|
};
|
|
}
|
|
function recipientMap2Counts(recipientMap) {
|
|
const recipientCounts = {};
|
|
for (const name in recipientMap) {
|
|
recipientCounts[name] = recipientMap[name].size;
|
|
}
|
|
return recipientCounts;
|
|
}
|
|
|
|
function getTopologicalSortAndRecipientCountsForOneFetch(fetch, feedDict) {
|
|
const visited = new Set();
|
|
const sorted = [];
|
|
const recipientMap = {};
|
|
|
|
|
|
|
|
for (const key of feedDict.names()) {
|
|
visited.add(key);
|
|
}
|
|
const stack = [];
|
|
const marks = [];
|
|
|
|
stack.push(fetch);
|
|
while (stack.length > 0) {
|
|
const top = stack[stack.length - 1];
|
|
if (visited.has(top.name)) {
|
|
stack.pop();
|
|
continue;
|
|
}
|
|
const topIsMarked = marks[marks.length - 1] === stack.length - 1;
|
|
if (top.inputs.length === 0 || topIsMarked) {
|
|
|
|
stack.pop();
|
|
sorted.push(top);
|
|
visited.add(top.name);
|
|
if (topIsMarked) {
|
|
marks.pop();
|
|
}
|
|
}
|
|
else {
|
|
|
|
|
|
marks.push(stack.length - 1);
|
|
for (const input of top.inputs) {
|
|
|
|
|
|
if (recipientMap[input.name] == null) {
|
|
recipientMap[input.name] = new Set();
|
|
}
|
|
recipientMap[input.name].add(top.name);
|
|
if (visited.has(input.name)) {
|
|
continue;
|
|
}
|
|
stack.push(input);
|
|
}
|
|
}
|
|
}
|
|
return { sorted, recipientMap };
|
|
}
|
|
|
|
function getNodeOutputs(fetch) {
|
|
let layerOutputs;
|
|
if (fetch.sourceLayer.inboundNodes.length === 1) {
|
|
layerOutputs = fetch.sourceLayer.output;
|
|
}
|
|
else {
|
|
let nodeIndex = null;
|
|
for (let i = 0; i < fetch.sourceLayer.inboundNodes.length; ++i) {
|
|
for (const outputTensor of fetch.sourceLayer.inboundNodes[i]
|
|
.outputTensors) {
|
|
if (outputTensor.id === fetch.id) {
|
|
nodeIndex = i;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
layerOutputs = fetch.sourceLayer.getOutputAt(nodeIndex);
|
|
}
|
|
return layerOutputs;
|
|
}
|
|
|
|
|
|
|
|
|
|
function calcL2Norms(w, axis) {
|
|
return tidy(() => sqrt$2(sum$2(mul(w, w), axis, true)));
|
|
}
|
|
|
|
class Constraint extends Serializable {
|
|
getConfig() {
|
|
return {};
|
|
}
|
|
}
|
|
class MaxNorm extends Constraint {
|
|
constructor(args) {
|
|
super();
|
|
this.defaultMaxValue = 2;
|
|
this.defaultAxis = 0;
|
|
this.maxValue =
|
|
args.maxValue != null ? args.maxValue : this.defaultMaxValue;
|
|
this.axis = args.axis != null ? args.axis : this.defaultAxis;
|
|
}
|
|
apply(w) {
|
|
return tidy(() => {
|
|
const norms = calcL2Norms(w, this.axis);
|
|
const desired = clipByValue$2(norms, 0, this.maxValue);
|
|
return mul(w, div$1(desired, add$1(epsilon(), norms)));
|
|
});
|
|
}
|
|
getConfig() {
|
|
return { maxValue: this.maxValue, axis: this.axis };
|
|
}
|
|
}
|
|
|
|
MaxNorm.className = 'MaxNorm';
|
|
registerClass(MaxNorm);
|
|
class UnitNorm extends Constraint {
|
|
constructor(args) {
|
|
super();
|
|
this.defaultAxis = 0;
|
|
this.axis = args.axis != null ? args.axis : this.defaultAxis;
|
|
}
|
|
apply(w) {
|
|
return tidy(() => div$1(w, add$1(epsilon(), calcL2Norms(w, this.axis))));
|
|
}
|
|
getConfig() {
|
|
return { axis: this.axis };
|
|
}
|
|
}
|
|
|
|
UnitNorm.className = 'UnitNorm';
|
|
registerClass(UnitNorm);
|
|
class NonNeg extends Constraint {
|
|
apply(w) {
|
|
return relu$2(w);
|
|
}
|
|
}
|
|
|
|
NonNeg.className = 'NonNeg';
|
|
registerClass(NonNeg);
|
|
class MinMaxNorm extends Constraint {
|
|
constructor(args) {
|
|
super();
|
|
this.defaultMinValue = 0.0;
|
|
this.defaultMaxValue = 1.0;
|
|
this.defaultRate = 1.0;
|
|
this.defaultAxis = 0;
|
|
this.minValue =
|
|
args.minValue != null ? args.minValue : this.defaultMinValue;
|
|
this.maxValue =
|
|
args.maxValue != null ? args.maxValue : this.defaultMaxValue;
|
|
this.rate = args.rate != null ? args.rate : this.defaultRate;
|
|
this.axis = args.axis != null ? args.axis : this.defaultAxis;
|
|
}
|
|
apply(w) {
|
|
return tidy(() => {
|
|
const norms = calcL2Norms(w, this.axis);
|
|
const desired = add$1(mul(this.rate, clipByValue$2(norms, this.minValue, this.maxValue)), mul(1.0 - this.rate, norms));
|
|
return mul(w, div$1(desired, add$1(epsilon(), norms)));
|
|
});
|
|
}
|
|
getConfig() {
|
|
return {
|
|
minValue: this.minValue,
|
|
maxValue: this.maxValue,
|
|
rate: this.rate,
|
|
axis: this.axis
|
|
};
|
|
}
|
|
}
|
|
|
|
MinMaxNorm.className = 'MinMaxNorm';
|
|
registerClass(MinMaxNorm);
|
|
|
|
|
|
const CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
|
|
'maxNorm': 'MaxNorm',
|
|
'minMaxNorm': 'MinMaxNorm',
|
|
'nonNeg': 'NonNeg',
|
|
'unitNorm': 'UnitNorm'
|
|
};
|
|
function serializeConstraint(constraint) {
|
|
return serializeKerasObject(constraint);
|
|
}
|
|
function deserializeConstraint(config, customObjects = {}) {
|
|
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'constraint');
|
|
}
|
|
function getConstraint(identifier) {
|
|
if (identifier == null) {
|
|
return null;
|
|
}
|
|
if (typeof identifier === 'string') {
|
|
const className = identifier in CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
|
|
CONSTRAINT_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
|
|
identifier;
|
|
const config = { className, config: {} };
|
|
return deserializeConstraint(config);
|
|
}
|
|
else if (identifier instanceof Constraint) {
|
|
return identifier;
|
|
}
|
|
else {
|
|
return deserializeConstraint(identifier);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
function glorotUniform(args) {
|
|
return new GlorotUniform(args);
|
|
}
|
|
|
|
|
|
|
|
async function resolveScalarsInLogs(logs) {
|
|
if (logs == null) {
|
|
return;
|
|
}
|
|
const promises = [];
|
|
const keys = [];
|
|
const scalarsToDispose = [];
|
|
for (const key in logs) {
|
|
const value = logs[key];
|
|
if (typeof value !== 'number') {
|
|
const valueScalar = value;
|
|
promises.push(valueScalar.data());
|
|
keys.push(key);
|
|
scalarsToDispose.push(valueScalar);
|
|
}
|
|
}
|
|
if (promises.length > 0) {
|
|
const values = await Promise.all(promises);
|
|
for (let i = 0; i < values.length; ++i) {
|
|
logs[keys[i]] = values[i][0];
|
|
}
|
|
|
|
dispose(scalarsToDispose);
|
|
}
|
|
}
|
|
|
|
function disposeTensorsInLogs(logs) {
|
|
if (logs == null) {
|
|
return;
|
|
}
|
|
for (const key in logs) {
|
|
const value = logs[key];
|
|
if (typeof value !== 'number') {
|
|
value.dispose();
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
var ModelLoggingVerbosity;
|
|
(function (ModelLoggingVerbosity) {
|
|
ModelLoggingVerbosity[ModelLoggingVerbosity["SILENT"] = 0] = "SILENT";
|
|
ModelLoggingVerbosity[ModelLoggingVerbosity["VERBOSE"] = 1] = "VERBOSE";
|
|
})(ModelLoggingVerbosity || (ModelLoggingVerbosity = {}));
|
|
|
|
const DEFAULT_YIELD_EVERY_MS = 125;
|
|
|
|
class BaseCallback {
|
|
constructor() {
|
|
|
|
this.validationData = null;
|
|
}
|
|
setParams(params) {
|
|
this.params = params;
|
|
}
|
|
async onEpochBegin(epoch, logs) { }
|
|
async onEpochEnd(epoch, logs) { }
|
|
async onBatchBegin(batch, logs) { }
|
|
async onBatchEnd(batch, logs) { }
|
|
async onTrainBegin(logs) { }
|
|
async onTrainEnd(logs) { }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setModel(model) {
|
|
|
|
}
|
|
}
|
|
|
|
class CallbackList {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constructor(callbacks, queueLength = 10) {
|
|
|
|
|
|
if (callbacks == null) {
|
|
callbacks = [];
|
|
}
|
|
this.callbacks = callbacks;
|
|
this.queueLength = queueLength;
|
|
}
|
|
append(callback) {
|
|
this.callbacks.push(callback);
|
|
}
|
|
setParams(params) {
|
|
for (const callback of this.callbacks) {
|
|
callback.setParams(params);
|
|
}
|
|
}
|
|
setModel(model) {
|
|
for (const callback of this.callbacks) {
|
|
callback.setModel(model);
|
|
}
|
|
}
|
|
|
|
async onEpochBegin(epoch, logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
for (const callback of this.callbacks) {
|
|
await callback.onEpochBegin(epoch, logs);
|
|
}
|
|
}
|
|
|
|
async onEpochEnd(epoch, logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
for (const callback of this.callbacks) {
|
|
await callback.onEpochEnd(epoch, logs);
|
|
}
|
|
}
|
|
|
|
async onBatchBegin(batch, logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
for (const callback of this.callbacks) {
|
|
await callback.onBatchBegin(batch, logs);
|
|
}
|
|
}
|
|
|
|
async onBatchEnd(batch, logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
for (const callback of this.callbacks) {
|
|
await callback.onBatchEnd(batch, logs);
|
|
}
|
|
}
|
|
|
|
async onTrainBegin(logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
for (const callback of this.callbacks) {
|
|
await callback.onTrainBegin(logs);
|
|
}
|
|
}
|
|
|
|
async onTrainEnd(logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
for (const callback of this.callbacks) {
|
|
await callback.onTrainEnd(logs);
|
|
}
|
|
}
|
|
}
|
|
|
|
class BaseLogger extends BaseCallback {
|
|
constructor() {
|
|
super();
|
|
}
|
|
async onEpochBegin(epoch) {
|
|
this.seen = 0;
|
|
this.totals = {};
|
|
}
|
|
async onBatchEnd(batch, logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
const batchSize = logs['size'] == null ? 0 : logs['size'];
|
|
this.seen += batchSize;
|
|
for (const key in logs) {
|
|
const value = logs[key];
|
|
if (typeof value === 'number') {
|
|
if (!this.totals.hasOwnProperty(key)) {
|
|
this.totals[key] = 0;
|
|
}
|
|
this.totals[key] = this.totals[key] + value * batchSize;
|
|
}
|
|
else {
|
|
let oldTotalsToDispose;
|
|
if (key in this.totals) {
|
|
oldTotalsToDispose = this.totals[key];
|
|
}
|
|
else {
|
|
this.totals[key] = 0;
|
|
}
|
|
const total = tidy(() => add$1((this.totals[key]), mul(value, batchSize)));
|
|
this.totals[key] = total;
|
|
if (oldTotalsToDispose != null) {
|
|
oldTotalsToDispose.dispose();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
async onEpochEnd(epoch, logs) {
|
|
if (logs != null) {
|
|
for (const key of this.params['metrics']) {
|
|
if (this.totals[key] == null) {
|
|
continue;
|
|
}
|
|
if (typeof this.totals[key] === 'number') {
|
|
logs[key] = this.totals[key] / this.seen;
|
|
}
|
|
else {
|
|
tidy(() => {
|
|
const log = mul(div$1(1, this.seen), this.totals[key]);
|
|
logs[key] = log;
|
|
this.totals[key].dispose();
|
|
keep(logs[key]);
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
class History extends BaseCallback {
|
|
async onTrainBegin(logs) {
|
|
this.epoch = [];
|
|
this.history = {};
|
|
}
|
|
async onEpochEnd(epoch, logs) {
|
|
if (logs == null) {
|
|
logs = {};
|
|
}
|
|
this.epoch.push(epoch);
|
|
for (const key in logs) {
|
|
if (this.history[key] == null) {
|
|
this.history[key] = [];
|
|
}
|
|
this.history[key].push(logs[key]);
|
|
}
|
|
}
|
|
|
|
async syncData() {
|
|
const promises = [];
|
|
const keys = [];
|
|
const indices = [];
|
|
for (const key in this.history) {
|
|
const valueArray = this.history[key];
|
|
for (let i = 0; i < valueArray.length; ++i) {
|
|
if (typeof valueArray[i] !== 'number') {
|
|
const valueScalar = valueArray[i];
|
|
promises.push(valueScalar.data());
|
|
keys.push(key);
|
|
indices.push(i);
|
|
}
|
|
}
|
|
}
|
|
const values = await Promise.all(promises);
|
|
for (let n = 0; n < values.length; ++n) {
|
|
const tensorToDispose = this.history[keys[n]][indices[n]];
|
|
tensorToDispose.dispose();
|
|
this.history[keys[n]][indices[n]] = values[n][0];
|
|
}
|
|
}
|
|
}
|
|
|
|
class CustomCallback extends BaseCallback {
|
|
constructor(args, yieldEvery) {
|
|
super();
|
|
this.currentEpoch = 0;
|
|
this.nowFunc = args.nowFunc;
|
|
this.nextFrameFunc = args.nextFrameFunc || nextFrame;
|
|
this.yieldEvery = yieldEvery || 'auto';
|
|
if (this.yieldEvery === 'auto') {
|
|
this.yieldEvery = DEFAULT_YIELD_EVERY_MS;
|
|
}
|
|
if (this.yieldEvery === 'never' && args.onYield != null) {
|
|
throw new Error('yieldEvery is `never` but you provided an `onYield` callback. ' +
|
|
'Either change `yieldEvery` or remove the callback');
|
|
}
|
|
if (isNumber(this.yieldEvery)) {
|
|
|
|
|
|
this.maybeWait = debounce(this.maybeWait.bind(this), this.yieldEvery, this.nowFunc);
|
|
}
|
|
this.trainBegin = args.onTrainBegin;
|
|
this.trainEnd = args.onTrainEnd;
|
|
this.epochBegin = args.onEpochBegin;
|
|
this.epochEnd = args.onEpochEnd;
|
|
this.batchBegin = args.onBatchBegin;
|
|
this.batchEnd = args.onBatchEnd;
|
|
this.yield = args.onYield;
|
|
}
|
|
async maybeWait(epoch, batch, logs) {
|
|
const ps = [];
|
|
if (this.yield != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
ps.push(this.yield(epoch, batch, logs));
|
|
}
|
|
ps.push(this.nextFrameFunc());
|
|
await Promise.all(ps);
|
|
}
|
|
async onEpochBegin(epoch, logs) {
|
|
this.currentEpoch = epoch;
|
|
if (this.epochBegin != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
await this.epochBegin(epoch, logs);
|
|
}
|
|
}
|
|
async onEpochEnd(epoch, logs) {
|
|
const ps = [];
|
|
if (this.epochEnd != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
ps.push(this.epochEnd(epoch, logs));
|
|
}
|
|
if (this.yieldEvery === 'epoch') {
|
|
ps.push(this.nextFrameFunc());
|
|
}
|
|
await Promise.all(ps);
|
|
}
|
|
async onBatchBegin(batch, logs) {
|
|
if (this.batchBegin != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
await this.batchBegin(batch, logs);
|
|
}
|
|
}
|
|
async onBatchEnd(batch, logs) {
|
|
const ps = [];
|
|
if (this.batchEnd != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
ps.push(this.batchEnd(batch, logs));
|
|
}
|
|
if (this.yieldEvery === 'batch') {
|
|
ps.push(this.nextFrameFunc());
|
|
}
|
|
else if (isNumber(this.yieldEvery)) {
|
|
ps.push(this.maybeWait(this.currentEpoch, batch, logs));
|
|
}
|
|
await Promise.all(ps);
|
|
}
|
|
async onTrainBegin(logs) {
|
|
if (this.trainBegin != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
await this.trainBegin(logs);
|
|
}
|
|
}
|
|
async onTrainEnd(logs) {
|
|
if (this.trainEnd != null) {
|
|
await resolveScalarsInLogs(logs);
|
|
await this.trainEnd(logs);
|
|
}
|
|
}
|
|
}
|
|
|
|
function standardizeCallbacks(callbacks, yieldEvery) {
|
|
if (callbacks == null) {
|
|
callbacks = {};
|
|
}
|
|
if (callbacks instanceof BaseCallback) {
|
|
return [callbacks];
|
|
}
|
|
if (Array.isArray(callbacks) && callbacks[0] instanceof BaseCallback) {
|
|
return callbacks;
|
|
}
|
|
|
|
const callbackConfigs = toList(callbacks);
|
|
return callbackConfigs.map(callbackConfig => new CustomCallback(callbackConfig, yieldEvery));
|
|
}
|
|
|
|
class CallbackConstructorRegistry {
|
|
|
|
constructor() { }
|
|
|
|
static registerCallbackConstructor(verbosityLevel, callbackConstructor) {
|
|
assert$1(verbosityLevel >= 0 && Number.isInteger(verbosityLevel), () => `Verbosity level is expected to be an integer >= 0, ` +
|
|
`but got ${verbosityLevel}`);
|
|
CallbackConstructorRegistry.checkForDuplicate(callbackConstructor);
|
|
if (CallbackConstructorRegistry.constructors[verbosityLevel] == null) {
|
|
CallbackConstructorRegistry.constructors[verbosityLevel] = [];
|
|
}
|
|
CallbackConstructorRegistry.constructors[verbosityLevel].push(callbackConstructor);
|
|
}
|
|
static checkForDuplicate(callbackConstructor) {
|
|
for (const levelName in CallbackConstructorRegistry.constructors) {
|
|
const constructors = CallbackConstructorRegistry.constructors[+levelName];
|
|
constructors.forEach(ctor => {
|
|
if (ctor === callbackConstructor) {
|
|
throw new ValueError('Duplicate callback constructor.');
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
static clear() {
|
|
CallbackConstructorRegistry.constructors = {};
|
|
}
|
|
|
|
static createCallbacks(verbosityLevel) {
|
|
const constructors = [];
|
|
for (const levelName in CallbackConstructorRegistry.constructors) {
|
|
const level = +levelName;
|
|
if (verbosityLevel >= level) {
|
|
constructors.push(...CallbackConstructorRegistry.constructors[level]);
|
|
}
|
|
}
|
|
return constructors.map(ctor => new ctor());
|
|
}
|
|
}
|
|
CallbackConstructorRegistry.constructors = {};
|
|
function configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics) {
|
|
const history = new History();
|
|
const actualCallbacks = [
|
|
new BaseLogger(), ...CallbackConstructorRegistry.createCallbacks(verbose)
|
|
];
|
|
if (callbacks != null) {
|
|
actualCallbacks.push(...callbacks);
|
|
}
|
|
actualCallbacks.push(history);
|
|
const callbackList = new CallbackList(actualCallbacks);
|
|
|
|
|
|
|
|
callbackList.setParams({
|
|
epochs,
|
|
initialEpoch,
|
|
samples: numTrainSamples,
|
|
steps: stepsPerEpoch,
|
|
batchSize,
|
|
verbose,
|
|
doValidation,
|
|
metrics: callbackMetrics,
|
|
});
|
|
return { callbackList, history };
|
|
}
|
|
|
|
|
|
|
|
|
|
function deserialize(config, customObjects = {}, fastWeightInit = false) {
|
|
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'layer', fastWeightInit);
|
|
}
|
|
|
|
|
|
|
|
|
|
function l2Normalize(x, axis) {
|
|
return tidy(() => {
|
|
if (x.dtype !== 'float32') {
|
|
x = cast$3(x, 'float32');
|
|
}
|
|
const squareSum = sum$2(square(x), axis, true);
|
|
const epsilonTensor = fill$2(squareSum.shape, epsilon());
|
|
const norm = sqrt$2(maximum$2(squareSum, epsilonTensor));
|
|
return div$1(x, norm);
|
|
});
|
|
}
|
|
function meanSquaredError(yTrue, yPred) {
|
|
return tidy(() => mean$1(square(sub$2(yPred, yTrue)), -1));
|
|
}
|
|
function meanAbsoluteError(yTrue, yPred) {
|
|
return tidy(() => mean$1(abs$2(sub$2(yPred, yTrue)), -1));
|
|
}
|
|
function meanAbsolutePercentageError(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const diff = sub$2(yTrue, yPred);
|
|
const clippedTrue = clipByValue$2(abs$2(yTrue), epsilon(), Number.MAX_VALUE);
|
|
const absResult = abs$2(div$1(diff, clippedTrue));
|
|
return mul(100, mean$1(absResult, -1));
|
|
});
|
|
}
|
|
function meanSquaredLogarithmicError(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const clippedPred = clipByValue$2(yPred, epsilon(), Number.MAX_VALUE);
|
|
const firstLog = log$2(add$1(1, clippedPred));
|
|
const clippedTrue = clipByValue$2(yTrue, epsilon(), Number.MAX_VALUE);
|
|
const secondLog = log$2(add$1(1, clippedTrue));
|
|
return mean$1(square(sub$2(firstLog, secondLog)), -1);
|
|
});
|
|
}
|
|
function squaredHinge(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const maxResult = maximum$2(0, sub$2(1, mul(yTrue, yPred)));
|
|
return mean$1(square(maxResult), -1);
|
|
});
|
|
}
|
|
function hinge(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const maxResult = maximum$2(0, sub$2(1, mul(yTrue, yPred)));
|
|
return mean$1(maxResult, -1);
|
|
});
|
|
}
|
|
function categoricalHinge(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const pos = sum$2(mul(yTrue, yPred), -1);
|
|
const neg = max$2(mul(sub$2(1, yTrue), yPred), -1);
|
|
return maximum$2(0, add$1(1, sub$2(neg, pos)));
|
|
});
|
|
}
|
|
|
|
function logcosh(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const log2 = Math.log(2);
|
|
const predictionDiff = sub$2(yPred, yTrue);
|
|
const logcoshResult = sub$2(add$1(predictionDiff, softplus$2(mul(-2, predictionDiff))), log2);
|
|
return mean$1(logcoshResult, -1);
|
|
});
|
|
}
|
|
function categoricalCrossentropy$1(target, output, fromLogits = false) {
|
|
return tidy(() => {
|
|
if (fromLogits) {
|
|
output = softmax$2(output);
|
|
}
|
|
else {
|
|
|
|
const outputSum = sum$2(output, output.shape.length - 1, true);
|
|
output = div$1(output, outputSum);
|
|
}
|
|
output = clipByValue$2(output, epsilon(), 1 - epsilon());
|
|
return neg$2(sum$2(mul(cast$3(target, 'float32'), log$2(output)), output.shape.length - 1));
|
|
});
|
|
}
|
|
|
|
function sparseCategoricalCrossentropy$1(target, output, fromLogits = false) {
|
|
return tidy(() => {
|
|
const flatTarget = cast$3(floor$2(flatten(target)), 'int32');
|
|
output = clipByValue$2(output, epsilon(), 1 - epsilon());
|
|
const outputShape = output.shape;
|
|
const oneHotTarget = reshape$2(oneHot$2(flatTarget, outputShape[outputShape.length - 1]), outputShape);
|
|
return categoricalCrossentropy$1(oneHotTarget, output, fromLogits);
|
|
});
|
|
}
|
|
|
|
function sigmoidCrossEntropyWithLogits(labels, logits) {
|
|
if (!arraysEqual(labels.shape, logits.shape)) {
|
|
throw new ValueError(`logits and labels must have the same shape, but got shapes ` +
|
|
`${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`);
|
|
}
|
|
return tidy(() => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const reluLogits = relu$2(logits);
|
|
const negAbsLogits = neg$2(abs$2(logits));
|
|
return add$1(sub$2(reluLogits, mul(logits, labels)), log1p$2(exp$2(negAbsLogits)));
|
|
});
|
|
}
|
|
function binaryCrossentropy$1(yTrue, yPred) {
|
|
return tidy(() => {
|
|
let y;
|
|
y = clipByValue$2(yPred, epsilon(), 1 - epsilon());
|
|
y = log$2(div$1(y, sub$2(1, y)));
|
|
return mean$1(sigmoidCrossEntropyWithLogits(yTrue, y), -1);
|
|
});
|
|
}
|
|
function kullbackLeiblerDivergence(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const clippedTrue = clipByValue$2(yTrue, epsilon(), 1);
|
|
const clippedPred = clipByValue$2(yPred, epsilon(), 1);
|
|
return sum$2(mul(yTrue, log$2(div$1(clippedTrue, clippedPred))), -1);
|
|
});
|
|
}
|
|
function poisson(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const logPred = log$2(add$1(epsilon(), yPred));
|
|
return mean$1(sub$2(yPred, mul(yTrue, logPred)), -1);
|
|
});
|
|
}
|
|
function cosineProximity(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const trueNormalized = l2Normalize(yTrue, -1);
|
|
const predNormalized = l2Normalize(yPred, -1);
|
|
const trueXPred = mul(trueNormalized, predNormalized);
|
|
return neg$2(sum$2(trueXPred, -1));
|
|
});
|
|
}
|
|
|
|
const lossesMap = {
|
|
meanSquaredError,
|
|
meanAbsoluteError,
|
|
meanAbsolutePercentageError,
|
|
meanSquaredLogarithmicError,
|
|
squaredHinge,
|
|
hinge,
|
|
categoricalHinge,
|
|
logcosh,
|
|
categoricalCrossentropy: categoricalCrossentropy$1,
|
|
sparseCategoricalCrossentropy: sparseCategoricalCrossentropy$1,
|
|
binaryCrossentropy: binaryCrossentropy$1,
|
|
kullbackLeiblerDivergence,
|
|
poisson,
|
|
cosineProximity
|
|
};
|
|
|
|
|
|
function get$1(identifierOrFn) {
|
|
if (typeof identifierOrFn === 'string') {
|
|
if (identifierOrFn in lossesMap) {
|
|
return lossesMap[identifierOrFn];
|
|
}
|
|
let errMsg = `Unknown loss ${identifierOrFn}`;
|
|
if (identifierOrFn.toLowerCase().includes('softmaxcrossentropy')) {
|
|
errMsg = `Unknown loss ${identifierOrFn}. ` +
|
|
'Use "categoricalCrossentropy" as the string name for ' +
|
|
'tf.losses.softmaxCrossEntropy';
|
|
}
|
|
throw new ValueError(errMsg);
|
|
}
|
|
else {
|
|
return identifierOrFn;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function binaryAccuracy(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const threshold = mul(.5, onesLike$2(yPred));
|
|
const yPredThresholded = cast(greater$2(yPred, threshold), yTrue.dtype);
|
|
return mean$1(equal$2(yTrue, yPredThresholded), -1);
|
|
});
|
|
}
|
|
function categoricalAccuracy(yTrue, yPred) {
|
|
return tidy(() => cast(equal$2(argMax$2(yTrue, -1), argMax$2(yPred, -1)), 'float32'));
|
|
}
|
|
function truePositives(yTrue, yPred) {
|
|
return tidy(() => {
|
|
return cast$3(sum$2(logicalAnd$2(equal$2(yTrue, 1), equal$2(yPred, 1))), 'float32');
|
|
});
|
|
}
|
|
function falsePositives(yTrue, yPred) {
|
|
return tidy(() => {
|
|
return cast$3(sum$2(logicalAnd$2(equal$2(yTrue, 0), equal$2(yPred, 1))), 'float32');
|
|
});
|
|
}
|
|
function precision(yTrue, yPred) {
|
|
return tidy(() => {
|
|
const tp = truePositives(yTrue, yPred);
|
|
const fp = falsePositives(yTrue, yPred);
|
|
const denominator = add$1(tp, fp);
|
|
return cast$3(where(greater$2(denominator, 0), div$1(tp, denominator), 0), 'float32');
|
|
});
|
|
}
|
|
function binaryCrossentropy(yTrue, yPred) {
|
|
return binaryCrossentropy$1(yTrue, yPred);
|
|
}
|
|
function sparseCategoricalAccuracy(yTrue, yPred) {
|
|
if (yTrue.rank === yPred.rank) {
|
|
yTrue = squeeze(yTrue, [yTrue.rank - 1]);
|
|
}
|
|
yPred = argMax$2(yPred, -1);
|
|
if (yPred.dtype !== yTrue.dtype) {
|
|
yPred = cast$3(yPred, yTrue.dtype);
|
|
}
|
|
return cast$3(equal$2(yTrue, yPred), 'float32');
|
|
}
|
|
|
|
const mse = meanSquaredError;
|
|
const MSE = meanSquaredError;
|
|
const mae = meanAbsoluteError;
|
|
const MAE = meanAbsoluteError;
|
|
const mape = meanAbsolutePercentageError;
|
|
const MAPE = meanAbsolutePercentageError;
|
|
const categoricalCrossentropy = categoricalCrossentropy$1;
|
|
const cosine = cosineProximity;
|
|
const sparseCategoricalCrossentropy = sparseCategoricalCrossentropy$1;
|
|
|
|
const metricsMap = {
|
|
binaryAccuracy,
|
|
categoricalAccuracy,
|
|
precision,
|
|
categoricalCrossentropy,
|
|
sparseCategoricalCrossentropy,
|
|
mse,
|
|
MSE,
|
|
mae,
|
|
MAE,
|
|
mape,
|
|
MAPE,
|
|
cosine
|
|
};
|
|
function get(identifier) {
|
|
if (typeof identifier === 'string' && identifier in metricsMap) {
|
|
return metricsMap[identifier];
|
|
}
|
|
else if (typeof identifier !== 'string' && identifier != null) {
|
|
return identifier;
|
|
}
|
|
else {
|
|
throw new ValueError(`Unknown metric ${identifier}`);
|
|
}
|
|
}
|
|
|
|
function getLossOrMetricName(fn) {
|
|
assert(fn !== null, `Unknown LossOrMetricFn ${fn}`);
|
|
if (typeof fn === 'string') {
|
|
return fn;
|
|
}
|
|
else {
|
|
let fnName;
|
|
for (const key of Object.keys(lossesMap)) {
|
|
if (lossesMap[key] === fn) {
|
|
fnName = key;
|
|
break;
|
|
}
|
|
}
|
|
if (fnName !== undefined) {
|
|
return fnName;
|
|
}
|
|
for (const key of Object.keys(metricsMap)) {
|
|
if (metricsMap[key] === fn) {
|
|
fnName = key;
|
|
break;
|
|
}
|
|
}
|
|
if (fnName !== undefined) {
|
|
return fnName;
|
|
}
|
|
return fn.name;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function getOptimizer(identifier) {
|
|
const optimizerMap = {
|
|
'Adagrad': () => train.adagrad(0.01),
|
|
'Adadelta': () => train.adadelta(1, 0.95, epsilon()),
|
|
'Adam': () => train.adam(0.001, 0.9, 0.999, epsilon()),
|
|
'Adamax': () => train.adamax(0.002, 0.9, 0.999, epsilon(), 0),
|
|
'RMSProp': () => train.rmsprop(0.001, 0.9, 0, epsilon()),
|
|
'SGD': () => train.sgd(0.01)
|
|
};
|
|
optimizerMap['adagrad'] = optimizerMap['Adagrad'];
|
|
optimizerMap['adadelta'] = optimizerMap['Adadelta'];
|
|
optimizerMap['adam'] = optimizerMap['Adam'];
|
|
optimizerMap['adamax'] = optimizerMap['Adamax'];
|
|
optimizerMap['rmsprop'] = optimizerMap['RMSProp'];
|
|
optimizerMap['sgd'] = optimizerMap['SGD'];
|
|
if (identifier in optimizerMap) {
|
|
return optimizerMap[identifier]();
|
|
}
|
|
throw new ValueError(`Unknown Optimizer ${identifier}`);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH = 1 * 1024 * 1024;
|
|
|
|
function checkUserDefinedMetadata(userDefinedMetadata, modelName, checkSize = false) {
|
|
if (userDefinedMetadata == null ||
|
|
typeof userDefinedMetadata !== 'object' ||
|
|
Object.getPrototypeOf(userDefinedMetadata) !== Object.prototype ||
|
|
!plainObjectCheck(userDefinedMetadata)) {
|
|
throw new Error('User-defined metadata is expected to be a JSON object, but is not.');
|
|
}
|
|
if (checkSize) {
|
|
const out = JSON.stringify(userDefinedMetadata);
|
|
if (out.length > MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH) {
|
|
console.warn(`User-defined metadata of model "${modelName}" is too large in ` +
|
|
`size (length=${out.length} when serialized). It is not ` +
|
|
`recommended to store such large objects in user-defined metadata. ` +
|
|
`Please make sure its serialized length is <= ` +
|
|
`${MAX_USER_DEFINED_METADATA_SERIALIZED_LENGTH}.`);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
function plainObjectCheck(x) {
|
|
if (x === null) {
|
|
|
|
return true;
|
|
}
|
|
else if (typeof x === 'object') {
|
|
if (Object.getPrototypeOf(x) === Object.prototype) {
|
|
|
|
const keys = Object.keys(x);
|
|
for (const key of keys) {
|
|
if (typeof key !== 'string') {
|
|
|
|
return false;
|
|
}
|
|
if (!plainObjectCheck(x[key])) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
else {
|
|
|
|
if (Array.isArray(x)) {
|
|
|
|
for (const item of x) {
|
|
if (!plainObjectCheck(item)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
else {
|
|
|
|
|
|
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
|
|
const xType = typeof x;
|
|
return xType === 'string' || xType === 'number' || xType === 'boolean';
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function printSummary(model, lineLength, positions,
|
|
|
|
printFn = console.log) {
|
|
const sequentialLike = isModelSequentialLike(model);
|
|
|
|
const toDisplay = ['Layer (type)', 'Input Shape', 'Output shape', 'Param #'];
|
|
if (sequentialLike) {
|
|
lineLength = lineLength || 90;
|
|
positions = positions || [0.32, 0.61, 0.89, 1];
|
|
}
|
|
else {
|
|
lineLength = lineLength || 115;
|
|
positions = positions || [0.24, 0.48, 0.70, 0.80, 1];
|
|
|
|
}
|
|
if (positions[positions.length - 1] <= 1) {
|
|
|
|
positions = positions.map(p => Math.floor(lineLength * p));
|
|
}
|
|
let relevantNodes;
|
|
if (!sequentialLike) {
|
|
toDisplay.push('Receives inputs');
|
|
relevantNodes = [];
|
|
for (const depth in model.nodesByDepth) {
|
|
relevantNodes.push(...model.nodesByDepth[depth]);
|
|
}
|
|
}
|
|
printFn('_'.repeat(lineLength));
|
|
printRow(toDisplay, positions, printFn);
|
|
printFn('='.repeat(lineLength));
|
|
const layers = model.layers;
|
|
for (let i = 0; i < layers.length; ++i) {
|
|
if (sequentialLike) {
|
|
printLayerSummary(layers[i], positions, printFn);
|
|
}
|
|
else {
|
|
printLayerSummaryWithConnections(layers[i], positions, relevantNodes, printFn);
|
|
}
|
|
printFn((i === layers.length - 1 ? '=' : '_').repeat(lineLength));
|
|
}
|
|
|
|
model.checkTrainableWeightsConsistency();
|
|
const trainableCount = countTrainableParams(model);
|
|
const nonTrainableCount = countParamsInWeights(model.nonTrainableWeights);
|
|
printFn(`Total params: ${trainableCount + nonTrainableCount}`);
|
|
printFn(`Trainable params: ${trainableCount}`);
|
|
printFn(`Non-trainable params: ${nonTrainableCount}`);
|
|
printFn('_'.repeat(lineLength));
|
|
}
|
|
function countTrainableParams(model) {
|
|
let trainableCount;
|
|
|
|
if (model.collectedTrainableWeights != null) {
|
|
trainableCount =
|
|
countParamsInWeights(model.collectedTrainableWeights);
|
|
}
|
|
else {
|
|
trainableCount = countParamsInWeights(model.trainableWeights);
|
|
}
|
|
|
|
return trainableCount;
|
|
}
|
|
function isModelSequentialLike(model) {
|
|
let sequentialLike = true;
|
|
const nodesByDepth = [];
|
|
const nodes = [];
|
|
for (const depth in model.nodesByDepth) {
|
|
nodesByDepth.push(model.nodesByDepth[depth]);
|
|
}
|
|
for (const depthNodes of nodesByDepth) {
|
|
if (depthNodes.length > 1 ||
|
|
depthNodes.length === 1 && depthNodes[0].inboundLayers.length > 1) {
|
|
sequentialLike = false;
|
|
break;
|
|
}
|
|
nodes.push(...depthNodes);
|
|
}
|
|
if (sequentialLike) {
|
|
|
|
for (const layer of model.layers) {
|
|
let flag = false;
|
|
for (const node of layer.inboundNodes) {
|
|
if (nodes.indexOf(node) !== -1) {
|
|
if (flag) {
|
|
sequentialLike = false;
|
|
break;
|
|
}
|
|
else {
|
|
flag = true;
|
|
}
|
|
}
|
|
}
|
|
if (!sequentialLike) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return sequentialLike;
|
|
}
|
|
function printRow(fields, positions,
|
|
|
|
printFn = console.log) {
|
|
let line = '';
|
|
for (let i = 0; i < fields.length; ++i) {
|
|
if (i > 0) {
|
|
line = line.slice(0, line.length - 1) + ' ';
|
|
}
|
|
line += fields[i];
|
|
line = line.slice(0, positions[i]);
|
|
line += ' '.repeat(positions[i] - line.length);
|
|
}
|
|
printFn(line);
|
|
}
|
|
|
|
function printLayerSummary(layer, positions,
|
|
|
|
printFn) {
|
|
let outputShape;
|
|
let inputShape;
|
|
try {
|
|
inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
|
|
}
|
|
catch (err) {
|
|
inputShape = 'multiple';
|
|
}
|
|
try {
|
|
outputShape = JSON.stringify(layer.outputShape);
|
|
}
|
|
catch (err) {
|
|
outputShape = 'multiple';
|
|
}
|
|
const name = layer.name;
|
|
const className = layer.getClassName();
|
|
const fields = [`${name} (${className})`, inputShape,
|
|
outputShape, layer.countParams().toString()];
|
|
printRow(fields, positions, printFn);
|
|
}
|
|
|
|
function printLayerSummaryWithConnections(layer, positions, relevantNodes,
|
|
|
|
printFn) {
|
|
let outputShape;
|
|
let inputShape;
|
|
try {
|
|
inputShape = (layer.inboundNodes.map(x => JSON.stringify(x.inputShapes))).join(',');
|
|
}
|
|
catch (err) {
|
|
inputShape = 'multiple';
|
|
}
|
|
try {
|
|
outputShape = JSON.stringify(layer.outputShape);
|
|
}
|
|
catch (err) {
|
|
outputShape = 'multiple';
|
|
}
|
|
const connections = [];
|
|
for (const node of layer.inboundNodes) {
|
|
if (relevantNodes != null && relevantNodes.length > 0 &&
|
|
relevantNodes.indexOf(node) === -1) {
|
|
continue;
|
|
}
|
|
for (let i = 0; i < node.inboundLayers.length; ++i) {
|
|
const inboundLayer = node.inboundLayers[i].name;
|
|
const inboundLayerIndex = node.nodeIndices[i];
|
|
const inboundTensorIndex = node.tensorIndices[i];
|
|
connections.push(`${inboundLayer}[${inboundLayerIndex}][${inboundTensorIndex}]`);
|
|
}
|
|
}
|
|
const name = layer.name;
|
|
const className = layer.getClassName();
|
|
const firstConnection = connections.length === 0 ? '' : connections[0];
|
|
const fields = [
|
|
`${name} (${className})`, inputShape,
|
|
outputShape, layer.countParams().toString(),
|
|
firstConnection
|
|
];
|
|
printRow(fields, positions, printFn);
|
|
for (let i = 1; i < connections.length; ++i) {
|
|
printRow(['', '', '', '', connections[i]], positions, printFn);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
function isArrayItemInputOrOutputName(key, index, value) {
|
|
return (key === 'inboundNodes' || key === 'outputLayers' ||
|
|
key === 'inputLayers') &&
|
|
index === 0 && typeof value === 'string';
|
|
}
|
|
|
|
function convertPythonicToTs(pythonicConfig, key) {
|
|
if (pythonicConfig === null) {
|
|
return null;
|
|
}
|
|
else if (typeof pythonicConfig === 'string') {
|
|
return toCamelCase(pythonicConfig);
|
|
}
|
|
else if ((typeof pythonicConfig === 'number') ||
|
|
(typeof pythonicConfig === 'boolean')) {
|
|
return pythonicConfig;
|
|
}
|
|
else if (pythonicConfig instanceof Array) {
|
|
const tsArray = [];
|
|
const arrayLength = pythonicConfig.length;
|
|
for (let i = 0; i < arrayLength; ++i) {
|
|
const item = pythonicConfig[i];
|
|
if (isArrayItemInputOrOutputName(key, i, item)) {
|
|
tsArray.push(item);
|
|
}
|
|
else {
|
|
tsArray.push(convertPythonicToTs(item, key));
|
|
}
|
|
}
|
|
return tsArray;
|
|
}
|
|
else {
|
|
const tsDict = {};
|
|
for (const pythonicKey of Object.keys(pythonicConfig)) {
|
|
const pythonicValue = pythonicConfig[pythonicKey];
|
|
if (pythonicKey === 'name' && typeof pythonicValue === 'string') {
|
|
|
|
|
|
|
|
tsDict[pythonicKey] = pythonicValue;
|
|
}
|
|
else {
|
|
const tsKey = toCamelCase(pythonicKey);
|
|
tsDict[tsKey] = convertPythonicToTs(pythonicValue, tsKey);
|
|
}
|
|
}
|
|
return tsDict;
|
|
}
|
|
}
|
|
|
|
function convertTsToPythonic(tsConfig, key) {
|
|
if (tsConfig === null || tsConfig === undefined) {
|
|
return null;
|
|
}
|
|
else if (typeof tsConfig === 'string') {
|
|
return toSnakeCase(tsConfig);
|
|
}
|
|
else if ((typeof tsConfig === 'number') || (typeof tsConfig === 'boolean')) {
|
|
return tsConfig;
|
|
}
|
|
else if (tsConfig instanceof Array) {
|
|
const pyArray = [];
|
|
const arrayLength = tsConfig.length;
|
|
for (let i = 0; i < arrayLength; ++i) {
|
|
const item = tsConfig[i];
|
|
if (isArrayItemInputOrOutputName(key, i, item)) {
|
|
pyArray.push(item);
|
|
}
|
|
else {
|
|
pyArray.push(convertTsToPythonic(item, key));
|
|
}
|
|
}
|
|
return pyArray;
|
|
}
|
|
else {
|
|
const pyDict = {};
|
|
for (const tsKey of Object.keys(tsConfig)) {
|
|
const tsValue = tsConfig[tsKey];
|
|
const pyKey = toSnakeCase(tsKey);
|
|
if ((tsKey === 'name' || tsKey === 'className') &&
|
|
typeof tsValue === 'string') {
|
|
|
|
|
|
|
|
pyDict[pyKey] = tsValue;
|
|
}
|
|
else {
|
|
pyDict[pyKey] = convertTsToPythonic(tsValue, tsKey);
|
|
}
|
|
}
|
|
return pyDict;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const version = '4.22.0';
|
|
|
|
|
|
|
|
|
|
|
|
const isKerasSavedModelFormat = (weights) => {
|
|
const keys = Object.keys(weights);
|
|
if (keys.length === 0) {
|
|
return false;
|
|
}
|
|
const key = keys[0].split('/');
|
|
return !isNaN(parseInt(key[key.length - 1], 10));
|
|
};
|
|
|
|
class Container extends Layer {
|
|
constructor(args) {
|
|
|
|
super({});
|
|
this.containerNodes = new Set();
|
|
this.name = args.name;
|
|
if (this.name == null) {
|
|
const prefix = this.getClassName().toLowerCase();
|
|
this.name = getUid(prefix);
|
|
}
|
|
this.supportsMasking = false;
|
|
this.trainable_ = true;
|
|
|
|
|
|
if (Array.isArray(args.inputs)) {
|
|
this.inputs = args.inputs.slice();
|
|
}
|
|
else {
|
|
this.inputs = [args.inputs];
|
|
}
|
|
if (Array.isArray(args.outputs)) {
|
|
this.outputs = args.outputs.slice();
|
|
}
|
|
else {
|
|
this.outputs = [args.outputs];
|
|
}
|
|
|
|
if (unique(this.inputs).length !== this.inputs.length) {
|
|
throw new ValueError('The list of inputs passed to the model is ' +
|
|
'redundant. All inputs should only appear once. Found: ' +
|
|
`${this.inputs.map(x => x.name)}`);
|
|
}
|
|
|
|
if (unique(this.outputs).length !== this.outputs.length) {
|
|
console.warn('The list of outputs passed to the model is redundant. ' +
|
|
'All outputs should only appear once. Found: ' +
|
|
`${this.outputs.map(x => x.name)}`);
|
|
}
|
|
|
|
this.inputLayers = [];
|
|
this.inputLayersNodeIndices = [];
|
|
this.inputLayersTensorIndices = [];
|
|
|
|
this.outputLayers = [];
|
|
this.outputLayersNodeIndices = [];
|
|
this.outputLayersTensorIndices = [];
|
|
|
|
this.layers = [];
|
|
|
|
this.internalContainerRefs = [];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (const x of this.outputs) {
|
|
const layer = x.sourceLayer;
|
|
const nodeIndex = x.nodeIndex;
|
|
const tensorIndex = x.tensorIndex;
|
|
this.outputLayers.push(layer);
|
|
this.outputLayersNodeIndices.push(nodeIndex);
|
|
this.outputLayersTensorIndices.push(tensorIndex);
|
|
}
|
|
|
|
|
|
for (const x of this.inputs) {
|
|
const layer = x.sourceLayer;
|
|
const nodeIndex = x.nodeIndex;
|
|
const tensorIndex = x.tensorIndex;
|
|
|
|
assert(nodeIndex === 0, 'input layer has >1 nodes');
|
|
assert(tensorIndex === 0, 'input layer has >1 tensors');
|
|
this.inputLayers.push(layer);
|
|
this.inputLayersNodeIndices.push(nodeIndex);
|
|
this.inputLayersTensorIndices.push(tensorIndex);
|
|
}
|
|
|
|
this.inputNames = [];
|
|
this.outputNames = [];
|
|
this.feedInputShapes = [];
|
|
this.feedInputNames = [];
|
|
this.feedOutputNames = [];
|
|
for (let i = 0; i < this.inputLayers.length; i++) {
|
|
const layer = this.inputLayers[i];
|
|
|
|
if (!(layer instanceof InputLayer)) {
|
|
throw new TypeError('Input layers to a LayersModel must be InputLayer objects. ' +
|
|
`Received inputs: ${args.inputs}. ` +
|
|
`Input ${i} (0-based) originates ` +
|
|
`from layer type ${layer.getClassName()}.`);
|
|
}
|
|
this.inputNames.push(layer.name);
|
|
this.feedInputShapes.push(layer.batchInputShape);
|
|
this.feedInputNames.push(layer.name);
|
|
}
|
|
for (const layer of this.outputLayers) {
|
|
this.outputNames.push(layer.name);
|
|
}
|
|
this.internalInputShapes = this.inputs.map(x => x.shape);
|
|
this.internalOutputShapes = this.outputs.map(x => x.shape);
|
|
|
|
|
|
const nodesDepths = {};
|
|
|
|
const nodeIDToNode = {};
|
|
const layersDepths = {};
|
|
|
|
const layerIDToLayer = {};
|
|
const layerIndices = {};
|
|
const nodesInDecreasingDepth = [];
|
|
|
|
const buildMapOfGraph = (tensor, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex) => {
|
|
if (layer == null || nodeIndex == null || tensorIndex == null) {
|
|
layer = tensor.sourceLayer;
|
|
nodeIndex = tensor.nodeIndex;
|
|
tensorIndex = tensor.tensorIndex;
|
|
}
|
|
const node = layer.inboundNodes[nodeIndex];
|
|
|
|
if (nodesInProgress.indexOf(node) !== -1) {
|
|
throw new RuntimeError(`The tensor ${tensor.name} at layer "${layer.name}" ` +
|
|
'is part of a cycle.');
|
|
}
|
|
|
|
if (finishedNodes.indexOf(node) !== -1) {
|
|
return;
|
|
}
|
|
|
|
this.containerNodes.add(Container.nodeKey(layer, nodeIndex));
|
|
|
|
if (!(layer.id in layerIndices)) {
|
|
layerIndices[layer.id] = Object.keys(layerIndices).length;
|
|
}
|
|
if (nodesInProgress.indexOf(node) === -1) {
|
|
nodesInProgress.push(node);
|
|
}
|
|
|
|
const numInboundLayers = node.inboundLayers.length;
|
|
for (let i = 0; i < numInboundLayers; i++) {
|
|
const x = node.inputTensors[i];
|
|
const layer = node.inboundLayers[i];
|
|
const nodeIndex = node.nodeIndices[i];
|
|
const tensorIndex = node.tensorIndices[i];
|
|
buildMapOfGraph(x, finishedNodes, nodesInProgress, layer, nodeIndex, tensorIndex);
|
|
}
|
|
finishedNodes.push(node);
|
|
while (nodesInProgress.indexOf(node) >= 0) {
|
|
nodesInProgress.splice(nodesInProgress.indexOf(node), 1);
|
|
}
|
|
nodesInDecreasingDepth.push(node);
|
|
};
|
|
const finishedNodes = [];
|
|
const nodesInProgress = [];
|
|
for (const x of this.outputs) {
|
|
buildMapOfGraph(x, finishedNodes, nodesInProgress);
|
|
}
|
|
const reversedNodesInDecreasingDepth = nodesInDecreasingDepth.slice().reverse();
|
|
for (const node of reversedNodesInDecreasingDepth) {
|
|
nodeIDToNode[node.id] = node;
|
|
|
|
if (!(node.id in nodesDepths)) {
|
|
nodesDepths[node.id] = 0;
|
|
}
|
|
let depth = nodesDepths[node.id];
|
|
|
|
const previousDepth = (layersDepths[node.outboundLayer.id] == null ?
|
|
0 :
|
|
layersDepths[node.outboundLayer.id]);
|
|
|
|
depth = Math.max(depth, previousDepth);
|
|
layersDepths[node.outboundLayer.id] = depth;
|
|
layerIDToLayer[node.outboundLayer.id] = node.outboundLayer;
|
|
nodesDepths[node.id] = depth;
|
|
|
|
for (let i = 0; i < node.inboundLayers.length; i++) {
|
|
const inboundLayer = node.inboundLayers[i];
|
|
const nodeIndex = node.nodeIndices[i];
|
|
const inboundNode = inboundLayer.inboundNodes[nodeIndex];
|
|
const previousDepth = (nodesDepths[inboundNode.id] == null ? 0 :
|
|
nodesDepths[inboundNode.id]);
|
|
nodesDepths[inboundNode.id] = Math.max(depth + 1, previousDepth);
|
|
nodeIDToNode[inboundNode.id] = inboundNode;
|
|
}
|
|
}
|
|
|
|
const nodesByDepth = {};
|
|
for (const nodeID in nodesDepths) {
|
|
const depth = nodesDepths[nodeID];
|
|
if (!(depth in nodesByDepth)) {
|
|
nodesByDepth[depth] = [];
|
|
}
|
|
nodesByDepth[depth].push(nodeIDToNode[nodeID]);
|
|
}
|
|
|
|
const layersByDepth = {};
|
|
for (const layerID in layersDepths) {
|
|
const depth = layersDepths[layerID];
|
|
if (!(depth in layersByDepth)) {
|
|
layersByDepth[depth] = [];
|
|
}
|
|
layersByDepth[depth].push(layerIDToLayer[layerID]);
|
|
}
|
|
|
|
let depthKeys = Object.keys(layersByDepth)
|
|
.map(x => parseInt(x, 10))
|
|
.sort(reverseNumberCompare);
|
|
|
|
this.layers = [];
|
|
for (const depth of depthKeys) {
|
|
const layersForDepth = layersByDepth[depth];
|
|
|
|
|
|
layersForDepth.sort((a, b) => {
|
|
const aIndex = layerIndices[a.id];
|
|
const bIndex = layerIndices[b.id];
|
|
if (aIndex < bIndex) {
|
|
return -1;
|
|
}
|
|
if (aIndex > bIndex) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
});
|
|
for (const layer of layersForDepth) {
|
|
if (layer instanceof Container) {
|
|
this.internalContainerRefs.push(layer);
|
|
}
|
|
this.layers.push(layer);
|
|
}
|
|
}
|
|
this.layersByDepth = layersByDepth;
|
|
|
|
depthKeys = Object.keys(nodesByDepth)
|
|
.map(x => parseInt(x, 10))
|
|
.sort(reverseNumberCompare);
|
|
|
|
|
|
|
|
const computableTensors = this.inputs.slice();
|
|
|
|
const layersWithCompleteInput = [];
|
|
for (const depth of depthKeys) {
|
|
for (const node of nodesByDepth[depth]) {
|
|
const layer = node.outboundLayer;
|
|
if (layer != null) {
|
|
for (const x of node.inputTensors) {
|
|
if (computableTensors.indexOf(x) === -1) {
|
|
throw new RuntimeError(`Graph disconnected: cannot obtain value for tensor ${x}` +
|
|
` at layer "${layer.name}". ` +
|
|
'The following previous layers were accessed without ' +
|
|
`issue: ${layersWithCompleteInput}`);
|
|
}
|
|
}
|
|
for (const x of node.outputTensors) {
|
|
computableTensors.push(x);
|
|
}
|
|
layersWithCompleteInput.push(layer.name);
|
|
}
|
|
}
|
|
}
|
|
|
|
this.nodesByDepth = nodesByDepth;
|
|
|
|
|
|
const allNames = this.layers.map(x => x.name);
|
|
for (const name of allNames) {
|
|
const numOccurrences = allNames.filter(x => x === name).length;
|
|
if (numOccurrences !== 1) {
|
|
throw new RuntimeError(`The name "${name}" is used ${numOccurrences} times ` +
|
|
'in the model. All layer names should be unique. Layer names: ' +
|
|
JSON.stringify(allNames));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
this.outboundNodes = [];
|
|
|
|
this.inboundNodes = [];
|
|
|
|
|
|
|
|
new Node({
|
|
outboundLayer: this,
|
|
inboundLayers: [],
|
|
nodeIndices: [],
|
|
tensorIndices: [],
|
|
inputTensors: this.inputs,
|
|
outputTensors: this.outputs,
|
|
inputMasks: this.inputs.map(x => null),
|
|
outputMasks: this.outputs.map(x => null),
|
|
inputShapes: this.inputs.map(x => x.shape),
|
|
outputShapes: this.outputs.map(x => x.shape)
|
|
});
|
|
this.built = true;
|
|
this._refCount = 1;
|
|
}
|
|
assertNotDisposed() {
|
|
if (this._refCount === 0) {
|
|
throw new Error(`Container '${this.name}' is already disposed.`);
|
|
}
|
|
}
|
|
|
|
dispose() {
|
|
this.assertNotDisposed();
|
|
const result = { refCountAfterDispose: null, numDisposedVariables: 0 };
|
|
if (--this._refCount === 0) {
|
|
for (const layer of this.layers) {
|
|
result.numDisposedVariables += layer.dispose().numDisposedVariables;
|
|
}
|
|
|
|
|
|
for (const container of this.internalContainerRefs) {
|
|
result.numDisposedVariables += container.dispose().numDisposedVariables;
|
|
}
|
|
}
|
|
result.refCountAfterDispose = this._refCount;
|
|
return result;
|
|
}
|
|
get trainable() {
|
|
return this.trainable_;
|
|
}
|
|
set trainable(trainable) {
|
|
this.layers.forEach(layer => {
|
|
|
|
layer._trainableWeights
|
|
.forEach(w => w.trainable = trainable);
|
|
});
|
|
this.trainable_ = trainable;
|
|
}
|
|
get trainableWeights() {
|
|
|
|
|
|
|
|
if (this._trainableWeights.length > 0) {
|
|
throw new ValueError('Container instance unexpectedly contains _trainableWeights.' +
|
|
'The trainable weights of a Container are a union of the ' +
|
|
'trainable weights of its consituent Layers. Its own ' +
|
|
'_trainableWeights must remain an empty Array.');
|
|
}
|
|
if (!this.trainable) {
|
|
return [];
|
|
}
|
|
let weights = [];
|
|
for (const layer of this.layers) {
|
|
weights = weights.concat(layer.trainableWeights);
|
|
}
|
|
return weights;
|
|
}
|
|
get nonTrainableWeights() {
|
|
const weights = [];
|
|
for (const layer of this.layers) {
|
|
weights.push(...layer.nonTrainableWeights);
|
|
}
|
|
if (!this.trainable) {
|
|
const trainableWeights = [];
|
|
for (const layer of this.layers) {
|
|
trainableWeights.push(...layer.trainableWeights);
|
|
}
|
|
return trainableWeights.concat(weights);
|
|
}
|
|
return weights;
|
|
}
|
|
get weights() {
|
|
return this.trainableWeights.concat(this.nonTrainableWeights);
|
|
}
|
|
|
|
loadWeights(weights, strict = true) {
|
|
const nameToWeight = {};
|
|
let totalWeightsCount = 0;
|
|
const modelIsKerasSavedModelFormat = isKerasSavedModelFormat(weights);
|
|
if (modelIsKerasSavedModelFormat) {
|
|
this.parseWeights(weights);
|
|
}
|
|
|
|
for (const layer of this.layers) {
|
|
for (const [index, weight] of layer.weights.entries()) {
|
|
|
|
|
|
const parsedName = modelIsKerasSavedModelFormat ?
|
|
`${weight.name.split('/').slice(0, -1).join('/') + '/'}${index}` :
|
|
weight.originalName;
|
|
if (nameToWeight[parsedName] != null) {
|
|
throw new ValueError(`Duplicate weight name: ${parsedName}`);
|
|
}
|
|
nameToWeight[parsedName] = weight;
|
|
totalWeightsCount++;
|
|
}
|
|
}
|
|
const weightValueTuples = [];
|
|
for (const name in weights) {
|
|
|
|
|
|
|
|
let validatedName = name;
|
|
if (nameToWeight[name] == null) {
|
|
const tokens = name.split('/');
|
|
const shortenNameArray = tokens.slice(0, -2).concat([tokens[tokens.length - 1]]);
|
|
validatedName = shortenNameArray.join('/');
|
|
}
|
|
if (nameToWeight[validatedName] != null) {
|
|
weightValueTuples.push([nameToWeight[validatedName], weights[name]]);
|
|
}
|
|
else if (strict) {
|
|
throw new ValueError(`Provided weight data has no target variable: ${name}`);
|
|
}
|
|
delete nameToWeight[validatedName];
|
|
}
|
|
if (strict) {
|
|
|
|
const unsetNames = [];
|
|
for (const name in nameToWeight) {
|
|
unsetNames.push(name);
|
|
}
|
|
if (unsetNames.length > 0) {
|
|
throw new ValueError(`${unsetNames.length} of ${totalWeightsCount} weights are not set: ` +
|
|
`${unsetNames}`);
|
|
}
|
|
}
|
|
batchSetValue(weightValueTuples);
|
|
}
|
|
parseWeights(weights) {
|
|
for (const key in Object.keys(weights)) {
|
|
const listParts = key.split('/');
|
|
const list = ['vars', 'layer_checkpoint_dependencies'];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const newKey = listParts
|
|
.map(str => {
|
|
if (str.startsWith('_')) {
|
|
return str.slice(1);
|
|
}
|
|
return str;
|
|
})
|
|
.filter(str => !list.includes(str))
|
|
.join('/');
|
|
if (newKey !== key) {
|
|
weights[newKey] = weights[key];
|
|
delete weights[key];
|
|
}
|
|
}
|
|
}
|
|
|
|
updatedConfig() {
|
|
const theConfig = this.getConfig();
|
|
const modelConfig = {};
|
|
modelConfig['className'] = this.getClassName();
|
|
modelConfig['config'] = theConfig;
|
|
modelConfig['kerasVersion'] = `tfjs-layers ${version}`;
|
|
|
|
|
|
modelConfig['backend'] = 'TensorFlow.js';
|
|
return modelConfig;
|
|
}
|
|
|
|
|
|
toJSON(unused, returnString = true) {
|
|
const modelConfig = convertTsToPythonic(this.updatedConfig());
|
|
return returnString ? JSON.stringify(modelConfig) : modelConfig;
|
|
}
|
|
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
inputs = toList(inputs);
|
|
const feedDict = new FeedDict();
|
|
for (let i = 0; i < this.inputs.length; ++i) {
|
|
feedDict.add(this.inputs[i], inputs[i]);
|
|
}
|
|
return execute(this.outputs, feedDict, kwargs);
|
|
});
|
|
}
|
|
|
|
computeMask(inputs, mask) {
|
|
return tidy(() => {
|
|
inputs = toList(inputs);
|
|
let masks;
|
|
if (mask == null) {
|
|
masks = pyListRepeat(null, inputs.length);
|
|
}
|
|
else {
|
|
masks = toList(mask);
|
|
}
|
|
|
|
return this.runInternalGraph(inputs, masks)[1];
|
|
});
|
|
}
|
|
|
|
computeOutputShape(inputShape) {
|
|
const inputShapes = normalizeShapeList(inputShape);
|
|
if (inputShapes.length !== this.inputLayers.length) {
|
|
throw new ValueError(`Invalid inputShape argument ${inputShape}: ` +
|
|
`model has ${this.inputLayers.length} tensor inputs.`);
|
|
}
|
|
|
|
const layersToOutputShapes = {};
|
|
for (let i = 0; i < inputShapes.length; i++) {
|
|
const layer = this.inputLayers[i];
|
|
const inputShape = inputShapes[i];
|
|
|
|
|
|
const shapeKey = layer.name + '_0_0';
|
|
layersToOutputShapes[shapeKey] = inputShape;
|
|
}
|
|
const depthKeys = Object.keys(this.nodesByDepth)
|
|
.map(x => parseInt(x, 10))
|
|
.sort(reverseNumberCompare);
|
|
|
|
if (depthKeys.length > 1) {
|
|
for (const depth of depthKeys) {
|
|
const nodes = this.nodesByDepth[depth];
|
|
for (const node of nodes) {
|
|
|
|
const layer = node.outboundLayer;
|
|
if (this.inputLayers.map(x => x.id).indexOf(layer.id) !== -1) {
|
|
|
|
continue;
|
|
}
|
|
|
|
const inputShapes = [];
|
|
for (let j = 0; j < node.inboundLayers.length; j++) {
|
|
const inboundLayer = node.inboundLayers[j];
|
|
const nodeIndex = node.nodeIndices[j];
|
|
const tensorIndex = node.tensorIndices[j];
|
|
const shapeKey = `${inboundLayer.name}_${nodeIndex}_${tensorIndex}`;
|
|
const inputShape = layersToOutputShapes[shapeKey];
|
|
inputShapes.push(inputShape);
|
|
}
|
|
const outputShape = layer.computeOutputShape(singletonOrArray(inputShapes));
|
|
const outputShapes = normalizeShapeList(outputShape);
|
|
const nodeIndex = layer.inboundNodes.indexOf(node);
|
|
for (let j = 0; j < outputShapes.length; j++) {
|
|
const shapeKey = `${layer.name}_${nodeIndex}_${j}`;
|
|
layersToOutputShapes[shapeKey] = outputShapes[j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
const outputShapes = [];
|
|
const outputShapeKeys = [];
|
|
for (let i = 0; i < this.outputLayers.length; i++) {
|
|
const layer = this.outputLayers[i];
|
|
const nodeIndex = this.outputLayersNodeIndices[i];
|
|
const tensorIndex = this.outputLayersTensorIndices[i];
|
|
const shapeKey = `${layer.name}_${nodeIndex}_${tensorIndex}`;
|
|
outputShapeKeys.push(shapeKey);
|
|
}
|
|
for (let i = 0; i < outputShapeKeys.length; i++) {
|
|
const key = outputShapeKeys[i];
|
|
assert(key in layersToOutputShapes);
|
|
outputShapes.push(layersToOutputShapes[key]);
|
|
}
|
|
|
|
return singletonOrArray(outputShapes);
|
|
}
|
|
|
|
runInternalGraph(inputs, masks) {
|
|
if (masks == null) {
|
|
masks = pyListRepeat(null, inputs.length);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
const tensorMap = {};
|
|
for (let i = 0; i < this.inputs.length; ++i) {
|
|
const x = this.inputs[i];
|
|
const y = inputs[i];
|
|
const mask = masks[i];
|
|
tensorMap[x.id] = [y, mask];
|
|
}
|
|
const depthKeys = Object.keys(this.nodesByDepth)
|
|
.map(x => parseInt(x, 10))
|
|
.sort(reverseNumberCompare);
|
|
for (const depth of depthKeys) {
|
|
const nodes = this.nodesByDepth[depth];
|
|
for (const node of nodes) {
|
|
|
|
const layer = node.outboundLayer;
|
|
const referenceInputTensors = node.inputTensors;
|
|
const referenceOutputTensors = node.outputTensors;
|
|
|
|
|
|
|
|
const computedData = new Array();
|
|
for (const x of referenceInputTensors) {
|
|
if (x.id in tensorMap) {
|
|
computedData.push(tensorMap[x.id]);
|
|
}
|
|
}
|
|
if (computedData.length === referenceInputTensors.length) {
|
|
|
|
let kwargs = {};
|
|
let computedTensors;
|
|
let computedMasks;
|
|
let outputTensors;
|
|
let outputMasks;
|
|
|
|
if (node.callArgs != null) {
|
|
kwargs = node.callArgs;
|
|
}
|
|
if (computedData.length === 1) {
|
|
const [computedTensor, computedMask] = computedData[0];
|
|
if (kwargs['mask'] == null) {
|
|
kwargs['mask'] = computedMask;
|
|
}
|
|
outputTensors =
|
|
toList(layer.call(computedTensor, kwargs));
|
|
outputMasks = toList(layer.computeMask(computedTensor, computedMask));
|
|
computedTensors = [computedTensor];
|
|
computedMasks = [computedMask];
|
|
}
|
|
else {
|
|
computedTensors = computedData.map(x => x[0]);
|
|
computedMasks = computedData.map(x => x[1]);
|
|
if (kwargs['mask'] == null) {
|
|
kwargs['mask'] = computedMasks;
|
|
}
|
|
outputTensors =
|
|
toList(layer.call(computedTensors, kwargs));
|
|
outputMasks = toList(layer.computeMask(computedTensors, computedMasks));
|
|
}
|
|
if (layer.activityRegularizer) {
|
|
throw new NotImplementedError('LayersModel invocation with concrete Tensor value(s) in the ' +
|
|
'presence of activity regularizer(s) is not supported yet.');
|
|
}
|
|
|
|
|
|
for (let i = 0; i < referenceOutputTensors.length; ++i) {
|
|
const x = referenceOutputTensors[i];
|
|
const y = outputTensors[i];
|
|
const mask = outputMasks[i];
|
|
tensorMap[x.id] = [y, mask];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const outputTensors = [];
|
|
const outputMasks = [];
|
|
const outputShapes = [];
|
|
for (const x of this.outputs) {
|
|
assert(x.id in tensorMap, `Could not compute output ${x.name} : ${x.id}`);
|
|
const [tensor, mask] = tensorMap[x.id];
|
|
outputShapes.push(tensor.shape);
|
|
outputTensors.push(tensor);
|
|
outputMasks.push(mask);
|
|
}
|
|
|
|
return [outputTensors, outputMasks, outputShapes];
|
|
}
|
|
|
|
buildNodeConversionMap(layers) {
|
|
const nodeConversionMap = {};
|
|
let keptNodes;
|
|
for (const layer of this.layers) {
|
|
keptNodes = layer instanceof Container ? 1 : 0;
|
|
for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
|
|
const nodeKey = Container.nodeKey(layer, originalNodeIndex);
|
|
if (this.containerNodes.has(nodeKey)) {
|
|
|
|
nodeConversionMap[nodeKey] = keptNodes;
|
|
keptNodes += 1;
|
|
}
|
|
}
|
|
}
|
|
return nodeConversionMap;
|
|
}
|
|
getLayer(nameOrIndex, index) {
|
|
if (index != null) {
|
|
return this.findLayer(index);
|
|
}
|
|
else {
|
|
if (nameOrIndex == null) {
|
|
throw new ValueError('Provide either a layer name or layer index');
|
|
}
|
|
if (typeof nameOrIndex === 'number') {
|
|
return this.findLayer(nameOrIndex);
|
|
}
|
|
}
|
|
for (const layer of this.layers) {
|
|
if (layer.name === nameOrIndex) {
|
|
return layer;
|
|
}
|
|
}
|
|
throw new ValueError(`No such layer: ${nameOrIndex}`);
|
|
}
|
|
findLayer(index) {
|
|
if (this.layers.length <= index) {
|
|
throw new ValueError(`Was asked to retrieve layer at index ${index}, but model only ` +
|
|
`has ${this.layers.length} layer(s).`);
|
|
}
|
|
else {
|
|
return this.layers[index];
|
|
}
|
|
}
|
|
|
|
calculateLosses() {
|
|
|
|
|
|
|
|
|
|
return tidy(() => {
|
|
const losses = [];
|
|
for (const layer of this.layers) {
|
|
for (let nodeIndex = 0; nodeIndex < layer.inboundNodes.length; ++nodeIndex) {
|
|
const nodeKey = Container.nodeKey(layer, nodeIndex);
|
|
if (this.containerNodes.has(nodeKey)) {
|
|
losses.push(...layer.calculateLosses());
|
|
}
|
|
}
|
|
}
|
|
|
|
return losses;
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = { name: this.name };
|
|
|
|
|
|
|
|
const nodeConversionMap = this.buildNodeConversionMap(this.layers);
|
|
|
|
const layerConfigs = [];
|
|
for (const layer of this.layers) {
|
|
const layerClassName = layer.getClassName();
|
|
const layerConfig = layer.getConfig();
|
|
const filteredInboundNodes = [];
|
|
for (let originalNodeIndex = 0; originalNodeIndex < layer.inboundNodes.length; originalNodeIndex++) {
|
|
const node = layer.inboundNodes[originalNodeIndex];
|
|
const nodeKey = Container.nodeKey(layer, originalNodeIndex);
|
|
let kwargs = {};
|
|
if (this.containerNodes.has(nodeKey)) {
|
|
|
|
|
|
if (node.callArgs) {
|
|
try {
|
|
JSON.stringify(node.callArgs);
|
|
kwargs = node.callArgs;
|
|
}
|
|
catch (err) {
|
|
console.warn(`Layer ${layer.name} was passed ` +
|
|
`non-serializable keyword arguments: ` +
|
|
`${node.callArgs}. They will not be included ` +
|
|
`in the serialized model (and thus will be ` +
|
|
`missing at deserialization time).`);
|
|
kwargs = {};
|
|
}
|
|
}
|
|
if (node.inboundLayers.length > 0) {
|
|
const nodeData = [];
|
|
for (let i = 0; i < node.inboundLayers.length; i++) {
|
|
const inboundLayer = node.inboundLayers[i];
|
|
const nodeIndex = node.nodeIndices[i];
|
|
const tensorIndex = node.tensorIndices[i];
|
|
const nodeKey = Container.nodeKey(inboundLayer, nodeIndex);
|
|
let newNodeIndex = nodeConversionMap[nodeKey];
|
|
if (newNodeIndex == null) {
|
|
newNodeIndex = 0;
|
|
}
|
|
nodeData.push([inboundLayer.name, newNodeIndex, tensorIndex, kwargs]);
|
|
}
|
|
filteredInboundNodes.push(nodeData);
|
|
}
|
|
}
|
|
}
|
|
const dict = {};
|
|
dict['name'] = layer.name;
|
|
dict['className'] = layerClassName;
|
|
dict['config'] = layerConfig;
|
|
dict['inboundNodes'] = filteredInboundNodes;
|
|
layerConfigs.push(dict);
|
|
}
|
|
config['layers'] = layerConfigs;
|
|
|
|
const modelInputs = [];
|
|
for (let i = 0; i < this.inputLayers.length; i++) {
|
|
const layer = this.inputLayers[i];
|
|
const nodeIndex = this.inputLayersNodeIndices[i];
|
|
const nodeKey = Container.nodeKey(layer, nodeIndex);
|
|
if (!this.containerNodes.has(nodeKey)) {
|
|
continue;
|
|
}
|
|
let newNodeIndex = nodeConversionMap[nodeKey];
|
|
if (newNodeIndex === null || newNodeIndex === undefined) {
|
|
newNodeIndex = 0;
|
|
}
|
|
const tensorIndex = this.inputLayersTensorIndices[i];
|
|
modelInputs.push([layer.name, newNodeIndex, tensorIndex]);
|
|
}
|
|
config['inputLayers'] = modelInputs;
|
|
const modelOutputs = [];
|
|
for (let i = 0; i < this.outputLayers.length; i++) {
|
|
const layer = this.outputLayers[i];
|
|
const nodeIndex = this.outputLayersNodeIndices[i];
|
|
const nodeKey = Container.nodeKey(layer, nodeIndex);
|
|
if (!this.containerNodes.has(nodeKey)) {
|
|
continue;
|
|
}
|
|
let newNodeIndex = nodeConversionMap[nodeKey];
|
|
if (newNodeIndex === null || newNodeIndex === undefined) {
|
|
newNodeIndex = 0;
|
|
}
|
|
const tensorIndex = this.outputLayersTensorIndices[i];
|
|
modelOutputs.push([layer.name, newNodeIndex, tensorIndex]);
|
|
}
|
|
config['outputLayers'] = modelOutputs;
|
|
return config;
|
|
}
|
|
|
|
|
|
static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
|
|
|
|
|
|
const createdLayers = {};
|
|
|
|
|
|
|
|
|
|
|
|
const unprocessedNodes = {};
|
|
function addUnprocessedNode(layer, nodeData) {
|
|
if (!(layer.name in unprocessedNodes)) {
|
|
unprocessedNodes[layer.name] = [nodeData];
|
|
}
|
|
else {
|
|
unprocessedNodes[layer.name].push(nodeData);
|
|
}
|
|
}
|
|
function processNode(layer, nodeData) {
|
|
const inputTensors = [];
|
|
let kwargs;
|
|
for (const inputData of nodeData) {
|
|
const inboundLayerName = inputData[0];
|
|
const inboundNodeIndex = inputData[1];
|
|
const inboundTensorIndex = inputData[2];
|
|
kwargs = inputData[3] == null ?
|
|
{} :
|
|
inputData[3];
|
|
if (!(inboundLayerName in createdLayers)) {
|
|
addUnprocessedNode(layer, nodeData);
|
|
return;
|
|
}
|
|
const inboundLayer = createdLayers[inboundLayerName];
|
|
if (inboundLayer.inboundNodes.length <= inboundNodeIndex) {
|
|
addUnprocessedNode(layer, nodeData);
|
|
return;
|
|
}
|
|
const inboundNode = inboundLayer.inboundNodes[inboundNodeIndex];
|
|
inputTensors.push(inboundNode.outputTensors[inboundTensorIndex]);
|
|
}
|
|
|
|
|
|
|
|
if (inputTensors.length > 0) {
|
|
layer.apply(singletonOrArray(inputTensors), kwargs);
|
|
}
|
|
}
|
|
|
|
function processLayer(layerData) {
|
|
const layerName = layerData['name'];
|
|
|
|
const layer = deserialize(layerData, config['customObjects'] != null ?
|
|
config['customObjects'] :
|
|
{});
|
|
layer.setFastWeightInitDuringBuild(fastWeightInit);
|
|
createdLayers[layerName] = layer;
|
|
|
|
const inboundNodesData = layerData['inboundNodes'];
|
|
inboundNodesData.forEach(nodeData => {
|
|
if (!(nodeData instanceof Array)) {
|
|
throw new ValueError(`Corrupted configuration, expected array for nodeData: ${nodeData}`);
|
|
}
|
|
|
|
|
|
|
|
|
|
addUnprocessedNode(layer, nodeData);
|
|
});
|
|
}
|
|
|
|
const name = config['name'];
|
|
const layersFromConfig = config['layers'];
|
|
for (const layerData of layersFromConfig) {
|
|
processLayer(layerData);
|
|
}
|
|
|
|
|
|
|
|
|
|
while (!isObjectEmpty(unprocessedNodes)) {
|
|
for (const layerData of layersFromConfig) {
|
|
const layer = createdLayers[layerData['name']];
|
|
if (layer.name in unprocessedNodes) {
|
|
const currentUnprocessedNodesForLayer = unprocessedNodes[layer.name];
|
|
delete unprocessedNodes[layer.name];
|
|
for (const nodeData of currentUnprocessedNodesForLayer) {
|
|
processNode(layer, nodeData);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
const inputTensors = [];
|
|
const outputTensors = [];
|
|
const inputLayersFromConfig = config['inputLayers'];
|
|
for (const layerData of inputLayersFromConfig) {
|
|
const layerName = layerData[0];
|
|
const nodeIndex = layerData[1];
|
|
const tensorIndex = layerData[2];
|
|
assert(layerName in createdLayers);
|
|
const layer = createdLayers[layerName];
|
|
const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
|
|
inputTensors.push(layerOutputTensors[tensorIndex]);
|
|
}
|
|
const outputLayersFromConfig = config['outputLayers'];
|
|
for (const layerData of outputLayersFromConfig) {
|
|
const layerName = layerData[0];
|
|
const nodeIndex = layerData[1];
|
|
const tensorIndex = layerData[2];
|
|
assert(layerName in createdLayers);
|
|
const layer = createdLayers[layerName];
|
|
const layerOutputTensors = layer.inboundNodes[nodeIndex].outputTensors;
|
|
outputTensors.push(layerOutputTensors[tensorIndex]);
|
|
}
|
|
return new cls({ inputs: inputTensors, outputs: outputTensors, name });
|
|
}
|
|
|
|
get stateful() {
|
|
|
|
|
|
if (this._stateful) {
|
|
throw new ValueError('Container instance unexpectedly has _stateful = true. The ' +
|
|
'statefulness of a Container is determined by the Layers it ' +
|
|
'contains. Its _stateful property must remain the default false.');
|
|
}
|
|
for (const layer of this.layers) {
|
|
if (layer.stateful) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
resetStates() {
|
|
tidy(() => {
|
|
this.layers.forEach(layer => {
|
|
|
|
if (layer.stateful) {
|
|
layer.resetStates();
|
|
}
|
|
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
|
|
function standardizeSampleOrClassWeights(xWeight, outputNames, weightType) {
|
|
const numOutputs = outputNames.length;
|
|
if (xWeight == null || (Array.isArray(xWeight) && xWeight.length === 0)) {
|
|
return outputNames.map(name => null);
|
|
}
|
|
if (numOutputs === 1) {
|
|
if (Array.isArray(xWeight) && xWeight.length === 1) {
|
|
return xWeight;
|
|
}
|
|
else if (typeof xWeight === 'object' && outputNames[0] in xWeight) {
|
|
return [xWeight[outputNames[0]]];
|
|
}
|
|
else {
|
|
return [xWeight];
|
|
}
|
|
}
|
|
if (Array.isArray(xWeight)) {
|
|
if (xWeight.length !== numOutputs) {
|
|
throw new Error(`Provided ${weightType} is an array of ${xWeight.length} ` +
|
|
`element(s), but the model has ${numOutputs} outputs. ` +
|
|
`Make sure a set of weights is provided for each model output.`);
|
|
}
|
|
return xWeight;
|
|
}
|
|
else if (typeof xWeight === 'object' && Object.keys(xWeight).length > 0 &&
|
|
typeof xWeight[Object.keys(xWeight)[0]] ===
|
|
'object') {
|
|
const output = [];
|
|
outputNames.forEach(outputName => {
|
|
if (outputName in xWeight) {
|
|
output.push(xWeight[outputName]);
|
|
}
|
|
else {
|
|
output.push(null);
|
|
}
|
|
});
|
|
return output;
|
|
}
|
|
else {
|
|
throw new Error(`The model has multiple (${numOutputs}) outputs, ` +
|
|
`so ${weightType} must be either an array with ` +
|
|
`${numOutputs} elements or an object with ${outputNames} keys. ` +
|
|
`Provided ${weightType} not understood: ${JSON.stringify(xWeight)}`);
|
|
}
|
|
}
|
|
|
|
function standardizeClassWeights(classWeight, outputNames) {
|
|
return standardizeSampleOrClassWeights(classWeight, outputNames, 'classWeight');
|
|
}
|
|
|
|
async function standardizeWeights(y, sampleWeight, classWeight, sampleWeightMode) {
|
|
if (classWeight != null) {
|
|
|
|
const yClasses = tidy(() => {
|
|
if (y.shape.length === 1) {
|
|
|
|
return clone(y);
|
|
}
|
|
else if (y.shape.length === 2) {
|
|
if (y.shape[1] > 1) {
|
|
|
|
const axis = 1;
|
|
return argMax$2(y, axis);
|
|
}
|
|
else if (y.shape[1] === 1) {
|
|
|
|
return reshape$2(y, [y.shape[0]]);
|
|
}
|
|
else {
|
|
throw new Error(`Encountered unexpected last-dimension size (${y.shape[1]}) ` +
|
|
`during handling of class weights. The size is expected to be ` +
|
|
`>= 1.`);
|
|
}
|
|
}
|
|
else {
|
|
throw new Error(`Unexpected rank of target (y) tensor (${y.rank}) during ` +
|
|
`handling of class weights. The rank is expected to be 1 or 2.`);
|
|
}
|
|
});
|
|
const yClassIndices = Array.from(await yClasses.data());
|
|
dispose(yClasses);
|
|
const classSampleWeight = [];
|
|
yClassIndices.forEach(classIndex => {
|
|
if (classWeight[classIndex] == null) {
|
|
throw new Error(`classWeight must contain all classes in the training data. ` +
|
|
`The class ${classIndex} exists in the data but not in ` +
|
|
`classWeight`);
|
|
}
|
|
else {
|
|
classSampleWeight.push(classWeight[classIndex]);
|
|
}
|
|
});
|
|
return tensor1d(classSampleWeight, 'float32');
|
|
}
|
|
else {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
function computeWeightedLoss(losses, sampleWeights) {
|
|
return mul(losses, sampleWeights);
|
|
}
|
|
|
|
|
|
|
|
|
|
const DEFAULT_VALIDATION_BATCH_SIZE = 32;
|
|
|
|
function standardizeDataIteratorOutput(
|
|
|
|
|
|
|
|
model, iteratorOut) {
|
|
let xs;
|
|
let ys;
|
|
const iteratorOutObj = iteratorOut;
|
|
xs = iteratorOutObj['xs'];
|
|
ys = iteratorOutObj['ys'];
|
|
assert$1(xs != null && ys != null, () => 'A Dataset iterator for fitDataset() is expected to generate ' +
|
|
'objects of the form `{xs: xVal, ys: yVal}`, where the two ' +
|
|
'values may be `tf.Tensor`, an array of Tensors, or a map of ' +
|
|
'string to Tensor. The provided Dataset instead generates ' +
|
|
`${iteratorOut}`);
|
|
const flattenedXs = flattenTensorOrArrayOrMap('input', model.inputNames, xs);
|
|
const flattenedYs = flattenTensorOrArrayOrMap('output', model.outputNames, ys);
|
|
const batchSize = flattenedXs[0].shape[0];
|
|
assert$1(flattenedXs.length === model.inputs.length, () => `LayersModel has ${model.inputs.length} inputs, but the dataset ` +
|
|
`provides ${flattenedXs.length} inputs. (Expected input keys: ` +
|
|
`${JSON.stringify(model.inputNames)})`);
|
|
assert$1(flattenedYs.length === model.outputs.length, () => `LayersModel has ${model.outputs.length} outputs, but the dataset ` +
|
|
`provides ${flattenedYs.length} outputs. (Expected output keys: ` +
|
|
`${JSON.stringify(model.outputNames)})`);
|
|
for (let xIndex = 0; xIndex < flattenedXs.length; xIndex++) {
|
|
assert$1(flattenedXs[xIndex].shape[0] === batchSize, () => `Batch size mismatch: input ` +
|
|
`${model.inputNames[xIndex]} has ${flattenedXs[xIndex].shape[0]}; ` +
|
|
`expected ${batchSize} based on input ${model.inputNames[0]}.`);
|
|
}
|
|
for (let yIndex = 0; yIndex < flattenedYs.length; yIndex++) {
|
|
assert$1(flattenedYs[yIndex].shape[0] === batchSize, () => `Batch size mismatch: output ` +
|
|
`${model.outputNames[yIndex]} has ${flattenedYs[yIndex].shape[0]}; ` +
|
|
`expected ${batchSize} based on input ${model.inputNames[0]}.`);
|
|
}
|
|
return { xs: flattenedXs, ys: flattenedYs };
|
|
}
|
|
function flattenTensorOrArrayOrMap(inputOrOutput, names, values) {
|
|
if (values instanceof Tensor) {
|
|
return [values];
|
|
}
|
|
else if (Array.isArray(values)) {
|
|
assert$1(values.length === names.length, () => `Received an array of ${values.length} Tensors, but expected ${names.length} to match the ${inputOrOutput} keys ${names}.`);
|
|
return values;
|
|
}
|
|
else {
|
|
const result = [];
|
|
|
|
for (const name of names) {
|
|
if (values[name] == null) {
|
|
throw new ValueError(`The feature data generated by the dataset lacks the required ` +
|
|
`${inputOrOutput} key '${name}'.`);
|
|
}
|
|
result.push(values[name]);
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
function standardizeTensorValidationData(data) {
|
|
if (data.length === 3) {
|
|
throw new NotImplementedError('Validation with sample weights is not implemented yet.');
|
|
}
|
|
return { xs: data[0], ys: data[1] };
|
|
}
|
|
async function fitDataset(
|
|
|
|
|
|
|
|
model, dataset, args) {
|
|
const hasBatchesPerEpoch = args.batchesPerEpoch != null;
|
|
assert$1(model.optimizer != null, () => 'You must compile a model before training/testing. Use ' +
|
|
'LayersModel.compile(modelCompileConfig).');
|
|
assert$1(args != null, () => `For fitDataset(), the 2nd argument (config) is required, ` +
|
|
`but it is not provided in this call.`);
|
|
assert$1(args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), () => `For fitDataset(), config.epochs is expected to be a positive ` +
|
|
`integer, but got ${args.epochs}`);
|
|
assert$1(!hasBatchesPerEpoch ||
|
|
(args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
|
|
`positive integer if specified, but got ${args.batchesPerEpoch}`);
|
|
assert$1(
|
|
|
|
args['validationSplit'] == null, () => '`validationSplit` is not supported by `fitDataset()`. ' +
|
|
'Use validationData instead.');
|
|
if (model.isTraining) {
|
|
throw new Error('Cannot start training because another fit() call is ongoing.');
|
|
}
|
|
model.isTraining = true;
|
|
try {
|
|
const doValidation = args.validationData != null;
|
|
let valXs;
|
|
let valYs;
|
|
if (doValidation) {
|
|
if (isDatasetObject(args.validationData)) {
|
|
assert$1(args.validationBatches == null ||
|
|
(args.validationBatches > 0 &&
|
|
Number.isInteger(args.validationBatches)), () => `For fitDataset() with dataset-based validation, ` +
|
|
`config.validationBatches is expected not to be provided, ` +
|
|
`or to be a positive integer, ` +
|
|
`but got ${args.validationBatches}`);
|
|
}
|
|
else {
|
|
const validationData = standardizeTensorValidationData(args.validationData);
|
|
valXs = validationData.xs;
|
|
valYs = validationData.ys;
|
|
}
|
|
}
|
|
const trainFunction = model.makeTrainFunction();
|
|
const outLabels = model.getDedupedMetricsNames();
|
|
let callbackMetrics;
|
|
if (doValidation) {
|
|
callbackMetrics =
|
|
outLabels.slice().concat(outLabels.map(n => 'val_' + n));
|
|
}
|
|
else {
|
|
callbackMetrics = outLabels.slice();
|
|
}
|
|
const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
|
|
const verbose = args.verbose == null ? 1 : args.verbose;
|
|
const { callbackList, history } = configureCallbacks(callbacks, verbose, args.epochs, null, null, getStepsPerEpoch(dataset, args), null,
|
|
doValidation, callbackMetrics);
|
|
callbackList.setModel(model);
|
|
model.history = history;
|
|
await callbackList.onTrainBegin();
|
|
model.stopTraining_ = false;
|
|
let epoch = args.initialEpoch == null ? 0 : args.initialEpoch;
|
|
let dataIterator = await dataset.iterator();
|
|
while (epoch < args.epochs) {
|
|
const epochLogs = {};
|
|
await callbackList.onEpochBegin(epoch);
|
|
let stepsDone = 0;
|
|
let batchIndex = 0;
|
|
if (!hasBatchesPerEpoch) {
|
|
dataIterator = await dataset.iterator();
|
|
}
|
|
while (hasBatchesPerEpoch ? stepsDone < args.batchesPerEpoch : true) {
|
|
const iteratorOut = await dataIterator.next();
|
|
|
|
|
|
if (hasBatchesPerEpoch && iteratorOut.done) {
|
|
console.warn('You provided `batchesPerEpoch` as ' +
|
|
`${args.batchesPerEpoch}, ` +
|
|
'but your dataset iterator ran out of data after ' +
|
|
`${stepsDone} batches; ` +
|
|
'interrupting training. Make sure that your ' +
|
|
'dataset can generate at least `batchesPerEpoch * epochs` ' +
|
|
'batches (in this case, ' +
|
|
`${args.batchesPerEpoch * args.epochs} batches). ` +
|
|
'You may need to use the repeat() function when building ' +
|
|
'your dataset.');
|
|
break;
|
|
}
|
|
if (iteratorOut.value != null) {
|
|
const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
|
|
const batchLogs = {};
|
|
batchLogs['batch'] = batchIndex;
|
|
batchLogs['size'] = xs[0].shape[0];
|
|
await callbackList.onBatchBegin(batchIndex, batchLogs);
|
|
const sampleWeights = [];
|
|
if (args.classWeight != null) {
|
|
const standardClassWeights = standardizeClassWeights(args.classWeight, model.outputNames);
|
|
for (let i = 0; i < standardClassWeights.length; ++i) {
|
|
sampleWeights.push(await standardizeWeights(ys[i], null, standardClassWeights[i]));
|
|
}
|
|
}
|
|
|
|
const ins = xs.concat(ys).concat(sampleWeights);
|
|
const outs = trainFunction(ins);
|
|
dispose(ins);
|
|
for (let i = 0; i < outLabels.length; ++i) {
|
|
const label = outLabels[i];
|
|
const out = outs[i];
|
|
batchLogs[label] = out;
|
|
keep(out);
|
|
}
|
|
await callbackList.onBatchEnd(batchIndex, batchLogs);
|
|
disposeTensorsInLogs(batchLogs);
|
|
batchIndex++;
|
|
stepsDone++;
|
|
}
|
|
if (hasBatchesPerEpoch ? stepsDone >= args.batchesPerEpoch :
|
|
iteratorOut.done) {
|
|
|
|
if (doValidation) {
|
|
let valOuts;
|
|
if (isDatasetObject(args.validationData)) {
|
|
valOuts = toList(await model.evaluateDataset(args.validationData, { batches: args.validationBatches }));
|
|
}
|
|
else {
|
|
valOuts = toList(model.evaluate(valXs, valYs, {
|
|
batchSize: args.validationBatchSize == null ?
|
|
DEFAULT_VALIDATION_BATCH_SIZE :
|
|
args.validationBatchSize,
|
|
verbose: 0
|
|
}));
|
|
}
|
|
for (let i = 0; i < model.metricsNames.length; ++i) {
|
|
epochLogs[`val_${model.metricsNames[i]}`] = valOuts[i];
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
break;
|
|
}
|
|
if (model.stopTraining_) {
|
|
break;
|
|
}
|
|
}
|
|
await callbackList.onEpochEnd(epoch, epochLogs);
|
|
epoch++;
|
|
if (model.stopTraining_) {
|
|
break;
|
|
}
|
|
}
|
|
await callbackList.onTrainEnd();
|
|
await model.history.syncData();
|
|
return model.history;
|
|
}
|
|
finally {
|
|
model.isTraining = false;
|
|
}
|
|
}
|
|
|
|
function getStepsPerEpoch(dataset, args) {
|
|
|
|
let stepsPerEpoch = null;
|
|
if (args.batchesPerEpoch != null) {
|
|
stepsPerEpoch = args.batchesPerEpoch;
|
|
}
|
|
else if (Number.isFinite(dataset.size)) {
|
|
stepsPerEpoch = dataset.size;
|
|
}
|
|
return stepsPerEpoch;
|
|
}
|
|
|
|
|
|
function isDatasetObject(dataset) {
|
|
return (typeof dataset.iterator === 'function');
|
|
}
|
|
|
|
|
|
function isLazyIteratorObject(iterator) {
|
|
return (typeof iterator.next === 'function');
|
|
}
|
|
async function evaluateDataset(
|
|
|
|
|
|
|
|
model, dataset, args) {
|
|
args = args || {};
|
|
const hasBatches = args.batches != null;
|
|
const f = model.testFunction;
|
|
let outs = [];
|
|
if (args.verbose > 0) {
|
|
throw new NotImplementedError('Verbose mode is not implemented yet.');
|
|
}
|
|
assert$1(!hasBatches || (args.batches > 0 && Number.isInteger(args.batches)), () => 'Test loop expects `batches` to be a positive integer, but ' +
|
|
`received ${JSON.stringify(args.batches)}`);
|
|
const dataIterator = isLazyIteratorObject(dataset) ?
|
|
dataset :
|
|
await dataset.iterator();
|
|
|
|
let numExamples = 0;
|
|
let batch = 0;
|
|
while (hasBatches ? batch < args.batches : true) {
|
|
const iteratorOut = await dataIterator.next();
|
|
outs = tidy(() => {
|
|
if (iteratorOut.value) {
|
|
|
|
|
|
const { xs, ys } = standardizeDataIteratorOutput(model, iteratorOut.value);
|
|
const xsAndYs = xs.concat(ys);
|
|
const batchOuts = tidy(() => f(xsAndYs));
|
|
dispose(xsAndYs);
|
|
if (batch === 0) {
|
|
for (let i = 0; i < batchOuts.length; ++i) {
|
|
outs.push(scalar(0));
|
|
}
|
|
}
|
|
const batchSize = xsAndYs[0].shape[0];
|
|
for (let i = 0; i < batchOuts.length; ++i) {
|
|
const batchOut = batchOuts[i];
|
|
const oldScalar = outs[i];
|
|
outs[i] =
|
|
tidy(() => add$1(outs[i], mul(batchSize, batchOut)));
|
|
if (batch > 0) {
|
|
dispose(oldScalar);
|
|
}
|
|
}
|
|
dispose(batchOuts);
|
|
numExamples += batchSize;
|
|
++batch;
|
|
}
|
|
return outs;
|
|
});
|
|
if (iteratorOut.done) {
|
|
if (hasBatches) {
|
|
console.warn('Your dataset iterator ran out of data during evaluateDataset(). ' +
|
|
'Interrupting evalution. Make sure that your ' +
|
|
'dataset can generate at least `batches` ' +
|
|
`batches (in this case, ${args.batches} batches). ` +
|
|
'You may need to use the repeat() function when building ' +
|
|
'your dataset.');
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
for (let i = 0; i < outs.length; ++i) {
|
|
const oldScalar = outs[i];
|
|
outs[i] = div$1(outs[i], numExamples);
|
|
dispose(oldScalar);
|
|
}
|
|
return singletonOrArray(outs);
|
|
}
|
|
|
|
|
|
|
|
function checkBatchSize(batchSize) {
|
|
assert$1(batchSize > 0 && Number.isInteger(batchSize), () => `batchSize is required to be a positive integer, but got ${batchSize}`);
|
|
}
|
|
|
|
function sliceArrays(arrays, start, stop) {
|
|
if (arrays == null) {
|
|
return [null];
|
|
}
|
|
else if (Array.isArray(arrays)) {
|
|
return arrays.map(array => sliceAlongFirstAxis(array, start, stop - start));
|
|
}
|
|
else {
|
|
return sliceAlongFirstAxis(arrays, start, stop - start);
|
|
}
|
|
}
|
|
|
|
function sliceArraysByIndices(arrays, indices) {
|
|
return tidy(() => {
|
|
if (arrays == null) {
|
|
return null;
|
|
}
|
|
else if (Array.isArray(arrays)) {
|
|
return arrays.map(array => sliceArraysByIndices(array, indices));
|
|
}
|
|
else {
|
|
|
|
|
|
return gather(arrays, indices.dtype === 'int32' ? indices : cast$3(indices, 'int32'));
|
|
}
|
|
});
|
|
}
|
|
|
|
function makeBatches(size, batchSize) {
|
|
const output = [];
|
|
let batchStart = 0;
|
|
let batchEnd = null;
|
|
while (batchStart < size) {
|
|
batchEnd = batchStart + batchSize;
|
|
if (batchEnd >= size) {
|
|
batchEnd = size;
|
|
}
|
|
output.push([batchStart, batchEnd]);
|
|
batchStart = batchEnd;
|
|
}
|
|
return output;
|
|
}
|
|
|
|
function ensureTensorsRank2OrHigher(tensors) {
|
|
const outs = [];
|
|
if (tensors instanceof Tensor) {
|
|
tensors = [tensors];
|
|
}
|
|
|
|
for (let i = 0; i < tensors.length; ++i) {
|
|
const tensor = tensors[i];
|
|
if (tensor.rank === 1) {
|
|
outs.push(expandDims(tensor, 1));
|
|
}
|
|
else if (tensor.rank === 0) {
|
|
throw new Error('Expected tensor to be at least 1D, but received a 0D tensor ' +
|
|
'(scalar).');
|
|
}
|
|
else {
|
|
outs.push(tensor);
|
|
}
|
|
}
|
|
return outs;
|
|
}
|
|
|
|
|
|
function disposeNewTensors(tensors, refTensors) {
|
|
if (tensors == null) {
|
|
return;
|
|
}
|
|
const oldTensorIds = [];
|
|
if (refTensors instanceof Tensor) {
|
|
oldTensorIds.push(refTensors.id);
|
|
}
|
|
else if (Array.isArray(refTensors)) {
|
|
refTensors.forEach(t => oldTensorIds.push(t.id));
|
|
}
|
|
else if (refTensors != null) {
|
|
|
|
for (const name in refTensors) {
|
|
const oldTensor = refTensors[name];
|
|
oldTensorIds.push(oldTensor.id);
|
|
}
|
|
}
|
|
const tensorsToDispose = [];
|
|
if (tensors instanceof Tensor) {
|
|
if (oldTensorIds.indexOf(tensors.id) === -1) {
|
|
tensorsToDispose.push(tensors);
|
|
}
|
|
}
|
|
else if (Array.isArray(tensors)) {
|
|
tensors.forEach(t => {
|
|
if (oldTensorIds.indexOf(t.id) === -1) {
|
|
tensorsToDispose.push(t);
|
|
}
|
|
});
|
|
}
|
|
else if (tensors != null) {
|
|
|
|
for (const name in tensors) {
|
|
const tensor = tensors[name];
|
|
if (oldTensorIds.indexOf(tensor.id) === -1) {
|
|
tensorsToDispose.push(tensor);
|
|
}
|
|
}
|
|
}
|
|
tensorsToDispose.forEach(t => {
|
|
if (!t.isDisposed) {
|
|
t.dispose();
|
|
}
|
|
});
|
|
}
|
|
|
|
|
|
|
|
|
|
function isDataTensor(x) {
|
|
return x instanceof Tensor;
|
|
}
|
|
|
|
function isDataArray(x) {
|
|
return Array.isArray(x);
|
|
}
|
|
|
|
function isDataDict(x) {
|
|
return !isDataTensor(x) && !isDataArray(x);
|
|
}
|
|
|
|
function standardizeInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
|
|
if (names == null || names.length === 0) {
|
|
|
|
|
|
if (data != null) {
|
|
let gotUnexpectedData = false;
|
|
if (isDataArray(data) && data.length > 0) {
|
|
gotUnexpectedData = true;
|
|
}
|
|
else if (isDataDict(data)) {
|
|
for (const key in data) {
|
|
if (data.hasOwnProperty(key)) {
|
|
gotUnexpectedData = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
|
|
gotUnexpectedData = true;
|
|
}
|
|
if (gotUnexpectedData) {
|
|
throw new ValueError(`Error when checking model ${exceptionPrefix} expected no data, ` +
|
|
`but got ${data}`);
|
|
}
|
|
}
|
|
return [];
|
|
}
|
|
if (data == null) {
|
|
return names.map(name => null);
|
|
}
|
|
let arrays;
|
|
if (isDataDict(data)) {
|
|
data = data;
|
|
arrays = [];
|
|
for (const name of names) {
|
|
if (data[name] == null) {
|
|
throw new ValueError(`No data provided for "${name}". Need data for each key in: ` +
|
|
`${names}`);
|
|
}
|
|
arrays.push(data[name]);
|
|
}
|
|
}
|
|
else if (isDataArray(data)) {
|
|
data = data;
|
|
if (data.length !== names.length) {
|
|
throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
|
|
`Tensors that you are passing to your model is not the size the ` +
|
|
`model expected. Expected to see ${names.length} Tensor(s), but ` +
|
|
`instead got the following list of Tensor(s): ${data}`);
|
|
}
|
|
arrays = data;
|
|
}
|
|
else {
|
|
data = data;
|
|
if (names.length > 1) {
|
|
throw new ValueError(`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
|
|
`but only received one Tensor. Found: Tensor with shape ${data.shape}`);
|
|
}
|
|
arrays = [data];
|
|
}
|
|
arrays = ensureTensorsRank2OrHigher(arrays);
|
|
|
|
if (shapes != null) {
|
|
for (let i = 0; i < names.length; ++i) {
|
|
if (shapes[i] == null) {
|
|
continue;
|
|
}
|
|
const array = arrays[i];
|
|
if (array.shape.length !== shapes[i].length) {
|
|
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
|
|
`to have ${shapes[i].length} dimension(s). but got array with ` +
|
|
`shape ${array.shape}`);
|
|
}
|
|
for (let j = 0; j < shapes[i].length; ++j) {
|
|
if (j === 0 && !checkBatchAxis) {
|
|
|
|
continue;
|
|
}
|
|
const dim = array.shape[j];
|
|
const refDim = shapes[i][j];
|
|
if (refDim != null && refDim >= 0 && dim !== refDim) {
|
|
throw new ValueError(`${exceptionPrefix} expected a batch of elements where each ` +
|
|
`example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
|
|
`(i.e.,tensor shape [*,${shapes[i].slice(1, shapes[i].length)}])` +
|
|
` but the ${exceptionPrefix} received an input with ${array.shape[0]}` +
|
|
` examples, each with shape [${array.shape.slice(1, array.shape.length)}]` +
|
|
` (tensor shape [${array.shape}])`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return arrays;
|
|
}
|
|
|
|
function checkArrayLengths(inputs, targets, weights) {
|
|
const setX = unique(inputs.map(input => input.shape[0]));
|
|
setX.sort();
|
|
const setY = unique(targets.map(target => target.shape[0]));
|
|
setY.sort();
|
|
|
|
if (setX.length > 1) {
|
|
throw new ValueError(`All input Tensors (x) should have the same number of samples. ` +
|
|
`Got array shapes: ` +
|
|
`${JSON.stringify(inputs.map(input => input.shape))}`);
|
|
}
|
|
if (setY.length > 1) {
|
|
throw new ValueError(`All target Tensors (y) should have the same number of samples. ` +
|
|
`Got array shapes: ` +
|
|
`${JSON.stringify(targets.map(target => target.shape))}`);
|
|
}
|
|
if (setX.length > 0 && setY.length > 0 && !arraysEqual(setX, setY)) {
|
|
throw new ValueError(`Input Tensors should have the same number of samples as target ` +
|
|
`Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
|
|
`sample(s).`);
|
|
}
|
|
}
|
|
|
|
function checkLossAndTargetCompatibility(targets, lossFns, outputShapes) {
|
|
|
|
const keyLosses = [
|
|
meanSquaredError, binaryCrossentropy$1,
|
|
categoricalCrossentropy$1
|
|
];
|
|
for (let i = 0; i < targets.length; ++i) {
|
|
const y = targets[i];
|
|
const loss = lossFns[i];
|
|
const shape = outputShapes[i];
|
|
if (loss == null) {
|
|
continue;
|
|
}
|
|
if (loss === categoricalCrossentropy$1) {
|
|
if (y.shape[y.shape.length - 1] === 1) {
|
|
throw new ValueError(`You are passing a target array of shape ${y.shape} while using ` +
|
|
`a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
|
|
`expects targets to be binary matrices (1s and 0s) of shape ` +
|
|
`[samples, classes].`);
|
|
|
|
}
|
|
}
|
|
if (keyLosses.indexOf(loss) !== -1) {
|
|
const slicedYShape = y.shape.slice(1);
|
|
const slicedShape = shape.slice(1);
|
|
for (let j = 0; j < slicedYShape.length; ++j) {
|
|
const targetDim = slicedYShape[j];
|
|
const outDim = slicedShape[j];
|
|
if (outDim != null && targetDim !== outDim) {
|
|
throw new ValueError(`A target Tensor with shape ${y.shape} was passed for an ` +
|
|
`output of shape ${shape}, while using a loss function that ` +
|
|
`expects targets to have the same shape as the output.`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
function checkInputData(data, names, shapes, checkBatchAxis = true, exceptionPrefix = '') {
|
|
let arrays;
|
|
if (Array.isArray(data)) {
|
|
if (data.length !== names.length) {
|
|
throw new ValueError(`Error when checking model ${exceptionPrefix}: the Array of ` +
|
|
`Tensors that you are passing to your model is not the size the ` +
|
|
`the model expected. Expected to see ${names.length} Tensor(s),` +
|
|
` but instead got ${data.length} Tensors(s).`);
|
|
}
|
|
arrays = data;
|
|
}
|
|
else {
|
|
if (names.length > 1) {
|
|
throw new ValueError(`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
|
|
`but only received one Tensor. Found: array with shape ` +
|
|
`${JSON.stringify(data.shape)}.`);
|
|
}
|
|
arrays = [data];
|
|
}
|
|
if (shapes != null) {
|
|
for (let i = 0; i < names.length; ++i) {
|
|
if (shapes[i] == null) {
|
|
continue;
|
|
}
|
|
const array = arrays[i];
|
|
if (array.shape.length !== shapes[i].length) {
|
|
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
|
|
`to have ${shapes[i].length} dimension(s), but got array with ` +
|
|
`shape ${JSON.stringify(array.shape)}`);
|
|
}
|
|
for (let j = 0; j < shapes[i].length; ++j) {
|
|
if (j === 0 && !checkBatchAxis) {
|
|
continue;
|
|
}
|
|
const dim = array.shape[j];
|
|
const refDim = shapes[i][j];
|
|
if (refDim != null) {
|
|
if (refDim !== dim) {
|
|
throw new ValueError(`Error when checking ${exceptionPrefix}: expected ` +
|
|
`${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
|
|
`got array with shape ${JSON.stringify(array.shape)}.`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
function collectMetrics(metrics, outputNames) {
|
|
if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
|
|
return outputNames.map(name => []);
|
|
}
|
|
let wrappedMetrics;
|
|
if (typeof metrics === 'string' || typeof metrics === 'function') {
|
|
wrappedMetrics = [metrics];
|
|
}
|
|
else if (Array.isArray(metrics) || typeof metrics === 'object') {
|
|
wrappedMetrics = metrics;
|
|
}
|
|
else {
|
|
throw new TypeError('Type of metrics argument not understood. Expected an string,' +
|
|
`function, Array, or Object, found: ${metrics}`);
|
|
}
|
|
if (Array.isArray(wrappedMetrics)) {
|
|
|
|
return outputNames.map(name => wrappedMetrics);
|
|
}
|
|
else {
|
|
|
|
const nestedMetrics = [];
|
|
for (const name of outputNames) {
|
|
let outputMetrics = wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
|
|
if (!Array.isArray(outputMetrics)) {
|
|
outputMetrics = [outputMetrics];
|
|
}
|
|
nestedMetrics.push(outputMetrics);
|
|
}
|
|
return nestedMetrics;
|
|
}
|
|
}
|
|
const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
|
|
|
|
class LayersModel extends Container {
|
|
constructor(args) {
|
|
super(args);
|
|
this.isTraining = false;
|
|
}
|
|
|
|
summary(lineLength, positions, printFn = console.log) {
|
|
if (!this.built) {
|
|
throw new ValueError(`This model has never been called, thus its weights have not been ` +
|
|
`created yet. So no summary can be displayed. Build the model ` +
|
|
`first (e.g., by calling it on some test data).`);
|
|
}
|
|
printSummary(this, lineLength, positions, printFn);
|
|
}
|
|
|
|
compile(args) {
|
|
if (args.loss == null) {
|
|
args.loss = [];
|
|
}
|
|
this.loss = args.loss;
|
|
if (typeof args.optimizer === 'string') {
|
|
this.optimizer_ = getOptimizer(args.optimizer);
|
|
this.isOptimizerOwned = true;
|
|
}
|
|
else {
|
|
if (!(args.optimizer instanceof Optimizer)) {
|
|
throw new ValueError(`User-defined optimizer must be an instance of tf.Optimizer.`);
|
|
}
|
|
this.optimizer_ = args.optimizer;
|
|
this.isOptimizerOwned = false;
|
|
}
|
|
|
|
|
|
|
|
let lossFunctions = [];
|
|
if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
|
|
typeof args.loss !== 'function') {
|
|
args.loss = args.loss;
|
|
for (const name in args.loss) {
|
|
if (this.outputNames.indexOf(name) === -1) {
|
|
throw new ValueError(`Unknown entry in loss dictionary: "${name}". ` +
|
|
`Only expected the following keys: ${this.outputNames}`);
|
|
}
|
|
}
|
|
for (const name of this.outputNames) {
|
|
if (args.loss[name] == null) {
|
|
console.warn(`Output "${name}" is missing from loss dictionary. We assume ` +
|
|
`this was done on purpose, and we will not be expecting data ` +
|
|
`to be passed to ${name} during training`);
|
|
}
|
|
lossFunctions.push(get$1(args.loss[name]));
|
|
}
|
|
}
|
|
else if (Array.isArray(args.loss)) {
|
|
if (args.loss.length !== this.outputs.length) {
|
|
throw new ValueError(`When passing an Array as loss, it should have one entry per ` +
|
|
`model output. The model has ${this.outputs.length} output(s), ` +
|
|
`but you passed loss=${args.loss}.`);
|
|
}
|
|
const theLosses = args.loss;
|
|
lossFunctions = theLosses.map(l => get$1(l));
|
|
}
|
|
else {
|
|
const lossFunction = get$1(args.loss);
|
|
this.outputs.forEach(_ => {
|
|
lossFunctions.push(lossFunction);
|
|
});
|
|
}
|
|
this.lossFunctions = lossFunctions;
|
|
this.feedOutputNames = [];
|
|
this.feedOutputShapes = [];
|
|
this.feedLossFns = [];
|
|
for (let i = 0; i < this.outputs.length; ++i) {
|
|
|
|
const shape = this.internalOutputShapes[i];
|
|
const name = this.outputNames[i];
|
|
this.feedOutputNames.push(name);
|
|
this.feedOutputShapes.push(shape);
|
|
this.feedLossFns.push(this.lossFunctions[i]);
|
|
}
|
|
|
|
|
|
const skipTargetIndices = [];
|
|
|
|
this.metrics = args.metrics;
|
|
|
|
this.metricsNames = ['loss'];
|
|
this.metricsTensors = [];
|
|
|
|
|
|
|
|
|
|
nameScope('loss', () => {
|
|
for (let i = 0; i < this.outputs.length; ++i) {
|
|
if (skipTargetIndices.indexOf(i) !== -1) {
|
|
continue;
|
|
}
|
|
|
|
|
|
const weightedLoss = this.lossFunctions[i];
|
|
if (this.outputs.length > 1) {
|
|
this.metricsTensors.push([weightedLoss, i]);
|
|
this.metricsNames.push(this.outputNames[i] + '_loss');
|
|
}
|
|
}
|
|
|
|
|
|
});
|
|
const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
|
|
|
|
|
|
const appendMetric = (outputIndex, metricName, metricTensor) => {
|
|
if (this.outputNames.length > 1) {
|
|
metricName = this.outputNames[outputIndex] + '_' + metricName;
|
|
}
|
|
this.metricsNames.push(metricName);
|
|
this.metricsTensors.push([metricTensor, outputIndex]);
|
|
};
|
|
nameScope('metric', () => {
|
|
for (let i = 0; i < this.outputs.length; ++i) {
|
|
if (skipTargetIndices.indexOf(i) !== -1) {
|
|
continue;
|
|
}
|
|
const outputMetrics = nestedMetrics[i];
|
|
|
|
|
|
const handleMetrics = (metrics) => {
|
|
const metricNamePrefix = '';
|
|
let metricName;
|
|
let accFn;
|
|
let weightedMetricFn;
|
|
|
|
for (const metric of metrics) {
|
|
if (typeof metric === 'string' &&
|
|
['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
|
|
-1) {
|
|
const outputShape = this.internalOutputShapes[i];
|
|
if (outputShape[outputShape.length - 1] === 1 ||
|
|
this.lossFunctions[i] === binaryCrossentropy$1) {
|
|
|
|
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
|
|
accFn = binaryAccuracy;
|
|
}
|
|
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
|
|
accFn = binaryCrossentropy;
|
|
}
|
|
}
|
|
else if (this.lossFunctions[i] ===
|
|
sparseCategoricalCrossentropy$1) {
|
|
|
|
|
|
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
|
|
accFn = sparseCategoricalAccuracy;
|
|
}
|
|
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
|
|
accFn = sparseCategoricalCrossentropy;
|
|
}
|
|
}
|
|
else {
|
|
|
|
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
|
|
accFn = categoricalAccuracy;
|
|
}
|
|
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
|
|
accFn = categoricalCrossentropy;
|
|
}
|
|
}
|
|
let suffix;
|
|
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
|
|
suffix = 'acc';
|
|
}
|
|
else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
|
|
suffix = 'ce';
|
|
}
|
|
|
|
weightedMetricFn = accFn;
|
|
metricName = metricNamePrefix + suffix;
|
|
}
|
|
else {
|
|
const metricFn = get(metric);
|
|
|
|
weightedMetricFn = metricFn;
|
|
metricName =
|
|
metricNamePrefix + getLossOrMetricName(metric);
|
|
}
|
|
|
|
let metricResult;
|
|
nameScope(metricName, () => {
|
|
metricResult = weightedMetricFn;
|
|
});
|
|
appendMetric(i, metricName, metricResult);
|
|
}
|
|
};
|
|
handleMetrics(outputMetrics);
|
|
|
|
}
|
|
});
|
|
|
|
|
|
this.collectedTrainableWeights = this.trainableWeights;
|
|
}
|
|
|
|
checkTrainableWeightsConsistency() {
|
|
if (this.collectedTrainableWeights == null) {
|
|
return;
|
|
}
|
|
if (this.trainableWeights.length !==
|
|
this.collectedTrainableWeights.length) {
|
|
console.warn('Discrepancy between trainableweights and collected trainable ' +
|
|
'weights. Did you set `model.trainable` without calling ' +
|
|
'`model.compile()` afterwards?');
|
|
}
|
|
}
|
|
|
|
evaluate(x, y, args = {}) {
|
|
const batchSize = args.batchSize == null ? 32 : args.batchSize;
|
|
checkBatchSize(batchSize);
|
|
|
|
|
|
const checkBatchAxis = true;
|
|
const standardizedOuts = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
|
|
try {
|
|
|
|
|
|
const ins = standardizedOuts[0].concat(standardizedOuts[1]);
|
|
this.makeTestFunction();
|
|
const f = this.testFunction;
|
|
const testOuts = this.testLoop(f, ins, batchSize, args.verbose, args.steps);
|
|
return singletonOrArray(testOuts);
|
|
}
|
|
finally {
|
|
disposeNewTensors(standardizedOuts[0], x);
|
|
disposeNewTensors(standardizedOuts[1], y);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
async evaluateDataset(dataset, args) {
|
|
this.makeTestFunction();
|
|
return evaluateDataset(this, dataset, args);
|
|
}
|
|
|
|
checkNumSamples(ins, batchSize, steps, stepsName = 'steps') {
|
|
let numSamples;
|
|
if (steps != null) {
|
|
numSamples = null;
|
|
if (batchSize != null) {
|
|
throw new ValueError(`If ${stepsName} is set, batchSize must be null or undefined.` +
|
|
`Got batchSize = ${batchSize}`);
|
|
}
|
|
}
|
|
else if (ins != null) {
|
|
if (Array.isArray(ins)) {
|
|
numSamples = ins[0].shape[0];
|
|
}
|
|
else {
|
|
numSamples = ins.shape[0];
|
|
}
|
|
}
|
|
else {
|
|
throw new ValueError(`Either the input data should have a defined shape, or ` +
|
|
`${stepsName} shoud be specified.`);
|
|
}
|
|
return numSamples;
|
|
}
|
|
|
|
execute(inputs, outputs) {
|
|
if (Array.isArray(outputs) && outputs.length === 0) {
|
|
throw new ValueError('`outputs` is an empty Array, which is not allowed.');
|
|
}
|
|
const outputsIsArray = Array.isArray(outputs);
|
|
const outputNames = (outputsIsArray ? outputs : [outputs]);
|
|
const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
|
|
|
|
const feedDict = new FeedDict();
|
|
if (inputs instanceof Tensor) {
|
|
inputs = [inputs];
|
|
}
|
|
if (Array.isArray(inputs)) {
|
|
if (inputs.length !== this.inputs.length) {
|
|
throw new ValueError(`The number of inputs provided (${inputs.length}) ` +
|
|
`does not match the number of inputs of this model ` +
|
|
`(${this.inputs.length}).`);
|
|
}
|
|
for (let i = 0; i < this.inputs.length; ++i) {
|
|
feedDict.add(this.inputs[i], inputs[i]);
|
|
}
|
|
}
|
|
else {
|
|
for (const input of this.inputs) {
|
|
const tensorValue = inputs[input.name];
|
|
if (tensorValue == null) {
|
|
throw new ValueError(`No value is provided for the model's input ${input.name}`);
|
|
}
|
|
feedDict.add(input, tensorValue);
|
|
}
|
|
}
|
|
|
|
const executeOutputs = execute(outputSymbolicTensors, feedDict);
|
|
return outputsIsArray ? executeOutputs : executeOutputs[0];
|
|
}
|
|
|
|
retrieveSymbolicTensors(symbolicTensorNames) {
|
|
const outputSymbolicTensors = pyListRepeat(null, symbolicTensorNames.length);
|
|
let outputsRemaining = symbolicTensorNames.length;
|
|
for (const layer of this.layers) {
|
|
const layerOutputs = Array.isArray(layer.output) ? layer.output : [layer.output];
|
|
const layerOutputNames = layerOutputs.map(output => output.name);
|
|
for (let i = 0; i < symbolicTensorNames.length; ++i) {
|
|
const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
|
|
if (index !== -1) {
|
|
outputSymbolicTensors[i] = layerOutputs[index];
|
|
outputsRemaining--;
|
|
}
|
|
if (outputsRemaining === 0) {
|
|
break;
|
|
}
|
|
}
|
|
if (outputsRemaining === 0) {
|
|
break;
|
|
}
|
|
}
|
|
if (outputsRemaining > 0) {
|
|
const remainingNames = [];
|
|
outputSymbolicTensors.forEach((tensor, i) => {
|
|
if (tensor == null) {
|
|
remainingNames.push(symbolicTensorNames[i]);
|
|
}
|
|
});
|
|
throw new ValueError(`Cannot find SymbolicTensors for output name(s): ` +
|
|
`${JSON.stringify(remainingNames)}`);
|
|
}
|
|
return outputSymbolicTensors;
|
|
}
|
|
|
|
predictLoop(ins, batchSize = 32, verbose = false) {
|
|
return tidy(() => {
|
|
const numSamples = this.checkNumSamples(ins);
|
|
if (verbose) {
|
|
throw new NotImplementedError('Verbose predictLoop() is not implemented yet.');
|
|
}
|
|
|
|
|
|
|
|
|
|
const batches = makeBatches(numSamples, batchSize);
|
|
const outsBatches = this.outputs.map(output => []);
|
|
|
|
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
|
|
const batchOuts = tidy(() => {
|
|
const batchStart = batches[batchIndex][0];
|
|
const batchEnd = batches[batchIndex][1];
|
|
|
|
|
|
const insBatch = sliceArrays(ins, batchStart, batchEnd);
|
|
|
|
const feeds = [];
|
|
if (Array.isArray(insBatch)) {
|
|
for (let i = 0; i < insBatch.length; ++i) {
|
|
feeds.push({ key: this.inputs[i], value: insBatch[i] });
|
|
}
|
|
}
|
|
else {
|
|
feeds.push({ key: this.inputs[0], value: insBatch });
|
|
}
|
|
const feedDict = new FeedDict(feeds);
|
|
return execute(this.outputs, feedDict);
|
|
});
|
|
batchOuts.forEach((batchOut, i) => outsBatches[i].push(batchOut));
|
|
}
|
|
return singletonOrArray(outsBatches.map(batches => concat$2(batches, 0)));
|
|
});
|
|
}
|
|
|
|
predict(x, args = {}) {
|
|
const xsRank2OrHigher = ensureTensorsRank2OrHigher(x);
|
|
checkInputData(xsRank2OrHigher, this.inputNames, this.feedInputShapes, false);
|
|
try {
|
|
|
|
|
|
|
|
|
|
const batchSize = args.batchSize == null ? 32 : args.batchSize;
|
|
checkBatchSize(batchSize);
|
|
return this.predictLoop(xsRank2OrHigher, batchSize);
|
|
}
|
|
finally {
|
|
disposeNewTensors(xsRank2OrHigher, x);
|
|
}
|
|
}
|
|
|
|
predictOnBatch(x) {
|
|
checkInputData(x, this.inputNames, this.feedInputShapes, true);
|
|
|
|
|
|
const batchSize = (Array.isArray(x) ? x[0] : x).shape[0];
|
|
return this.predictLoop(x, batchSize);
|
|
}
|
|
standardizeUserDataXY(x, y, checkBatchAxis = true, batchSize) {
|
|
|
|
if (this.optimizer_ == null) {
|
|
throw new RuntimeError('You must compile a model before training/testing. Use ' +
|
|
'LayersModel.compile(modelCompileArgs).');
|
|
}
|
|
const outputShapes = [];
|
|
for (let i = 0; i < this.feedOutputShapes.length; ++i) {
|
|
const outputShape = this.feedOutputShapes[i];
|
|
const lossFn = this.feedLossFns[i];
|
|
if (lossFn === sparseCategoricalCrossentropy$1) {
|
|
outputShapes.push(outputShape.slice(0, outputShape.length - 1).concat([1]));
|
|
}
|
|
else {
|
|
|
|
outputShapes.push(outputShape);
|
|
}
|
|
}
|
|
x = standardizeInputData(x, this.feedInputNames, this.feedInputShapes, false, 'input');
|
|
y = standardizeInputData(y, this.feedOutputNames, outputShapes, false, 'target');
|
|
|
|
checkArrayLengths(x, y);
|
|
|
|
checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
|
|
if (this.stateful && batchSize != null && batchSize > 0) {
|
|
if (x[0].shape[0] % batchSize !== 0) {
|
|
throw new ValueError(`In a stateful network, you should only pass inputs with a ` +
|
|
`number of samples that is divisible by the batch size ` +
|
|
`${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
|
|
}
|
|
}
|
|
return [x, y];
|
|
}
|
|
async standardizeUserData(x, y, sampleWeight, classWeight, checkBatchAxis = true, batchSize) {
|
|
const [standardXs, standardYs] = this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
|
|
|
|
if (sampleWeight != null) {
|
|
throw new Error('sample weight is not supported yet.');
|
|
}
|
|
let standardSampleWeights = null;
|
|
if (classWeight != null) {
|
|
const classWeights = standardizeClassWeights(classWeight, this.outputNames);
|
|
standardSampleWeights = [];
|
|
for (let i = 0; i < classWeights.length; ++i) {
|
|
standardSampleWeights.push(await standardizeWeights(standardYs[i], null, classWeights[i]));
|
|
}
|
|
}
|
|
|
|
return [standardXs, standardYs, standardSampleWeights];
|
|
}
|
|
|
|
testLoop(f, ins, batchSize, verbose = 0, steps) {
|
|
return tidy(() => {
|
|
const numSamples = this.checkNumSamples(ins, batchSize, steps, 'steps');
|
|
const outs = [];
|
|
if (verbose > 0) {
|
|
throw new NotImplementedError('Verbose mode is not implemented yet.');
|
|
}
|
|
|
|
if (steps != null) {
|
|
throw new NotImplementedError('steps mode in testLoop() is not implemented yet');
|
|
}
|
|
else {
|
|
const batches = makeBatches(numSamples, batchSize);
|
|
const indexArray = tensor1d(range(0, numSamples));
|
|
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
|
|
const batchStart = batches[batchIndex][0];
|
|
const batchEnd = batches[batchIndex][1];
|
|
const batchIds = sliceAlongFirstAxis(indexArray, batchStart, batchEnd - batchStart);
|
|
|
|
|
|
const insBatch = sliceArraysByIndices(ins, batchIds);
|
|
const batchOuts = f(insBatch);
|
|
if (batchIndex === 0) {
|
|
for (let i = 0; i < batchOuts.length; ++i) {
|
|
outs.push(scalar(0));
|
|
}
|
|
}
|
|
for (let i = 0; i < batchOuts.length; ++i) {
|
|
const batchOut = batchOuts[i];
|
|
outs[i] =
|
|
add$1(outs[i], mul(batchEnd - batchStart, batchOut));
|
|
}
|
|
}
|
|
for (let i = 0; i < outs.length; ++i) {
|
|
outs[i] = div$1(outs[i], numSamples);
|
|
}
|
|
}
|
|
return outs;
|
|
});
|
|
}
|
|
getDedupedMetricsNames() {
|
|
const outLabels = this.metricsNames;
|
|
|
|
|
|
const dedupedOutLabels = [];
|
|
for (let i = 0; i < outLabels.length; ++i) {
|
|
const label = outLabels[i];
|
|
let newLabel = label;
|
|
if (count(outLabels, label) > 1) {
|
|
const dupIndex = count(outLabels.slice(0, i), label);
|
|
newLabel += `_${dupIndex}`;
|
|
}
|
|
dedupedOutLabels.push(newLabel);
|
|
}
|
|
return dedupedOutLabels;
|
|
}
|
|
|
|
makeTrainFunction() {
|
|
return (data) => {
|
|
const lossValues = [];
|
|
const inputs = data.slice(0, this.inputs.length);
|
|
const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
|
|
const sampleWeights = data.slice(this.inputs.length + this.outputs.length, this.inputs.length + this.outputs.length * 2);
|
|
const metricsValues = [];
|
|
|
|
|
|
|
|
const totalLossFunction = () => {
|
|
const feeds = [];
|
|
for (let i = 0; i < this.inputs.length; ++i) {
|
|
feeds.push({ key: this.inputs[i], value: inputs[i] });
|
|
}
|
|
const feedDict = new FeedDict(feeds);
|
|
const outputs = execute(this.outputs, feedDict, { 'training': true });
|
|
|
|
|
|
let totalLoss;
|
|
for (let i = 0; i < this.lossFunctions.length; ++i) {
|
|
const lossFunction = this.lossFunctions[i];
|
|
let loss = lossFunction(targets[i], outputs[i]);
|
|
if (sampleWeights[i] != null) {
|
|
loss = computeWeightedLoss(loss, sampleWeights[i]);
|
|
}
|
|
|
|
const meanLoss = mean$1(loss);
|
|
|
|
lossValues.push(meanLoss);
|
|
if (i === 0) {
|
|
totalLoss = loss;
|
|
}
|
|
else {
|
|
totalLoss = add$1(totalLoss, loss);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for (let i = 0; i < this.metricsTensors.length; ++i) {
|
|
let weightedMetric;
|
|
if (this.outputs.length > 1 && i < this.outputs.length) {
|
|
weightedMetric = lossValues[i];
|
|
}
|
|
else {
|
|
const metric = this.metricsTensors[i][0];
|
|
const outputIndex = this.metricsTensors[i][1];
|
|
weightedMetric =
|
|
mean$1(metric(targets[outputIndex], outputs[outputIndex]));
|
|
}
|
|
keep(weightedMetric);
|
|
|
|
metricsValues.push(weightedMetric);
|
|
}
|
|
totalLoss = mean$1(totalLoss);
|
|
|
|
this.calculateLosses().forEach(regularizerLoss => {
|
|
totalLoss = add$1(totalLoss, regularizerLoss);
|
|
});
|
|
return totalLoss;
|
|
};
|
|
const variables = this.collectedTrainableWeights.map(param => param.read());
|
|
const returnCost = true;
|
|
const totalLossValue = this.optimizer_.minimize(totalLossFunction, returnCost, variables);
|
|
return [totalLossValue].concat(metricsValues);
|
|
};
|
|
}
|
|
|
|
makeTestFunction() {
|
|
this.testFunction = (data) => {
|
|
return tidy(() => {
|
|
const valOutputs = [];
|
|
let totalLoss;
|
|
const inputs = data.slice(0, this.inputs.length);
|
|
const targets = data.slice(this.inputs.length, this.inputs.length + this.outputs.length);
|
|
const feeds = [];
|
|
for (let i = 0; i < this.inputs.length; ++i) {
|
|
feeds.push({ key: this.inputs[i], value: inputs[i] });
|
|
}
|
|
const feedDict = new FeedDict(feeds);
|
|
const outputs = execute(this.outputs, feedDict);
|
|
|
|
for (let i = 0; i < this.lossFunctions.length; ++i) {
|
|
const lossFunction = this.lossFunctions[i];
|
|
|
|
|
|
const loss = mean$1(lossFunction(targets[i], outputs[i]));
|
|
if (i === 0) {
|
|
totalLoss = loss;
|
|
}
|
|
else {
|
|
totalLoss = add$1(totalLoss, loss);
|
|
}
|
|
valOutputs.push(totalLoss);
|
|
}
|
|
|
|
for (let i = 0; i < this.metricsTensors.length; ++i) {
|
|
const metric = this.metricsTensors[i][0];
|
|
const outputIndex = this.metricsTensors[i][1];
|
|
|
|
const meanMetric = mean$1(metric(targets[outputIndex], outputs[outputIndex]));
|
|
valOutputs.push(meanMetric);
|
|
}
|
|
return valOutputs;
|
|
});
|
|
};
|
|
}
|
|
|
|
async fit(x, y, args = {}) {
|
|
if (this.isTraining) {
|
|
throw new Error('Cannot start training because another fit() call is ongoing.');
|
|
}
|
|
this.isTraining = true;
|
|
let inputs;
|
|
let targets;
|
|
let originalInputs;
|
|
let originalTargets;
|
|
let inputValX;
|
|
let inputValY;
|
|
let valX;
|
|
let valY;
|
|
let sampleWeights;
|
|
try {
|
|
const batchSize = args.batchSize == null ? 32 : args.batchSize;
|
|
checkBatchSize(batchSize);
|
|
|
|
|
|
const checkBatchAxis = false;
|
|
const standardizedOuts = await this.standardizeUserData(x, y, args.sampleWeight, args.classWeight, checkBatchAxis, batchSize);
|
|
inputs = standardizedOuts[0];
|
|
targets = standardizedOuts[1];
|
|
sampleWeights = standardizedOuts[2];
|
|
|
|
let doValidation = false;
|
|
let valIns;
|
|
if (args.validationData != null && args.validationData.length > 0) {
|
|
doValidation = true;
|
|
if (args.validationData.length === 2) {
|
|
|
|
inputValX = args.validationData[0];
|
|
inputValY = args.validationData[1];
|
|
}
|
|
else if (args.validationData.length === 3) {
|
|
throw new NotImplementedError('validationData including sample weights is not supported yet.');
|
|
}
|
|
else {
|
|
throw new ValueError(`When passing validation data, it must contain 2 (valX, valY) ` +
|
|
`or 3 (valX, valY, valSampleWeight) items; ` +
|
|
`${args.validationData} is invalid.`);
|
|
}
|
|
const checkBatchAxis = true;
|
|
const valStandardized = await this.standardizeUserData(inputValX, inputValY, null, null, checkBatchAxis, batchSize);
|
|
valX = valStandardized[0];
|
|
valY = valStandardized[1];
|
|
valIns = valX.concat(valY);
|
|
|
|
}
|
|
else if (args.validationSplit != null && args.validationSplit > 0 &&
|
|
args.validationSplit < 1) {
|
|
doValidation = true;
|
|
|
|
const splitAt = Math.floor(inputs[0].shape[0] * (1 - args.validationSplit));
|
|
const originalBatchSize = inputs[0].shape[0];
|
|
valX = sliceArrays(inputs, splitAt, originalBatchSize);
|
|
originalInputs = inputs;
|
|
inputs = sliceArrays(inputs, 0, splitAt);
|
|
valY = sliceArrays(targets, splitAt, originalBatchSize);
|
|
originalTargets = targets;
|
|
targets = sliceArrays(targets, 0, splitAt);
|
|
|
|
|
|
valIns = valX.concat(valY);
|
|
|
|
}
|
|
else if (args.validationSteps != null) {
|
|
doValidation = true;
|
|
|
|
}
|
|
const ins = inputs.concat(targets).concat(sampleWeights);
|
|
this.checkTrainableWeightsConsistency();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const trainFunction = this.makeTrainFunction();
|
|
const outLabels = this.getDedupedMetricsNames();
|
|
let valFunction;
|
|
let callbackMetrics;
|
|
if (doValidation) {
|
|
this.makeTestFunction();
|
|
valFunction = this.testFunction;
|
|
callbackMetrics =
|
|
outLabels.slice().concat(outLabels.map(n => 'val_' + n));
|
|
}
|
|
else {
|
|
valFunction = null;
|
|
valIns = [];
|
|
callbackMetrics = outLabels.slice();
|
|
}
|
|
const callbacks = standardizeCallbacks(args.callbacks, args.yieldEvery);
|
|
const out = await this.fitLoop(trainFunction, ins, outLabels, batchSize, args.epochs, args.verbose, callbacks, valFunction, valIns, args.shuffle, callbackMetrics, args.initialEpoch, null, null);
|
|
return out;
|
|
}
|
|
finally {
|
|
this.isTraining = false;
|
|
|
|
disposeNewTensors(inputs, x);
|
|
disposeNewTensors(targets, y);
|
|
disposeNewTensors(originalInputs, x);
|
|
disposeNewTensors(originalTargets, y);
|
|
disposeNewTensors(valX, inputValX);
|
|
disposeNewTensors(valY, inputValY);
|
|
if (sampleWeights != null) {
|
|
dispose(sampleWeights);
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
async fitLoop(f, ins, outLabels, batchSize, epochs, verbose, callbacks, valF, valIns, shuffle$1, callbackMetrics, initialEpoch, stepsPerEpoch, validationSteps) {
|
|
if (batchSize == null) {
|
|
batchSize = 32;
|
|
}
|
|
if (epochs == null) {
|
|
epochs = 1;
|
|
}
|
|
if (shuffle$1 == null) {
|
|
shuffle$1 = true;
|
|
}
|
|
if (initialEpoch == null) {
|
|
initialEpoch = 0;
|
|
}
|
|
|
|
let doValidation = false;
|
|
if (valF != null && valIns != null) {
|
|
doValidation = true;
|
|
|
|
}
|
|
if (validationSteps != null) {
|
|
doValidation = true;
|
|
if (stepsPerEpoch == null) {
|
|
throw new ValueError('Can only use `validationSteps` when doing step-wise training, ' +
|
|
'i.e., `stepsPerEpoch` must be set.');
|
|
}
|
|
}
|
|
const numTrainSamples = this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
|
|
let indexArray;
|
|
if (numTrainSamples != null) {
|
|
indexArray = range(0, numTrainSamples);
|
|
}
|
|
if (verbose == null) {
|
|
verbose = 1;
|
|
}
|
|
const { callbackList, history } = configureCallbacks(callbacks, verbose, epochs, initialEpoch, numTrainSamples, stepsPerEpoch, batchSize, doValidation, callbackMetrics);
|
|
callbackList.setModel(this);
|
|
this.history = history;
|
|
await callbackList.onTrainBegin();
|
|
this.stopTraining_ = false;
|
|
|
|
|
|
for (let epoch = initialEpoch; epoch < epochs; ++epoch) {
|
|
await callbackList.onEpochBegin(epoch);
|
|
const epochLogs = {};
|
|
if (stepsPerEpoch != null) {
|
|
throw new NotImplementedError('stepsPerEpoch mode is not implemented yet.');
|
|
}
|
|
else {
|
|
if (shuffle$1 === 'batch') {
|
|
throw new NotImplementedError('batch shuffling is not implemneted'
|
|
+ ' yet');
|
|
}
|
|
else if (shuffle$1) {
|
|
shuffle(indexArray);
|
|
}
|
|
|
|
|
|
const epochIndexArray1D = tensor1d(indexArray);
|
|
const batches = makeBatches(numTrainSamples, batchSize);
|
|
for (let batchIndex = 0; batchIndex < batches.length; ++batchIndex) {
|
|
const batchLogs = {};
|
|
await callbackList.onBatchBegin(batchIndex, batchLogs);
|
|
tidy(() => {
|
|
const batchStart = batches[batchIndex][0];
|
|
const batchEnd = batches[batchIndex][1];
|
|
const batchIds = sliceAlongFirstAxis(epochIndexArray1D, batchStart, batchEnd - batchStart);
|
|
batchLogs['batch'] = batchIndex;
|
|
batchLogs['size'] = batchEnd - batchStart;
|
|
|
|
|
|
const insBatch = sliceArraysByIndices(ins, batchIds);
|
|
const outs = f(insBatch);
|
|
for (let i = 0; i < outLabels.length; ++i) {
|
|
const label = outLabels[i];
|
|
const out = outs[i];
|
|
batchLogs[label] = out;
|
|
keep(out);
|
|
|
|
}
|
|
if (batchIndex === batches.length - 1) {
|
|
if (doValidation) {
|
|
const valOuts = this.testLoop(valF, valIns, batchSize);
|
|
|
|
for (let i = 0; i < outLabels.length; ++i) {
|
|
const label = outLabels[i];
|
|
const out = valOuts[i];
|
|
keep(out);
|
|
|
|
epochLogs['val_' + label] = out;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
await callbackList.onBatchEnd(batchIndex, batchLogs);
|
|
disposeTensorsInLogs(batchLogs);
|
|
if (this.stopTraining_) {
|
|
break;
|
|
}
|
|
|
|
}
|
|
epochIndexArray1D.dispose();
|
|
}
|
|
|
|
await callbackList.onEpochEnd(epoch, epochLogs);
|
|
if (this.stopTraining_) {
|
|
break;
|
|
}
|
|
}
|
|
await callbackList.onTrainEnd();
|
|
await this.history.syncData();
|
|
return this.history;
|
|
}
|
|
|
|
|
|
|
|
async fitDataset(dataset, args) {
|
|
return fitDataset(this, dataset, args);
|
|
}
|
|
|
|
async trainOnBatch(x, y) {
|
|
|
|
|
|
const standardizeOut = await this.standardizeUserData(x, y);
|
|
const inputs = standardizeOut[0];
|
|
const targets = standardizeOut[1];
|
|
const trainFunction = this.makeTrainFunction();
|
|
const losses = trainFunction(inputs.concat(targets));
|
|
const lossValues = [];
|
|
for (const loss of losses) {
|
|
const v = await loss.data();
|
|
lossValues.push(v[0]);
|
|
}
|
|
dispose(losses);
|
|
disposeNewTensors(standardizeOut[0], x);
|
|
disposeNewTensors(standardizeOut[1], y);
|
|
return singletonOrArray(lossValues);
|
|
}
|
|
|
|
getNamedWeights(config) {
|
|
const namedWeights = [];
|
|
const trainableOnly = config != null && config.trainableOnly;
|
|
const weights = trainableOnly ? this.trainableWeights : this.weights;
|
|
const weightValues = this.getWeights(trainableOnly);
|
|
for (let i = 0; i < weights.length; ++i) {
|
|
if (trainableOnly && !weights[i].trainable) {
|
|
|
|
continue;
|
|
}
|
|
namedWeights.push({ name: weights[i].originalName, tensor: weightValues[i] });
|
|
}
|
|
return namedWeights;
|
|
}
|
|
|
|
set stopTraining(stop) {
|
|
this.stopTraining_ = stop;
|
|
}
|
|
get stopTraining() {
|
|
return this.stopTraining_;
|
|
}
|
|
get optimizer() {
|
|
return this.optimizer_;
|
|
}
|
|
set optimizer(optimizer) {
|
|
if (this.optimizer_ !== optimizer) {
|
|
this.optimizer_ = optimizer;
|
|
this.isOptimizerOwned = false;
|
|
}
|
|
}
|
|
dispose() {
|
|
const result = super.dispose();
|
|
if (result.refCountAfterDispose === 0 && this.optimizer != null &&
|
|
this.isOptimizerOwned) {
|
|
const numTensorsBeforeOptmizerDisposal = memory().numTensors;
|
|
this.optimizer_.dispose();
|
|
result.numDisposedVariables +=
|
|
numTensorsBeforeOptmizerDisposal - memory().numTensors;
|
|
}
|
|
return result;
|
|
}
|
|
getLossIdentifiers() {
|
|
let lossNames;
|
|
if (typeof this.loss === 'string') {
|
|
lossNames = toSnakeCase(this.loss);
|
|
}
|
|
else if (Array.isArray(this.loss)) {
|
|
for (const loss of this.loss) {
|
|
if (typeof loss !== 'string') {
|
|
throw new Error('Serialization of non-string loss is not supported.');
|
|
}
|
|
}
|
|
lossNames = this.loss.map(name => toSnakeCase(name));
|
|
}
|
|
else {
|
|
const outputNames = Object.keys(this.loss);
|
|
lossNames = {};
|
|
const losses = this.loss;
|
|
for (const outputName of outputNames) {
|
|
if (typeof losses[outputName] === 'string') {
|
|
lossNames[outputName] =
|
|
toSnakeCase(losses[outputName]);
|
|
}
|
|
else {
|
|
throw new Error('Serialization of non-string loss is not supported.');
|
|
}
|
|
}
|
|
}
|
|
return lossNames;
|
|
}
|
|
getMetricIdentifiers() {
|
|
if (typeof this.metrics === 'string' ||
|
|
typeof this.metrics === 'function') {
|
|
return [toSnakeCase(getLossOrMetricName(this.metrics))];
|
|
}
|
|
else if (Array.isArray(this.metrics)) {
|
|
return this.metrics.map(metric => toSnakeCase(getLossOrMetricName(metric)));
|
|
}
|
|
else {
|
|
const metricsIdentifiers = {};
|
|
for (const key in this.metrics) {
|
|
metricsIdentifiers[key] =
|
|
toSnakeCase(getLossOrMetricName(this.metrics[key]));
|
|
}
|
|
return metricsIdentifiers;
|
|
}
|
|
}
|
|
getTrainingConfig() {
|
|
return {
|
|
loss: this.getLossIdentifiers(),
|
|
metrics: this.getMetricIdentifiers(),
|
|
optimizer_config: {
|
|
class_name: this.optimizer.getClassName(),
|
|
config: this.optimizer.getConfig()
|
|
}
|
|
};
|
|
|
|
|
|
|
|
}
|
|
loadTrainingConfig(trainingConfig) {
|
|
if (trainingConfig.weighted_metrics != null) {
|
|
throw new Error('Loading weight_metrics is not supported yet.');
|
|
}
|
|
if (trainingConfig.loss_weights != null) {
|
|
throw new Error('Loading loss_weights is not supported yet.');
|
|
}
|
|
if (trainingConfig.sample_weight_mode != null) {
|
|
throw new Error('Loading sample_weight_mode is not supported yet.');
|
|
}
|
|
const tsConfig = convertPythonicToTs(trainingConfig.optimizer_config);
|
|
const optimizer = deserialize(tsConfig);
|
|
let loss;
|
|
if (typeof trainingConfig.loss === 'string') {
|
|
loss = toCamelCase(trainingConfig.loss);
|
|
}
|
|
else if (Array.isArray(trainingConfig.loss)) {
|
|
loss = trainingConfig.loss.map(lossEntry => toCamelCase(lossEntry));
|
|
}
|
|
else if (trainingConfig.loss != null) {
|
|
loss = {};
|
|
for (const key in trainingConfig.loss) {
|
|
loss[key] = toCamelCase(trainingConfig.loss[key]);
|
|
}
|
|
}
|
|
let metrics;
|
|
if (Array.isArray(trainingConfig.metrics)) {
|
|
metrics = trainingConfig.metrics.map(metric => toCamelCase(metric));
|
|
}
|
|
else if (trainingConfig.metrics != null) {
|
|
metrics = {};
|
|
for (const key in trainingConfig.metrics) {
|
|
metrics[key] = toCamelCase(trainingConfig.metrics[key]);
|
|
}
|
|
}
|
|
this.compile({ loss, metrics, optimizer });
|
|
}
|
|
|
|
async save(handlerOrURL, config) {
|
|
if (typeof handlerOrURL === 'string') {
|
|
const handlers = getSaveHandlers(handlerOrURL);
|
|
if (handlers.length === 0) {
|
|
throw new ValueError(`Cannot find any save handlers for URL '${handlerOrURL}'`);
|
|
}
|
|
else if (handlers.length > 1) {
|
|
throw new ValueError(`Found more than one (${handlers.length}) save handlers for ` +
|
|
`URL '${handlerOrURL}'`);
|
|
}
|
|
handlerOrURL = handlers[0];
|
|
}
|
|
if (handlerOrURL.save == null) {
|
|
throw new ValueError('LayersModel.save() cannot proceed because the IOHandler ' +
|
|
'provided does not have the `save` attribute defined.');
|
|
}
|
|
const weightDataAndSpecs = await encodeWeights(this.getNamedWeights(config));
|
|
const returnString = false;
|
|
const unusedArg = null;
|
|
const modelConfig = this.toJSON(unusedArg, returnString);
|
|
const modelArtifacts = {
|
|
modelTopology: modelConfig,
|
|
format: LAYERS_MODEL_FORMAT_NAME,
|
|
generatedBy: `TensorFlow.js tfjs-layers v${version}`,
|
|
convertedBy: null,
|
|
};
|
|
const includeOptimizer = config == null ? false : config.includeOptimizer;
|
|
if (includeOptimizer && this.optimizer != null) {
|
|
modelArtifacts.trainingConfig = this.getTrainingConfig();
|
|
const weightType = 'optimizer';
|
|
const { data: optimizerWeightData, specs: optimizerWeightSpecs } = await encodeWeights(await this.optimizer.getWeights(), weightType);
|
|
weightDataAndSpecs.specs.push(...optimizerWeightSpecs);
|
|
weightDataAndSpecs.data = concatenateArrayBuffers([weightDataAndSpecs.data, optimizerWeightData]);
|
|
}
|
|
if (this.userDefinedMetadata != null) {
|
|
|
|
const checkSize = true;
|
|
checkUserDefinedMetadata(this.userDefinedMetadata, this.name, checkSize);
|
|
modelArtifacts.userDefinedMetadata = this.userDefinedMetadata;
|
|
}
|
|
modelArtifacts.weightData = weightDataAndSpecs.data;
|
|
modelArtifacts.weightSpecs = weightDataAndSpecs.specs;
|
|
return handlerOrURL.save(modelArtifacts);
|
|
}
|
|
|
|
setUserDefinedMetadata(userDefinedMetadata) {
|
|
checkUserDefinedMetadata(userDefinedMetadata, this.name);
|
|
this.userDefinedMetadata = userDefinedMetadata;
|
|
}
|
|
|
|
getUserDefinedMetadata() {
|
|
return this.userDefinedMetadata;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
LayersModel.className = 'Model';
|
|
registerClass(LayersModel);
|
|
|
|
|
|
class Functional extends LayersModel {
|
|
}
|
|
Functional.className = 'Functional';
|
|
registerClass(Functional);
|
|
|
|
|
|
|
|
|
|
async function loadLayersModelFromIOHandler(handler, customObjects, options) {
|
|
if (options == null) {
|
|
options = {};
|
|
}
|
|
if (handler.load == null) {
|
|
throw new ValueError('Cannot proceed with model loading because the IOHandler provided ' +
|
|
'does not have the `load` method implemented.');
|
|
}
|
|
const artifacts = await handler.load();
|
|
let modelTopology = artifacts.modelTopology;
|
|
if (modelTopology['model_config'] != null) {
|
|
modelTopology = modelTopology['model_config'];
|
|
}
|
|
const strict = options.strict == null ? true : options.strict;
|
|
|
|
|
|
|
|
|
|
|
|
const fastWeightInit = artifacts.weightData != null && artifacts.weightSpecs != null && strict;
|
|
const model = deserialize(convertPythonicToTs(modelTopology), customObjects, fastWeightInit);
|
|
const trainingConfig = artifacts.trainingConfig;
|
|
if (trainingConfig != null) {
|
|
model.loadTrainingConfig(trainingConfig);
|
|
}
|
|
if (artifacts.userDefinedMetadata != null) {
|
|
model.setUserDefinedMetadata(artifacts.userDefinedMetadata);
|
|
}
|
|
|
|
if (artifacts.weightData != null) {
|
|
|
|
if (artifacts.weightSpecs == null) {
|
|
throw new ValueError('LayersModel artifacts contains weight data, but not weight specs. ' +
|
|
'Therefore loading of weights cannot proceed.');
|
|
}
|
|
const { modelWeights, optimizerWeights } = decodeModelAndOptimizerWeights(artifacts.weightData, artifacts.weightSpecs);
|
|
model.loadWeights(modelWeights, strict);
|
|
if (model.optimizer != null && optimizerWeights.length > 0) {
|
|
await model.optimizer.setWeights(optimizerWeights);
|
|
}
|
|
|
|
dispose(modelWeights);
|
|
dispose(optimizerWeights.map(w => w.tensor));
|
|
}
|
|
return model;
|
|
}
|
|
function decodeModelAndOptimizerWeights(weightData, specs) {
|
|
const name2Tensor = decodeWeights(weightData, specs);
|
|
const modelWeights = {};
|
|
const optimizerWeights = [];
|
|
specs.forEach(spec => {
|
|
if (spec.group === 'optimizer') {
|
|
optimizerWeights.push({ name: spec.name, tensor: name2Tensor[spec.name] });
|
|
}
|
|
else {
|
|
modelWeights[spec.name] = name2Tensor[spec.name];
|
|
}
|
|
});
|
|
return { modelWeights, optimizerWeights };
|
|
}
|
|
|
|
class Sequential extends LayersModel {
|
|
constructor(args) {
|
|
super({ inputs: [], outputs: [] });
|
|
args = args || {};
|
|
this.trainable = true;
|
|
this.built = false;
|
|
|
|
this.name = (args.name != null) ? args.name : getUid('sequential_');
|
|
|
|
if (args.layers != null) {
|
|
for (const layer of args.layers) {
|
|
this.add(layer);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
checkShape(layer) {
|
|
const shape = layer.inboundNodes[0].outputTensors[0].shape;
|
|
if (shape.some(x => x < 0)) {
|
|
throw new ValueError('Negative dimension size caused by adding layer ' +
|
|
`${layer.name} with input shape [` +
|
|
`${layer.inboundNodes[0].inputTensors[0].shape}]`);
|
|
}
|
|
}
|
|
|
|
add(layer) {
|
|
const isLayerModelInstance = layer instanceof Sequential || layer instanceof LayersModel;
|
|
let modelLayer;
|
|
if (isLayerModelInstance) {
|
|
modelLayer = layer;
|
|
if (modelLayer.outputs.length !== 1) {
|
|
throw new ValueError('All layers in a Sequential model ' +
|
|
'should have a single output tensor. ' +
|
|
'For multi-output layers, ' +
|
|
'use the functional API.');
|
|
}
|
|
if (modelLayer.inputs.length !== 1) {
|
|
throw new ValueError('All layers in a Sequential model ' +
|
|
'should have a single input tensor. ' +
|
|
'For multi-input layers, ' +
|
|
'use the functional API.');
|
|
}
|
|
}
|
|
if (this.outputs.length === 0) {
|
|
|
|
if (layer.inboundNodes.length === 0) {
|
|
|
|
if (layer.batchInputShape == null) {
|
|
throw new ValueError('The first layer in a Sequential model must ' +
|
|
'get an `inputShape` or `batchInputShape` argument.');
|
|
}
|
|
|
|
const x = Input({
|
|
batchShape: layer.batchInputShape,
|
|
dtype: layer.dtype,
|
|
name: layer.name + '_input'
|
|
});
|
|
|
|
|
|
layer.apply(x);
|
|
}
|
|
if (isLayerModelInstance) {
|
|
this.outputs = modelLayer.outputs;
|
|
this.inputs = modelLayer.inputs;
|
|
}
|
|
else {
|
|
if (layer.inboundNodes.length !== 1) {
|
|
throw new ValueError('A layer added to a Sequential model must not already be ' +
|
|
`connected somewhere else. LayersModel received layer ${layer.name} ` +
|
|
`which has ${layer.inboundNodes.length} pre-existing inbound ` +
|
|
'connections.');
|
|
}
|
|
if (layer.inboundNodes[0].outputTensors.length !== 1) {
|
|
throw new ValueError('All layers in a Sequential model ' +
|
|
'should have a single output tensor. ' +
|
|
'For multi-output layers, ' +
|
|
'use the functional API.');
|
|
}
|
|
this.checkShape(layer);
|
|
this.outputs = [layer.inboundNodes[0].outputTensors[0]];
|
|
this.inputs = getSourceInputs(this.outputs[0]);
|
|
}
|
|
this.inboundNodes = [];
|
|
|
|
|
|
|
|
|
|
new Node({
|
|
outboundLayer: this,
|
|
inboundLayers: [],
|
|
nodeIndices: [],
|
|
tensorIndices: [],
|
|
inputTensors: this.inputs,
|
|
outputTensors: this.outputs,
|
|
|
|
inputMasks: pyListRepeat(null, this.inputs.length),
|
|
outputMasks: [null],
|
|
inputShapes: this.inputs.map(x => x.shape),
|
|
outputShapes: this.outputs[0].shape
|
|
});
|
|
}
|
|
else {
|
|
const outputTensor = layer.apply(this.outputs[0]);
|
|
if (Array.isArray(outputTensor)) {
|
|
throw new TypeError('All layers in a Sequential model ' +
|
|
'should have a single output tensor. ' +
|
|
'For multi-output layers, ' +
|
|
'use the functional API.');
|
|
}
|
|
this.checkShape(layer);
|
|
this.outputs = [outputTensor];
|
|
|
|
this.inboundNodes[0].outputTensors = this.outputs;
|
|
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
|
|
}
|
|
this.layers.push(layer);
|
|
this.built = false;
|
|
}
|
|
|
|
pop() {
|
|
if (this.layers.length === 0) {
|
|
throw new TypeError('There are no layers in the model.');
|
|
}
|
|
this.layers.pop();
|
|
if (this.layers.length === 0) {
|
|
this.outputs = [];
|
|
this.inboundNodes = [];
|
|
this.outboundNodes = [];
|
|
}
|
|
else {
|
|
const lastLayerIndex = this.layers.length - 1;
|
|
this.layers[lastLayerIndex].outboundNodes = [];
|
|
this.outputs = [this.layers[lastLayerIndex].output];
|
|
|
|
this.inboundNodes[0].outputTensors = this.outputs;
|
|
this.inboundNodes[0].outputShapes = [this.outputs[0].shape];
|
|
}
|
|
}
|
|
call(inputs, kwargs) {
|
|
if (this.model == null) {
|
|
this.build();
|
|
}
|
|
return this.model.call(inputs, kwargs);
|
|
}
|
|
build(inputShape) {
|
|
|
|
|
|
getExactlyOneShape(inputShape);
|
|
if (this.inputs.length === 0 || this.outputs.length === 0) {
|
|
throw new TypeError('Sequential model cannot be built: model is empty.' +
|
|
' Add some layers first.');
|
|
}
|
|
|
|
this.model = new LayersModel({
|
|
inputs: this.inputs,
|
|
outputs: this.outputs[0],
|
|
name: this.name + '_model'
|
|
});
|
|
this.model.trainable = this.trainable;
|
|
|
|
this.supportsMasking = this.model.supportsMasking;
|
|
|
|
this.inputLayers = this.model.inputLayers;
|
|
this.inputLayersNodeIndices = this.model.inputLayersNodeIndices;
|
|
this.inputLayersTensorIndices = this.model.inputLayersTensorIndices;
|
|
this.outputLayers = this.model.outputLayers;
|
|
this.outputLayersNodeIndices = this.model.outputLayersNodeIndices;
|
|
this.outputLayersTensorIndices = this.model.outputLayersTensorIndices;
|
|
this.nodesByDepth = this.model.nodesByDepth;
|
|
this.containerNodes = this.model.containerNodes;
|
|
this.outputNames = this.model.outputNames;
|
|
this.inputNames = this.model.inputNames;
|
|
|
|
|
|
this.built = true;
|
|
}
|
|
countParams() {
|
|
if (!this.built) {
|
|
this.build();
|
|
}
|
|
return super.countParams();
|
|
}
|
|
|
|
summary(lineLength, positions, printFn = console.log) {
|
|
if (!this.built) {
|
|
this.build();
|
|
}
|
|
super.summary(lineLength, positions, printFn);
|
|
}
|
|
|
|
setWeights(weights) {
|
|
if (this.model == null) {
|
|
this.build();
|
|
}
|
|
this.model.setWeights(weights);
|
|
}
|
|
|
|
evaluate(x, y, args = {}) {
|
|
if (!this.built) {
|
|
throw new RuntimeError('The model needs to be compiled before being used.');
|
|
}
|
|
return this.model.evaluate(x, y, args);
|
|
}
|
|
|
|
|
|
|
|
async evaluateDataset(dataset, args) {
|
|
if (!this.built) {
|
|
throw new RuntimeError('The model needs to be compiled before being used.');
|
|
}
|
|
return this.model.evaluateDataset(dataset, args);
|
|
}
|
|
|
|
predict(x, args = {}) {
|
|
if (this.model == null) {
|
|
this.build();
|
|
}
|
|
return this.model.predict(x, args);
|
|
}
|
|
|
|
predictOnBatch(x) {
|
|
if (this.model == null) {
|
|
this.build();
|
|
}
|
|
return this.model.predictOnBatch(x);
|
|
}
|
|
|
|
compile(args) {
|
|
this.build();
|
|
this.model.compile(args);
|
|
this.optimizer_ = this.model.optimizer;
|
|
|
|
this.isOptimizerOwned = this.model.isOptimizerOwned;
|
|
this.loss = this.model.loss;
|
|
this.metrics = this.model.metrics;
|
|
|
|
|
|
this.metricsTensors = this.model.metricsTensors;
|
|
this.metricsNames = this.model.metricsNames;
|
|
|
|
}
|
|
get optimizer() {
|
|
return this.model == null ? undefined : this.model.optimizer;
|
|
}
|
|
set optimizer(optimizer) {
|
|
this.model.optimizer = optimizer;
|
|
}
|
|
|
|
async fit(x, y, args = {}) {
|
|
if (!this.built) {
|
|
throw new RuntimeError('The model needs to be compiled before ' +
|
|
'being used.');
|
|
}
|
|
return this.model.fit(x, y, args);
|
|
}
|
|
|
|
async fitDataset(dataset, args) {
|
|
if (!this.built) {
|
|
throw new RuntimeError('The model needs to be compiled before ' +
|
|
'being used.');
|
|
}
|
|
return this.model.fitDataset(dataset, args);
|
|
}
|
|
|
|
async trainOnBatch(x, y) {
|
|
return this.model.trainOnBatch(x, y);
|
|
}
|
|
|
|
|
|
static fromConfig(cls, config, customObjects = {}, fastWeightInit = false) {
|
|
let configArray;
|
|
let extraModelConfig = {};
|
|
if (config instanceof Array) {
|
|
if (!(config[0].className != null) ||
|
|
config[0]['className'] === 'Merge') {
|
|
throw new ValueError('Legacy serialization format not supported yet.');
|
|
}
|
|
configArray = config;
|
|
}
|
|
else {
|
|
assert$1(config['layers'] != null, () => `When the config data for a Sequential model is not an Array, ` +
|
|
`it must be an Object that contains the 'layers' field.`);
|
|
configArray = config['layers'];
|
|
delete config['layers'];
|
|
extraModelConfig = config;
|
|
}
|
|
const model = new cls(extraModelConfig);
|
|
if (!(model instanceof Sequential)) {
|
|
throw new NotImplementedError(`Sequential.fromConfig called on non-Sequential input: ${model}`);
|
|
}
|
|
for (const conf of configArray) {
|
|
const customObjects = undefined;
|
|
const layer = deserialize(conf, customObjects, fastWeightInit);
|
|
if (fastWeightInit) {
|
|
layer.setFastWeightInitDuringBuild(true);
|
|
}
|
|
model.add(layer);
|
|
}
|
|
return model;
|
|
}
|
|
|
|
set stopTraining(stop) {
|
|
|
|
|
|
if (this.model == null) {
|
|
throw new ValueError('Cannot set the stopTraining property of a sequential model before ' +
|
|
'it is compiled.');
|
|
}
|
|
this.model.stopTraining = stop;
|
|
}
|
|
get stopTraining() {
|
|
if (this.model == null) {
|
|
throw new ValueError('Cannot get the stopTraining property of a sequential model before ' +
|
|
'it is compiled.');
|
|
}
|
|
return this.model.stopTraining;
|
|
}
|
|
|
|
|
|
getConfig() {
|
|
|
|
|
|
|
|
|
|
const layers = [];
|
|
for (const layer of this.layers) {
|
|
const dict = {};
|
|
dict['className'] = layer.getClassName();
|
|
dict['config'] = layer.getConfig();
|
|
layers.push(dict);
|
|
}
|
|
return { name: this.name, layers };
|
|
}
|
|
}
|
|
|
|
Sequential.className = 'Sequential';
|
|
registerClass(Sequential);
|
|
|
|
|
|
|
|
|
|
function sequential(config) {
|
|
return new Sequential(config);
|
|
}
|
|
|
|
|
|
|
|
|
|
let Activation$1 = class Activation extends Serializable {
|
|
getConfig() {
|
|
return {};
|
|
}
|
|
};
|
|
|
|
class Elu extends Activation$1 {
|
|
|
|
apply(x, alpha = 1) {
|
|
return elu(x, alpha);
|
|
}
|
|
}
|
|
|
|
Elu.className = 'elu';
|
|
registerClass(Elu);
|
|
|
|
class Selu extends Activation$1 {
|
|
apply(x) {
|
|
return selu$2(x);
|
|
}
|
|
}
|
|
|
|
Selu.className = 'selu';
|
|
registerClass(Selu);
|
|
|
|
class Relu extends Activation$1 {
|
|
apply(x) {
|
|
return relu$2(x);
|
|
}
|
|
}
|
|
|
|
Relu.className = 'relu';
|
|
registerClass(Relu);
|
|
|
|
class Relu6 extends Activation$1 {
|
|
apply(x) {
|
|
return tidy(() => minimum$2(6.0, relu$2(x)));
|
|
}
|
|
}
|
|
|
|
Relu6.className = 'relu6';
|
|
registerClass(Relu6);
|
|
|
|
class Linear extends Activation$1 {
|
|
apply(x) {
|
|
return x;
|
|
}
|
|
}
|
|
|
|
Linear.className = 'linear';
|
|
registerClass(Linear);
|
|
|
|
class Sigmoid extends Activation$1 {
|
|
apply(x) {
|
|
return sigmoid$2(x);
|
|
}
|
|
}
|
|
|
|
Sigmoid.className = 'sigmoid';
|
|
registerClass(Sigmoid);
|
|
|
|
class HardSigmoid extends Activation$1 {
|
|
apply(x) {
|
|
return hardSigmoid(x);
|
|
}
|
|
}
|
|
|
|
HardSigmoid.className = 'hardSigmoid';
|
|
registerClass(HardSigmoid);
|
|
|
|
class Softplus extends Activation$1 {
|
|
apply(x) {
|
|
return softplus$2(x);
|
|
}
|
|
}
|
|
|
|
Softplus.className = 'softplus';
|
|
registerClass(Softplus);
|
|
|
|
class Softsign extends Activation$1 {
|
|
apply(x) {
|
|
return softsign(x);
|
|
}
|
|
}
|
|
|
|
Softsign.className = 'softsign';
|
|
registerClass(Softsign);
|
|
|
|
class Tanh extends Activation$1 {
|
|
apply(x) {
|
|
return tanh$2(x);
|
|
}
|
|
}
|
|
|
|
Tanh.className = 'tanh';
|
|
registerClass(Tanh);
|
|
|
|
class Softmax extends Activation$1 {
|
|
|
|
apply(x, axis = (-1)) {
|
|
return softmax$2(x, axis);
|
|
}
|
|
}
|
|
|
|
Softmax.className = 'softmax';
|
|
registerClass(Softmax);
|
|
|
|
class LogSoftmax extends Activation$1 {
|
|
|
|
apply(x, axis = (-1)) {
|
|
return logSoftmax(x, axis);
|
|
}
|
|
}
|
|
|
|
LogSoftmax.className = 'logSoftmax';
|
|
registerClass(LogSoftmax);
|
|
|
|
class Gelu extends Activation$1 {
|
|
|
|
apply(x) {
|
|
return tidy(() => {
|
|
return tidy(() => {
|
|
const sqrtTwo = Math.sqrt(2);
|
|
|
|
const cdf = mul(0.5, add$1(1, erf$2(div$1(x, sqrtTwo))));
|
|
|
|
return mul(x, cdf);
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
Gelu.className = 'gelu';
|
|
registerClass(Gelu);
|
|
|
|
class GeluNew extends Activation$1 {
|
|
|
|
apply(x) {
|
|
return tidy(() => {
|
|
return mul(0.5, mul(x, add$1(1, tanh$2(mul(sqrt$2(div$1(2, Math.PI)), add$1(x, mul(0.044715, pow$2(x, 3))))))));
|
|
});
|
|
}
|
|
}
|
|
|
|
GeluNew.className = 'gelu_new';
|
|
registerClass(GeluNew);
|
|
|
|
class Mish extends Activation$1 {
|
|
|
|
apply(x) {
|
|
return tidy(() => mul(x, tanh$2(softplus$2(x))));
|
|
}
|
|
}
|
|
|
|
Mish.className = 'mish';
|
|
registerClass(Mish);
|
|
|
|
class Swish extends Activation$1 {
|
|
|
|
apply(x, alpha = 1) {
|
|
return tidy(() => mul(sigmoid$2(mul(x, alpha)), x));
|
|
}
|
|
}
|
|
|
|
Swish.className = 'swish';
|
|
registerClass(Swish);
|
|
function serializeActivation(activation) {
|
|
return activation.getClassName();
|
|
}
|
|
function deserializeActivation(config, customObjects = {}) {
|
|
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'activation');
|
|
}
|
|
function getActivation(identifier) {
|
|
if (identifier == null) {
|
|
const config = {};
|
|
config['className'] = 'linear';
|
|
config['config'] = {};
|
|
return deserializeActivation(config);
|
|
}
|
|
if (typeof identifier === 'string') {
|
|
const config = {};
|
|
config['className'] = identifier;
|
|
config['config'] = {};
|
|
return deserializeActivation(config);
|
|
}
|
|
else if (identifier instanceof Activation$1) {
|
|
return identifier;
|
|
}
|
|
else {
|
|
return deserializeActivation(identifier);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
function assertObjectArgs(args) {
|
|
if (args != null && typeof args !== 'object') {
|
|
throw new Error(`Argument to L1L2 regularizer's constructor is expected to be an ` +
|
|
`object, but received: ${args}`);
|
|
}
|
|
}
|
|
|
|
class Regularizer extends Serializable {
|
|
}
|
|
class L1L2 extends Regularizer {
|
|
constructor(args) {
|
|
super();
|
|
assertObjectArgs(args);
|
|
this.l1 = args == null || args.l1 == null ? 0.01 : args.l1;
|
|
this.l2 = args == null || args.l2 == null ? 0.01 : args.l2;
|
|
this.hasL1 = this.l1 !== 0;
|
|
this.hasL2 = this.l2 !== 0;
|
|
}
|
|
|
|
apply(x) {
|
|
return tidy(() => {
|
|
let regularization = zeros$1([1]);
|
|
if (this.hasL1) {
|
|
regularization = add$1(regularization, sum$2(mul(this.l1, abs$2(x))));
|
|
}
|
|
if (this.hasL2) {
|
|
regularization =
|
|
add$1(regularization, sum$2(mul(this.l2, square(x))));
|
|
}
|
|
return reshape$2(regularization, []);
|
|
});
|
|
}
|
|
getConfig() {
|
|
return { 'l1': this.l1, 'l2': this.l2 };
|
|
}
|
|
|
|
static fromConfig(cls, config) {
|
|
return new cls({ l1: config['l1'], l2: config['l2'] });
|
|
}
|
|
}
|
|
|
|
L1L2.className = 'L1L2';
|
|
registerClass(L1L2);
|
|
|
|
const REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP = {
|
|
'l1l2': 'L1L2'
|
|
};
|
|
function serializeRegularizer(constraint) {
|
|
return serializeKerasObject(constraint);
|
|
}
|
|
function deserializeRegularizer(config, customObjects = {}) {
|
|
return deserializeKerasObject(config, SerializationMap.getMap().classNameMap, customObjects, 'regularizer');
|
|
}
|
|
function getRegularizer(identifier) {
|
|
if (identifier == null) {
|
|
return null;
|
|
}
|
|
if (typeof identifier === 'string') {
|
|
const className = identifier in REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP ?
|
|
REGULARIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP[identifier] :
|
|
identifier;
|
|
const config = { className, config: {} };
|
|
return deserializeRegularizer(config);
|
|
}
|
|
else if (identifier instanceof Regularizer) {
|
|
return identifier;
|
|
}
|
|
else {
|
|
return deserializeRegularizer(identifier);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
class Dropout extends Layer {
|
|
constructor(args) {
|
|
super(args);
|
|
this.rate = Math.max(Math.min(args.rate, 1), 0);
|
|
|
|
this.noiseShape = args.noiseShape;
|
|
this.seed = args.seed;
|
|
this.supportsMasking = true;
|
|
}
|
|
getNoiseShape(input) {
|
|
if (this.noiseShape == null) {
|
|
return this.noiseShape;
|
|
}
|
|
const inputShape = input.shape;
|
|
const noiseShape = [];
|
|
for (let i = 0; i < this.noiseShape.length; ++i) {
|
|
noiseShape.push(this.noiseShape[i] == null ? inputShape[i] : this.noiseShape[i]);
|
|
}
|
|
return noiseShape;
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
this.invokeCallHook(inputs, kwargs);
|
|
const input = getExactlyOneTensor(inputs);
|
|
if (0 < this.rate && this.rate < 1) {
|
|
const training = kwargs['training'] == null ? false : kwargs['training'];
|
|
const noiseShape = this.getNoiseShape(input);
|
|
const output = inTrainPhase(() => dropout$1(input, this.rate, noiseShape, this.seed), () => input, training);
|
|
return output;
|
|
}
|
|
return inputs;
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = {
|
|
rate: this.rate,
|
|
noiseShape: this.noiseShape,
|
|
seed: this.seed,
|
|
};
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
dispose() {
|
|
return super.dispose();
|
|
}
|
|
}
|
|
|
|
Dropout.className = 'Dropout';
|
|
registerClass(Dropout);
|
|
class SpatialDropout1D extends Dropout {
|
|
constructor(args) {
|
|
super(args);
|
|
this.inputSpec = [{ ndim: 3 }];
|
|
}
|
|
getNoiseShape(input) {
|
|
const inputShape = input.shape;
|
|
return [inputShape[0], 1, inputShape[2]];
|
|
}
|
|
}
|
|
|
|
SpatialDropout1D.className = 'SpatialDropout1D';
|
|
registerClass(SpatialDropout1D);
|
|
class Dense extends Layer {
|
|
constructor(args) {
|
|
super(args);
|
|
|
|
this.activation = null;
|
|
this.useBias = true;
|
|
this.kernel = null;
|
|
this.bias = null;
|
|
this.DEFAULT_KERNEL_INITIALIZER = 'glorotNormal';
|
|
this.DEFAULT_BIAS_INITIALIZER = 'zeros';
|
|
if (args.batchInputShape == null && args.inputShape == null &&
|
|
args.inputDim != null) {
|
|
|
|
|
|
let batchSize = null;
|
|
if (args.batchSize != null) {
|
|
batchSize = args.batchSize;
|
|
}
|
|
this.batchInputShape = [batchSize, args.inputDim];
|
|
}
|
|
this.units = args.units;
|
|
assertPositiveInteger(this.units, 'units');
|
|
this.activation = getActivation(args.activation);
|
|
if (args.useBias != null) {
|
|
this.useBias = args.useBias;
|
|
}
|
|
this.kernelInitializer = getInitializer(args.kernelInitializer || this.DEFAULT_KERNEL_INITIALIZER);
|
|
this.biasInitializer =
|
|
getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
|
|
this.kernelConstraint = getConstraint(args.kernelConstraint);
|
|
this.biasConstraint = getConstraint(args.biasConstraint);
|
|
this.kernelRegularizer = getRegularizer(args.kernelRegularizer);
|
|
this.biasRegularizer = getRegularizer(args.biasRegularizer);
|
|
this.activityRegularizer = getRegularizer(args.activityRegularizer);
|
|
this.supportsMasking = true;
|
|
this.inputSpec = [{ minNDim: 2 }];
|
|
}
|
|
build(inputShape) {
|
|
inputShape = getExactlyOneShape(inputShape);
|
|
const inputLastDim = inputShape[inputShape.length - 1];
|
|
if (this.kernel == null) {
|
|
this.kernel = this.addWeight('kernel', [inputLastDim, this.units], null, this.kernelInitializer, this.kernelRegularizer, true, this.kernelConstraint);
|
|
if (this.useBias) {
|
|
this.bias = this.addWeight('bias', [this.units], null, this.biasInitializer, this.biasRegularizer, true, this.biasConstraint);
|
|
}
|
|
}
|
|
this.inputSpec = [{ minNDim: 2, axes: { [-1]: inputLastDim } }];
|
|
this.built = true;
|
|
}
|
|
computeOutputShape(inputShape) {
|
|
inputShape = getExactlyOneShape(inputShape);
|
|
const outputShape = inputShape.slice();
|
|
outputShape[outputShape.length - 1] = this.units;
|
|
return outputShape;
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
this.invokeCallHook(inputs, kwargs);
|
|
|
|
const input = getExactlyOneTensor(inputs);
|
|
const fusedActivationName = mapActivationToFusedKernel(this.activation.getClassName());
|
|
let output;
|
|
if (fusedActivationName != null) {
|
|
output = dot(input, this.kernel.read(), fusedActivationName, this.bias ? this.bias.read() : null);
|
|
}
|
|
else {
|
|
output = dot(input, this.kernel.read());
|
|
if (this.bias != null) {
|
|
output = biasAdd(output, this.bias.read());
|
|
}
|
|
if (this.activation != null) {
|
|
output = this.activation.apply(output);
|
|
}
|
|
}
|
|
return output;
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = {
|
|
units: this.units,
|
|
activation: serializeActivation(this.activation),
|
|
useBias: this.useBias,
|
|
kernelInitializer: serializeInitializer(this.kernelInitializer),
|
|
biasInitializer: serializeInitializer(this.biasInitializer),
|
|
kernelRegularizer: serializeRegularizer(this.kernelRegularizer),
|
|
biasRegularizer: serializeRegularizer(this.biasRegularizer),
|
|
activityRegularizer: serializeRegularizer(this.activityRegularizer),
|
|
kernelConstraint: serializeConstraint(this.kernelConstraint),
|
|
biasConstraint: serializeConstraint(this.biasConstraint)
|
|
};
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
}
|
|
|
|
Dense.className = 'Dense';
|
|
registerClass(Dense);
|
|
class Flatten extends Layer {
|
|
constructor(args) {
|
|
args = args || {};
|
|
super(args);
|
|
this.inputSpec = [{ minNDim: 3 }];
|
|
this.dataFormat = args.dataFormat;
|
|
}
|
|
computeOutputShape(inputShape) {
|
|
inputShape = getExactlyOneShape(inputShape);
|
|
for (const dim of inputShape.slice(1)) {
|
|
if (dim == null) {
|
|
throw new ValueError(`The shape of the input to "Flatten" is not fully defined ` +
|
|
`(got ${inputShape.slice(1)}). Make sure to pass a complete ` +
|
|
`"input_shape" or "batch_input_shape" argument to the first ` +
|
|
`layer in your model.`);
|
|
}
|
|
}
|
|
return [inputShape[0], arrayProd(inputShape, 1)];
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
this.invokeCallHook(inputs, kwargs);
|
|
let input = getExactlyOneTensor(inputs);
|
|
if (this.dataFormat === 'channelsFirst' && input.rank > 1) {
|
|
const permutation = [0];
|
|
for (let i = 2; i < input.rank; ++i) {
|
|
permutation.push(i);
|
|
}
|
|
permutation.push(1);
|
|
input = transpose$2(input, permutation);
|
|
}
|
|
return batchFlatten(input);
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = {};
|
|
if (this.dataFormat != null) {
|
|
config['dataFormat'] = this.dataFormat;
|
|
}
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
}
|
|
|
|
Flatten.className = 'Flatten';
|
|
registerClass(Flatten);
|
|
class Activation extends Layer {
|
|
constructor(args) {
|
|
super(args);
|
|
this.supportsMasking = true;
|
|
this.activation = getActivation(args.activation);
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
this.invokeCallHook(inputs, kwargs);
|
|
const input = getExactlyOneTensor(inputs);
|
|
return this.activation.apply(input);
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = { activation: serializeActivation(this.activation) };
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
}
|
|
|
|
Activation.className = 'Activation';
|
|
registerClass(Activation);
|
|
class RepeatVector extends Layer {
|
|
constructor(args) {
|
|
super(args);
|
|
this.n = args.n;
|
|
this.inputSpec = [{ ndim: 2 }];
|
|
}
|
|
computeOutputShape(inputShape) {
|
|
return [inputShape[0], this.n, inputShape[1]];
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
inputs = getExactlyOneTensor(inputs);
|
|
return repeat(inputs, this.n);
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = {
|
|
n: this.n,
|
|
};
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
}
|
|
|
|
RepeatVector.className = 'RepeatVector';
|
|
registerClass(RepeatVector);
|
|
class Reshape extends Layer {
|
|
constructor(args) {
|
|
super(args);
|
|
this.targetShape = args.targetShape;
|
|
|
|
for (let i = 0; i < this.targetShape.length; ++i) {
|
|
if (this.isUnknown(this.targetShape[i])) {
|
|
this.targetShape[i] = null;
|
|
}
|
|
}
|
|
}
|
|
isUnknown(dim) {
|
|
return dim < 0 || dim == null;
|
|
}
|
|
|
|
fixUnknownDimension(inputShape, outputShape) {
|
|
const errorMsg = 'Total size of new array must be unchanged.';
|
|
const finalShape = outputShape.slice();
|
|
let known = 1;
|
|
let unknown = null;
|
|
for (let i = 0; i < finalShape.length; ++i) {
|
|
const dim = finalShape[i];
|
|
if (this.isUnknown(dim)) {
|
|
if (unknown === null) {
|
|
unknown = i;
|
|
}
|
|
else {
|
|
throw new ValueError('Can only specifiy one unknown dimension.');
|
|
}
|
|
}
|
|
else {
|
|
known *= dim;
|
|
}
|
|
}
|
|
const originalSize = arrayProd(inputShape);
|
|
if (unknown !== null) {
|
|
if (known === 0 || originalSize % known !== 0) {
|
|
throw new ValueError(errorMsg);
|
|
}
|
|
finalShape[unknown] = originalSize / known;
|
|
}
|
|
else if (originalSize !== known) {
|
|
throw new ValueError(errorMsg);
|
|
}
|
|
return finalShape;
|
|
}
|
|
computeOutputShape(inputShape) {
|
|
let anyUnknownDims = false;
|
|
for (let i = 0; i < inputShape.length; ++i) {
|
|
if (this.isUnknown(inputShape[i])) {
|
|
anyUnknownDims = true;
|
|
break;
|
|
}
|
|
}
|
|
if (anyUnknownDims) {
|
|
return inputShape.slice(0, 1).concat(this.targetShape);
|
|
}
|
|
else {
|
|
return inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
|
|
}
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
this.invokeCallHook(inputs, kwargs);
|
|
const input = getExactlyOneTensor(inputs);
|
|
const inputShape = input.shape;
|
|
const outputShape = inputShape.slice(0, 1).concat(this.fixUnknownDimension(inputShape.slice(1), this.targetShape));
|
|
return reshape$2(input, outputShape);
|
|
});
|
|
}
|
|
getConfig() {
|
|
const config = {
|
|
targetShape: this.targetShape,
|
|
};
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
}
|
|
|
|
Reshape.className = 'Reshape';
|
|
registerClass(Reshape);
|
|
class Permute extends Layer {
|
|
constructor(args) {
|
|
super(args);
|
|
if (args.dims == null) {
|
|
throw new Error('Required configuration field `dims` is missing during Permute ' +
|
|
'constructor call.');
|
|
}
|
|
if (!Array.isArray(args.dims)) {
|
|
throw new Error('Permute constructor requires `dims` to be an Array, but received ' +
|
|
`${args.dims} instead.`);
|
|
}
|
|
|
|
const expectedSortedIndices = range(1, args.dims.length + 1);
|
|
if (!arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) {
|
|
throw new Error('Invalid permutation `dims`: ' + JSON.stringify(args.dims) +
|
|
' `dims` must contain consecutive integers starting from 1.');
|
|
}
|
|
this.dims = args.dims;
|
|
this.dimsIncludingBatch = [0].concat(this.dims);
|
|
this.inputSpec = [new InputSpec({ ndim: this.dims.length + 1 })];
|
|
}
|
|
computeOutputShape(inputShape) {
|
|
inputShape = getExactlyOneShape(inputShape);
|
|
const outputShape = inputShape.slice();
|
|
this.dims.forEach((dim, i) => {
|
|
outputShape[i + 1] = inputShape[dim];
|
|
});
|
|
return outputShape;
|
|
}
|
|
call(inputs, kwargs) {
|
|
return transpose$2(getExactlyOneTensor(inputs), this.dimsIncludingBatch);
|
|
}
|
|
getConfig() {
|
|
const config = {
|
|
dims: this.dims,
|
|
};
|
|
const baseConfig = super.getConfig();
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
}
|
|
|
|
Permute.className = 'Permute';
|
|
registerClass(Permute);
|
|
class Masking extends Layer {
|
|
constructor(args) {
|
|
super(args == null ? {} : args);
|
|
this.supportsMasking = true;
|
|
if (args != null) {
|
|
this.maskValue = args.maskValue == null ? 0 : args.maskValue;
|
|
}
|
|
else {
|
|
this.maskValue = 0;
|
|
}
|
|
}
|
|
computeOutputShape(inputShape) {
|
|
return inputShape;
|
|
}
|
|
getConfig() {
|
|
const baseConfig = super.getConfig();
|
|
const config = { maskValue: this.maskValue };
|
|
Object.assign(config, baseConfig);
|
|
return config;
|
|
}
|
|
computeMask(inputs, mask) {
|
|
const input = getExactlyOneTensor(inputs);
|
|
const axis = -1;
|
|
return any$2(notEqual$2(input, this.maskValue), axis);
|
|
}
|
|
call(inputs, kwargs) {
|
|
return tidy(() => {
|
|
this.invokeCallHook(inputs, kwargs);
|
|
const input = getExactlyOneTensor(inputs);
|
|
const axis = -1;
|
|
const keepDims = true;
|
|
const booleanMask = any$2(notEqual$2(input, this.maskValue), axis, keepDims);
|
|
const output = mul(input, cast$3(booleanMask, input.dtype));
|
|
return output;
|
|
});
|
|
}
|
|
}
|
|
|
|
Masking.className = 'Masking';
|
|
registerClass(Masking);
|
|
|
|
|
|
|
|
function dense(args) {
|
|
return new Dense(args);
|
|
}
|
|
|
|
function dropout(args) {
|
|
return new Dropout(args);
|
|
}
|
|
|
|
export { LayersModel, PlatformStub, dense, dropout, enableProdMode, env, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$2 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler };
|