2025-10-14 12:32:17 +02:00
import { assertWorkerOrNode } from "../../../common/api/common/Env"
2025-11-03 18:01:36 +01:00
import { assertNotNull , groupByAndMap , isEmpty , Nullable , promiseMap } from "@tutao/tutanota-utils"
import { SpamClassificationDataDealer , TrainingDataset } from "./SpamClassificationDataDealer"
2025-10-14 12:32:17 +02:00
import { CacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache"
import {
dense ,
2025-11-03 18:01:36 +01:00
enableProdMode ,
2025-10-14 12:32:17 +02:00
fromMemory ,
glorotUniform ,
LayersModel ,
loadLayersModelFromIOHandler ,
sequential ,
tensor1d ,
tensor2d ,
withSaveHandler ,
} from "./tensorflow-custom"
import type { ModelArtifacts } from "@tensorflow/tfjs-core/dist/io/types"
import type { ModelFitArgs } from "@tensorflow/tfjs-layers"
2025-11-03 18:01:36 +01:00
import type { Tensor } from "@tensorflow/tfjs-core"
import { DEFAULT_PREPROCESS_CONFIGURATION , SpamMailDatum , SpamMailProcessor } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { SparseVectorCompressor } from "../../../common/api/common/utils/spamClassificationUtils/SparseVectorCompressor"
import { SpamDecision } from "../../../common/api/common/TutanotaConstants"
2025-10-14 12:32:17 +02:00
export type SpamClassificationModel = {
modelTopology : string
weightSpecs : string
weightData : Uint8Array
ownerGroup : Id
2025-11-03 18:01:36 +01:00
hamCount : number
spamCount : number
2025-10-14 12:32:17 +02:00
}
2025-11-03 18:01:36 +01:00
export const DEFAULT_PREDICTION_THRESHOLD = 0.55
2025-10-14 12:32:17 +02:00
2025-10-14 12:11:22 +02:00
const TRAINING_INTERVAL = 1000 * 60 * 10 // 10 minutes
const FULL_RETRAINING_INTERVAL = 1000 * 60 * 60 * 24 * 7 // 1 week
2025-10-14 12:32:17 +02:00
2025-11-03 18:01:36 +01:00
export type Classifier = {
isEnabled : boolean
layersModel : LayersModel
threshold : number
hamCount : number
spamCount : number
2025-10-14 12:32:17 +02:00
}
export class SpamClassifier {
2025-11-03 18:01:36 +01:00
// Visible for testing
readonly classifiers : Map < Id , Classifier >
sparseVectorCompressor : SparseVectorCompressor
spamMailProcessor : SpamMailProcessor
2025-10-14 12:32:17 +02:00
constructor (
2025-11-03 18:01:36 +01:00
private readonly cacheStorage : CacheStorage ,
private readonly initializer : SpamClassificationDataDealer ,
2025-10-14 12:32:17 +02:00
private readonly deterministic : boolean = false ,
) {
2025-11-03 18:01:36 +01:00
// enable tensorflow production mode
enableProdMode ( )
this . classifiers = new Map ( )
this . sparseVectorCompressor = new SparseVectorCompressor ( )
2025-11-18 16:42:23 +01:00
this . spamMailProcessor = new SpamMailProcessor ( DEFAULT_PREPROCESS_CONFIGURATION , this . sparseVectorCompressor )
2025-10-14 12:32:17 +02:00
}
2025-11-03 18:01:36 +01:00
calculateThreshold ( hamCount : number , spamCount : number ) {
const hamToSpamRatio = hamCount / spamCount
let threshold = - 0.1 * Math . log10 ( hamToSpamRatio * 10 ) + 0.65
if ( threshold < DEFAULT_PREDICTION_THRESHOLD ) {
threshold = DEFAULT_PREDICTION_THRESHOLD
} else if ( threshold > 0.75 ) {
threshold = 0.75
}
return threshold
}
2025-10-14 12:11:22 +02:00
2025-11-03 18:01:36 +01:00
public async initialize ( ownerGroup : Id ) : Promise < void > {
const classifier = await this . loadClassifier ( ownerGroup )
if ( classifier ) {
const timeSinceLastFullTraining = Date . now ( ) - FULL_RETRAINING_INTERVAL
const lastFullTrainingTime = await this . cacheStorage . getLastTrainedFromScratchTime ( )
if ( timeSinceLastFullTraining > lastFullTrainingTime ) {
console . log ( ` Retraining from scratch as last train ( ${ new Date ( lastFullTrainingTime ) } ) was more than a week ago ` )
await this . trainFromScratch ( this . cacheStorage , ownerGroup )
} else {
console . log ( "loaded existing spam classification model from database" )
this . classifiers . set ( ownerGroup , classifier )
await this . updateAndSaveModel ( this . cacheStorage , ownerGroup )
2025-10-14 12:11:22 +02:00
}
2025-10-14 12:32:17 +02:00
setInterval ( async ( ) = > {
2025-11-03 18:01:36 +01:00
await this . updateAndSaveModel ( this . cacheStorage , ownerGroup )
} , TRAINING_INTERVAL )
} else {
console . log ( "no existing model found. Training from scratch ..." )
await this . trainFromScratch ( this . cacheStorage , ownerGroup )
setInterval ( async ( ) = > {
await this . updateAndSaveModel ( this . cacheStorage , ownerGroup )
2025-10-14 12:32:17 +02:00
} , TRAINING_INTERVAL )
2025-10-22 16:44:57 +02:00
}
2025-10-14 12:11:22 +02:00
}
2025-11-03 18:01:36 +01:00
// visibleForTesting
2025-10-14 12:32:17 +02:00
public async updateAndSaveModel ( storage : CacheStorage , ownerGroup : Id ) {
2025-11-03 18:01:36 +01:00
const isModelUpdated = await this . updateModelFromIndexStartId ( await storage . getLastTrainingDataIndexId ( ) , ownerGroup )
2025-10-14 12:32:17 +02:00
if ( isModelUpdated ) {
2025-11-03 18:01:36 +01:00
console . log ( ` Model updated successfully at ${ Date . now ( ) } ` )
2025-10-14 12:32:17 +02:00
}
2025-10-22 16:18:24 +02:00
}
2025-11-03 18:01:36 +01:00
public async initialTraining ( ownerGroup : Id , trainingDataset : TrainingDataset ) : Promise < void > {
const { trainingData : clientSpamTrainingData , hamCount , spamCount } = trainingDataset
const trainingInput = await promiseMap (
clientSpamTrainingData ,
( d ) = > {
const vector = this . sparseVectorCompressor . binaryToVector ( d . vector )
const label = d . spamDecision === SpamDecision . BLACKLIST ? 1 : 0
return { vector , label }
} ,
{
concurrency : 5 ,
} ,
)
const vectors = trainingInput . map ( ( input ) = > input . vector )
const labels = trainingInput . map ( ( input ) = > input . label )
2025-10-14 12:32:17 +02:00
2025-11-03 18:01:36 +01:00
const xs = tensor2d ( vectors , [ trainingInput . length , this . sparseVectorCompressor . dimension ] , undefined )
2025-10-14 12:32:17 +02:00
const ys = tensor1d ( labels , undefined )
2025-11-03 18:01:36 +01:00
const layersModel = this . buildModel ( this . sparseVectorCompressor . dimension )
2025-10-14 12:32:17 +02:00
const trainingStart = performance . now ( )
2025-11-03 18:01:36 +01:00
await layersModel . fit ( xs , ys , {
2025-10-14 12:32:17 +02:00
epochs : 16 ,
batchSize : 32 ,
shuffle : ! this . deterministic ,
2025-10-14 12:11:22 +02:00
// callbacks: {
// onEpochEnd: async (epoch, logs) => {
// if (logs) {
// console.log(`Epoch ${epoch + 1} - Loss: ${logs.loss.toFixed(4)}`)
// }
// },
// },
2025-11-03 18:01:36 +01:00
yieldEvery : 15 ,
2025-10-14 12:32:17 +02:00
} )
const trainingTime = performance . now ( ) - trainingStart
2025-11-03 18:01:36 +01:00
// when using the webgl backend we need to manually dispose @tensorflow tensors
2025-10-14 12:32:17 +02:00
xs . dispose ( )
ys . dispose ( )
2025-11-03 18:01:36 +01:00
const threshold = this . calculateThreshold ( trainingDataset . hamCount , trainingDataset . spamCount )
const classifier = {
layersModel : layersModel ,
isEnabled : true ,
hamCount ,
spamCount ,
threshold ,
}
this . classifiers . set ( ownerGroup , classifier )
2025-10-14 12:32:17 +02:00
2025-10-14 12:11:22 +02:00
console . log (
2025-11-03 18:01:36 +01:00
` ### Finished Initial Spam Classification Model Training ### (total trained mails: ${ clientSpamTrainingData . length } (ham:spam ${ hamCount } : ${ spamCount } => threshold: ${ threshold } ), training time: ${ trainingTime } ) ` ,
2025-10-14 12:11:22 +02:00
)
2025-10-14 12:32:17 +02:00
}
2025-11-03 18:01:36 +01:00
public async updateModelFromIndexStartId ( indexStartId : Id , ownerGroup : Id ) : Promise < boolean > {
2025-10-14 12:32:17 +02:00
try {
2025-11-03 18:01:36 +01:00
const modelNotEnabled = this . classifiers . get ( ownerGroup ) === undefined || this . classifiers . get ( ownerGroup ) ? . isEnabled === false
2025-10-14 12:32:17 +02:00
if ( modelNotEnabled ) {
2025-11-03 18:01:36 +01:00
console . warn ( "client spam classification is not enabled or there were errors during training" )
2025-10-14 12:32:17 +02:00
return false
}
2025-11-03 18:01:36 +01:00
const trainingDataset = await this . initializer . fetchPartialTrainingDataFromIndexStartId ( indexStartId , ownerGroup )
if ( isEmpty ( trainingDataset . trainingData ) ) {
console . log ( "no new spam classification training data since last update" )
2025-10-14 12:32:17 +02:00
return false
}
2025-11-03 18:01:36 +01:00
console . log (
` retraining spam classification model with ${ trainingDataset . trainingData . length } new mails (ham:spam ${ trainingDataset . hamCount } : ${ trainingDataset . spamCount } ) (lastTrainingDataIndexId > ${ indexStartId } ) ` ,
)
return await this . updateModel ( ownerGroup , trainingDataset )
2025-10-14 12:32:17 +02:00
} catch ( e ) {
2025-11-03 18:01:36 +01:00
console . error ( "failed to update the model" , e )
2025-10-14 12:32:17 +02:00
return false
}
}
2025-11-03 18:01:36 +01:00
// visibleForTesting
async updateModel ( ownerGroup : Id , trainingDataset : TrainingDataset ) : Promise < boolean > {
2025-10-14 12:32:17 +02:00
const retrainingStart = performance . now ( )
2025-11-03 18:01:36 +01:00
if ( isEmpty ( trainingDataset . trainingData ) ) {
console . log ( "no new spam classification training data since last update" )
return false
}
2025-10-14 12:32:17 +02:00
2025-11-03 18:01:36 +01:00
const modelToUpdate = assertNotNull ( this . classifiers . get ( ownerGroup ) )
const trainingInput = await promiseMap (
trainingDataset . trainingData ,
( d ) = > {
const vector = this . sparseVectorCompressor . binaryToVector ( d . vector )
const label = d . spamDecision === SpamDecision . BLACKLIST ? 1 : 0
const isSpamConfidence = Number ( d . confidence )
return { vector , label , isSpamConfidence }
} ,
{
concurrency : 5 ,
} ,
)
const trainingInputByConfidence = groupByAndMap (
trainingInput ,
2025-10-14 12:32:17 +02:00
( { isSpamConfidence } ) = > isSpamConfidence ,
2025-11-03 18:01:36 +01:00
( { vector , label } ) = > {
return { vector , label }
2025-10-14 12:32:17 +02:00
} ,
)
2025-11-03 18:01:36 +01:00
2025-10-14 12:32:17 +02:00
modelToUpdate . isEnabled = false
2025-11-03 18:01:36 +01:00
2025-10-14 12:32:17 +02:00
try {
2025-11-03 18:01:36 +01:00
for ( const [ isSpamConfidence , trainingInput ] of trainingInputByConfidence ) {
const vectors = trainingInput . map ( ( input ) = > input . vector )
const labels = trainingInput . map ( ( input ) = > input . label )
const xs = tensor2d ( vectors , [ vectors . length , this . sparseVectorCompressor . dimension ] , "int32" )
const ys = tensor1d ( labels , "int32" )
2025-10-14 12:32:17 +02:00
// We need a way to put weight on a specific mail, ideal way would be to pass sampleWeight to modelFitArgs,
// but is not yet implemented: https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/engine/training.ts#L1487
//
2025-11-03 18:01:36 +01:00
// For now, we use the following workaround:
// Re-fit the vector multiple times corresponding to `isSpamConfidence`
2025-10-14 12:32:17 +02:00
const modelFitArgs : ModelFitArgs = {
epochs : 8 ,
batchSize : 32 ,
shuffle : ! this . deterministic ,
2025-10-14 12:11:22 +02:00
// callbacks: {
// onEpochEnd: async (epoch, logs) => {
// console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`)
// },
// },
2025-11-03 18:01:36 +01:00
yieldEvery : 15 ,
2025-10-14 12:32:17 +02:00
}
for ( let i = 0 ; i <= isSpamConfidence ; i ++ ) {
2025-11-03 18:01:36 +01:00
await modelToUpdate . layersModel . fit ( xs , ys , modelFitArgs )
2025-10-14 12:32:17 +02:00
}
2025-11-03 18:01:36 +01:00
// when using the webgl backend we need to manually dispose @tensorflow tensors
2025-10-14 12:32:17 +02:00
xs . dispose ( )
ys . dispose ( )
}
} finally {
2025-11-03 18:01:36 +01:00
modelToUpdate . hamCount += trainingDataset . hamCount
modelToUpdate . spamCount += trainingDataset . spamCount
modelToUpdate . threshold = this . calculateThreshold ( modelToUpdate . hamCount , modelToUpdate . spamCount )
2025-10-14 12:32:17 +02:00
modelToUpdate . isEnabled = true
}
2025-11-03 18:01:36 +01:00
const trainingMetadata = ` Total Ham: ${ modelToUpdate . hamCount } Spam: ${ modelToUpdate . spamCount } threshold: ${ modelToUpdate . threshold } } `
console . log ( ` retraining spam classification model finished, took: ${ performance . now ( ) - retrainingStart } ms ${ trainingMetadata } ` )
await this . saveModel ( ownerGroup )
await this . cacheStorage . setLastTrainingDataIndexId ( trainingDataset . lastTrainingDataIndexId )
2025-10-14 12:32:17 +02:00
return true
}
// visibleForTesting
2025-11-03 18:01:36 +01:00
public async predict ( vector : number [ ] , ownerGroup : Id ) : Promise < Nullable < boolean > > {
const classifier = this . classifiers . get ( ownerGroup )
2025-10-14 12:32:17 +02:00
if ( classifier == null || ! classifier . isEnabled ) {
return null
}
2025-11-03 18:01:36 +01:00
const vectors = [ vector ]
const xs = tensor2d ( vectors , [ vectors . length , this . sparseVectorCompressor . dimension ] , "int32" )
2025-10-14 12:32:17 +02:00
2025-11-03 18:01:36 +01:00
const predictionTensor = classifier . layersModel . predict ( xs ) as Tensor
2025-10-14 12:32:17 +02:00
const predictionData = await predictionTensor . data ( )
const prediction = predictionData [ 0 ]
2025-11-03 18:01:36 +01:00
console . log ( ` predicted new mail to be with probability ${ prediction . toFixed ( 2 ) } spam. Owner Group: ${ ownerGroup } ` )
2025-10-14 12:32:17 +02:00
2025-11-03 18:01:36 +01:00
// when using the webgl backend we need to manually dispose @tensorflow tensors
2025-10-14 12:32:17 +02:00
xs . dispose ( )
predictionTensor . dispose ( )
2025-11-03 18:01:36 +01:00
return prediction > classifier . threshold
2025-10-14 12:32:17 +02:00
}
// visibleForTesting
public buildModel ( inputDimension : number ) : LayersModel {
const model = sequential ( )
model . add (
dense ( {
inputShape : [ inputDimension ] ,
units : 16 ,
activation : "relu" ,
kernelInitializer : this.deterministic ? glorotUniform ( { seed : 42 } ) : glorotUniform ( { } ) ,
} ) ,
)
model . add (
dense ( {
inputShape : [ 16 ] ,
units : 16 ,
activation : "relu" ,
kernelInitializer : this.deterministic ? glorotUniform ( { seed : 42 } ) : glorotUniform ( { } ) ,
} ) ,
)
model . add (
dense ( {
inputShape : [ 16 ] ,
units : 16 ,
activation : "relu" ,
kernelInitializer : this.deterministic ? glorotUniform ( { seed : 42 } ) : glorotUniform ( { } ) ,
} ) ,
)
model . add (
dense ( {
units : 1 ,
activation : "sigmoid" ,
kernelInitializer : this.deterministic ? glorotUniform ( { seed : 42 } ) : glorotUniform ( { } ) ,
} ) ,
)
model . compile ( { optimizer : "adam" , loss : "binaryCrossentropy" , metrics : [ "accuracy" ] } )
return model
}
public async saveModel ( ownerGroup : Id ) : Promise < void > {
2025-11-03 18:01:36 +01:00
const spamClassificationModel = await this . getSpamClassificationModel ( ownerGroup )
if ( spamClassificationModel == null ) {
throw new Error ( "spam classification model is not available, and therefore can not be saved" )
2025-10-14 12:32:17 +02:00
}
2025-11-03 18:01:36 +01:00
await this . cacheStorage . setSpamClassificationModel ( spamClassificationModel )
2025-10-14 12:32:17 +02:00
}
2025-11-03 18:01:36 +01:00
async vectorizeAndCompress ( mailDatum : SpamMailDatum ) {
return await this . spamMailProcessor . vectorizeAndCompress ( mailDatum )
}
2025-10-14 12:32:17 +02:00
2025-11-03 18:01:36 +01:00
async vectorize ( mailDatum : SpamMailDatum ) {
return await this . spamMailProcessor . vectorize ( mailDatum )
2025-10-14 12:32:17 +02:00
}
2025-10-14 12:11:22 +02:00
// visibleForTesting
2025-11-03 18:01:36 +01:00
public async loadClassifier ( ownerGroup : Id ) : Promise < Nullable < Classifier > > {
const spamClassificationModel = await assertNotNull ( this . cacheStorage ) . getSpamClassificationModel ( ownerGroup )
if ( spamClassificationModel ) {
const modelTopology = JSON . parse ( spamClassificationModel . modelTopology )
const weightSpecs = JSON . parse ( spamClassificationModel . weightSpecs )
const weightData = spamClassificationModel . weightData . buffer . slice (
spamClassificationModel . weightData . byteOffset ,
spamClassificationModel . weightData . byteOffset + spamClassificationModel . weightData . byteLength ,
)
const modelArtifacts = { modelTopology , weightSpecs , weightData }
const layersModel = await loadLayersModelFromIOHandler ( fromMemory ( modelArtifacts ) , undefined , undefined )
layersModel . compile ( {
2025-10-14 12:32:17 +02:00
optimizer : "adam" ,
loss : "binaryCrossentropy" ,
metrics : [ "accuracy" ] ,
} )
2025-11-03 18:01:36 +01:00
const threshold = this . calculateThreshold ( spamClassificationModel . hamCount , spamClassificationModel . spamCount )
return {
isEnabled : true ,
layersModel : layersModel ,
threshold ,
hamCount : spamClassificationModel.hamCount ,
spamCount : spamClassificationModel.spamCount ,
}
2025-10-14 12:32:17 +02:00
} else {
2025-11-03 18:01:36 +01:00
console . log ( "loading the spam classification spamClassificationModel from offline db failed ... " )
2025-10-14 12:32:17 +02:00
return null
}
}
2025-10-14 12:11:22 +02:00
// visibleForTesting
2025-10-14 12:32:17 +02:00
public async cloneClassifier ( ) : Promise < SpamClassifier > {
2025-11-03 18:01:36 +01:00
const newClassifier = new SpamClassifier ( this . cacheStorage , this . initializer , this . deterministic )
newClassifier . spamMailProcessor = this . spamMailProcessor
newClassifier . sparseVectorCompressor = this . sparseVectorCompressor
for ( const [ ownerGroup , { layersModel : _ , isEnabled , threshold , hamCount , spamCount } ] of this . classifiers ) {
const modelArtifacts = assertNotNull ( await this . getModelArtifacts ( ownerGroup ) )
const newModel = await loadLayersModelFromIOHandler ( fromMemory ( modelArtifacts ) , undefined , undefined )
2025-10-14 12:32:17 +02:00
newModel . compile ( {
optimizer : "adam" ,
loss : "binaryCrossentropy" ,
metrics : [ "accuracy" ] ,
} )
2025-11-03 18:01:36 +01:00
newClassifier . classifiers . set ( ownerGroup , {
layersModel : newModel ,
isEnabled ,
threshold ,
hamCount ,
spamCount ,
} )
2025-10-14 12:32:17 +02:00
}
return newClassifier
}
2025-10-14 12:11:22 +02:00
// visibleForTesting
2025-11-03 18:01:36 +01:00
public addSpamClassifierForOwner ( ownerGroup : Id , classifier : Classifier ) {
this . classifiers . set ( ownerGroup , classifier )
}
private async trainFromScratch ( storage : CacheStorage , ownerGroup : string ) {
const trainingDataset = await this . initializer . fetchAllTrainingData ( ownerGroup )
const { trainingData , lastTrainingDataIndexId } = trainingDataset
if ( isEmpty ( trainingData ) ) {
console . log ( "No training trainingData found. Training from scratch aborted." )
return
}
await this . initialTraining ( ownerGroup , trainingDataset )
await this . saveModel ( ownerGroup )
await storage . setLastTrainedFromScratchTime ( Date . now ( ) )
await storage . setLastTrainingDataIndexId ( lastTrainingDataIndexId )
}
private async getSpamClassificationModel ( ownerGroup : Id ) : Promise < SpamClassificationModel | null > {
const classifier = this . classifiers . get ( ownerGroup )
if ( ! classifier ) {
return null
}
const modelArtifacts = await this . getModelArtifacts ( ownerGroup )
if ( ! modelArtifacts ) {
return null
}
const modelTopology = JSON . stringify ( modelArtifacts . modelTopology )
const weightSpecs = JSON . stringify ( modelArtifacts . weightSpecs )
const weightData = new Uint8Array ( modelArtifacts . weightData as ArrayBuffer )
return {
modelTopology ,
weightSpecs ,
weightData ,
ownerGroup ,
hamCount : classifier.hamCount ,
spamCount : classifier.spamCount ,
}
}
private async getModelArtifacts ( ownerGroup : Id ) {
const classifier = this . classifiers . get ( ownerGroup )
if ( ! classifier ) {
return null
}
return await new Promise < ModelArtifacts > ( ( resolve ) = > {
const saveInfo = withSaveHandler ( async ( artifacts : any ) = > {
resolve ( artifacts )
return {
modelArtifactsInfo : {
dateSaved : new Date ( ) ,
modelTopologyType : "JSON" ,
} ,
}
} )
classifier . layersModel . save ( saveInfo , undefined )
} )
2025-10-14 12:11:22 +02:00
}
2025-10-14 12:32:17 +02:00
}