mirror of
https://github.com/tutao/tutanota.git
synced 2025-12-08 06:09:50 +00:00
improve inbox rule handling and run spam prediction after inbox rules
Instead of applying inbox rules based on the unread mail state in the inbox folder, we introduce the new ProcessingState enum on the mail type. If a mail has been processed by the leader client, which is checking for matching inbox rules, the ProcessingState is updated. If there is a matching rule the flag is updated through the MoveMailService, if there is no matching rule, the flag is updated using the ClientClassifierResultService. Both requests are throttled / debounced. After processing inbox rules, spam prediction is conducted for mails that have not yet been moved by an inbox rule. The ProcessingState for not matching ham mails is also updated using the ClientClassifierResultService. This new inbox rule handing solves the following two problems: - when clicking on a notification it could still happen, that sometimes the inbox rules where not applied - when the inbox folder had a lot of unread mails, the loading time did massively increase, since inbox rules were re-applied on every load Co-authored-by: amm <amm@tutao.de> Co-authored-by: Nick <nif@tutao.de> Co-authored-by: das <das@tutao.de> Co-authored-by: abp <abp@tutao.de> Co-authored-by: jhm <17314077+jomapp@users.noreply.github.com> Co-authored-by: map <mpfau@users.noreply.github.com> Co-authored-by: Kinan <104761667+kibibytium@users.noreply.github.com>
This commit is contained in:
parent
030bea4fe6
commit
f11e59672e
53 changed files with 1269 additions and 1010 deletions
|
|
@ -4,9 +4,9 @@ import { parseCsv } from "../../../../../../src/common/misc/parsing/CsvParser"
|
|||
import {
|
||||
DEFAULT_PREPROCESS_CONFIGURATION,
|
||||
SpamClassifier,
|
||||
spamClassifierTokenizer as testTokenize,
|
||||
SpamTrainMailDatum,
|
||||
} from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
|
||||
import { tokenize as testTokenize } from "./HashingVectorizerTest"
|
||||
import { OfflineStoragePersistence } from "../../../../../../src/mail-app/workerUtils/index/OfflineStoragePersistence"
|
||||
import { matchers, object, when } from "testdouble"
|
||||
import { assertNotNull, promiseMap } from "@tutao/tutanota-utils"
|
||||
|
|
@ -16,7 +16,11 @@ import { mockAttribute } from "@tutao/tutanota-test-utils"
|
|||
import "@tensorflow/tfjs-backend-cpu"
|
||||
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
|
||||
import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom"
|
||||
import { createTestEntity } from "../../../../TestUtils"
|
||||
import { MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
|
||||
import { Sequential } from "@tensorflow/tfjs-layers"
|
||||
|
||||
const { anything } = matchers
|
||||
export const DATASET_FILE_PATH: string = "./tests/api/worker/utils/spamClassification/spam_classification_test_mails.csv"
|
||||
|
||||
export async function readMailDataFromCSV(filePath: string): Promise<{
|
||||
|
|
@ -50,7 +54,7 @@ export async function readMailDataFromCSV(filePath: string): Promise<{
|
|||
}
|
||||
|
||||
// Initial training (cutoff by day or amount)
|
||||
o.spec("SpamClassifier", () => {
|
||||
o.spec("SpamClassifierTest", () => {
|
||||
const mockOfflineStorageCache = object<CacheStorage>()
|
||||
const mockOfflineStorage = object<OfflineStoragePersistence>()
|
||||
const mockSpamClassificationInitializer = object<SpamClassificationInitializer>()
|
||||
|
|
@ -68,9 +72,6 @@ o.spec("SpamClassifier", () => {
|
|||
dataSlice = spamData.concat(hamData)
|
||||
seededShuffle(dataSlice, 42)
|
||||
|
||||
mockOfflineStorage.tokenize = async (text) => {
|
||||
return testTokenize(text)
|
||||
}
|
||||
mockSpamClassificationInitializer.init = async () => {
|
||||
return dataSlice
|
||||
}
|
||||
|
|
@ -86,6 +87,48 @@ o.spec("SpamClassifier", () => {
|
|||
)
|
||||
})
|
||||
|
||||
o("processSpam maintains server classification when client classification is not enabled", async function () {
|
||||
const mail = createTestEntity(MailTypeRef, {
|
||||
_id: ["mailListId", "mailId"],
|
||||
sets: [["folderList", "serverFolder"]],
|
||||
})
|
||||
const spamTrainMailDatum: SpamTrainMailDatum = {
|
||||
mailId: mail._id,
|
||||
subject: mail.subject,
|
||||
body: "some body",
|
||||
isSpam: true,
|
||||
isSpamConfidence: 1,
|
||||
ownerGroup: "owner",
|
||||
}
|
||||
const layersModel = object<Sequential>()
|
||||
spamClassifier.addSpamClassifierForOwner(spamTrainMailDatum.ownerGroup, layersModel, false)
|
||||
|
||||
const predictedSpam = await spamClassifier.predict(spamTrainMailDatum)
|
||||
o(predictedSpam).equals(null)
|
||||
})
|
||||
|
||||
o("processSpam uses client classification when enabled", async function () {
|
||||
const mail = createTestEntity(MailTypeRef, {
|
||||
_id: ["mailListId", "mailId"],
|
||||
sets: [["folderList", "serverFolder"]],
|
||||
})
|
||||
const spamTrainMailDatum: SpamTrainMailDatum = {
|
||||
mailId: mail._id,
|
||||
subject: mail.subject,
|
||||
body: "some body",
|
||||
isSpam: false,
|
||||
isSpamConfidence: 0,
|
||||
ownerGroup: "owner",
|
||||
}
|
||||
|
||||
const layersModel = object<Sequential>()
|
||||
when(layersModel.predict(anything())).thenReturn(tensor1d([1]))
|
||||
spamClassifier.addSpamClassifierForOwner(spamTrainMailDatum.ownerGroup, layersModel, true)
|
||||
|
||||
const predictedSpam = await spamClassifier.predict(spamTrainMailDatum)
|
||||
o(predictedSpam).equals(true)
|
||||
})
|
||||
|
||||
o("Initial training only", async () => {
|
||||
o.timeout(20_000)
|
||||
|
||||
|
|
@ -94,13 +137,12 @@ o.spec("SpamClassifier", () => {
|
|||
const testSet = dataSlice.slice(trainTestSplit)
|
||||
|
||||
await spamClassifier.initialTraining(trainSet)
|
||||
await spamClassifier.test(testSet)
|
||||
await testClassifier(spamClassifier, testSet)
|
||||
})
|
||||
|
||||
o("Initial training and refitting in multi step", async () => {
|
||||
o.timeout(20_000)
|
||||
|
||||
const testStart = Date.now()
|
||||
const trainTestSplit = dataSlice.length * 0.8
|
||||
const trainSet = dataSlice.slice(0, trainTestSplit)
|
||||
const testSet = dataSlice.slice(trainTestSplit)
|
||||
|
|
@ -112,15 +154,15 @@ o.spec("SpamClassifier", () => {
|
|||
o(await mockSpamClassificationInitializer.init("owner")).deepEquals(trainSetFirstHalf)
|
||||
await spamClassifier.initialTraining(dataSlice)
|
||||
console.log(`==> Result when testing with mails in two steps (first step).`)
|
||||
await spamClassifier.test(testSet)
|
||||
await testClassifier(spamClassifier, testSet)
|
||||
|
||||
await spamClassifier.updateModel("owner", trainSetSecondHalf)
|
||||
console.log(`==> Result when testing with mails in two steps (second step).`)
|
||||
await spamClassifier.test(testSet)
|
||||
await testClassifier(spamClassifier, testSet)
|
||||
})
|
||||
|
||||
o("preprocessMail outputs expected tokens for mail content", async () => {
|
||||
const classifier = new SpamClassifier(null, object(), object())
|
||||
const classifier = new SpamClassifier(object(), object(), object())
|
||||
const mail = {
|
||||
subject: `Sample Tokens and values`,
|
||||
// prettier-ignore
|
||||
|
|
@ -336,18 +378,14 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
o.spec("SpamClassifier - Performance Analysis", () => {
|
||||
const mockOfflineStorageCache = object<CacheStorage>()
|
||||
const mockOfflineStorage = object<OfflineStoragePersistence>()
|
||||
let classifier = object<SpamClassifier>()
|
||||
let spamClassifier = object<SpamClassifier>()
|
||||
let dataSlice: SpamTrainMailDatum[]
|
||||
o.beforeEach(() => {
|
||||
mockOfflineStorage.tokenize = async (text) => {
|
||||
return testTokenize(text)
|
||||
}
|
||||
|
||||
const mockSpamClassificationInitializer = object<SpamClassificationInitializer>()
|
||||
mockSpamClassificationInitializer.init = async () => {
|
||||
return dataSlice
|
||||
}
|
||||
classifier = new SpamClassifier(mockOfflineStorage, mockOfflineStorageCache, mockSpamClassificationInitializer)
|
||||
spamClassifier = new SpamClassifier(mockOfflineStorage, mockOfflineStorageCache, mockSpamClassificationInitializer)
|
||||
})
|
||||
|
||||
o("time to refit", async () => {
|
||||
|
|
@ -359,7 +397,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
seededShuffle(dataSlice, 42)
|
||||
|
||||
const start = performance.now()
|
||||
await classifier.initialTraining(dataSlice)
|
||||
await spamClassifier.initialTraining(dataSlice)
|
||||
const initialTrainingDuration = performance.now() - start
|
||||
console.log(`initial training time ${initialTrainingDuration}ms`)
|
||||
|
||||
|
|
@ -367,7 +405,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
const nowSpam = [hamSlice[0]]
|
||||
nowSpam.map((formerHam) => (formerHam.isSpam = true))
|
||||
const retrainingStart = performance.now()
|
||||
await classifier.updateModel("owner", nowSpam)
|
||||
await spamClassifier.updateModel("owner", nowSpam)
|
||||
const retrainingDuration = performance.now() - retrainingStart
|
||||
console.log(`retraining time ${retrainingDuration}ms`)
|
||||
}
|
||||
|
|
@ -381,17 +419,17 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
dataSlice = hamSlice.concat(spamSlice)
|
||||
// seededShuffle(dataSlice, 42)
|
||||
|
||||
await classifier.initialTraining(dataSlice)
|
||||
await spamClassifier.initialTraining(dataSlice)
|
||||
const falseNegatives = spamData
|
||||
.slice(10)
|
||||
.filter(async (mailDatum) => mailDatum.isSpam !== (await classifier.predict(mailDatum)))
|
||||
.filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
|
||||
.sort()
|
||||
.slice(0, 10)
|
||||
|
||||
let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0)
|
||||
for (let i = 0; i < falseNegatives.length; i++) {
|
||||
const sample = falseNegatives[i]
|
||||
const copiedClassifier = await classifier.cloneClassifier()
|
||||
const copiedClassifier = await spamClassifier.cloneClassifier()
|
||||
|
||||
let retrainCount = 0
|
||||
let predictedSpam = false
|
||||
|
|
@ -458,15 +496,15 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
dataSlice = hamSlice.concat(spamSlice)
|
||||
// seededShuffle(dataSlice, 42)
|
||||
|
||||
await classifier.initialTraining(dataSlice)
|
||||
await spamClassifier.initialTraining(dataSlice)
|
||||
const falsePositive = hamData
|
||||
.slice(10)
|
||||
.filter(async (mailDatum) => mailDatum.isSpam !== (await classifier.predict(mailDatum)))
|
||||
.filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
|
||||
.slice(0, 10)
|
||||
let retrainingNeeded = new Array<number>(falsePositive.length).fill(0)
|
||||
for (let i = 0; i < falsePositive.length; i++) {
|
||||
const sample = falsePositive[i]
|
||||
const copiedClassifier = await classifier.cloneClassifier()
|
||||
const copiedClassifier = await spamClassifier.cloneClassifier()
|
||||
|
||||
let retrainCount = 0
|
||||
let predictedSpam = false
|
||||
|
|
@ -492,16 +530,16 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
dataSlice = hamSlice.concat(spamSlice)
|
||||
seededShuffle(dataSlice, 42)
|
||||
|
||||
await classifier.initialTraining(dataSlice)
|
||||
await spamClassifier.initialTraining(dataSlice)
|
||||
const falseNegatives = spamData
|
||||
.slice(10)
|
||||
.filter(async (mailDatum) => mailDatum.isSpam !== (await classifier.predict(mailDatum)))
|
||||
.filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
|
||||
.slice(0, 10)
|
||||
|
||||
let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0)
|
||||
for (let i = 0; i < falseNegatives.length; i++) {
|
||||
const sample = falseNegatives[i]
|
||||
const copiedClassifier = await classifier.cloneClassifier()
|
||||
const copiedClassifier = await spamClassifier.cloneClassifier()
|
||||
|
||||
let retrainCount = 0
|
||||
let predictedSpam = false
|
||||
|
|
@ -532,7 +570,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
await promiseMap(
|
||||
new Array<number>(ITERATION_COUNT).fill(0),
|
||||
async () => {
|
||||
const { vectorizationTime, trainingTime } = await classifier.initialTraining(dataSlice)
|
||||
const { vectorizationTime, trainingTime } = await spamClassifier.initialTraining(dataSlice)
|
||||
trainingTimes.push(trainingTime)
|
||||
vectorizationTimes.push(vectorizationTime)
|
||||
trainingWithoutVectorization.push(trainingTime - vectorizationTime)
|
||||
|
|
@ -560,6 +598,47 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
})
|
||||
})
|
||||
}
|
||||
|
||||
async function testClassifier(classifier: SpamClassifier, mails: SpamTrainMailDatum[]): Promise<void> {
|
||||
let predictionArray: number[] = []
|
||||
for (let mail of mails) {
|
||||
const prediction = await classifier.predict(mail)
|
||||
predictionArray.push(prediction ? 1 : 0)
|
||||
}
|
||||
const ysArray = mails.map((mail) => mail.isSpam)
|
||||
|
||||
let tp = 0,
|
||||
tn = 0,
|
||||
fp = 0,
|
||||
fn = 0
|
||||
|
||||
for (let i = 0; i < predictionArray.length; i++) {
|
||||
const predictedSpam = predictionArray[i] > 0.5
|
||||
const isActuallyASpam = ysArray[i]
|
||||
if (predictedSpam && isActuallyASpam) tp++
|
||||
else if (!predictedSpam && !isActuallyASpam) tn++
|
||||
else if (predictedSpam && !isActuallyASpam) fp++
|
||||
else if (!predictedSpam && isActuallyASpam) fn++
|
||||
}
|
||||
|
||||
const total = tp + tn + fp + fn
|
||||
const accuracy = (tp + tn) / total
|
||||
const precision = tp / (tp + fp + 1e-7)
|
||||
const recall = tp / (tp + fn + 1e-7)
|
||||
const f1 = 2 * ((precision * recall) / (precision + recall + 1e-7))
|
||||
|
||||
console.log("\n--- Evaluation Metrics ---")
|
||||
console.log(`Accuracy: \t${(accuracy * 100).toFixed(2)}%`)
|
||||
console.log(`Precision:\t${(precision * 100).toFixed(2)}%`)
|
||||
console.log(`Recall: \t${(recall * 100).toFixed(2)}%`)
|
||||
console.log(`F1 Score: \t${(f1 * 100).toFixed(2)}%`)
|
||||
console.log("\nConfusion Matrix:")
|
||||
console.log({
|
||||
Predicted_Spam: { True_Positive: tp, False_Positive: fp },
|
||||
Predicted_Ham: { False_Negative: fn, True_Negative: tn },
|
||||
})
|
||||
}
|
||||
|
||||
// For testing, we need deterministic shuffling which is not provided by tf.util.shuffle(dataSlice)
|
||||
// Seeded Fisher-Yates shuffle
|
||||
function seededShuffle<T>(array: T[], seed: number): void {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue