mirror of
https://github.com/tutao/tutanota.git
synced 2025-12-08 06:09:50 +00:00
improve inbox rule handling and run spam prediction after inbox rules
Instead of applying inbox rules based on the unread mail state in the inbox folder, we introduce the new ProcessingState enum on the mail type. If a mail has been processed by the leader client, which is checking for matching inbox rules, the ProcessingState is updated. If there is a matching rule the flag is updated through the MoveMailService, if there is no matching rule, the flag is updated using the ClientClassifierResultService. Both requests are throttled / debounced. After processing inbox rules, spam prediction is conducted for mails that have not yet been moved by an inbox rule. The ProcessingState for not matching ham mails is also updated using the ClientClassifierResultService. This new inbox rule handing solves the following two problems: - when clicking on a notification it could still happen, that sometimes the inbox rules where not applied - when the inbox folder had a lot of unread mails, the loading time did massively increase, since inbox rules were re-applied on every load Co-authored-by: amm <amm@tutao.de> Co-authored-by: Nick <nif@tutao.de> Co-authored-by: das <das@tutao.de> Co-authored-by: abp <abp@tutao.de> Co-authored-by: jhm <17314077+jomapp@users.noreply.github.com> Co-authored-by: map <mpfau@users.noreply.github.com> Co-authored-by: Kinan <104761667+kibibytium@users.noreply.github.com>
This commit is contained in:
parent
030bea4fe6
commit
f11e59672e
53 changed files with 1269 additions and 1010 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import { assertWorkerOrNode } from "../../../common/api/common/Env"
|
||||
import { assertNotNull, defer, groupByAndMap, isNotNull, Nullable, promiseMap } from "@tutao/tutanota-utils"
|
||||
import { assertNotNull, defer, groupByAndMap, isNotNull, Nullable, promiseMap, tokenize } from "@tutao/tutanota-utils"
|
||||
import { DynamicTfVectorizer } from "./DynamicTfVectorizer"
|
||||
import { HashingVectorizer } from "./HashingVectorizer"
|
||||
import {
|
||||
|
|
@ -22,11 +22,9 @@ import {
|
|||
} from "./PreprocessPatterns"
|
||||
import { SpamClassificationInitializer } from "./SpamClassificationInitializer"
|
||||
import { CacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache"
|
||||
import { OfflineStoragePersistence } from "../index/OfflineStoragePersistence"
|
||||
import { filterMailMemberships, htmlToText } from "../../../common/api/common/utils/IndexUtils"
|
||||
import { htmlToText } from "../../../common/api/common/utils/IndexUtils"
|
||||
import {
|
||||
dense,
|
||||
dropout,
|
||||
fromMemory,
|
||||
glorotUniform,
|
||||
LayersModel,
|
||||
|
|
@ -39,6 +37,7 @@ import {
|
|||
import type { Tensor } from "@tensorflow/tfjs-core"
|
||||
import type { ModelArtifacts } from "@tensorflow/tfjs-core/dist/io/types"
|
||||
import type { ModelFitArgs } from "@tensorflow/tfjs-layers"
|
||||
import { OfflineStoragePersistence } from "../index/OfflineStoragePersistence"
|
||||
|
||||
assertWorkerOrNode()
|
||||
|
||||
|
|
@ -64,7 +63,7 @@ export type SpamPredMailDatum = {
|
|||
ownerGroup: Id
|
||||
}
|
||||
|
||||
const PREDICTION_THRESHOLD = 0.5
|
||||
const PREDICTION_THRESHOLD = 0.55
|
||||
|
||||
export type PreprocessConfiguration = {
|
||||
isPreprocessMails: boolean
|
||||
|
|
@ -92,18 +91,21 @@ export const DEFAULT_PREPROCESS_CONFIGURATION: PreprocessConfiguration = {
|
|||
isRemoveSpaceBeforeNewLine: true,
|
||||
}
|
||||
|
||||
const TRAINING_INTERVAL = 1000 * 60 * 10
|
||||
const TRAINING_INTERVAL = 1000 * 60 * 10 // 10 minutes
|
||||
const FULL_RETRAINING_INTERVAL = 1000 * 60 * 60 * 24 * 7 // 1 week
|
||||
|
||||
type TrainingPerformance = {
|
||||
trainingTime: number
|
||||
vectorizationTime: number
|
||||
}
|
||||
|
||||
export const spamClassifierTokenizer = (text: string): string[] => tokenize(text)
|
||||
|
||||
export class SpamClassifier {
|
||||
private readonly classifier: Map<Id, { model: LayersModel; isEnabled: boolean }>
|
||||
|
||||
constructor(
|
||||
private readonly offlineStorage: OfflineStoragePersistence | null,
|
||||
private readonly offlineStorage: OfflineStoragePersistence,
|
||||
private readonly offlineStorageCache: CacheStorage,
|
||||
private readonly initializer: SpamClassificationInitializer,
|
||||
private readonly deterministic: boolean = false,
|
||||
|
|
@ -113,18 +115,18 @@ export class SpamClassifier {
|
|||
this.classifier = new Map()
|
||||
}
|
||||
|
||||
public getEnabledSpamClassifierForOwnerGroup(ownerGroup: Id): Nullable<LayersModel> {
|
||||
const classifier = this.classifier.get(ownerGroup) ?? null
|
||||
if (classifier && classifier.isEnabled) {
|
||||
return classifier.model
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
public async initialize(ownerGroup: Id): Promise<void> {
|
||||
const loadedModel = await this.loadModel(ownerGroup)
|
||||
|
||||
const storage = assertNotNull(this.offlineStorageCache)
|
||||
setInterval(async () => {
|
||||
const cutoffDate = Date.now() - FULL_RETRAINING_INTERVAL
|
||||
const lastFullTrainingTime = await storage.getLastTrainedFromScratchTime()
|
||||
|
||||
if (cutoffDate > lastFullTrainingTime) {
|
||||
await this.retrainModelFromScratch(storage, ownerGroup, cutoffDate)
|
||||
}
|
||||
}, FULL_RETRAINING_INTERVAL)
|
||||
if (isNotNull(loadedModel)) {
|
||||
console.log("Loaded existing spam classification model from database")
|
||||
|
||||
|
|
@ -138,14 +140,19 @@ export class SpamClassifier {
|
|||
}
|
||||
|
||||
console.log("No existing model found. Training from scratch...")
|
||||
const data = await this.initializer.init(ownerGroup)
|
||||
await this.initialTraining(data)
|
||||
await this.saveModel(ownerGroup)
|
||||
await this.trainFromScratch(storage, ownerGroup)
|
||||
setInterval(async () => {
|
||||
await this.updateAndSaveModel(storage, ownerGroup)
|
||||
}, TRAINING_INTERVAL)
|
||||
}
|
||||
|
||||
private async trainFromScratch(storage: CacheStorage, ownerGroup: string) {
|
||||
const data = await this.initializer.init(ownerGroup)
|
||||
await this.initialTraining(data)
|
||||
await this.saveModel(ownerGroup)
|
||||
await storage.setLastTrainedFromScratchTime(Date.now())
|
||||
}
|
||||
|
||||
// VisibleForTesting
|
||||
public async updateAndSaveModel(storage: CacheStorage, ownerGroup: Id) {
|
||||
const isModelUpdated = await this.updateModelFromCutoff(await storage.getLastTrainedTime(), ownerGroup)
|
||||
|
|
@ -216,9 +223,11 @@ export class SpamClassifier {
|
|||
}
|
||||
|
||||
public async initialTraining(mails: SpamTrainMailDatum[]): Promise<TrainingPerformance> {
|
||||
const vectorizationStart = performance.now()
|
||||
const preprocessingStart = performance.now()
|
||||
const tokenizedMails = await promiseMap(mails, (mail) => spamClassifierTokenizer(this.preprocessMail(mail)))
|
||||
const preprocessingTime = performance.now() - preprocessingStart
|
||||
|
||||
const tokenizedMails = await promiseMap(mails, (mail) => assertNotNull(this.offlineStorage).tokenize(this.preprocessMail(mail)))
|
||||
const vectorizationStart = performance.now()
|
||||
if (this.vectorizer instanceof DynamicTfVectorizer) {
|
||||
this.vectorizer.buildInitialTokenVocabulary(tokenizedMails)
|
||||
}
|
||||
|
|
@ -237,11 +246,13 @@ export class SpamClassifier {
|
|||
epochs: 16,
|
||||
batchSize: 32,
|
||||
shuffle: !this.deterministic,
|
||||
callbacks: {
|
||||
onEpochEnd: async (epoch, logs) => {
|
||||
console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`)
|
||||
},
|
||||
},
|
||||
// callbacks: {
|
||||
// onEpochEnd: async (epoch, logs) => {
|
||||
// if (logs) {
|
||||
// console.log(`Epoch ${epoch + 1} - Loss: ${logs.loss.toFixed(4)}`)
|
||||
// }
|
||||
// },
|
||||
// },
|
||||
})
|
||||
const trainingTime = performance.now() - trainingStart
|
||||
|
||||
|
|
@ -251,7 +262,9 @@ export class SpamClassifier {
|
|||
|
||||
this.classifier.set(mails[0].ownerGroup, { model: classifier, isEnabled: true })
|
||||
|
||||
console.log(`### Finished Initial Training ### (total trained mails: ${mails.length})`)
|
||||
console.log(
|
||||
`### Finished Initial Training ### (total trained mails: ${mails.length}, preprocessing time: ${preprocessingTime}, vectorization time: ${vectorizationTime}ms, training time: ${trainingTime})`,
|
||||
)
|
||||
|
||||
return { vectorizationTime, trainingTime }
|
||||
}
|
||||
|
|
@ -283,10 +296,9 @@ export class SpamClassifier {
|
|||
const retrainingStart = performance.now()
|
||||
|
||||
const modelToUpdate = assertNotNull(this.classifier.get(ownerGroup))
|
||||
const offlineStorage = assertNotNull(this.offlineStorage)
|
||||
const tokenizedMailsArray = await promiseMap(newTrainingMails, async (mail) => {
|
||||
const preprocessedMail = this.preprocessMail(mail)
|
||||
const tokenizedMail = await offlineStorage.tokenize(preprocessedMail)
|
||||
const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
|
||||
return { tokenizedMail, isSpamConfidence: mail.isSpamConfidence, isSpam: mail.isSpam ? 1 : 0 }
|
||||
})
|
||||
|
||||
|
|
@ -319,11 +331,11 @@ export class SpamClassifier {
|
|||
epochs: 8,
|
||||
batchSize: 32,
|
||||
shuffle: !this.deterministic,
|
||||
callbacks: {
|
||||
onEpochEnd: async (epoch, logs) => {
|
||||
console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`)
|
||||
},
|
||||
},
|
||||
// callbacks: {
|
||||
// onEpochEnd: async (epoch, logs) => {
|
||||
// console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`)
|
||||
// },
|
||||
// },
|
||||
}
|
||||
for (let i = 0; i <= isSpamConfidence; i++) {
|
||||
await modelToUpdate.model.fit(xs, ys, modelFitArgs)
|
||||
|
|
@ -349,7 +361,7 @@ export class SpamClassifier {
|
|||
}
|
||||
|
||||
const preprocessedMail = this.preprocessMail(spamPredMailDatum)
|
||||
const tokenizedMail = await assertNotNull(this.offlineStorage).tokenize(preprocessedMail)
|
||||
const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
|
||||
const vectors = await assertNotNull(this.vectorizer).transform([tokenizedMail])
|
||||
|
||||
const xs = tensor2d(vectors, [vectors.length, assertNotNull(this.vectorizer).dimension], undefined)
|
||||
|
|
@ -357,7 +369,7 @@ export class SpamClassifier {
|
|||
const predictionData = await predictionTensor.data()
|
||||
const prediction = predictionData[0]
|
||||
|
||||
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${spamPredMailDatum.ownerGroup}`)
|
||||
// console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${spamPredMailDatum.ownerGroup}`)
|
||||
|
||||
// When using the webgl backend we need to manually dispose @tensorflow tensors
|
||||
xs.dispose()
|
||||
|
|
@ -366,97 +378,20 @@ export class SpamClassifier {
|
|||
return prediction > PREDICTION_THRESHOLD
|
||||
}
|
||||
|
||||
/*
|
||||
* TODO: Only for internal release
|
||||
*
|
||||
* Allows to check the accuracy of your currently trained classifier against the content of mailbox itself
|
||||
* How-to:
|
||||
* 1) Open console and switch context to worker-bootstrap.js
|
||||
* 2) Execute this method in console: `locator.spamClassifier.getSpamMetricsForCurrentMailBox()`
|
||||
* 3) Let execution continue from breakpoint
|
||||
*
|
||||
* Since we change constant of this.initializer,
|
||||
* it's better to restart the client to not have unexpected effect
|
||||
*/
|
||||
public async getSpamMetricsForCurrentMailBox(ownerGroup?: Id): Promise<void> {
|
||||
const { LocalTimeDateProvider } = await import("../../../common/api/worker/DateProvider.js")
|
||||
const dateProvider = new LocalTimeDateProvider()
|
||||
|
||||
const getIdOfClassificationMail = (classificationData: any) => {
|
||||
return ((classificationData.listId as Id) + "/" + classificationData.elementId) as Id
|
||||
}
|
||||
const user = assertNotNull((this.initializer as any).userFacade.getUser())
|
||||
const firstOwnerGroup = ownerGroup ?? filterMailMemberships(user)[0]._id
|
||||
console.log(`Testing with ownergroup: ${firstOwnerGroup}`)
|
||||
|
||||
const readingAllSpamStart = performance.now()
|
||||
const trainedMails = await assertNotNull(this.offlineStorage)
|
||||
.getCertainSpamClassificationTrainingDataAfterCutoff(0, firstOwnerGroup)
|
||||
.then((mails) => new Set(mails.map(getIdOfClassificationMail)))
|
||||
console.log(`Done reading ${trainedMails.size} certain training mail data in: ${performance.now() - readingAllSpamStart}ms`)
|
||||
|
||||
// since we train with last -28 days, we can test with last -90
|
||||
;(this.initializer as any).TIME_LIMIT = dateProvider.getStartOfDayShiftedBy(-90)
|
||||
// if exists, try to test with at 5xleast same number of mails as in training sample
|
||||
;(this.initializer as any).MIN_MAILS_COUNT = trainedMails.size * 5
|
||||
// to avoid putting stuff into offline storage
|
||||
;(this.initializer as any).offlineStorage.storeSpamClassification = async () => {
|
||||
console.log("not putting classification datum into offline storage")
|
||||
}
|
||||
|
||||
const downloadingExtraMailsStart = performance.now()
|
||||
const testingMails = (await this.initializer.init(firstOwnerGroup))
|
||||
// do not test with the same mails that was used to train
|
||||
.filter((classificationData) => !trainedMails.has(getIdOfClassificationMail(classificationData)))
|
||||
console.log(`Done downloading extra ${testingMails.length} of last 90 days mail data in: ${performance.now() - downloadingExtraMailsStart}ms`)
|
||||
|
||||
const testingAllSamplesStart = performance.now()
|
||||
await this.test(testingMails)
|
||||
console.log(`Done testing all extra mails sample in: ${performance.now() - testingAllSamplesStart}ms`)
|
||||
public getSpamClassification(mailId: IdTuple) {
|
||||
return this.offlineStorage.getSpamClassification(mailId)
|
||||
}
|
||||
|
||||
public async test(mails: SpamTrainMailDatum[]): Promise<void> {
|
||||
if (!this.classifier) {
|
||||
throw new Error("Model has not been loaded")
|
||||
}
|
||||
public updateSpamClassification(mailId: IdTuple, isSpam: boolean, isSpamConfidence: number) {
|
||||
return this.offlineStorage.updateSpamClassification(mailId, isSpam, isSpamConfidence)
|
||||
}
|
||||
|
||||
let predictionArray: number[] = []
|
||||
for (let mail of mails) {
|
||||
const prediction = await this.predict(mail)
|
||||
predictionArray.push(prediction ? 1 : 0)
|
||||
}
|
||||
const ysArray = mails.map((mail) => mail.isSpam)
|
||||
public storeSpamClassification(spamTrainMailDatum: SpamTrainMailDatum) {
|
||||
return this.offlineStorage.storeSpamClassification(spamTrainMailDatum)
|
||||
}
|
||||
|
||||
let tp = 0,
|
||||
tn = 0,
|
||||
fp = 0,
|
||||
fn = 0
|
||||
|
||||
for (let i = 0; i < predictionArray.length; i++) {
|
||||
const predictedSpam = predictionArray[i] > 0.5
|
||||
const isActuallyASpam = ysArray[i]
|
||||
if (predictedSpam && isActuallyASpam) tp++
|
||||
else if (!predictedSpam && !isActuallyASpam) tn++
|
||||
else if (predictedSpam && !isActuallyASpam) fp++
|
||||
else if (!predictedSpam && isActuallyASpam) fn++
|
||||
}
|
||||
|
||||
const total = tp + tn + fp + fn
|
||||
const accuracy = (tp + tn) / total
|
||||
const precision = tp / (tp + fp + 1e-7)
|
||||
const recall = tp / (tp + fn + 1e-7)
|
||||
const f1 = 2 * ((precision * recall) / (precision + recall + 1e-7))
|
||||
|
||||
console.log("\n--- Evaluation Metrics ---")
|
||||
console.log(`Accuracy: \t${(accuracy * 100).toFixed(2)}%`)
|
||||
console.log(`Precision:\t${(precision * 100).toFixed(2)}%`)
|
||||
console.log(`Recall: \t${(recall * 100).toFixed(2)}%`)
|
||||
console.log(`F1 Score: \t${(f1 * 100).toFixed(2)}%`)
|
||||
console.log("\nConfusion Matrix:")
|
||||
console.log({
|
||||
Predicted_Spam: { True_Positive: tp, False_Positive: fp },
|
||||
Predicted_Ham: { False_Negative: fn, True_Negative: tn },
|
||||
})
|
||||
public deleteSpamClassification(mailId: IdTuple) {
|
||||
return this.offlineStorage.deleteSpamClassification(mailId)
|
||||
}
|
||||
|
||||
// visibleForTesting
|
||||
|
|
@ -538,7 +473,7 @@ export class SpamClassifier {
|
|||
}
|
||||
}
|
||||
|
||||
// VisibleForTesting
|
||||
// visibleForTesting
|
||||
public async loadModel(ownerGroup: Id): Promise<Nullable<LayersModel>> {
|
||||
const model = await assertNotNull(this.offlineStorage).getSpamClassificationModel(ownerGroup)
|
||||
if (model) {
|
||||
|
|
@ -565,7 +500,19 @@ export class SpamClassifier {
|
|||
return concatenated.length > 0 ? concatenated : " "
|
||||
}
|
||||
|
||||
// === Testing methods
|
||||
private async retrainModelFromScratch(storage: CacheStorage, ownerGroup: Id, cutoffTimestamp: number) {
|
||||
console.log("Model is being re-trained from scratch, deleting old data")
|
||||
try {
|
||||
await assertNotNull(this.offlineStorage).deleteSpamClassificationTrainingDataBeforeCutoff(cutoffTimestamp, ownerGroup)
|
||||
} catch (e) {
|
||||
console.error("Failed delete old training data: ", e)
|
||||
return
|
||||
}
|
||||
|
||||
await this.trainFromScratch(storage, ownerGroup)
|
||||
}
|
||||
|
||||
// visibleForTesting
|
||||
public async cloneClassifier(): Promise<SpamClassifier> {
|
||||
const newClassifier = new SpamClassifier(
|
||||
this.offlineStorage,
|
||||
|
|
@ -587,4 +534,9 @@ export class SpamClassifier {
|
|||
|
||||
return newClassifier
|
||||
}
|
||||
|
||||
// visibleForTesting
|
||||
public addSpamClassifierForOwner(ownerGroup: Id, model: LayersModel, isEnabled: boolean) {
|
||||
this.classifier.set(ownerGroup, { model, isEnabled })
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue