mirror of
https://github.com/tutao/tutanota.git
synced 2025-12-07 05:39:56 +00:00
instantiate and import spam classifier lazily
Co-authored-by: das <das@tutao.de>
This commit is contained in:
parent
858d7e1e71
commit
c33591eaca
10 changed files with 29 additions and 33 deletions
|
|
@ -33,7 +33,7 @@ export const allowedImports = {
|
|||
wasm: ["wasm-fallback"],
|
||||
"common-min": ["polyfill-helpers"],
|
||||
boot: ["polyfill-helpers", "common-min"],
|
||||
common: ["polyfill-helpers", "common-min", "spam-classifier"],
|
||||
common: ["polyfill-helpers", "common-min"],
|
||||
"gui-base": ["polyfill-helpers", "common-min", "common", "boot"],
|
||||
main: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "date"],
|
||||
sanitizer: ["polyfill-helpers", "common-min", "common", "boot", "gui-base"],
|
||||
|
|
@ -47,7 +47,7 @@ export const allowedImports = {
|
|||
"calendar-view": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "date", "date-gui", "sharing", "contacts"],
|
||||
login: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main"],
|
||||
"spam-classifier": ["polyfill-helpers", "common", "common-min"],
|
||||
worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback", "spam-classifier"],
|
||||
worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback"],
|
||||
"pow-worker": [],
|
||||
settings: [
|
||||
"polyfill-helpers",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { HashingVectorizer } from "../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer"
|
||||
import type { HashingVectorizer } from "../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer"
|
||||
import { htmlToText } from "../IndexUtils"
|
||||
import {
|
||||
ML_BITCOIN_REGEX,
|
||||
|
|
@ -19,11 +19,11 @@ import {
|
|||
ML_URL_TOKEN,
|
||||
} from "./PreprocessPatterns"
|
||||
import { SparseVectorCompressor } from "./SparseVectorCompressor"
|
||||
import { ProgrammingError } from "../../error/ProgrammingError"
|
||||
import { assertNotNull, tokenize } from "@tutao/tutanota-utils"
|
||||
import { assertNotNull, lazyAsync, lazyMemoized, tokenize } from "@tutao/tutanota-utils"
|
||||
import { Mail, MailAddress, MailDetails } from "../../../entities/tutanota/TypeRefs"
|
||||
import { getMailBodyText } from "../../CommonMailUtils"
|
||||
import { MailAuthenticationStatus } from "../../TutanotaConstants"
|
||||
import { ProgrammingError } from "../../error/ProgrammingError"
|
||||
|
||||
export type PreprocessConfiguration = {
|
||||
isPreprocessMails: boolean
|
||||
|
|
@ -69,15 +69,12 @@ export type PreprocessedMailContent = string
|
|||
export class SpamMailProcessor {
|
||||
constructor(
|
||||
private readonly preprocessConfiguration: PreprocessConfiguration = DEFAULT_PREPROCESS_CONFIGURATION,
|
||||
readonly vectorizer: HashingVectorizer = new HashingVectorizer(),
|
||||
private readonly sparseVectorCompressor: SparseVectorCompressor = new SparseVectorCompressor(),
|
||||
) {
|
||||
if (vectorizer.dimension !== sparseVectorCompressor.dimension) {
|
||||
throw new ProgrammingError(
|
||||
`a spam preprocessor was created with different dimensions. Vectorizer:${vectorizer.dimension} compressor: ${sparseVectorCompressor.dimension}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
private readonly vectorizer: lazyAsync<HashingVectorizer> = lazyMemoized(async () => {
|
||||
const { HashingVectorizer } = await import("../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer")
|
||||
return new HashingVectorizer(this.sparseVectorCompressor.dimension)
|
||||
}),
|
||||
) {}
|
||||
|
||||
public async vectorizeAndCompress(spamMailDatum: SpamMailDatum): Promise<Uint8Array> {
|
||||
const vector = await this.vectorize(spamMailDatum)
|
||||
|
|
@ -85,13 +82,13 @@ export class SpamMailProcessor {
|
|||
}
|
||||
|
||||
public async vectorize(spamMailDatum: SpamMailDatum): Promise<number[]> {
|
||||
const vectorizer = await this.vectorizer()
|
||||
const preprocessedMail = this.preprocessMail(spamMailDatum)
|
||||
const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
|
||||
const vector = await this.vectorizer.vectorize(tokenizedMail)
|
||||
return vector
|
||||
return await vectorizer.vectorize(tokenizedMail)
|
||||
}
|
||||
|
||||
public async compress(uncompressedVector: number[]): Promise<Uint8Array> {
|
||||
private async compress(uncompressedVector: number[]): Promise<Uint8Array> {
|
||||
return this.sparseVectorCompressor.vectorToBinary(uncompressedVector)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ import { AttributeModel } from "../../common/AttributeModel"
|
|||
import { TypeModelResolver } from "../../common/EntityFunctions"
|
||||
import { collapseId, expandId } from "../rest/RestClientIdUtils"
|
||||
import { Category, syncMetrics } from "../utils/SyncMetrics"
|
||||
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
|
||||
/**
|
||||
* this is the value of SQLITE_MAX_VARIABLE_NUMBER in sqlite3.c
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import { Nullable, TypeRef } from "@tutao/tutanota-utils"
|
|||
import { OfflineStorage, OfflineStorageInitArgs } from "../offline/OfflineStorage.js"
|
||||
import { EphemeralCacheStorage, EphemeralStorageInitArgs } from "./EphemeralCacheStorage"
|
||||
import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js"
|
||||
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
|
||||
export interface EphemeralStorageArgs extends EphemeralStorageInitArgs {
|
||||
type: "ephemeral"
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ import { AttributeModel } from "../../common/AttributeModel"
|
|||
import { collapseId, expandId } from "./RestClientIdUtils"
|
||||
import { PatchMerger } from "../offline/PatchMerger"
|
||||
import { hasError, isExpectedErrorForSynchronization } from "../../common/utils/ErrorUtils"
|
||||
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
|
||||
assertWorkerOrNode()
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import { ModelMapper } from "../crypto/ModelMapper"
|
|||
import { ServerTypeModelResolver } from "../../common/EntityFunctions"
|
||||
import { expandId } from "./RestClientIdUtils"
|
||||
import { hasError } from "../../common/utils/ErrorUtils"
|
||||
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
|
||||
/** Cache for a single list. */
|
||||
type ListCache = {
|
||||
|
|
|
|||
|
|
@ -20,9 +20,6 @@ import type { Tensor } from "@tensorflow/tfjs-core"
|
|||
import { DEFAULT_PREPROCESS_CONFIGURATION, SpamMailDatum, SpamMailProcessor } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
|
||||
import { SparseVectorCompressor } from "../../../common/api/common/utils/spamClassificationUtils/SparseVectorCompressor"
|
||||
import { SpamDecision } from "../../../common/api/common/TutanotaConstants"
|
||||
import { HashingVectorizer } from "./HashingVectorizer"
|
||||
|
||||
assertWorkerOrNode()
|
||||
|
||||
export type SpamClassificationModel = {
|
||||
modelTopology: string
|
||||
|
|
@ -61,7 +58,7 @@ export class SpamClassifier {
|
|||
enableProdMode()
|
||||
this.classifiers = new Map()
|
||||
this.sparseVectorCompressor = new SparseVectorCompressor()
|
||||
this.spamMailProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(), this.sparseVectorCompressor)
|
||||
this.spamMailProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, this.sparseVectorCompressor)
|
||||
}
|
||||
|
||||
calculateThreshold(hamCount: number, spamCount: number) {
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ export class WorkerImpl implements NativeInterface {
|
|||
return locator.autosaveFacade()
|
||||
},
|
||||
async spamClassifier() {
|
||||
return locator.spamClassifier
|
||||
return locator.spamClassifier()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,10 +112,9 @@ import { PublicKeySignatureFacade } from "../../../common/api/worker/facades/Pub
|
|||
import { AdminKeyLoaderFacade } from "../../../common/api/worker/facades/AdminKeyLoaderFacade"
|
||||
import { IdentityKeyCreator } from "../../../common/api/worker/facades/lazy/IdentityKeyCreator"
|
||||
import { PublicIdentityKeyProvider } from "../../../common/api/worker/facades/PublicIdentityKeyProvider"
|
||||
import { SpamClassifier } from "../spamClassification/SpamClassifier"
|
||||
import { IdentityKeyTrustDatabase } from "../../../common/api/worker/facades/IdentityKeyTrustDatabase"
|
||||
import { AutosaveFacade } from "../../../common/api/worker/facades/lazy/AutosaveFacade"
|
||||
import { SpamClassificationDataDealer } from "../spamClassification/SpamClassificationDataDealer"
|
||||
import type { SpamClassifier } from "../spamClassification/SpamClassifier"
|
||||
|
||||
assertWorkerOrNode()
|
||||
|
||||
|
|
@ -198,7 +197,7 @@ export type WorkerLocatorType = {
|
|||
contactFacade: lazyAsync<ContactFacade>
|
||||
|
||||
//spam classification
|
||||
spamClassifier: SpamClassifier
|
||||
spamClassifier: lazyAsync<SpamClassifier>
|
||||
}
|
||||
export const locator: WorkerLocatorType = {} as any
|
||||
|
||||
|
|
@ -740,8 +739,12 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
|
|||
)
|
||||
})
|
||||
|
||||
const spamClassificationDataDealer = new SpamClassificationDataDealer(locator.cachingEntityClient, locator.bulkMailLoader, locator.mail)
|
||||
locator.spamClassifier = new SpamClassifier(locator.cacheStorage, spamClassificationDataDealer)
|
||||
locator.spamClassifier = lazyMemoized(async () => {
|
||||
const { SpamClassificationDataDealer } = await import("../spamClassification/SpamClassificationDataDealer")
|
||||
const { SpamClassifier } = await import("../spamClassification/SpamClassifier")
|
||||
const spamClassificationDataDealer = new SpamClassificationDataDealer(locator.cachingEntityClient, locator.bulkMailLoader, locator.mail)
|
||||
return new SpamClassifier(locator.cacheStorage, spamClassificationDataDealer)
|
||||
})
|
||||
|
||||
const nativePushFacade = new NativePushFacadeSendDispatcher(worker)
|
||||
locator.calendar = lazyMemoized(async () => {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import { SpamClassificationDataDealer, TrainingDataset } from "../../../../../..
|
|||
import { CacheStorage } from "../../../../../../src/common/api/worker/rest/DefaultEntityRestCache"
|
||||
import { mockAttribute } from "@tutao/tutanota-test-utils"
|
||||
import "@tensorflow/tfjs-backend-cpu"
|
||||
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
|
||||
import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom"
|
||||
import { createTestEntity } from "../../../../TestUtils"
|
||||
import { ClientSpamTrainingDatum, ClientSpamTrainingDatumTypeRef, MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
|
||||
|
|
@ -111,7 +110,7 @@ o.spec("SpamClassifierTest", () => {
|
|||
const vectorLength = 512
|
||||
|
||||
compressor = new SparseVectorCompressor(vectorLength)
|
||||
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(vectorLength), compressor)
|
||||
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, compressor)
|
||||
spamClassifier = new SpamClassifier(mockCacheStorage, mockSpamClassificationDataDealer, true)
|
||||
spamClassifier.spamMailProcessor = spamProcessor
|
||||
spamClassifier.sparseVectorCompressor = compressor
|
||||
|
|
@ -529,7 +528,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
mockSpamClassificationDataDealer.fetchAllTrainingData = async () => {
|
||||
return getTrainingDataset(dataSlice)
|
||||
}
|
||||
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(), compressor)
|
||||
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, compressor)
|
||||
spamClassifier = new SpamClassifier(mockOfflineStorageCache, mockSpamClassificationDataDealer, false)
|
||||
spamClassifier.spamMailProcessor = spamProcessor
|
||||
})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue