fix unnecessary retraining when lastTrainingDataIndexId did not change

We were not updating the lastTrainingDataIndexId on partial re-training
of the spam classification model correctly. This lead to unnecessary
re-training of the spam classification model on every login, even though
there was actually no new training data available.

Co-authored-by: das <das@tutao.de>
This commit is contained in:
jhm 2025-12-04 17:09:29 +01:00
parent b1207f1db0
commit e703609e87
No known key found for this signature in database
GPG key ID: 8932FDB35DF1C9E7
2 changed files with 43 additions and 22 deletions

View file

@ -170,9 +170,8 @@ export class SpamClassifier {
lastTrainedFromScratchTime: Date.now(), lastTrainedFromScratchTime: Date.now(),
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId, lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
} }
const classifier = { const classifier: Classifier = {
layersModel: layersModel, layersModel: layersModel,
isEnabled: true,
metaData, metaData,
threshold, threshold,
} }
@ -180,7 +179,7 @@ export class SpamClassifier {
await this.activateAndSaveClassifier(ownerGroup, classifier) await this.activateAndSaveClassifier(ownerGroup, classifier)
console.log( console.log(
`### Finished Initial Spam Classification Model Training ### (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`, `🐞 finished initial spam classification model training for mailbox ${ownerGroup} (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`,
) )
} }
@ -188,7 +187,7 @@ export class SpamClassifier {
this.classifierByMailGroup.set(ownerGroup, classifier) this.classifierByMailGroup.set(ownerGroup, classifier)
const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup, classifier) const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup, classifier)
if (spamClassificationModel == null) { if (spamClassificationModel == null) {
throw new Error("spam classification model is not available, and therefore can not be saved") throw new Error(`spam classification model for mailbox ${ownerGroup} is not available, and therefore can not be saved`)
} }
await this.spamClassifierStorageFacade.setSpamClassificationModel(spamClassificationModel) await this.spamClassifierStorageFacade.setSpamClassificationModel(spamClassificationModel)
} }
@ -203,12 +202,12 @@ export class SpamClassifier {
const trainingDataset = await this.spamClassifierDataDealer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup) const trainingDataset = await this.spamClassifierDataDealer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup)
if (isEmpty(trainingDataset.trainingData)) { if (isEmpty(trainingDataset.trainingData)) {
console.log("no new spam classification training data since last update") console.log(`no new spam classification training data since last update for mailbox ${ownerGroup}`)
return return
} }
console.log( console.log(
`retraining spam classification model with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`, `retraining spam classification model for mailbox ${ownerGroup} with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`,
) )
await this.updateModel(ownerGroup, trainingDataset) await this.updateModel(ownerGroup, trainingDataset)
} catch (e) { } catch (e) {
@ -219,7 +218,7 @@ export class SpamClassifier {
// visibleForTesting // visibleForTesting
async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<void> { async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<void> {
if (isEmpty(trainingDataset.trainingData)) { if (isEmpty(trainingDataset.trainingData)) {
console.log("no new spam classification training data since last update") console.log(`no new spam classification training data for mailbox ${ownerGroup} since last update`)
return return
} }
@ -283,18 +282,23 @@ export class SpamClassifier {
ys.dispose() ys.dispose()
} }
} finally { } finally {
classifierToUpdate.metaData.hamCount += trainingDataset.hamCount
classifierToUpdate.metaData.spamCount += trainingDataset.spamCount
classifierToUpdate.threshold = this.calculateThreshold(classifierToUpdate.metaData.hamCount, classifierToUpdate.metaData.spamCount) classifierToUpdate.threshold = this.calculateThreshold(classifierToUpdate.metaData.hamCount, classifierToUpdate.metaData.spamCount)
classifierToUpdate.metaData = {
hamCount: classifierToUpdate.metaData.hamCount + trainingDataset.hamCount,
spamCount: classifierToUpdate.metaData.spamCount + trainingDataset.spamCount,
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
// lastTrainedFromScratchTime update only happens on full training
lastTrainedFromScratchTime: classifierToUpdate.metaData.lastTrainedFromScratchTime,
}
classifierToUpdate.layersModel = layersModelToUpdate classifierToUpdate.layersModel = layersModelToUpdate
} }
//This does not set the classifier, so it probably would have failed.
const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}`
await this.activateAndSaveClassifier(ownerGroup, classifierToUpdate) await this.activateAndSaveClassifier(ownerGroup, classifierToUpdate)
console.log(`retraining spam classification model finished, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`) const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}`
console.log(
`retraining spam classification model finished for mailbox ${ownerGroup}, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`,
)
} }
// visibleForTesting // visibleForTesting
@ -311,7 +315,7 @@ export class SpamClassifier {
const predictionData = await predictionTensor.data() const predictionData = await predictionTensor.data()
const prediction = predictionData[0] const prediction = predictionData[0]
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${ownerGroup}`) console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam for mailbox: ${ownerGroup}`)
// when using the webgl backend, we need to manually dispose @tensorflow tensors // when using the webgl backend, we need to manually dispose @tensorflow tensors
xs.dispose() xs.dispose()
@ -383,12 +387,12 @@ export class SpamClassifier {
loss: "binaryCrossentropy", loss: "binaryCrossentropy",
metrics: ["accuracy"], metrics: ["accuracy"],
}) })
const metadata = spamClassificationModel.metaData const metaData = spamClassificationModel.metaData
const threshold = this.calculateThreshold(metadata.hamCount, metadata.spamCount) const threshold = this.calculateThreshold(metaData.hamCount, metaData.spamCount)
return { return {
layersModel: layersModel, layersModel: layersModel,
threshold, threshold,
metaData: spamClassificationModel.metaData, metaData,
} }
} else { } else {
console.log(`loading the spam classification spamClassificationModel from storage failed for mailbox ${ownerGroup} ... `) console.log(`loading the spam classification spamClassificationModel from storage failed for mailbox ${ownerGroup} ... `)
@ -427,7 +431,7 @@ export class SpamClassifier {
private async trainFromScratch(ownerGroup: string) { private async trainFromScratch(ownerGroup: string) {
const trainingDataset = await this.spamClassifierDataDealer.fetchAllTrainingData(ownerGroup) const trainingDataset = await this.spamClassifierDataDealer.fetchAllTrainingData(ownerGroup)
if (isEmpty(trainingDataset.trainingData)) { if (isEmpty(trainingDataset.trainingData)) {
console.log("no training trainingData found. training from scratch aborted.") console.log(`no training trainingData found for mailbox ${ownerGroup} training from scratch aborted.`)
return return
} }
await this.initialTraining(ownerGroup, trainingDataset) await this.initialTraining(ownerGroup, trainingDataset)

View file

@ -182,6 +182,7 @@ o.spec("SpamClassifierTest", () => {
await testClassifier(spamClassifier, testSet, compressor) await testClassifier(spamClassifier, testSet, compressor)
const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf) const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf)
trainingDatasetSecondHalf.lastTrainingDataIndexId = "some new index id"
await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf) await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf)
console.log(`==> Result when testing with mails in two steps (second step).`) console.log(`==> Result when testing with mails in two steps (second step).`)
await testClassifier(spamClassifier, testSet, compressor) await testClassifier(spamClassifier, testSet, compressor)
@ -191,6 +192,7 @@ o.spec("SpamClassifierTest", () => {
const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount
o(classifier?.metaData.hamCount).equals(finalHamCount) o(classifier?.metaData.hamCount).equals(finalHamCount)
o(classifier?.metaData.spamCount).equals(finalSpamCount) o(classifier?.metaData.spamCount).equals(finalSpamCount)
o(classifier?.metaData.lastTrainingDataIndexId).equals(trainingDatasetSecondHalf.lastTrainingDataIndexId)
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount)) o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
}) })
@ -474,7 +476,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
spamClassifier.spamMailProcessor = spamProcessor spamClassifier.spamMailProcessor = spamProcessor
}) })
o("time to refit", async () => { o("time to refit and multiple refits work correctly", async () => {
o.timeout(20_000_000) o.timeout(20_000_000)
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH) const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false) const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false)
@ -483,17 +485,32 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
const start = performance.now() const start = performance.now()
await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice)) const initialTrainingDataset = getTrainingDataset(dataSlice)
await spamClassifier.initialTraining(TEST_OWNER_GROUP, initialTrainingDataset)
const initialClassifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)!
const initialHamCount = initialClassifier.metaData.hamCount
const initialSpamCount = initialClassifier.metaData.spamCount
const initialTrainingDuration = performance.now() - start const initialTrainingDuration = performance.now() - start
console.log(`initial training time ${initialTrainingDuration}ms`) console.log(`initial training time ${initialTrainingDuration}ms`)
for (let i = 0; i < 20; i++) { for (let i = 0; i < 20; i++) {
const nowSpam = [hamSlice[0]] const nowSpam = [hamSlice[0]]
nowSpam.map((formerHam) => (formerHam.spamDecision = "1")) nowSpam.map((formerHam) => (formerHam.spamDecision = SpamDecision.BLACKLIST))
const retrainingStart = performance.now() const retrainingStart = performance.now()
await spamClassifier.updateModel(TEST_OWNER_GROUP, getTrainingDataset(nowSpam)) const newPartialRetrainingDataset = getTrainingDataset(nowSpam)
newPartialRetrainingDataset.lastTrainingDataIndexId = "lastTrainingDataIndexId" + i
await spamClassifier.updateModel(TEST_OWNER_GROUP, newPartialRetrainingDataset)
const retrainingDuration = performance.now() - retrainingStart const retrainingDuration = performance.now() - retrainingStart
console.log(`retraining time ${retrainingDuration}ms`) console.log(`retraining time ${retrainingDuration}ms`)
// verify classifier correctness
const classifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)!
const finalHamCount = initialHamCount
const finalSpamCount = initialSpamCount + i + 1
o(classifier?.metaData.hamCount).equals(finalHamCount)
o(classifier?.metaData.spamCount).equals(finalSpamCount)
o(classifier?.metaData.lastTrainingDataIndexId).equals(newPartialRetrainingDataset.lastTrainingDataIndexId)
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
} }
}) })