mirror of
https://github.com/tutao/tutanota.git
synced 2025-12-07 05:39:56 +00:00
fix unnecessary retraining when lastTrainingDataIndexId did not change
We were not updating the lastTrainingDataIndexId on partial re-training of the spam classification model correctly. This lead to unnecessary re-training of the spam classification model on every login, even though there was actually no new training data available. Co-authored-by: das <das@tutao.de>
This commit is contained in:
parent
b1207f1db0
commit
e703609e87
2 changed files with 43 additions and 22 deletions
|
|
@ -170,9 +170,8 @@ export class SpamClassifier {
|
|||
lastTrainedFromScratchTime: Date.now(),
|
||||
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
|
||||
}
|
||||
const classifier = {
|
||||
const classifier: Classifier = {
|
||||
layersModel: layersModel,
|
||||
isEnabled: true,
|
||||
metaData,
|
||||
threshold,
|
||||
}
|
||||
|
|
@ -180,7 +179,7 @@ export class SpamClassifier {
|
|||
await this.activateAndSaveClassifier(ownerGroup, classifier)
|
||||
|
||||
console.log(
|
||||
`### Finished Initial Spam Classification Model Training ### (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`,
|
||||
`🐞 finished initial spam classification model training for mailbox ${ownerGroup} (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -188,7 +187,7 @@ export class SpamClassifier {
|
|||
this.classifierByMailGroup.set(ownerGroup, classifier)
|
||||
const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup, classifier)
|
||||
if (spamClassificationModel == null) {
|
||||
throw new Error("spam classification model is not available, and therefore can not be saved")
|
||||
throw new Error(`spam classification model for mailbox ${ownerGroup} is not available, and therefore can not be saved`)
|
||||
}
|
||||
await this.spamClassifierStorageFacade.setSpamClassificationModel(spamClassificationModel)
|
||||
}
|
||||
|
|
@ -203,12 +202,12 @@ export class SpamClassifier {
|
|||
|
||||
const trainingDataset = await this.spamClassifierDataDealer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup)
|
||||
if (isEmpty(trainingDataset.trainingData)) {
|
||||
console.log("no new spam classification training data since last update")
|
||||
console.log(`no new spam classification training data since last update for mailbox ${ownerGroup}`)
|
||||
return
|
||||
}
|
||||
|
||||
console.log(
|
||||
`retraining spam classification model with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`,
|
||||
`retraining spam classification model for mailbox ${ownerGroup} with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`,
|
||||
)
|
||||
await this.updateModel(ownerGroup, trainingDataset)
|
||||
} catch (e) {
|
||||
|
|
@ -219,7 +218,7 @@ export class SpamClassifier {
|
|||
// visibleForTesting
|
||||
async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<void> {
|
||||
if (isEmpty(trainingDataset.trainingData)) {
|
||||
console.log("no new spam classification training data since last update")
|
||||
console.log(`no new spam classification training data for mailbox ${ownerGroup} since last update`)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -283,18 +282,23 @@ export class SpamClassifier {
|
|||
ys.dispose()
|
||||
}
|
||||
} finally {
|
||||
classifierToUpdate.metaData.hamCount += trainingDataset.hamCount
|
||||
classifierToUpdate.metaData.spamCount += trainingDataset.spamCount
|
||||
classifierToUpdate.threshold = this.calculateThreshold(classifierToUpdate.metaData.hamCount, classifierToUpdate.metaData.spamCount)
|
||||
classifierToUpdate.metaData = {
|
||||
hamCount: classifierToUpdate.metaData.hamCount + trainingDataset.hamCount,
|
||||
spamCount: classifierToUpdate.metaData.spamCount + trainingDataset.spamCount,
|
||||
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
|
||||
// lastTrainedFromScratchTime update only happens on full training
|
||||
lastTrainedFromScratchTime: classifierToUpdate.metaData.lastTrainedFromScratchTime,
|
||||
}
|
||||
classifierToUpdate.layersModel = layersModelToUpdate
|
||||
}
|
||||
|
||||
//This does not set the classifier, so it probably would have failed.
|
||||
const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}`
|
||||
|
||||
await this.activateAndSaveClassifier(ownerGroup, classifierToUpdate)
|
||||
|
||||
console.log(`retraining spam classification model finished, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`)
|
||||
const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}`
|
||||
console.log(
|
||||
`retraining spam classification model finished for mailbox ${ownerGroup}, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`,
|
||||
)
|
||||
}
|
||||
|
||||
// visibleForTesting
|
||||
|
|
@ -311,7 +315,7 @@ export class SpamClassifier {
|
|||
const predictionData = await predictionTensor.data()
|
||||
const prediction = predictionData[0]
|
||||
|
||||
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${ownerGroup}`)
|
||||
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam for mailbox: ${ownerGroup}`)
|
||||
|
||||
// when using the webgl backend, we need to manually dispose @tensorflow tensors
|
||||
xs.dispose()
|
||||
|
|
@ -383,12 +387,12 @@ export class SpamClassifier {
|
|||
loss: "binaryCrossentropy",
|
||||
metrics: ["accuracy"],
|
||||
})
|
||||
const metadata = spamClassificationModel.metaData
|
||||
const threshold = this.calculateThreshold(metadata.hamCount, metadata.spamCount)
|
||||
const metaData = spamClassificationModel.metaData
|
||||
const threshold = this.calculateThreshold(metaData.hamCount, metaData.spamCount)
|
||||
return {
|
||||
layersModel: layersModel,
|
||||
threshold,
|
||||
metaData: spamClassificationModel.metaData,
|
||||
metaData,
|
||||
}
|
||||
} else {
|
||||
console.log(`loading the spam classification spamClassificationModel from storage failed for mailbox ${ownerGroup} ... `)
|
||||
|
|
@ -427,7 +431,7 @@ export class SpamClassifier {
|
|||
private async trainFromScratch(ownerGroup: string) {
|
||||
const trainingDataset = await this.spamClassifierDataDealer.fetchAllTrainingData(ownerGroup)
|
||||
if (isEmpty(trainingDataset.trainingData)) {
|
||||
console.log("no training trainingData found. training from scratch aborted.")
|
||||
console.log(`no training trainingData found for mailbox ${ownerGroup} training from scratch aborted.`)
|
||||
return
|
||||
}
|
||||
await this.initialTraining(ownerGroup, trainingDataset)
|
||||
|
|
|
|||
|
|
@ -182,6 +182,7 @@ o.spec("SpamClassifierTest", () => {
|
|||
await testClassifier(spamClassifier, testSet, compressor)
|
||||
|
||||
const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf)
|
||||
trainingDatasetSecondHalf.lastTrainingDataIndexId = "some new index id"
|
||||
await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf)
|
||||
console.log(`==> Result when testing with mails in two steps (second step).`)
|
||||
await testClassifier(spamClassifier, testSet, compressor)
|
||||
|
|
@ -191,6 +192,7 @@ o.spec("SpamClassifierTest", () => {
|
|||
const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount
|
||||
o(classifier?.metaData.hamCount).equals(finalHamCount)
|
||||
o(classifier?.metaData.spamCount).equals(finalSpamCount)
|
||||
o(classifier?.metaData.lastTrainingDataIndexId).equals(trainingDatasetSecondHalf.lastTrainingDataIndexId)
|
||||
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
|
||||
})
|
||||
|
||||
|
|
@ -474,7 +476,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
spamClassifier.spamMailProcessor = spamProcessor
|
||||
})
|
||||
|
||||
o("time to refit", async () => {
|
||||
o("time to refit and multiple refits work correctly", async () => {
|
||||
o.timeout(20_000_000)
|
||||
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
|
||||
const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false)
|
||||
|
|
@ -483,17 +485,32 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
|||
seededShuffle(dataSlice, 42)
|
||||
|
||||
const start = performance.now()
|
||||
await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice))
|
||||
const initialTrainingDataset = getTrainingDataset(dataSlice)
|
||||
await spamClassifier.initialTraining(TEST_OWNER_GROUP, initialTrainingDataset)
|
||||
const initialClassifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)!
|
||||
const initialHamCount = initialClassifier.metaData.hamCount
|
||||
const initialSpamCount = initialClassifier.metaData.spamCount
|
||||
const initialTrainingDuration = performance.now() - start
|
||||
console.log(`initial training time ${initialTrainingDuration}ms`)
|
||||
|
||||
for (let i = 0; i < 20; i++) {
|
||||
const nowSpam = [hamSlice[0]]
|
||||
nowSpam.map((formerHam) => (formerHam.spamDecision = "1"))
|
||||
nowSpam.map((formerHam) => (formerHam.spamDecision = SpamDecision.BLACKLIST))
|
||||
const retrainingStart = performance.now()
|
||||
await spamClassifier.updateModel(TEST_OWNER_GROUP, getTrainingDataset(nowSpam))
|
||||
const newPartialRetrainingDataset = getTrainingDataset(nowSpam)
|
||||
newPartialRetrainingDataset.lastTrainingDataIndexId = "lastTrainingDataIndexId" + i
|
||||
await spamClassifier.updateModel(TEST_OWNER_GROUP, newPartialRetrainingDataset)
|
||||
const retrainingDuration = performance.now() - retrainingStart
|
||||
console.log(`retraining time ${retrainingDuration}ms`)
|
||||
|
||||
// verify classifier correctness
|
||||
const classifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)!
|
||||
const finalHamCount = initialHamCount
|
||||
const finalSpamCount = initialSpamCount + i + 1
|
||||
o(classifier?.metaData.hamCount).equals(finalHamCount)
|
||||
o(classifier?.metaData.spamCount).equals(finalSpamCount)
|
||||
o(classifier?.metaData.lastTrainingDataIndexId).equals(newPartialRetrainingDataset.lastTrainingDataIndexId)
|
||||
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue