instantiate and import spam classifier lazily

Co-authored-by: das <das@tutao.de>
This commit is contained in:
abp 2025-11-18 16:42:23 +01:00 committed by das
parent 858d7e1e71
commit c33591eaca
10 changed files with 29 additions and 33 deletions

View file

@ -33,7 +33,7 @@ export const allowedImports = {
wasm: ["wasm-fallback"],
"common-min": ["polyfill-helpers"],
boot: ["polyfill-helpers", "common-min"],
common: ["polyfill-helpers", "common-min", "spam-classifier"],
common: ["polyfill-helpers", "common-min"],
"gui-base": ["polyfill-helpers", "common-min", "common", "boot"],
main: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "date"],
sanitizer: ["polyfill-helpers", "common-min", "common", "boot", "gui-base"],
@ -47,7 +47,7 @@ export const allowedImports = {
"calendar-view": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "date", "date-gui", "sharing", "contacts"],
login: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main"],
"spam-classifier": ["polyfill-helpers", "common", "common-min"],
worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback", "spam-classifier"],
worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback"],
"pow-worker": [],
settings: [
"polyfill-helpers",

View file

@ -1,4 +1,4 @@
import { HashingVectorizer } from "../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer"
import type { HashingVectorizer } from "../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer"
import { htmlToText } from "../IndexUtils"
import {
ML_BITCOIN_REGEX,
@ -19,11 +19,11 @@ import {
ML_URL_TOKEN,
} from "./PreprocessPatterns"
import { SparseVectorCompressor } from "./SparseVectorCompressor"
import { ProgrammingError } from "../../error/ProgrammingError"
import { assertNotNull, tokenize } from "@tutao/tutanota-utils"
import { assertNotNull, lazyAsync, lazyMemoized, tokenize } from "@tutao/tutanota-utils"
import { Mail, MailAddress, MailDetails } from "../../../entities/tutanota/TypeRefs"
import { getMailBodyText } from "../../CommonMailUtils"
import { MailAuthenticationStatus } from "../../TutanotaConstants"
import { ProgrammingError } from "../../error/ProgrammingError"
export type PreprocessConfiguration = {
isPreprocessMails: boolean
@ -69,15 +69,12 @@ export type PreprocessedMailContent = string
export class SpamMailProcessor {
constructor(
private readonly preprocessConfiguration: PreprocessConfiguration = DEFAULT_PREPROCESS_CONFIGURATION,
readonly vectorizer: HashingVectorizer = new HashingVectorizer(),
private readonly sparseVectorCompressor: SparseVectorCompressor = new SparseVectorCompressor(),
) {
if (vectorizer.dimension !== sparseVectorCompressor.dimension) {
throw new ProgrammingError(
`a spam preprocessor was created with different dimensions. Vectorizer:${vectorizer.dimension} compressor: ${sparseVectorCompressor.dimension}`,
)
}
}
private readonly vectorizer: lazyAsync<HashingVectorizer> = lazyMemoized(async () => {
const { HashingVectorizer } = await import("../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer")
return new HashingVectorizer(this.sparseVectorCompressor.dimension)
}),
) {}
public async vectorizeAndCompress(spamMailDatum: SpamMailDatum): Promise<Uint8Array> {
const vector = await this.vectorize(spamMailDatum)
@ -85,13 +82,13 @@ export class SpamMailProcessor {
}
public async vectorize(spamMailDatum: SpamMailDatum): Promise<number[]> {
const vectorizer = await this.vectorizer()
const preprocessedMail = this.preprocessMail(spamMailDatum)
const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
const vector = await this.vectorizer.vectorize(tokenizedMail)
return vector
return await vectorizer.vectorize(tokenizedMail)
}
public async compress(uncompressedVector: number[]): Promise<Uint8Array> {
private async compress(uncompressedVector: number[]): Promise<Uint8Array> {
return this.sparseVectorCompressor.vectorToBinary(uncompressedVector)
}

View file

@ -46,7 +46,7 @@ import { AttributeModel } from "../../common/AttributeModel"
import { TypeModelResolver } from "../../common/EntityFunctions"
import { collapseId, expandId } from "../rest/RestClientIdUtils"
import { Category, syncMetrics } from "../utils/SyncMetrics"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
/**
* this is the value of SQLITE_MAX_VARIABLE_NUMBER in sqlite3.c

View file

@ -5,7 +5,7 @@ import { Nullable, TypeRef } from "@tutao/tutanota-utils"
import { OfflineStorage, OfflineStorageInitArgs } from "../offline/OfflineStorage.js"
import { EphemeralCacheStorage, EphemeralStorageInitArgs } from "./EphemeralCacheStorage"
import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
export interface EphemeralStorageArgs extends EphemeralStorageInitArgs {
type: "ephemeral"

View file

@ -55,7 +55,7 @@ import { AttributeModel } from "../../common/AttributeModel"
import { collapseId, expandId } from "./RestClientIdUtils"
import { PatchMerger } from "../offline/PatchMerger"
import { hasError, isExpectedErrorForSynchronization } from "../../common/utils/ErrorUtils"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
assertWorkerOrNode()

View file

@ -10,7 +10,7 @@ import { ModelMapper } from "../crypto/ModelMapper"
import { ServerTypeModelResolver } from "../../common/EntityFunctions"
import { expandId } from "./RestClientIdUtils"
import { hasError } from "../../common/utils/ErrorUtils"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
import type { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
/** Cache for a single list. */
type ListCache = {

View file

@ -20,9 +20,6 @@ import type { Tensor } from "@tensorflow/tfjs-core"
import { DEFAULT_PREPROCESS_CONFIGURATION, SpamMailDatum, SpamMailProcessor } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { SparseVectorCompressor } from "../../../common/api/common/utils/spamClassificationUtils/SparseVectorCompressor"
import { SpamDecision } from "../../../common/api/common/TutanotaConstants"
import { HashingVectorizer } from "./HashingVectorizer"
assertWorkerOrNode()
export type SpamClassificationModel = {
modelTopology: string
@ -61,7 +58,7 @@ export class SpamClassifier {
enableProdMode()
this.classifiers = new Map()
this.sparseVectorCompressor = new SparseVectorCompressor()
this.spamMailProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(), this.sparseVectorCompressor)
this.spamMailProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, this.sparseVectorCompressor)
}
calculateThreshold(hamCount: number, spamCount: number) {

View file

@ -308,7 +308,7 @@ export class WorkerImpl implements NativeInterface {
return locator.autosaveFacade()
},
async spamClassifier() {
return locator.spamClassifier
return locator.spamClassifier()
},
}
}

View file

@ -112,10 +112,9 @@ import { PublicKeySignatureFacade } from "../../../common/api/worker/facades/Pub
import { AdminKeyLoaderFacade } from "../../../common/api/worker/facades/AdminKeyLoaderFacade"
import { IdentityKeyCreator } from "../../../common/api/worker/facades/lazy/IdentityKeyCreator"
import { PublicIdentityKeyProvider } from "../../../common/api/worker/facades/PublicIdentityKeyProvider"
import { SpamClassifier } from "../spamClassification/SpamClassifier"
import { IdentityKeyTrustDatabase } from "../../../common/api/worker/facades/IdentityKeyTrustDatabase"
import { AutosaveFacade } from "../../../common/api/worker/facades/lazy/AutosaveFacade"
import { SpamClassificationDataDealer } from "../spamClassification/SpamClassificationDataDealer"
import type { SpamClassifier } from "../spamClassification/SpamClassifier"
assertWorkerOrNode()
@ -198,7 +197,7 @@ export type WorkerLocatorType = {
contactFacade: lazyAsync<ContactFacade>
//spam classification
spamClassifier: SpamClassifier
spamClassifier: lazyAsync<SpamClassifier>
}
export const locator: WorkerLocatorType = {} as any
@ -740,8 +739,12 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
)
})
const spamClassificationDataDealer = new SpamClassificationDataDealer(locator.cachingEntityClient, locator.bulkMailLoader, locator.mail)
locator.spamClassifier = new SpamClassifier(locator.cacheStorage, spamClassificationDataDealer)
locator.spamClassifier = lazyMemoized(async () => {
const { SpamClassificationDataDealer } = await import("../spamClassification/SpamClassificationDataDealer")
const { SpamClassifier } = await import("../spamClassification/SpamClassifier")
const spamClassificationDataDealer = new SpamClassificationDataDealer(locator.cachingEntityClient, locator.bulkMailLoader, locator.mail)
return new SpamClassifier(locator.cacheStorage, spamClassificationDataDealer)
})
const nativePushFacade = new NativePushFacadeSendDispatcher(worker)
locator.calendar = lazyMemoized(async () => {

View file

@ -8,7 +8,6 @@ import { SpamClassificationDataDealer, TrainingDataset } from "../../../../../..
import { CacheStorage } from "../../../../../../src/common/api/worker/rest/DefaultEntityRestCache"
import { mockAttribute } from "@tutao/tutanota-test-utils"
import "@tensorflow/tfjs-backend-cpu"
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom"
import { createTestEntity } from "../../../../TestUtils"
import { ClientSpamTrainingDatum, ClientSpamTrainingDatumTypeRef, MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
@ -111,7 +110,7 @@ o.spec("SpamClassifierTest", () => {
const vectorLength = 512
compressor = new SparseVectorCompressor(vectorLength)
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(vectorLength), compressor)
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, compressor)
spamClassifier = new SpamClassifier(mockCacheStorage, mockSpamClassificationDataDealer, true)
spamClassifier.spamMailProcessor = spamProcessor
spamClassifier.sparseVectorCompressor = compressor
@ -529,7 +528,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
mockSpamClassificationDataDealer.fetchAllTrainingData = async () => {
return getTrainingDataset(dataSlice)
}
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(), compressor)
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, compressor)
spamClassifier = new SpamClassifier(mockOfflineStorageCache, mockSpamClassificationDataDealer, false)
spamClassifier.spamMailProcessor = spamProcessor
})