diff --git a/src/mail-app/workerUtils/spamClassification/SpamClassifier.ts b/src/mail-app/workerUtils/spamClassification/SpamClassifier.ts index 50e7700552..22e693ff5b 100644 --- a/src/mail-app/workerUtils/spamClassification/SpamClassifier.ts +++ b/src/mail-app/workerUtils/spamClassification/SpamClassifier.ts @@ -170,9 +170,8 @@ export class SpamClassifier { lastTrainedFromScratchTime: Date.now(), lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId, } - const classifier = { + const classifier: Classifier = { layersModel: layersModel, - isEnabled: true, metaData, threshold, } @@ -180,7 +179,7 @@ export class SpamClassifier { await this.activateAndSaveClassifier(ownerGroup, classifier) console.log( - `### Finished Initial Spam Classification Model Training ### (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`, + `🐞 finished initial spam classification model training for mailbox ${ownerGroup} (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`, ) } @@ -188,7 +187,7 @@ export class SpamClassifier { this.classifierByMailGroup.set(ownerGroup, classifier) const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup, classifier) if (spamClassificationModel == null) { - throw new Error("spam classification model is not available, and therefore can not be saved") + throw new Error(`spam classification model for mailbox ${ownerGroup} is not available, and therefore can not be saved`) } await this.spamClassifierStorageFacade.setSpamClassificationModel(spamClassificationModel) } @@ -203,12 +202,12 @@ export class SpamClassifier { const trainingDataset = await this.spamClassifierDataDealer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup) if (isEmpty(trainingDataset.trainingData)) { - console.log("no new spam classification training data since last update") + console.log(`no new spam classification training data since last update for mailbox ${ownerGroup}`) return } console.log( - `retraining spam classification model with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`, + `retraining spam classification model for mailbox ${ownerGroup} with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`, ) await this.updateModel(ownerGroup, trainingDataset) } catch (e) { @@ -219,7 +218,7 @@ export class SpamClassifier { // visibleForTesting async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise { if (isEmpty(trainingDataset.trainingData)) { - console.log("no new spam classification training data since last update") + console.log(`no new spam classification training data for mailbox ${ownerGroup} since last update`) return } @@ -283,18 +282,23 @@ export class SpamClassifier { ys.dispose() } } finally { - classifierToUpdate.metaData.hamCount += trainingDataset.hamCount - classifierToUpdate.metaData.spamCount += trainingDataset.spamCount classifierToUpdate.threshold = this.calculateThreshold(classifierToUpdate.metaData.hamCount, classifierToUpdate.metaData.spamCount) + classifierToUpdate.metaData = { + hamCount: classifierToUpdate.metaData.hamCount + trainingDataset.hamCount, + spamCount: classifierToUpdate.metaData.spamCount + trainingDataset.spamCount, + lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId, + // lastTrainedFromScratchTime update only happens on full training + lastTrainedFromScratchTime: classifierToUpdate.metaData.lastTrainedFromScratchTime, + } classifierToUpdate.layersModel = layersModelToUpdate } - //This does not set the classifier, so it probably would have failed. - const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}` - await this.activateAndSaveClassifier(ownerGroup, classifierToUpdate) - console.log(`retraining spam classification model finished, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`) + const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}` + console.log( + `retraining spam classification model finished for mailbox ${ownerGroup}, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`, + ) } // visibleForTesting @@ -311,7 +315,7 @@ export class SpamClassifier { const predictionData = await predictionTensor.data() const prediction = predictionData[0] - console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${ownerGroup}`) + console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam for mailbox: ${ownerGroup}`) // when using the webgl backend, we need to manually dispose @tensorflow tensors xs.dispose() @@ -383,12 +387,12 @@ export class SpamClassifier { loss: "binaryCrossentropy", metrics: ["accuracy"], }) - const metadata = spamClassificationModel.metaData - const threshold = this.calculateThreshold(metadata.hamCount, metadata.spamCount) + const metaData = spamClassificationModel.metaData + const threshold = this.calculateThreshold(metaData.hamCount, metaData.spamCount) return { layersModel: layersModel, threshold, - metaData: spamClassificationModel.metaData, + metaData, } } else { console.log(`loading the spam classification spamClassificationModel from storage failed for mailbox ${ownerGroup} ... `) @@ -427,7 +431,7 @@ export class SpamClassifier { private async trainFromScratch(ownerGroup: string) { const trainingDataset = await this.spamClassifierDataDealer.fetchAllTrainingData(ownerGroup) if (isEmpty(trainingDataset.trainingData)) { - console.log("no training trainingData found. training from scratch aborted.") + console.log(`no training trainingData found for mailbox ${ownerGroup} training from scratch aborted.`) return } await this.initialTraining(ownerGroup, trainingDataset) diff --git a/test/tests/api/worker/utils/spamClassification/SpamClassifierTest.ts b/test/tests/api/worker/utils/spamClassification/SpamClassifierTest.ts index f5f3baf832..5606802bc3 100644 --- a/test/tests/api/worker/utils/spamClassification/SpamClassifierTest.ts +++ b/test/tests/api/worker/utils/spamClassification/SpamClassifierTest.ts @@ -182,6 +182,7 @@ o.spec("SpamClassifierTest", () => { await testClassifier(spamClassifier, testSet, compressor) const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf) + trainingDatasetSecondHalf.lastTrainingDataIndexId = "some new index id" await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf) console.log(`==> Result when testing with mails in two steps (second step).`) await testClassifier(spamClassifier, testSet, compressor) @@ -191,6 +192,7 @@ o.spec("SpamClassifierTest", () => { const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount o(classifier?.metaData.hamCount).equals(finalHamCount) o(classifier?.metaData.spamCount).equals(finalSpamCount) + o(classifier?.metaData.lastTrainingDataIndexId).equals(trainingDatasetSecondHalf.lastTrainingDataIndexId) o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount)) }) @@ -474,7 +476,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) { spamClassifier.spamMailProcessor = spamProcessor }) - o("time to refit", async () => { + o("time to refit and multiple refits work correctly", async () => { o.timeout(20_000_000) const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH) const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false) @@ -483,17 +485,32 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) { seededShuffle(dataSlice, 42) const start = performance.now() - await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice)) + const initialTrainingDataset = getTrainingDataset(dataSlice) + await spamClassifier.initialTraining(TEST_OWNER_GROUP, initialTrainingDataset) + const initialClassifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)! + const initialHamCount = initialClassifier.metaData.hamCount + const initialSpamCount = initialClassifier.metaData.spamCount const initialTrainingDuration = performance.now() - start console.log(`initial training time ${initialTrainingDuration}ms`) for (let i = 0; i < 20; i++) { const nowSpam = [hamSlice[0]] - nowSpam.map((formerHam) => (formerHam.spamDecision = "1")) + nowSpam.map((formerHam) => (formerHam.spamDecision = SpamDecision.BLACKLIST)) const retrainingStart = performance.now() - await spamClassifier.updateModel(TEST_OWNER_GROUP, getTrainingDataset(nowSpam)) + const newPartialRetrainingDataset = getTrainingDataset(nowSpam) + newPartialRetrainingDataset.lastTrainingDataIndexId = "lastTrainingDataIndexId" + i + await spamClassifier.updateModel(TEST_OWNER_GROUP, newPartialRetrainingDataset) const retrainingDuration = performance.now() - retrainingStart console.log(`retraining time ${retrainingDuration}ms`) + + // verify classifier correctness + const classifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)! + const finalHamCount = initialHamCount + const finalSpamCount = initialSpamCount + i + 1 + o(classifier?.metaData.hamCount).equals(finalHamCount) + o(classifier?.metaData.spamCount).equals(finalSpamCount) + o(classifier?.metaData.lastTrainingDataIndexId).equals(newPartialRetrainingDataset.lastTrainingDataIndexId) + o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount)) } })