mirror of
https://github.com/tutao/tutanota.git
synced 2025-12-07 13:49:47 +00:00
fix unnecessary retraining when lastTrainingDataIndexId did not change
We were not updating the lastTrainingDataIndexId on partial re-training of the spam classification model correctly. This lead to unnecessary re-training of the spam classification model on every login, even though there was actually no new training data available. Co-authored-by: das <das@tutao.de>
This commit is contained in:
parent
b1207f1db0
commit
e703609e87
2 changed files with 43 additions and 22 deletions
|
|
@ -170,9 +170,8 @@ export class SpamClassifier {
|
||||||
lastTrainedFromScratchTime: Date.now(),
|
lastTrainedFromScratchTime: Date.now(),
|
||||||
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
|
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
|
||||||
}
|
}
|
||||||
const classifier = {
|
const classifier: Classifier = {
|
||||||
layersModel: layersModel,
|
layersModel: layersModel,
|
||||||
isEnabled: true,
|
|
||||||
metaData,
|
metaData,
|
||||||
threshold,
|
threshold,
|
||||||
}
|
}
|
||||||
|
|
@ -180,7 +179,7 @@ export class SpamClassifier {
|
||||||
await this.activateAndSaveClassifier(ownerGroup, classifier)
|
await this.activateAndSaveClassifier(ownerGroup, classifier)
|
||||||
|
|
||||||
console.log(
|
console.log(
|
||||||
`### Finished Initial Spam Classification Model Training ### (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`,
|
`🐞 finished initial spam classification model training for mailbox ${ownerGroup} (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -188,7 +187,7 @@ export class SpamClassifier {
|
||||||
this.classifierByMailGroup.set(ownerGroup, classifier)
|
this.classifierByMailGroup.set(ownerGroup, classifier)
|
||||||
const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup, classifier)
|
const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup, classifier)
|
||||||
if (spamClassificationModel == null) {
|
if (spamClassificationModel == null) {
|
||||||
throw new Error("spam classification model is not available, and therefore can not be saved")
|
throw new Error(`spam classification model for mailbox ${ownerGroup} is not available, and therefore can not be saved`)
|
||||||
}
|
}
|
||||||
await this.spamClassifierStorageFacade.setSpamClassificationModel(spamClassificationModel)
|
await this.spamClassifierStorageFacade.setSpamClassificationModel(spamClassificationModel)
|
||||||
}
|
}
|
||||||
|
|
@ -203,12 +202,12 @@ export class SpamClassifier {
|
||||||
|
|
||||||
const trainingDataset = await this.spamClassifierDataDealer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup)
|
const trainingDataset = await this.spamClassifierDataDealer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup)
|
||||||
if (isEmpty(trainingDataset.trainingData)) {
|
if (isEmpty(trainingDataset.trainingData)) {
|
||||||
console.log("no new spam classification training data since last update")
|
console.log(`no new spam classification training data since last update for mailbox ${ownerGroup}`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log(
|
console.log(
|
||||||
`retraining spam classification model with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`,
|
`retraining spam classification model for mailbox ${ownerGroup} with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`,
|
||||||
)
|
)
|
||||||
await this.updateModel(ownerGroup, trainingDataset)
|
await this.updateModel(ownerGroup, trainingDataset)
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
|
@ -219,7 +218,7 @@ export class SpamClassifier {
|
||||||
// visibleForTesting
|
// visibleForTesting
|
||||||
async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<void> {
|
async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<void> {
|
||||||
if (isEmpty(trainingDataset.trainingData)) {
|
if (isEmpty(trainingDataset.trainingData)) {
|
||||||
console.log("no new spam classification training data since last update")
|
console.log(`no new spam classification training data for mailbox ${ownerGroup} since last update`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -283,18 +282,23 @@ export class SpamClassifier {
|
||||||
ys.dispose()
|
ys.dispose()
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
classifierToUpdate.metaData.hamCount += trainingDataset.hamCount
|
|
||||||
classifierToUpdate.metaData.spamCount += trainingDataset.spamCount
|
|
||||||
classifierToUpdate.threshold = this.calculateThreshold(classifierToUpdate.metaData.hamCount, classifierToUpdate.metaData.spamCount)
|
classifierToUpdate.threshold = this.calculateThreshold(classifierToUpdate.metaData.hamCount, classifierToUpdate.metaData.spamCount)
|
||||||
|
classifierToUpdate.metaData = {
|
||||||
|
hamCount: classifierToUpdate.metaData.hamCount + trainingDataset.hamCount,
|
||||||
|
spamCount: classifierToUpdate.metaData.spamCount + trainingDataset.spamCount,
|
||||||
|
lastTrainingDataIndexId: trainingDataset.lastTrainingDataIndexId,
|
||||||
|
// lastTrainedFromScratchTime update only happens on full training
|
||||||
|
lastTrainedFromScratchTime: classifierToUpdate.metaData.lastTrainedFromScratchTime,
|
||||||
|
}
|
||||||
classifierToUpdate.layersModel = layersModelToUpdate
|
classifierToUpdate.layersModel = layersModelToUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
//This does not set the classifier, so it probably would have failed.
|
|
||||||
const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}`
|
|
||||||
|
|
||||||
await this.activateAndSaveClassifier(ownerGroup, classifierToUpdate)
|
await this.activateAndSaveClassifier(ownerGroup, classifierToUpdate)
|
||||||
|
|
||||||
console.log(`retraining spam classification model finished, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`)
|
const trainingMetadata = `Total Ham: ${classifierToUpdate.metaData.hamCount} Spam: ${classifierToUpdate.metaData.spamCount} threshold: ${classifierToUpdate.threshold}`
|
||||||
|
console.log(
|
||||||
|
`retraining spam classification model finished for mailbox ${ownerGroup}, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// visibleForTesting
|
// visibleForTesting
|
||||||
|
|
@ -311,7 +315,7 @@ export class SpamClassifier {
|
||||||
const predictionData = await predictionTensor.data()
|
const predictionData = await predictionTensor.data()
|
||||||
const prediction = predictionData[0]
|
const prediction = predictionData[0]
|
||||||
|
|
||||||
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${ownerGroup}`)
|
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam for mailbox: ${ownerGroup}`)
|
||||||
|
|
||||||
// when using the webgl backend, we need to manually dispose @tensorflow tensors
|
// when using the webgl backend, we need to manually dispose @tensorflow tensors
|
||||||
xs.dispose()
|
xs.dispose()
|
||||||
|
|
@ -383,12 +387,12 @@ export class SpamClassifier {
|
||||||
loss: "binaryCrossentropy",
|
loss: "binaryCrossentropy",
|
||||||
metrics: ["accuracy"],
|
metrics: ["accuracy"],
|
||||||
})
|
})
|
||||||
const metadata = spamClassificationModel.metaData
|
const metaData = spamClassificationModel.metaData
|
||||||
const threshold = this.calculateThreshold(metadata.hamCount, metadata.spamCount)
|
const threshold = this.calculateThreshold(metaData.hamCount, metaData.spamCount)
|
||||||
return {
|
return {
|
||||||
layersModel: layersModel,
|
layersModel: layersModel,
|
||||||
threshold,
|
threshold,
|
||||||
metaData: spamClassificationModel.metaData,
|
metaData,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
console.log(`loading the spam classification spamClassificationModel from storage failed for mailbox ${ownerGroup} ... `)
|
console.log(`loading the spam classification spamClassificationModel from storage failed for mailbox ${ownerGroup} ... `)
|
||||||
|
|
@ -427,7 +431,7 @@ export class SpamClassifier {
|
||||||
private async trainFromScratch(ownerGroup: string) {
|
private async trainFromScratch(ownerGroup: string) {
|
||||||
const trainingDataset = await this.spamClassifierDataDealer.fetchAllTrainingData(ownerGroup)
|
const trainingDataset = await this.spamClassifierDataDealer.fetchAllTrainingData(ownerGroup)
|
||||||
if (isEmpty(trainingDataset.trainingData)) {
|
if (isEmpty(trainingDataset.trainingData)) {
|
||||||
console.log("no training trainingData found. training from scratch aborted.")
|
console.log(`no training trainingData found for mailbox ${ownerGroup} training from scratch aborted.`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
await this.initialTraining(ownerGroup, trainingDataset)
|
await this.initialTraining(ownerGroup, trainingDataset)
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,7 @@ o.spec("SpamClassifierTest", () => {
|
||||||
await testClassifier(spamClassifier, testSet, compressor)
|
await testClassifier(spamClassifier, testSet, compressor)
|
||||||
|
|
||||||
const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf)
|
const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf)
|
||||||
|
trainingDatasetSecondHalf.lastTrainingDataIndexId = "some new index id"
|
||||||
await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf)
|
await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf)
|
||||||
console.log(`==> Result when testing with mails in two steps (second step).`)
|
console.log(`==> Result when testing with mails in two steps (second step).`)
|
||||||
await testClassifier(spamClassifier, testSet, compressor)
|
await testClassifier(spamClassifier, testSet, compressor)
|
||||||
|
|
@ -191,6 +192,7 @@ o.spec("SpamClassifierTest", () => {
|
||||||
const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount
|
const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount
|
||||||
o(classifier?.metaData.hamCount).equals(finalHamCount)
|
o(classifier?.metaData.hamCount).equals(finalHamCount)
|
||||||
o(classifier?.metaData.spamCount).equals(finalSpamCount)
|
o(classifier?.metaData.spamCount).equals(finalSpamCount)
|
||||||
|
o(classifier?.metaData.lastTrainingDataIndexId).equals(trainingDatasetSecondHalf.lastTrainingDataIndexId)
|
||||||
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
|
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -474,7 +476,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
||||||
spamClassifier.spamMailProcessor = spamProcessor
|
spamClassifier.spamMailProcessor = spamProcessor
|
||||||
})
|
})
|
||||||
|
|
||||||
o("time to refit", async () => {
|
o("time to refit and multiple refits work correctly", async () => {
|
||||||
o.timeout(20_000_000)
|
o.timeout(20_000_000)
|
||||||
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
|
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
|
||||||
const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false)
|
const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false)
|
||||||
|
|
@ -483,17 +485,32 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
|
||||||
seededShuffle(dataSlice, 42)
|
seededShuffle(dataSlice, 42)
|
||||||
|
|
||||||
const start = performance.now()
|
const start = performance.now()
|
||||||
await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice))
|
const initialTrainingDataset = getTrainingDataset(dataSlice)
|
||||||
|
await spamClassifier.initialTraining(TEST_OWNER_GROUP, initialTrainingDataset)
|
||||||
|
const initialClassifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)!
|
||||||
|
const initialHamCount = initialClassifier.metaData.hamCount
|
||||||
|
const initialSpamCount = initialClassifier.metaData.spamCount
|
||||||
const initialTrainingDuration = performance.now() - start
|
const initialTrainingDuration = performance.now() - start
|
||||||
console.log(`initial training time ${initialTrainingDuration}ms`)
|
console.log(`initial training time ${initialTrainingDuration}ms`)
|
||||||
|
|
||||||
for (let i = 0; i < 20; i++) {
|
for (let i = 0; i < 20; i++) {
|
||||||
const nowSpam = [hamSlice[0]]
|
const nowSpam = [hamSlice[0]]
|
||||||
nowSpam.map((formerHam) => (formerHam.spamDecision = "1"))
|
nowSpam.map((formerHam) => (formerHam.spamDecision = SpamDecision.BLACKLIST))
|
||||||
const retrainingStart = performance.now()
|
const retrainingStart = performance.now()
|
||||||
await spamClassifier.updateModel(TEST_OWNER_GROUP, getTrainingDataset(nowSpam))
|
const newPartialRetrainingDataset = getTrainingDataset(nowSpam)
|
||||||
|
newPartialRetrainingDataset.lastTrainingDataIndexId = "lastTrainingDataIndexId" + i
|
||||||
|
await spamClassifier.updateModel(TEST_OWNER_GROUP, newPartialRetrainingDataset)
|
||||||
const retrainingDuration = performance.now() - retrainingStart
|
const retrainingDuration = performance.now() - retrainingStart
|
||||||
console.log(`retraining time ${retrainingDuration}ms`)
|
console.log(`retraining time ${retrainingDuration}ms`)
|
||||||
|
|
||||||
|
// verify classifier correctness
|
||||||
|
const classifier = spamClassifier.classifierByMailGroup.get(TEST_OWNER_GROUP)!
|
||||||
|
const finalHamCount = initialHamCount
|
||||||
|
const finalSpamCount = initialSpamCount + i + 1
|
||||||
|
o(classifier?.metaData.hamCount).equals(finalHamCount)
|
||||||
|
o(classifier?.metaData.spamCount).equals(finalSpamCount)
|
||||||
|
o(classifier?.metaData.lastTrainingDataIndexId).equals(newPartialRetrainingDataset.lastTrainingDataIndexId)
|
||||||
|
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue