Implement spam training data sync and add TutanotaModelV98

We sync the spam training data encrypted through our server to make
sure that all clients for a specific user behave the same when
classifying mails. Additionally, this enables the spam classification
in the webApp. We compress the training data vectors
(see clientSpamTrainingDatum) before uploading to our server using
SparseVectorCompressor.ts. When a user has the ClientSpamClassification
enabled, the spam training data sync will happen for every mail
received.

ClientSpamTrainingDatum are not stored in the CacheStorage.
No entityEvents are emitted for this type.
However, we retrieve creations and updates for ClientSpamTrainingData
through the modifiedClientSpamTrainingDataIndex.

We calculate a threshold per classifier based on the dataset ham to spam
ratio, we also subsample our training data to cap the ham to spam ratio
within a certain limit.

Co-authored-by: jomapp <17314077+jomapp@users.noreply.github.com>
Co-authored-by: das <das@tutao.de>
Co-authored-by: abp <abp@tutao.de>
Co-authored-by: Kinan <104761667+kibibytium@users.noreply.github.com>
Co-authored-by: sug <sug@tutao.de>
Co-authored-by: nif <nif@tutao.de>
Co-authored-by: map <mpfau@users.noreply.github.com>
This commit is contained in:
map 2025-11-03 18:01:36 +01:00 committed by abp
parent f8bbd32695
commit 5293be6a4a
No known key found for this signature in database
GPG key ID: 791D4EC38A7AA7C2
63 changed files with 3877 additions and 1963 deletions

View file

@ -33,7 +33,7 @@ export const allowedImports = {
wasm: ["wasm-fallback"], wasm: ["wasm-fallback"],
"common-min": ["polyfill-helpers"], "common-min": ["polyfill-helpers"],
boot: ["polyfill-helpers", "common-min"], boot: ["polyfill-helpers", "common-min"],
common: ["polyfill-helpers", "common-min"], common: ["polyfill-helpers", "common-min", "spam-classifier"],
"gui-base": ["polyfill-helpers", "common-min", "common", "boot"], "gui-base": ["polyfill-helpers", "common-min", "common", "boot"],
main: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "date"], main: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "date"],
sanitizer: ["polyfill-helpers", "common-min", "common", "boot", "gui-base"], sanitizer: ["polyfill-helpers", "common-min", "common", "boot", "gui-base"],
@ -46,8 +46,8 @@ export const allowedImports = {
contacts: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "mail-view", "date", "date-gui", "mail-editor"], contacts: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "mail-view", "date", "date-gui", "mail-editor"],
"calendar-view": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "date", "date-gui", "sharing", "contacts"], "calendar-view": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "date", "date-gui", "sharing", "contacts"],
login: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main"], login: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main"],
"spam-classifier": ["polyfill-helpers", "common", "common-min", "main"], "spam-classifier": ["polyfill-helpers", "common", "common-min"],
worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback"], worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback", "spam-classifier"],
"pow-worker": [], "pow-worker": [],
settings: [ settings: [
"polyfill-helpers", "polyfill-helpers",

View file

@ -4089,6 +4089,10 @@ setOpHandler(opHandler);
function enableProdMode() {
env().set('PROD', true);
}
function engine() { function engine() {
return ENGINE; return ENGINE;
} }
@ -39156,4 +39160,4 @@ function dropout(args) {
return new Dropout(args); return new Dropout(args);
} }
export { LayersModel, dense, dropout, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$1 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler }; export { LayersModel, dense, dropout, enableProdMode, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$1 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler };

11
libs/tensorflow.js vendored
View file

@ -5387,6 +5387,15 @@ setOpHandler(opHandler);
* limitations under the License. * limitations under the License.
* ============================================================================= * =============================================================================
*/ */
/**
* Enables production mode which disables correctness checks in favor of
* performance.
*
* @doc {heading: 'Environment'}
*/
function enableProdMode() {
env().set('PROD', true);
}
/** /**
* It returns the global engine that keeps track of all tensors and backends. * It returns the global engine that keeps track of all tensors and backends.
* *
@ -55873,4 +55882,4 @@ function dropout(args) {
return new Dropout(args); return new Dropout(args);
} }
export { LayersModel, dense, dropout, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$1 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler }; export { LayersModel, dense, dropout, enableProdMode, fromMemory, glorotUniform, loadLayersModelFromIOHandler, sequential, stringToHashBucketFast$1 as stringToHashBucketFast, tensor1d, tensor2d, withSaveHandler };

View file

@ -244,7 +244,6 @@ export function debounce<F extends (...args: any) => void>(timeout: number, toTh
if (timeoutId) { if (timeoutId) {
clearTimeout(timeoutId) clearTimeout(timeoutId)
} }
toInvoke = toThrottle.bind(null, ...args) toInvoke = toThrottle.bind(null, ...args)
timeoutId = setTimeout(toInvoke, timeout) timeoutId = setTimeout(toInvoke, timeout)
}) })

View file

@ -518,6 +518,7 @@ export async function initLocator(worker: CalendarWorkerImpl, browserData: Brows
locator.user, locator.user,
locator.cachingEntityClient, locator.cachingEntityClient,
locator.crypto, locator.crypto,
locator.cryptoWrapper,
locator.serviceExecutor, locator.serviceExecutor,
await locator.blob(), await locator.blob(),
fileApp, fileApp,

View file

@ -8,7 +8,6 @@ import {
ContactCustomDate, ContactCustomDate,
ContactRelationship, ContactRelationship,
ContactSocialId, ContactSocialId,
Mail,
MailFolder, MailFolder,
UserSettingsGroupRoot, UserSettingsGroupRoot,
} from "../entities/tutanota/TypeRefs.js" } from "../entities/tutanota/TypeRefs.js"
@ -1418,9 +1417,3 @@ export enum ProcessingState {
} }
export const PLAN_SELECTOR_SELECTED_BOX_SCALE = "1.03" export const PLAN_SELECTOR_SELECTED_BOX_SCALE = "1.03"
export const DEFAULT_IS_SPAM_CONFIDENCE = 1
export const DEFAULT_IS_SPAM = false
export function getSpamConfidence(mail: Mail): number {
return Number(mail.clientSpamClassifierResult?.confidence ?? DEFAULT_IS_SPAM_CONFIDENCE)
}

View file

@ -61,7 +61,20 @@ export const DELETE_MULTIPLE_LIMIT = 100
*/ */
export type Stripped<T extends Partial<SomeEntity>> = Omit< export type Stripped<T extends Partial<SomeEntity>> = Omit<
T, T,
"_id" | "_area" | "_owner" | "_ownerGroup" | "_ownerEncSessionKey" | "_ownerKeyVersion" | "_permissions" | "_errors" | "_format" | "_type" | "_original" | "_id"
| "_area"
| "_owner"
| "_ownerGroup"
| "_ownerEncSessionKey"
| "_ownerKeyVersion"
| "ownerGroup"
| "ownerEncSessionKey"
| "ownerKeyVersion"
| "_permissions"
| "_errors"
| "_format"
| "_type"
| "_original"
> >
type OptionalEntity<T extends Entity> = T & { type OptionalEntity<T extends Entity> = T & {
@ -76,6 +89,9 @@ export type StrippedEntity<T extends Entity> =
| "_ownerGroup" | "_ownerGroup"
| "_ownerEncSessionKey" | "_ownerEncSessionKey"
| "_ownerKeyVersion" | "_ownerKeyVersion"
| "ownerGroup"
| "ownerEncSessionKey"
| "ownerKeyVersion"
| "_permissions" | "_permissions"
| "_errors" | "_errors"
| "_format" | "_format"

View file

@ -0,0 +1,229 @@
import { HashingVectorizer } from "../../../../../mail-app/workerUtils/spamClassification/HashingVectorizer"
import { htmlToText } from "../IndexUtils"
import {
ML_BITCOIN_REGEX,
ML_BITCOIN_TOKEN,
ML_CREDIT_CARD_REGEX,
ML_CREDIT_CARD_TOKEN,
ML_DATE_REGEX,
ML_DATE_TOKEN,
ML_EMAIL_ADDR_REGEX,
ML_EMAIL_ADDR_TOKEN,
ML_NUMBER_SEQUENCE_REGEX,
ML_NUMBER_SEQUENCE_TOKEN,
ML_SPACE_BEFORE_NEW_LINE_REGEX,
ML_SPACE_BEFORE_NEW_LINE_TOKEN,
ML_SPECIAL_CHARACTER_REGEX,
ML_SPECIAL_CHARACTER_TOKEN,
ML_URL_REGEX,
ML_URL_TOKEN,
} from "./PreprocessPatterns"
import { SparseVectorCompressor } from "./SparseVectorCompressor"
import { ProgrammingError } from "../../error/ProgrammingError"
import { assertNotNull, tokenize } from "@tutao/tutanota-utils"
import { Mail, MailAddress, MailDetails } from "../../../entities/tutanota/TypeRefs"
import { getMailBodyText } from "../../CommonMailUtils"
import { MailAuthenticationStatus } from "../../TutanotaConstants"
export type PreprocessConfiguration = {
isPreprocessMails: boolean
isRemoveHTML: boolean
isReplaceDates: boolean
isReplaceUrls: boolean
isReplaceMailAddresses: boolean
isReplaceBitcoinAddress: boolean
isReplaceCreditCards: boolean
isReplaceNumbers: boolean
isReplaceSpecialCharacters: boolean
isRemoveSpaceBeforeNewLine: boolean
}
export const spamClassifierTokenizer = (text: PreprocessedMailContent): string[] => tokenize(text)
export const DEFAULT_PREPROCESS_CONFIGURATION: PreprocessConfiguration = {
isPreprocessMails: true,
isRemoveHTML: true,
isReplaceDates: true,
isReplaceUrls: true,
isReplaceMailAddresses: true,
isReplaceBitcoinAddress: true,
isReplaceCreditCards: true,
isReplaceNumbers: true,
isReplaceSpecialCharacters: true,
isRemoveSpaceBeforeNewLine: true,
}
export type SpamMailDatum = {
subject: string
body: string
ownerGroup: Id
sender: string
toRecipients: string
ccRecipients: string
bccRecipients: string
authStatus: string
}
export type PreprocessedMailContent = string
export class SpamMailProcessor {
constructor(
private readonly preprocessConfiguration: PreprocessConfiguration = DEFAULT_PREPROCESS_CONFIGURATION,
readonly vectorizer: HashingVectorizer = new HashingVectorizer(),
private readonly sparseVectorCompressor: SparseVectorCompressor = new SparseVectorCompressor(),
) {
if (vectorizer.dimension !== sparseVectorCompressor.dimension) {
throw new ProgrammingError(
`a spam preprocessor was created with different dimensions. Vectorizer:${vectorizer.dimension} compressor: ${sparseVectorCompressor.dimension}`,
)
}
}
public async vectorizeAndCompress(spamMailDatum: SpamMailDatum): Promise<Uint8Array> {
const vector = await this.vectorize(spamMailDatum)
return this.compress(vector)
}
public async vectorize(spamMailDatum: SpamMailDatum): Promise<number[]> {
const preprocessedMail = this.preprocessMail(spamMailDatum)
const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
const vector = await this.vectorizer.vectorize(tokenizedMail)
return vector
}
public async compress(uncompressedVector: number[]): Promise<Uint8Array> {
return this.sparseVectorCompressor.vectorToBinary(uncompressedVector)
}
// visibleForTesting
public preprocessMail(mail: SpamMailDatum): PreprocessedMailContent {
const mailText = this.concatSubjectAndBody(mail)
if (!this.preprocessConfiguration.isPreprocessMails) {
return mailText
}
let preprocessedMail = mailText
// 1. Remove HTML code
if (this.preprocessConfiguration.isRemoveHTML) {
preprocessedMail = htmlToText(preprocessedMail)
}
// 2. Replace dates
if (this.preprocessConfiguration.isReplaceDates) {
for (const datePattern of ML_DATE_REGEX) {
preprocessedMail = preprocessedMail.replaceAll(datePattern, ML_DATE_TOKEN)
}
}
// 3. Replace urls
if (this.preprocessConfiguration.isReplaceUrls) {
preprocessedMail = preprocessedMail.replaceAll(ML_URL_REGEX, ML_URL_TOKEN)
}
// 4. Replace email addresses
if (this.preprocessConfiguration.isReplaceMailAddresses) {
preprocessedMail = preprocessedMail.replaceAll(ML_EMAIL_ADDR_REGEX, ML_EMAIL_ADDR_TOKEN)
}
// 5. Replace Bitcoin addresses
if (this.preprocessConfiguration.isReplaceBitcoinAddress) {
preprocessedMail = preprocessedMail.replaceAll(ML_BITCOIN_REGEX, ML_BITCOIN_TOKEN)
}
// 6. Replace credit card numbers
if (this.preprocessConfiguration.isReplaceCreditCards) {
preprocessedMail = preprocessedMail.replaceAll(ML_CREDIT_CARD_REGEX, ML_CREDIT_CARD_TOKEN)
}
// 7. Replace remaining numbers
if (this.preprocessConfiguration.isReplaceNumbers) {
preprocessedMail = preprocessedMail.replaceAll(ML_NUMBER_SEQUENCE_REGEX, ML_NUMBER_SEQUENCE_TOKEN)
}
// 8. Remove special characters
if (this.preprocessConfiguration.isReplaceSpecialCharacters) {
preprocessedMail = preprocessedMail.replaceAll(ML_SPECIAL_CHARACTER_REGEX, ML_SPECIAL_CHARACTER_TOKEN)
}
// 9. Remove spaces at end of lines
if (this.preprocessConfiguration.isRemoveSpaceBeforeNewLine) {
preprocessedMail = preprocessedMail.replaceAll(ML_SPACE_BEFORE_NEW_LINE_REGEX, ML_SPACE_BEFORE_NEW_LINE_TOKEN)
}
preprocessedMail += this.getHeaderFeatures(mail)
return preprocessedMail
}
private concatSubjectAndBody(mail: SpamMailDatum) {
const subject = mail.subject || ""
const body = mail.body || ""
const concatenated = `${subject}\n${body}`.trim()
return concatenated.length > 0 ? concatenated : " "
}
private getHeaderFeatures(mail: SpamMailDatum): string {
const { sender, toRecipients, ccRecipients, bccRecipients, authStatus } = mail
return `\n${sender}\n${toRecipients}\n${ccRecipients}\n${bccRecipients}\n${authStatus}`
}
}
export function createSpamMailDatum(mail: Mail, mailDetails: MailDetails) {
const spamMailDatum: SpamMailDatum = {
subject: mail.subject,
body: getMailBodyText(mailDetails.body),
ownerGroup: assertNotNull(mail._ownerGroup),
...extractSpamHeaderFeatures(mail, mailDetails),
}
return spamMailDatum
}
export function extractSpamHeaderFeatures(mail: Mail, mailDetails: MailDetails) {
const sender = joinNamesAndMailAddresses([mail?.sender])
const { toRecipients, ccRecipients, bccRecipients } = extractRecipients(mailDetails)
const authStatus = convertAuthStatusToSpamCategorizationToken(mail.authStatus)
return { sender, toRecipients, ccRecipients, bccRecipients, authStatus }
}
function extractRecipients({ recipients }: MailDetails) {
const toRecipients = joinNamesAndMailAddresses(recipients?.toRecipients)
const ccRecipients = joinNamesAndMailAddresses(recipients?.ccRecipients)
const bccRecipients = joinNamesAndMailAddresses(recipients?.bccRecipients)
return { toRecipients, ccRecipients, bccRecipients }
}
function joinNamesAndMailAddresses(recipients: MailAddress[] | null) {
return recipients?.map((recipient) => `${recipient?.name} ${recipient?.address}`).join(" ") || ""
}
function convertAuthStatusToSpamCategorizationToken(authStatus: string | null): string {
if (authStatus === MailAuthenticationStatus.AUTHENTICATED) {
return "TAUTHENTICATED"
} else if (authStatus === MailAuthenticationStatus.HARD_FAIL) {
return "THARDFAIL"
} else if (authStatus === MailAuthenticationStatus.SOFT_FAIL) {
return "TSOFTFAIL"
} else if (authStatus === MailAuthenticationStatus.INVALID_MAIL_FROM) {
return "TINVALIDMAILFROM"
} else if (authStatus === MailAuthenticationStatus.MISSING_MAIL_FROM) {
return "TMISSINGMAILFROM"
}
return ""
}
export const DEFAULT_IS_SPAM_CONFIDENCE = "1"
export function getSpamConfidence(mail: Mail): string {
return mail.clientSpamClassifierResult?.confidence ?? DEFAULT_IS_SPAM_CONFIDENCE
}
/**
* We pick a max word frequency of 2^5 so that we can compress it together
* with the index (which is 2^11 =2048) into two bytes
*/
export const MAX_WORD_FREQUENCY = 31
export const DEFAULT_VECTOR_MAX_LENGTH = 2048

View file

@ -0,0 +1,74 @@
import { ProgrammingError } from "../../error/ProgrammingError"
import { DEFAULT_VECTOR_MAX_LENGTH, MAX_WORD_FREQUENCY } from "./SpamMailProcessor"
/**
* Example:
*
* const vector = [0,0,7,0,0,4,4,0,0]
*
* const compressedSparseVector = {
* indices: [2, 5, 6],
* values: [7, 4, 4]
* }
*/
export type CompressedSparseVector = {
indices: number[] // this can be UInt16 (max. 2048) (delta encoding still doesn't guarantee values would be below 256 so we cannot use it + UInt8?)
values: number[] // values: [val, val, ...] (values are limited to [0..32] range
}
/**
* Class for compressing and decompressing sparse numerical vectors using delta encoding
* and run-length encoding techniques. This allows efficient storage and manipulation of
* sparse data by reducing unnecessary memory usage.
*/
export class SparseVectorCompressor {
constructor(public readonly dimension: number = DEFAULT_VECTOR_MAX_LENGTH) {}
public vectorToBinary(vector: number[]): Uint8Array {
const compressedSparseVector = this.compressVector(vector)
const result: number[] = []
result.length = compressedSparseVector.indices.length
for (let i = 0; i < compressedSparseVector.indices.length; i++) {
const index = compressedSparseVector.indices[i]
const value = compressedSparseVector.values[i]
result[i] = ((index & 0x7ff) << 5) | (value & 0x1f)
}
return new Uint8Array(new Uint16Array(result).buffer)
}
public binaryToVector(binary: Uint8Array): number[] {
const vector = new Array(this.dimension).fill(0)
const array = new Uint16Array(binary.buffer)
for (let i = 0; i < array.length; i++) {
const packedValue = array[i]
const index = (packedValue >> 5) & 0x7ff // Extract 11 bits for index
const value = packedValue & 0x1f // Extract 5 bits for value
vector[index] = value
}
return vector
}
/**
* Converts a dense vector to flat sparse form: { indices, values }
*/
public compressVector(vector: number[]): CompressedSparseVector {
if (vector.length > this.dimension) {
throw new ProgrammingError("vector is too big for dimension")
}
const indices: number[] = []
const values: number[] = []
for (let i = 0; i < vector.length; i++) {
const val = vector[i]
if (val !== 0) {
indices.push(i)
values.push(Math.min(val, MAX_WORD_FREQUENCY))
}
}
return { indices, values }
}
}

View file

@ -1,5 +1,5 @@
const modelInfo = { const modelInfo = {
version: 97, version: 98,
} }
export default modelInfo export default modelInfo

View file

@ -37,6 +37,8 @@ import { MoveMailDataTypeRef } from "./TypeRefs.js"
import { MoveMailPostOutTypeRef } from "./TypeRefs.js" import { MoveMailPostOutTypeRef } from "./TypeRefs.js"
import { NewsOutTypeRef } from "./TypeRefs.js" import { NewsOutTypeRef } from "./TypeRefs.js"
import { NewsInTypeRef } from "./TypeRefs.js" import { NewsInTypeRef } from "./TypeRefs.js"
import { PopulateClientSpamTrainingDataPostInTypeRef } from "./TypeRefs.js"
import { ProcessInboxPostInTypeRef } from "./TypeRefs.js"
import { ReceiveInfoServiceDataTypeRef } from "./TypeRefs.js" import { ReceiveInfoServiceDataTypeRef } from "./TypeRefs.js"
import { ReceiveInfoServicePostOutTypeRef } from "./TypeRefs.js" import { ReceiveInfoServicePostOutTypeRef } from "./TypeRefs.js"
import { ReportMailPostDataTypeRef } from "./TypeRefs.js" import { ReportMailPostDataTypeRef } from "./TypeRefs.js"
@ -231,6 +233,24 @@ export const NewsService = Object.freeze({
delete: null, delete: null,
} as const) } as const)
export const PopulateClientSpamTrainingDataService = Object.freeze({
app: "tutanota",
name: "PopulateClientSpamTrainingDataService",
get: null,
post: { data: PopulateClientSpamTrainingDataPostInTypeRef, return: null },
put: null,
delete: null,
} as const)
export const ProcessInboxService = Object.freeze({
app: "tutanota",
name: "ProcessInboxService",
get: null,
post: { data: ProcessInboxPostInTypeRef, return: null },
put: null,
delete: null,
} as const)
export const ReceiveInfoService = Object.freeze({ export const ReceiveInfoService = Object.freeze({
app: "tutanota", app: "tutanota",
name: "ReceiveInfoService", name: "ReceiveInfoService",

File diff suppressed because it is too large Load diff

View file

@ -244,6 +244,7 @@ export type Mail = {
_ownerKeyVersion: null | NumberString; _ownerKeyVersion: null | NumberString;
keyVerificationState: null | NumberString; keyVerificationState: null | NumberString;
processingState: NumberString; processingState: NumberString;
processNeeded: boolean;
sender: MailAddress; sender: MailAddress;
attachments: IdTuple[]; attachments: IdTuple[];
@ -284,6 +285,8 @@ export type MailBox = {
importedAttachments: Id; importedAttachments: Id;
mailImportStates: Id; mailImportStates: Id;
extractedFeatures: null | Id; extractedFeatures: null | Id;
clientSpamTrainingData: null | Id;
modifiedClientSpamTrainingDataIndex: null | Id;
} }
export const CreateExternalUserGroupDataTypeRef: TypeRef<CreateExternalUserGroupData> = new TypeRef("tutanota", 138) export const CreateExternalUserGroupDataTypeRef: TypeRef<CreateExternalUserGroupData> = new TypeRef("tutanota", 138)
@ -2609,3 +2612,108 @@ export type ClientClassifierResultPostIn = {
mails: IdTuple[]; mails: IdTuple[];
} }
export const ClientSpamTrainingDatumTypeRef: TypeRef<ClientSpamTrainingDatum> = new TypeRef("tutanota", 1736)
export function createClientSpamTrainingDatum(values: StrippedEntity<ClientSpamTrainingDatum>): ClientSpamTrainingDatum {
return Object.assign(create(typeModels[ClientSpamTrainingDatumTypeRef.typeId], ClientSpamTrainingDatumTypeRef), values)
}
export type ClientSpamTrainingDatum = {
_type: TypeRef<ClientSpamTrainingDatum>;
_errors: Object;
_original?: ClientSpamTrainingDatum
_id: IdTuple;
_permissions: Id;
_format: NumberString;
_ownerGroup: null | Id;
_ownerEncSessionKey: null | Uint8Array;
_ownerKeyVersion: null | NumberString;
confidence: NumberString;
spamDecision: NumberString;
vector: Uint8Array;
}
export const ClientSpamTrainingDatumIndexEntryTypeRef: TypeRef<ClientSpamTrainingDatumIndexEntry> = new TypeRef("tutanota", 1747)
export function createClientSpamTrainingDatumIndexEntry(values: StrippedEntity<ClientSpamTrainingDatumIndexEntry>): ClientSpamTrainingDatumIndexEntry {
return Object.assign(create(typeModels[ClientSpamTrainingDatumIndexEntryTypeRef.typeId], ClientSpamTrainingDatumIndexEntryTypeRef), values)
}
export type ClientSpamTrainingDatumIndexEntry = {
_type: TypeRef<ClientSpamTrainingDatumIndexEntry>;
_original?: ClientSpamTrainingDatumIndexEntry
_id: IdTuple;
_permissions: Id;
_format: NumberString;
_ownerGroup: null | Id;
clientSpamTrainingDatumElementId: Id;
}
export const ProcessInboxDatumTypeRef: TypeRef<ProcessInboxDatum> = new TypeRef("tutanota", 1756)
export function createProcessInboxDatum(values: StrippedEntity<ProcessInboxDatum>): ProcessInboxDatum {
return Object.assign(create(typeModels[ProcessInboxDatumTypeRef.typeId], ProcessInboxDatumTypeRef), values)
}
export type ProcessInboxDatum = {
_type: TypeRef<ProcessInboxDatum>;
_original?: ProcessInboxDatum
_id: Id;
ownerEncVectorSessionKey: Uint8Array;
ownerKeyVersion: NumberString;
classifierType: null | NumberString;
encVector: Uint8Array;
mailId: IdTuple;
targetMoveFolder: IdTuple;
}
export const ProcessInboxPostInTypeRef: TypeRef<ProcessInboxPostIn> = new TypeRef("tutanota", 1764)
export function createProcessInboxPostIn(values: StrippedEntity<ProcessInboxPostIn>): ProcessInboxPostIn {
return Object.assign(create(typeModels[ProcessInboxPostInTypeRef.typeId], ProcessInboxPostInTypeRef), values)
}
export type ProcessInboxPostIn = {
_type: TypeRef<ProcessInboxPostIn>;
_original?: ProcessInboxPostIn
_format: NumberString;
mailOwnerGroup: Id;
processInboxDatum: ProcessInboxDatum[];
}
export const PopulateClientSpamTrainingDatumTypeRef: TypeRef<PopulateClientSpamTrainingDatum> = new TypeRef("tutanota", 1770)
export function createPopulateClientSpamTrainingDatum(values: StrippedEntity<PopulateClientSpamTrainingDatum>): PopulateClientSpamTrainingDatum {
return Object.assign(create(typeModels[PopulateClientSpamTrainingDatumTypeRef.typeId], PopulateClientSpamTrainingDatumTypeRef), values)
}
export type PopulateClientSpamTrainingDatum = {
_type: TypeRef<PopulateClientSpamTrainingDatum>;
_original?: PopulateClientSpamTrainingDatum
_id: Id;
ownerEncVectorSessionKey: Uint8Array;
ownerKeyVersion: NumberString;
isSpam: boolean;
confidence: NumberString;
encVector: Uint8Array;
mailId: IdTuple;
}
export const PopulateClientSpamTrainingDataPostInTypeRef: TypeRef<PopulateClientSpamTrainingDataPostIn> = new TypeRef("tutanota", 1778)
export function createPopulateClientSpamTrainingDataPostIn(values: StrippedEntity<PopulateClientSpamTrainingDataPostIn>): PopulateClientSpamTrainingDataPostIn {
return Object.assign(create(typeModels[PopulateClientSpamTrainingDataPostInTypeRef.typeId], PopulateClientSpamTrainingDataPostInTypeRef), values)
}
export type PopulateClientSpamTrainingDataPostIn = {
_type: TypeRef<PopulateClientSpamTrainingDataPostIn>;
_original?: PopulateClientSpamTrainingDataPostIn
_format: NumberString;
mailOwnerGroup: Id;
populateClientSpamTrainingDatum: PopulateClientSpamTrainingDatum[];
}

View file

@ -45,9 +45,13 @@ export class EventController {
// the UserController must be notified first as other event receivers depend on it to be up-to-date // the UserController must be notified first as other event receivers depend on it to be up-to-date
await this.logins.getUserController().entityEventsReceived(entityUpdates, eventOwnerGroupId) await this.logins.getUserController().entityEventsReceived(entityUpdates, eventOwnerGroupId)
} }
// sequentially to prevent parallel loading of instances
for (const listener of this.entityListeners) { for (const listener of this.entityListeners) {
await listener(entityUpdates, eventOwnerGroupId) // run listeners async to speed up processing
// we ran it sequentially before to prevent parallel loading of instances
// this should not be a problem anymore as we prefetch now
// noinspection ES6MissingAwait
listener(entityUpdates, eventOwnerGroupId)
} }
} }

View file

@ -9,6 +9,8 @@ import {
MailService, MailService,
ManageLabelService, ManageLabelService,
MoveMailService, MoveMailService,
PopulateClientSpamTrainingDataService,
ProcessInboxService,
ReportMailService, ReportMailService,
ResolveConversationsService, ResolveConversationsService,
SendDraftService, SendDraftService,
@ -60,6 +62,10 @@ import {
createManageLabelServicePostIn, createManageLabelServicePostIn,
createMoveMailData, createMoveMailData,
createNewDraftAttachment, createNewDraftAttachment,
createPopulateClientSpamTrainingDataPostIn,
createPopulateClientSpamTrainingDatum,
createProcessInboxDatum,
createProcessInboxPostIn,
createReportMailPostData, createReportMailPostData,
createResolveConversationsServiceGetIn, createResolveConversationsServiceGetIn,
createSecureExternalRecipientKeyData, createSecureExternalRecipientKeyData,
@ -81,6 +87,8 @@ import {
MailFolder, MailFolder,
MailTypeRef, MailTypeRef,
MovedMails, MovedMails,
PopulateClientSpamTrainingDatum,
ProcessInboxDatum,
ReportedMailFieldMarker, ReportedMailFieldMarker,
SendDraftData, SendDraftData,
SymEncInternalRecipientKeyData, SymEncInternalRecipientKeyData,
@ -114,7 +122,6 @@ import {
isNotNull, isNotNull,
isSameTypeRef, isSameTypeRef,
noOp, noOp,
Nullable,
ofClass, ofClass,
parseUrl, parseUrl,
promiseFilter, promiseFilter,
@ -132,6 +139,7 @@ import { UNCOMPRESSED_MAX_SIZE } from "../../Compression.js"
import { import {
Aes128Key, Aes128Key,
aes256RandomKey, aes256RandomKey,
aesEncrypt,
AesKey, AesKey,
bitArrayToUint8Array, bitArrayToUint8Array,
createAuthVerifier, createAuthVerifier,
@ -155,13 +163,16 @@ import { LoginFacade } from "../LoginFacade.js"
import { ProgrammingError } from "../../../common/error/ProgrammingError.js" import { ProgrammingError } from "../../../common/error/ProgrammingError.js"
import { OwnerEncSessionKeyProvider } from "../../rest/EntityRestClient.js" import { OwnerEncSessionKeyProvider } from "../../rest/EntityRestClient.js"
import { KeyLoaderFacade, parseKeyVersion } from "../KeyLoaderFacade.js" import { KeyLoaderFacade, parseKeyVersion } from "../KeyLoaderFacade.js"
import { _encryptBytes, _encryptKeyWithVersionedKey, _encryptString, VersionedKey } from "../../crypto/CryptoWrapper.js" import { CryptoWrapper, VersionedKey } from "../../crypto/CryptoWrapper.js"
import { PublicEncryptionKeyProvider } from "../PublicEncryptionKeyProvider.js" import { PublicEncryptionKeyProvider } from "../PublicEncryptionKeyProvider.js"
import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/utils/EntityUpdateUtils" import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/utils/EntityUpdateUtils"
import { Entity } from "../../../common/EntityTypes" import { Entity } from "../../../common/EntityTypes"
import { KeyVerificationMismatchError } from "../../../common/error/KeyVerificationMismatchError" import { KeyVerificationMismatchError } from "../../../common/error/KeyVerificationMismatchError"
import { VerifiedPublicEncryptionKey } from "./KeyVerificationFacade" import { VerifiedPublicEncryptionKey } from "./KeyVerificationFacade"
import { ClientClassifierType } from "../../../common/ClientClassifierType" import { UnencryptedProcessInboxDatum } from "../../../../../mail-app/mail/model/ProcessInboxHandler"
import { UnencryptedPopulateClientSpamTrainingDatum } from "../../../../../mail-app/workerUtils/spamClassification/SpamClassificationDataDealer"
import { MailWithMailDetails } from "../../../../../mail-app/workerUtils/index/BulkMailLoader"
import { createSpamMailDatum, SpamMailProcessor } from "../../../common/utils/spamClassificationUtils/SpamMailProcessor"
assertWorkerOrNode() assertWorkerOrNode()
type Attachments = ReadonlyArray<TutanotaFile | DataFile | FileReference> type Attachments = ReadonlyArray<TutanotaFile | DataFile | FileReference>
@ -199,11 +210,13 @@ export class MailFacade {
private phishingMarkers: Set<string> = new Set() private phishingMarkers: Set<string> = new Set()
private deferredDraftId: IdTuple | null = null // the mail id of the draft that we are waiting for to be updated via websocket private deferredDraftId: IdTuple | null = null // the mail id of the draft that we are waiting for to be updated via websocket
private deferredDraftUpdate: Record<string, any> | null = null // this deferred promise is resolved as soon as the update of the draft is received private deferredDraftUpdate: Record<string, any> | null = null // this deferred promise is resolved as soon as the update of the draft is received
private spamMailProcessor: SpamMailProcessor = new SpamMailProcessor()
constructor( constructor(
private readonly userFacade: UserFacade, private readonly userFacade: UserFacade,
private readonly entityClient: EntityClient, private readonly entityClient: EntityClient,
private readonly crypto: CryptoFacade, private readonly crypto: CryptoFacade,
private readonly cryptoWrapper: CryptoWrapper,
private readonly serviceExecutor: IServiceExecutor, private readonly serviceExecutor: IServiceExecutor,
private readonly blobFacade: BlobFacade, private readonly blobFacade: BlobFacade,
private readonly fileApp: NativeFileApp, private readonly fileApp: NativeFileApp,
@ -216,7 +229,7 @@ export class MailFacade {
const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(ownerGroupId) const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(ownerGroupId)
const sk = aes256RandomKey() const sk = aes256RandomKey()
const ownerEncSessionKey = _encryptKeyWithVersionedKey(mailGroupKey, sk) const ownerEncSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, sk)
const newFolder = createCreateMailFolderData({ const newFolder = createCreateMailFolderData({
folderName: name, folderName: name,
parentFolder: parent, parentFolder: parent,
@ -295,7 +308,7 @@ export class MailFacade {
const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(senderMailGroupId) const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(senderMailGroupId)
const sk = aes256RandomKey() const sk = aes256RandomKey()
const ownerEncSessionKey = _encryptKeyWithVersionedKey(mailGroupKey, sk) const ownerEncSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, sk)
const service = createDraftCreateData({ const service = createDraftCreateData({
previousMessageId: previousMessageId, previousMessageId: previousMessageId,
conversationType: conversationType, conversationType: conversationType,
@ -393,12 +406,7 @@ export class MailFacade {
/** /**
* Move mails from {@param targetFolder} except those that are in {@param excludeMailSet}. * Move mails from {@param targetFolder} except those that are in {@param excludeMailSet}.
*/ */
async moveMails( async moveMails(mails: readonly IdTuple[], targetFolder: IdTuple, excludeMailSet: IdTuple | null): Promise<MovedMails[]> {
mails: readonly IdTuple[],
targetFolder: IdTuple,
excludeMailSet: IdTuple | null,
moveReason: ClientClassifierType | null = null,
): Promise<MovedMails[]> {
if (isEmpty(mails)) { if (isEmpty(mails)) {
return [] return []
} }
@ -415,7 +423,7 @@ export class MailFacade {
mails, mails,
excludeMailSet, excludeMailSet,
targetFolder, targetFolder,
moveReason, moveReason: null, // moveReason is not needed anymore from clients using TutanotaModel > 97
}), }),
) )
movedMails.push(...moveMailPostOut.movedMails) movedMails.push(...moveMailPostOut.movedMails)
@ -424,11 +432,7 @@ export class MailFacade {
return movedMails return movedMails
} }
async simpleMoveMails( async simpleMoveMails(mails: readonly IdTuple[], targetFolderKind: SimpleMoveMailTarget): Promise<MovedMails[]> {
mails: readonly IdTuple[],
targetFolderKind: SimpleMoveMailTarget,
moveReason: Nullable<ClientClassifierType>,
): Promise<MovedMails[]> {
if (isEmpty(mails)) { if (isEmpty(mails)) {
return [] return []
} }
@ -441,7 +445,7 @@ export class MailFacade {
createSimpleMoveMailPostIn({ createSimpleMoveMailPostIn({
mails, mails,
destinationSetType: targetFolderKind, destinationSetType: targetFolderKind,
moveReason, moveReason: null, // moveReason is not needed anymore from clients using TutanotaModel > 97
}), }),
) )
movedMails.push(...simpleMove.movedMails) movedMails.push(...simpleMove.movedMails)
@ -541,7 +545,7 @@ export class MailFacade {
// forwarded attachment which was not in the draft before // forwarded attachment which was not in the draft before
return this.crypto.resolveSessionKey(providedFile).then((fileSessionKey) => { return this.crypto.resolveSessionKey(providedFile).then((fileSessionKey) => {
const sessionKey = assertNotNull(fileSessionKey, "filesessionkey was not resolved") const sessionKey = assertNotNull(fileSessionKey, "filesessionkey was not resolved")
const ownerEncFileSessionKey = _encryptKeyWithVersionedKey(mailGroupKey, sessionKey) const ownerEncFileSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, sessionKey)
const attachment = createDraftAttachment({ const attachment = createDraftAttachment({
existingFile: getLetId(providedFile), existingFile: getLetId(providedFile),
ownerEncFileSessionKey: ownerEncFileSessionKey.key, ownerEncFileSessionKey: ownerEncFileSessionKey.key,
@ -571,13 +575,13 @@ export class MailFacade {
providedFile: DataFile | FileReference, providedFile: DataFile | FileReference,
mailGroupKey: VersionedKey, mailGroupKey: VersionedKey,
): DraftAttachment { ): DraftAttachment {
const ownerEncFileSessionKey = _encryptKeyWithVersionedKey(mailGroupKey, fileSessionKey) const ownerEncFileSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, fileSessionKey)
return createDraftAttachment({ return createDraftAttachment({
newFile: createNewDraftAttachment({ newFile: createNewDraftAttachment({
encFileName: _encryptString(fileSessionKey, providedFile.name), encFileName: this.cryptoWrapper.encryptString(fileSessionKey, providedFile.name),
encMimeType: _encryptString(fileSessionKey, providedFile.mimeType), encMimeType: this.cryptoWrapper.encryptString(fileSessionKey, providedFile.mimeType),
referenceTokens: referenceTokens, referenceTokens: referenceTokens,
encCid: providedFile.cid == null ? null : _encryptString(fileSessionKey, providedFile.cid), encCid: providedFile.cid == null ? null : this.cryptoWrapper.encryptString(fileSessionKey, providedFile.cid),
}), }),
ownerEncFileSessionKey: ownerEncFileSessionKey.key, ownerEncFileSessionKey: ownerEncFileSessionKey.key,
ownerKeyVersion: ownerEncFileSessionKey.encryptingKeyVersion.toString(), ownerKeyVersion: ownerEncFileSessionKey.encryptingKeyVersion.toString(),
@ -640,7 +644,7 @@ export class MailFacade {
await this.addRecipientKeyData(bucketKey, sendDraftData, recipients, senderMailGroupId) await this.addRecipientKeyData(bucketKey, sendDraftData, recipients, senderMailGroupId)
if (this.isTutaCryptMail(sendDraftData)) { if (this.isTutaCryptMail(sendDraftData)) {
sendDraftData.sessionEncEncryptionAuthStatus = _encryptString(sk, EncryptionAuthStatus.TUTACRYPT_SENDER) sendDraftData.sessionEncEncryptionAuthStatus = this.cryptoWrapper.encryptString(sk, EncryptionAuthStatus.TUTACRYPT_SENDER)
} }
} else { } else {
sendDraftData.mailSessionKey = bitArrayToUint8Array(sk) sendDraftData.mailSessionKey = bitArrayToUint8Array(sk)
@ -788,7 +792,7 @@ export class MailFacade {
const passwordKey = await this.loginFacade.deriveUserPassphraseKey({ kdfType, passphrase, salt }) const passwordKey = await this.loginFacade.deriveUserPassphraseKey({ kdfType, passphrase, salt })
const passwordVerifier = createAuthVerifier(passwordKey) const passwordVerifier = createAuthVerifier(passwordKey)
const externalGroupKeys = await this.getExternalGroupKeys(recipient.address, kdfType, passwordKey, passwordVerifier) const externalGroupKeys = await this.getExternalGroupKeys(recipient.address, kdfType, passwordKey, passwordVerifier)
const ownerEncBucketKey = _encryptKeyWithVersionedKey(externalGroupKeys.currentExternalMailGroupKey, bucketKey) const ownerEncBucketKey = this.cryptoWrapper.encryptKeyWithVersionedKey(externalGroupKeys.currentExternalMailGroupKey, bucketKey)
const data = createSecureExternalRecipientKeyData({ const data = createSecureExternalRecipientKeyData({
mailAddress: recipient.address, mailAddress: recipient.address,
kdfVersion: kdfType, kdfVersion: kdfType,
@ -968,9 +972,9 @@ export class MailFacade {
const externalMailGroupInfoSessionKey = aes256RandomKey() const externalMailGroupInfoSessionKey = aes256RandomKey()
const tutanotaPropertiesSessionKey = aes256RandomKey() const tutanotaPropertiesSessionKey = aes256RandomKey()
const mailboxSessionKey = aes256RandomKey() const mailboxSessionKey = aes256RandomKey()
const externalUserEncEntropy = _encryptBytes(currentExternalUserGroupKey.object, random.generateRandomData(32)) const externalUserEncEntropy = this.cryptoWrapper.encryptBytes(currentExternalUserGroupKey.object, random.generateRandomData(32))
const internalUserEncGroupKey = _encryptKeyWithVersionedKey(internalUserGroupKey, currentExternalUserGroupKey.object) const internalUserEncGroupKey = this.cryptoWrapper.encryptKeyWithVersionedKey(internalUserGroupKey, currentExternalUserGroupKey.object)
const userGroupData = createCreateExternalUserGroupData({ const userGroupData = createCreateExternalUserGroupData({
mailAddress: cleanedMailAddress, mailAddress: cleanedMailAddress,
externalPwEncUserGroupKey: encryptKey(externalUserPwKey, currentExternalUserGroupKey.object), externalPwEncUserGroupKey: encryptKey(externalUserPwKey, currentExternalUserGroupKey.object),
@ -978,15 +982,24 @@ export class MailFacade {
internalUserGroupKeyVersion: internalUserEncGroupKey.encryptingKeyVersion.toString(), internalUserGroupKeyVersion: internalUserEncGroupKey.encryptingKeyVersion.toString(),
}) })
const externalUserEncUserGroupInfoSessionKey = _encryptKeyWithVersionedKey(currentExternalUserGroupKey, externalUserGroupInfoSessionKey) const externalUserEncUserGroupInfoSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(
const externalUserEncMailGroupKey = _encryptKeyWithVersionedKey(currentExternalUserGroupKey, currentExternalMailGroupKey.object) currentExternalUserGroupKey,
const externalUserEncTutanotaPropertiesSessionKey = _encryptKeyWithVersionedKey(currentExternalUserGroupKey, tutanotaPropertiesSessionKey) externalUserGroupInfoSessionKey,
)
const externalUserEncMailGroupKey = this.cryptoWrapper.encryptKeyWithVersionedKey(currentExternalUserGroupKey, currentExternalMailGroupKey.object)
const externalUserEncTutanotaPropertiesSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(
currentExternalUserGroupKey,
tutanotaPropertiesSessionKey,
)
const externalMailEncMailGroupInfoSessionKey = _encryptKeyWithVersionedKey(currentExternalMailGroupKey, externalMailGroupInfoSessionKey) const externalMailEncMailGroupInfoSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(
const externalMailEncMailBoxSessionKey = _encryptKeyWithVersionedKey(currentExternalMailGroupKey, mailboxSessionKey) currentExternalMailGroupKey,
externalMailGroupInfoSessionKey,
)
const externalMailEncMailBoxSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(currentExternalMailGroupKey, mailboxSessionKey)
const internalMailEncUserGroupInfoSessionKey = _encryptKeyWithVersionedKey(internalMailGroupKey, externalUserGroupInfoSessionKey) const internalMailEncUserGroupInfoSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(internalMailGroupKey, externalUserGroupInfoSessionKey)
const internalMailEncMailGroupInfoSessionKey = _encryptKeyWithVersionedKey(internalMailGroupKey, externalMailGroupInfoSessionKey) const internalMailEncMailGroupInfoSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(internalMailGroupKey, externalMailGroupInfoSessionKey)
const externalUserData = createExternalUserData({ const externalUserData = createExternalUserData({
verifier, verifier,
@ -1123,7 +1136,7 @@ export class MailFacade {
async createLabel(mailGroupId: Id, labelData: { name: string; color: string }) { async createLabel(mailGroupId: Id, labelData: { name: string; color: string }) {
const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(mailGroupId) const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(mailGroupId)
const sk = aes256RandomKey() const sk = aes256RandomKey()
const ownerEncSessionKey = _encryptKeyWithVersionedKey(mailGroupKey, sk) const ownerEncSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, sk)
await this.serviceExecutor.post( await this.serviceExecutor.post(
ManageLabelService, ManageLabelService,
@ -1215,6 +1228,96 @@ export class MailFacade {
) )
} }
private async encryptUnencryptedProcessInboxData(
mailGroupId: Id,
unencryptedProcessInboxData: readonly UnencryptedProcessInboxDatum[],
): Promise<ProcessInboxDatum[]> {
const processInboxData: ProcessInboxDatum[] = []
for (const unencryptedProcessInboxDatum of unencryptedProcessInboxData) {
const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(mailGroupId)
const sk = aes256RandomKey()
const ownerEncSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, sk)
const { targetMoveFolder, classifierType, mailId } = unencryptedProcessInboxDatum
processInboxData.push(
createProcessInboxDatum({
ownerEncVectorSessionKey: ownerEncSessionKey.key,
ownerKeyVersion: ownerEncSessionKey.encryptingKeyVersion.toString(),
encVector: aesEncrypt(sk, unencryptedProcessInboxDatum.vector),
classifierType,
mailId,
targetMoveFolder,
}),
)
}
return processInboxData
}
async processNewMails(mailGroupId: Id, unencryptedProcessInboxData: readonly UnencryptedProcessInboxDatum[]) {
const processInboxData = await this.encryptUnencryptedProcessInboxData(mailGroupId, unencryptedProcessInboxData)
await promiseMap(
splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, processInboxData),
async (inboxData) =>
this.serviceExecutor.post(
ProcessInboxService,
createProcessInboxPostIn({
mailOwnerGroup: mailGroupId,
processInboxDatum: inboxData,
}),
),
{ concurrency: 5 },
)
}
private async encryptUnencryptedPopulateClientSpamTrainingDatum(
mailGroupId: Id,
unencryptedPopulateClientSpamTrainingData: ReadonlyArray<UnencryptedPopulateClientSpamTrainingDatum>,
): Promise<Array<PopulateClientSpamTrainingDatum>> {
const populateClientSpamTrainingData: PopulateClientSpamTrainingDatum[] = []
for (const unencryptedProcessInboxDatum of unencryptedPopulateClientSpamTrainingData) {
const mailGroupKey = await this.keyLoaderFacade.getCurrentSymGroupKey(mailGroupId)
const sk = aes256RandomKey()
const ownerEncSessionKey = this.cryptoWrapper.encryptKeyWithVersionedKey(mailGroupKey, sk)
const { isSpam, confidence, mailId } = unencryptedProcessInboxDatum
populateClientSpamTrainingData.push(
createPopulateClientSpamTrainingDatum({
ownerEncVectorSessionKey: ownerEncSessionKey.key,
ownerKeyVersion: ownerEncSessionKey.encryptingKeyVersion.toString(),
encVector: aesEncrypt(sk, unencryptedProcessInboxDatum.vector),
isSpam,
mailId,
confidence,
}),
)
}
return populateClientSpamTrainingData
}
async populateClientSpamTrainingData(
mailGroupId: Id,
unencryptedPopulateClientSpamTrainingData: ReadonlyArray<UnencryptedPopulateClientSpamTrainingDatum>,
) {
const populateClientSpamTrainingData = await this.encryptUnencryptedPopulateClientSpamTrainingDatum(
mailGroupId,
unencryptedPopulateClientSpamTrainingData,
)
await promiseMap(
splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, populateClientSpamTrainingData),
async (clientSpamTrainingData) =>
this.serviceExecutor.post(
PopulateClientSpamTrainingDataService,
createPopulateClientSpamTrainingDataPostIn({
mailOwnerGroup: mailGroupId,
populateClientSpamTrainingDatum: clientSpamTrainingData,
}),
),
{ concurrency: 5 },
)
}
async vectorizeAndCompressMails(mailWithDetails: MailWithMailDetails) {
return this.spamMailProcessor.vectorizeAndCompress(createSpamMailDatum(mailWithDetails.mail, mailWithDetails.mailDetails))
}
/** Resolve conversation list ids to the IDs of mails in those conversations. */ /** Resolve conversation list ids to the IDs of mails in those conversations. */
async resolveConversations(conversationListIds: readonly Id[]): Promise<IdTuple[]> { async resolveConversations(conversationListIds: readonly Id[]): Promise<IdTuple[]> {
const result = await promiseMap( const result = await promiseMap(

View file

@ -46,6 +46,7 @@ import { AttributeModel } from "../../common/AttributeModel"
import { TypeModelResolver } from "../../common/EntityFunctions" import { TypeModelResolver } from "../../common/EntityFunctions"
import { collapseId, expandId } from "../rest/RestClientIdUtils" import { collapseId, expandId } from "../rest/RestClientIdUtils"
import { Category, syncMetrics } from "../utils/SyncMetrics" import { Category, syncMetrics } from "../utils/SyncMetrics"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
/** /**
* this is the value of SQLITE_MAX_VARIABLE_NUMBER in sqlite3.c * this is the value of SQLITE_MAX_VARIABLE_NUMBER in sqlite3.c
@ -102,6 +103,7 @@ export interface OfflineDbMeta {
"offline-version": number "offline-version": number
lastTrainedTime: number lastTrainedTime: number
lastTrainedFromScratchTime: number lastTrainedFromScratchTime: number
lastTrainingDataId: Id
} }
export const TableDefinitions = Object.freeze({ export const TableDefinitions = Object.freeze({
@ -140,6 +142,11 @@ export const TableDefinitions = Object.freeze({
"CREATE TABLE IF NOT EXISTS blob_element_entities (type TEXT NOT NULL, listId TEXT NOT NULL, elementId TEXT NOT NULL, ownerGroup TEXT, entity BLOB NOT NULL, PRIMARY KEY (type, listId, elementId))", "CREATE TABLE IF NOT EXISTS blob_element_entities (type TEXT NOT NULL, listId TEXT NOT NULL, elementId TEXT NOT NULL, ownerGroup TEXT, entity BLOB NOT NULL, PRIMARY KEY (type, listId, elementId))",
purgedWithCache: true, purgedWithCache: true,
}, },
spam_classification_model: {
definition:
"CREATE TABLE IF NOT EXISTS spam_classification_model (version NUMBER NOT NULL, ownerGroup TEXT NOT NULL, modelTopology TEXT NOT NULL, weightSpecs TEXT NOT NULL, weightData BLOB NOT NULL, hamCount NUMBER NOT NULL, spamCount NUMBER NOT NULL, PRIMARY KEY(version, ownerGroup))",
purgedWithCache: true,
},
} as const) satisfies Record<string, OfflineStorageTable> } as const) satisfies Record<string, OfflineStorageTable>
type Range = { lower: Id; upper: Id } type Range = { lower: Id; upper: Id }
@ -711,12 +718,12 @@ export class OfflineStorage implements CacheStorage {
await this.putMetadata("lastUpdateTime", ms) await this.putMetadata("lastUpdateTime", ms)
} }
async getLastTrainedTime(): Promise<number> { async getLastTrainingDataIndexId(): Promise<Id> {
return (await this.getMetadata("lastTrainedTime")) ?? 0 return (await this.getMetadata("lastTrainingDataId")) ?? GENERATED_MIN_ID
} }
async setLastTrainedTime(ms: number): Promise<void> { async setLastTrainingDataIndexId(id: Id): Promise<void> {
await this.putMetadata("lastTrainedTime", ms) await this.putMetadata("lastTrainingDataId", id)
} }
async getLastTrainedFromScratchTime(): Promise<number> { async getLastTrainedFromScratchTime(): Promise<number> {
@ -727,6 +734,41 @@ export class OfflineStorage implements CacheStorage {
await this.putMetadata("lastTrainedFromScratchTime", ms) await this.putMetadata("lastTrainedFromScratchTime", ms)
} }
async setSpamClassificationModel(model: SpamClassificationModel) {
const { query, params } = sql`INSERT
OR REPLACE INTO
spam_classification_model VALUES (
${1},
${model.ownerGroup},
${model.modelTopology},
${model.weightSpecs},
${model.weightData},
${model.hamCount},
${model.spamCount}
)`
await this.sqlCipherFacade.run(query, params)
}
async getSpamClassificationModel(ownerGroup: Id): Promise<Nullable<SpamClassificationModel>> {
const { query, params } = sql`SELECT modelTopology, weightSpecs, weightData, ownerGroup, hamCount, spamCount
FROM spam_classification_model
WHERE version = ${1}
AND ownerGroup = ${ownerGroup}`
const resultRows = await this.sqlCipherFacade.get(query, params)
if (resultRows !== null) {
const untaggedValue = untagSqlObject(resultRows)
return {
modelTopology: untaggedValue.modelTopology,
weightSpecs: untaggedValue.weightSpecs,
weightData: untaggedValue.weightData,
ownerGroup: untaggedValue.ownerGroup,
hamCount: untaggedValue.hamCount,
spamCount: untaggedValue.spamCount,
} as SpamClassificationModel
}
return null
}
async purgeStorage(): Promise<void> { async purgeStorage(): Promise<void> {
if (this.userId == null || this.databaseKey == null) { if (this.userId == null || this.databaseKey == null) {
console.warn("not purging storage since we don't have an open db") console.warn("not purging storage since we don't have an open db")

View file

@ -5,6 +5,7 @@ import { Nullable, TypeRef } from "@tutao/tutanota-utils"
import { OfflineStorage, OfflineStorageInitArgs } from "../offline/OfflineStorage.js" import { OfflineStorage, OfflineStorageInitArgs } from "../offline/OfflineStorage.js"
import { EphemeralCacheStorage, EphemeralStorageInitArgs } from "./EphemeralCacheStorage" import { EphemeralCacheStorage, EphemeralStorageInitArgs } from "./EphemeralCacheStorage"
import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js" import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
export interface EphemeralStorageArgs extends EphemeralStorageInitArgs { export interface EphemeralStorageArgs extends EphemeralStorageInitArgs {
type: "ephemeral" type: "ephemeral"
@ -185,12 +186,12 @@ export class LateInitializedCacheStorageImpl implements CacheStorageLateInitiali
return this.inner.putLastUpdateTime(value) return this.inner.putLastUpdateTime(value)
} }
setLastTrainedTime(value: number): Promise<void> { setLastTrainingDataIndexId(id: Id): Promise<void> {
return this.inner.setLastTrainedTime(value) return this.inner.setLastTrainingDataIndexId(id)
} }
getLastTrainedTime(): Promise<number> { getLastTrainingDataIndexId(): Promise<Id> {
return this.inner.getLastTrainedTime() return this.inner.getLastTrainingDataIndexId()
} }
setLastTrainedFromScratchTime(ms: number): Promise<void> { setLastTrainedFromScratchTime(ms: number): Promise<void> {
@ -201,6 +202,14 @@ export class LateInitializedCacheStorageImpl implements CacheStorageLateInitiali
return this.inner.getLastTrainedFromScratchTime() ?? Date.now() return this.inner.getLastTrainedFromScratchTime() ?? Date.now()
} }
setSpamClassificationModel(model: SpamClassificationModel): Promise<void> {
return this.inner.setSpamClassificationModel(model)
}
getSpamClassificationModel(ownerGroup: Id): Promise<Nullable<SpamClassificationModel>> {
return this.inner.getSpamClassificationModel(ownerGroup)
}
setLowerRangeForList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id, id: Id): Promise<void> { setLowerRangeForList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id, id: Id): Promise<void> {
return this.inner.setLowerRangeForList(typeRef, listId, id) return this.inner.setLowerRangeForList(typeRef, listId, id)
} }

View file

@ -25,7 +25,14 @@ import {
UserGroupRootTypeRef, UserGroupRootTypeRef,
} from "../../entities/sys/TypeRefs.js" } from "../../entities/sys/TypeRefs.js"
import { ValueType } from "../../common/EntityConstants.js" import { ValueType } from "../../common/EntityConstants.js"
import { CalendarEventUidIndexTypeRef, MailDetailsBlobTypeRef, MailSetEntryTypeRef, MailTypeRef } from "../../entities/tutanota/TypeRefs.js" import {
CalendarEventUidIndexTypeRef,
ClientSpamTrainingDatumIndexEntryTypeRef,
ClientSpamTrainingDatumTypeRef,
MailDetailsBlobTypeRef,
MailSetEntryTypeRef,
MailTypeRef,
} from "../../entities/tutanota/TypeRefs.js"
import { import {
CUSTOM_MAX_ID, CUSTOM_MAX_ID,
CUSTOM_MIN_ID, CUSTOM_MIN_ID,
@ -48,6 +55,7 @@ import { AttributeModel } from "../../common/AttributeModel"
import { collapseId, expandId } from "./RestClientIdUtils" import { collapseId, expandId } from "./RestClientIdUtils"
import { PatchMerger } from "../offline/PatchMerger" import { PatchMerger } from "../offline/PatchMerger"
import { hasError, isExpectedErrorForSynchronization } from "../../common/utils/ErrorUtils" import { hasError, isExpectedErrorForSynchronization } from "../../common/utils/ErrorUtils"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
assertWorkerOrNode() assertWorkerOrNode()
@ -74,6 +82,8 @@ const IGNORED_TYPES = [
UserGroupRootTypeRef, UserGroupRootTypeRef,
UserGroupKeyDistributionTypeRef, UserGroupKeyDistributionTypeRef,
AuditLogEntryTypeRef, // Should not be part of cached data because there are errors inside entity event processing after rotating the admin group key AuditLogEntryTypeRef, // Should not be part of cached data because there are errors inside entity event processing after rotating the admin group key
ClientSpamTrainingDatumTypeRef,
ClientSpamTrainingDatumIndexEntryTypeRef,
] as const ] as const
/** /**
@ -253,14 +263,18 @@ export interface CacheStorage extends ExposedCacheStorage {
putLastUpdateTime(value: number): Promise<void> putLastUpdateTime(value: number): Promise<void>
getLastTrainedTime(): Promise<number> getLastTrainingDataIndexId(): Promise<Id>
setLastTrainedTime(value: number): Promise<void> setLastTrainingDataIndexId(id: Id): Promise<void>
getLastTrainedFromScratchTime(): Promise<number> getLastTrainedFromScratchTime(): Promise<number>
setLastTrainedFromScratchTime(value: number): Promise<void> setLastTrainedFromScratchTime(value: number): Promise<void>
getSpamClassificationModel(ownerGroup: Id): Promise<Nullable<SpamClassificationModel>>
setSpamClassificationModel(model: SpamClassificationModel): Promise<void>
getUserId(): Id getUserId(): Id
deleteAllOwnedBy(owner: Id): Promise<void> deleteAllOwnedBy(owner: Id): Promise<void>

View file

@ -1,7 +1,7 @@
import { BlobElementEntity, Entity, ListElementEntity, ServerModelParsedInstance, SomeEntity, TypeModel } from "../../common/EntityTypes.js" import { BlobElementEntity, Entity, ListElementEntity, ServerModelParsedInstance, SomeEntity, TypeModel } from "../../common/EntityTypes.js"
import { customIdToBase64Url, ensureBase64Ext, firstBiggerThanSecond } from "../../common/utils/EntityUtils.js" import { customIdToBase64Url, ensureBase64Ext, firstBiggerThanSecond, GENERATED_MIN_ID } from "../../common/utils/EntityUtils.js"
import { CacheStorage, LastUpdateTime } from "./DefaultEntityRestCache.js" import { CacheStorage, LastUpdateTime } from "./DefaultEntityRestCache.js"
import { assertNotNull, clone, filterNull, getFromMap, getTypeString, Nullable, parseTypeString, remove, TypeRef } from "@tutao/tutanota-utils" import { assertNotNull, clone, filterNull, getFromMap, getTypeString, newPromise, Nullable, parseTypeString, remove, TypeRef } from "@tutao/tutanota-utils"
import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js" import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js"
import { Type as TypeId } from "../../common/EntityConstants.js" import { Type as TypeId } from "../../common/EntityConstants.js"
import { ProgrammingError } from "../../common/error/ProgrammingError.js" import { ProgrammingError } from "../../common/error/ProgrammingError.js"
@ -10,6 +10,7 @@ import { ModelMapper } from "../crypto/ModelMapper"
import { ServerTypeModelResolver } from "../../common/EntityFunctions" import { ServerTypeModelResolver } from "../../common/EntityFunctions"
import { expandId } from "./RestClientIdUtils" import { expandId } from "./RestClientIdUtils"
import { hasError } from "../../common/utils/ErrorUtils" import { hasError } from "../../common/utils/ErrorUtils"
import { SpamClassificationModel } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
/** Cache for a single list. */ /** Cache for a single list. */
type ListCache = { type ListCache = {
@ -41,8 +42,9 @@ export class EphemeralCacheStorage implements CacheStorage {
private readonly entities: Map<string, Map<Id, ServerModelParsedInstance>> = new Map() private readonly entities: Map<string, Map<Id, ServerModelParsedInstance>> = new Map()
private readonly lists: Map<string, ListTypeCache> = new Map() private readonly lists: Map<string, ListTypeCache> = new Map()
private readonly blobEntities: Map<string, BlobElementTypeCache> = new Map() private readonly blobEntities: Map<string, BlobElementTypeCache> = new Map()
private readonly spamClassificationModelCache: Map<Id, SpamClassificationModel> = new Map()
private lastUpdateTime: number | null = null private lastUpdateTime: number | null = null
private lastTrainedTime: number | null = null private lastTrainingDataId: Id = GENERATED_MIN_ID
private lastTrainedFromScratchTime: number | null = null private lastTrainedFromScratchTime: number | null = null
private userId: Id | null = null private userId: Id | null = null
private lastBatchIdPerGroup = new Map<Id, Id>() private lastBatchIdPerGroup = new Map<Id, Id>()
@ -419,12 +421,12 @@ export class EphemeralCacheStorage implements CacheStorage {
this.lastUpdateTime = value this.lastUpdateTime = value
} }
async getLastTrainedTime(): Promise<number> { async getLastTrainingDataIndexId(): Promise<Id> {
return this.lastTrainedTime ?? 0 return this.lastTrainingDataId
} }
async setLastTrainedTime(value: number): Promise<void> { async setLastTrainingDataIndexId(id: Id): Promise<void> {
this.lastTrainedTime = value this.lastTrainingDataId = id
} }
async getLastTrainedFromScratchTime(): Promise<number> { async getLastTrainedFromScratchTime(): Promise<number> {
@ -435,6 +437,14 @@ export class EphemeralCacheStorage implements CacheStorage {
this.lastTrainedFromScratchTime = ms this.lastTrainedFromScratchTime = ms
} }
async setSpamClassificationModel(model: SpamClassificationModel): Promise<void> {
this.spamClassificationModelCache.set(model.ownerGroup, model)
}
async getSpamClassificationModel(ownerGroup: Id): Promise<Nullable<SpamClassificationModel>> {
return this.spamClassificationModelCache.get(ownerGroup) ?? null
}
async getWholeList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id): Promise<Array<T>> { async getWholeList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id): Promise<Array<T>> {
const parsedInstances = await this.getWholeListParsed(typeRef, listId) const parsedInstances = await this.getWholeListParsed(typeRef, listId)
return await this.modelMapper.mapToInstances(typeRef, parsedInstances) return await this.modelMapper.mapToInstances(typeRef, parsedInstances)

View file

@ -97,6 +97,7 @@ export class CustomColorEditorPreview implements Component {
sets: [], sets: [],
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED, processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED,
clientSpamClassifierResult: null, clientSpamClassifierResult: null,
processNeeded: true,
} satisfies Partial<Mail> } satisfies Partial<Mail>
const mail = createMail({ const mail = createMail({
sender: createMailAddress({ sender: createMailAddress({

View file

@ -246,17 +246,17 @@ import("./translations/en.js")
), ),
) )
mailLocator.logins.addPostLoginAction(async () => { mailLocator.logins.addPostLoginAction(async () => {
const { MailIndexAndSpamClassificationPostLoginAction } = await import("./search/model/MailIndexAndSpamClassificationPostLoginAction") const { MailIndexerPostLoginAction } = await import("./search/model/MailIndexerPostLoginAction")
const offlineStorageSettings = await mailLocator.offlineStorageSettingsModel() const offlineStorageSettings = await mailLocator.offlineStorageSettingsModel()
return new MailIndexAndSpamClassificationPostLoginAction( return new MailIndexerPostLoginAction(assertNotNull(offlineStorageSettings), mailLocator.indexerFacade)
assertNotNull(offlineStorageSettings),
mailLocator.indexerFacade,
mailLocator.spamClassifier,
mailLocator.customerFacade,
)
}) })
} }
mailLocator.logins.addPostLoginAction(async () => {
const { SpamClassificationPostLoginAction } = await import("./mail/model/SpamClassificationPostLoginAction")
return new SpamClassificationPostLoginAction(mailLocator.spamClassifier, mailLocator.customerFacade)
})
mailLocator.logins.addPostLoginAction(async () => { mailLocator.logins.addPostLoginAction(async () => {
const { OpenLocallySavedDraftAction } = await import("./mail/editor/OpenLocallySavedDraftAction.js") const { OpenLocallySavedDraftAction } = await import("./mail/editor/OpenLocallySavedDraftAction.js")
const { newMailEditorFromLocalDraftData } = await import("./mail/editor/MailEditor.js") const { newMailEditorFromLocalDraftData } = await import("./mail/editor/MailEditor.js")

View file

@ -1,4 +1,4 @@
import { applyInboxRulesToEntries, LoadedMail, MailSetListModel, resolveMailSetEntries } from "./MailSetListModel" import { applyInboxRulesAndSpamPrediction, LoadedMail, MailSetListModel, resolveMailSetEntries } from "./MailSetListModel"
import { ListLoadingState, ListState } from "../../../common/gui/base/List" import { ListLoadingState, ListState } from "../../../common/gui/base/List"
import { Mail, MailFolder, MailFolderTypeRef, MailSetEntry, MailSetEntryTypeRef, MailTypeRef } from "../../../common/api/entities/tutanota/TypeRefs" import { Mail, MailFolder, MailFolderTypeRef, MailSetEntry, MailSetEntryTypeRef, MailTypeRef } from "../../../common/api/entities/tutanota/TypeRefs"
import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/api/common/utils/EntityUpdateUtils" import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/api/common/utils/EntityUpdateUtils"
@ -7,7 +7,6 @@ import Stream from "mithril/stream"
import { ConversationPrefProvider } from "../view/ConversationViewModel" import { ConversationPrefProvider } from "../view/ConversationViewModel"
import { EntityClient } from "../../../common/api/common/EntityClient" import { EntityClient } from "../../../common/api/common/EntityClient"
import { MailModel } from "./MailModel" import { MailModel } from "./MailModel"
import { InboxRuleHandler } from "./InboxRuleHandler"
import { ExposedCacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache" import { ExposedCacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache"
import { import {
CUSTOM_MAX_ID, CUSTOM_MAX_ID,
@ -34,6 +33,7 @@ import {
import { ListFetchResult } from "../../../common/gui/base/ListUtils" import { ListFetchResult } from "../../../common/gui/base/ListUtils"
import { isOfflineError } from "../../../common/api/common/utils/ErrorUtils" import { isOfflineError } from "../../../common/api/common/utils/ErrorUtils"
import { OperationType } from "../../../common/api/common/TutanotaConstants" import { OperationType } from "../../../common/api/common/TutanotaConstants"
import { ProcessInboxHandler } from "./ProcessInboxHandler"
/** /**
* Organizes mails into conversations and handles state upkeep. * Organizes mails into conversations and handles state upkeep.
@ -67,7 +67,7 @@ export class ConversationListModel implements MailSetListModel {
private readonly conversationPrefProvider: ConversationPrefProvider, private readonly conversationPrefProvider: ConversationPrefProvider,
private readonly entityClient: EntityClient, private readonly entityClient: EntityClient,
private readonly mailModel: MailModel, private readonly mailModel: MailModel,
private readonly inboxRuleHandler: InboxRuleHandler, private readonly processInboxHandler: ProcessInboxHandler,
private readonly cacheStorage: ExposedCacheStorage, private readonly cacheStorage: ExposedCacheStorage,
) { ) {
this.listModel = new ListModel({ this.listModel = new ListModel({
@ -467,7 +467,7 @@ export class ConversationListModel implements MailSetListModel {
if (mailSetEntries.length > 0) { if (mailSetEntries.length > 0) {
this.lastFetchedMailSetEntryId = getElementId(lastThrow(mailSetEntries)) this.lastFetchedMailSetEntryId = getElementId(lastThrow(mailSetEntries))
items = await this.resolveMailSetEntries(mailSetEntries, this.defaultMailProvider) items = await this.resolveMailSetEntries(mailSetEntries, this.defaultMailProvider)
items = await this.applyInboxRulesToEntries(items) items = await this.applyInboxRulesAndSpamPrediction(items)
} }
} catch (e) { } catch (e) {
if (isOfflineError(e)) { if (isOfflineError(e)) {
@ -496,8 +496,8 @@ export class ConversationListModel implements MailSetListModel {
} }
} }
private async applyInboxRulesToEntries(entries: LoadedMail[]): Promise<LoadedMail[]> { private async applyInboxRulesAndSpamPrediction(entries: LoadedMail[]): Promise<LoadedMail[]> {
return applyInboxRulesToEntries(entries, this.mailSet, this.mailModel, this.inboxRuleHandler) return applyInboxRulesAndSpamPrediction(entries, this.mailSet, this.mailModel, this.processInboxHandler)
} }
// @VisibleForTesting // @VisibleForTesting

View file

@ -1,72 +1,21 @@
import { createMoveMailData, InboxRule, Mail, MailFolder, MoveMailData } from "../../../common/api/entities/tutanota/TypeRefs.js" import { InboxRule, Mail, MailFolder } from "../../../common/api/entities/tutanota/TypeRefs.js"
import { FeatureType, InboxRuleType, MailSetKind, MAX_NBR_OF_MAILS_SYNC_OPERATION, ProcessingState } from "../../../common/api/common/TutanotaConstants" import { InboxRuleType, MailSetKind, ProcessingState } from "../../../common/api/common/TutanotaConstants"
import { isDomainName, isRegularExpression } from "../../../common/misc/FormatValidator" import { isDomainName, isRegularExpression } from "../../../common/misc/FormatValidator"
import { assertNotNull, asyncFind, debounce, ofClass, promiseMap, splitInChunks, throttleStart } from "@tutao/tutanota-utils" import { assertNotNull, asyncFind, Nullable } from "@tutao/tutanota-utils"
import { lang } from "../../../common/misc/LanguageViewModel" import { lang } from "../../../common/misc/LanguageViewModel"
import type { MailboxDetail } from "../../../common/mailFunctionality/MailboxModel.js" import type { MailboxDetail } from "../../../common/mailFunctionality/MailboxModel.js"
import { LockedError, PreconditionFailedError } from "../../../common/api/common/error/RestError"
import type { SelectorItemList } from "../../../common/gui/base/DropDownSelector.js" import type { SelectorItemList } from "../../../common/gui/base/DropDownSelector.js"
import { elementIdPart, isSameId } from "../../../common/api/common/utils/EntityUtils" import { elementIdPart } from "../../../common/api/common/utils/EntityUtils"
import { assertMainOrNode, isWebClient } from "../../../common/api/common/Env" import { assertMainOrNode } from "../../../common/api/common/Env"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js" import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js"
import { LoginController } from "../../../common/api/main/LoginController.js" import { LoginController } from "../../../common/api/main/LoginController.js"
import { getMailHeaders } from "./MailUtils.js" import { getMailHeaders } from "./MailUtils.js"
import { MailModel } from "./MailModel" import { MailModel } from "./MailModel"
import { UnencryptedProcessInboxDatum } from "./ProcessInboxHandler"
import { ClientClassifierType } from "../../../common/api/common/ClientClassifierType" import { ClientClassifierType } from "../../../common/api/common/ClientClassifierType"
assertMainOrNode() assertMainOrNode()
const moveMailDataPerFolder: MoveMailData[] = []
let noRuleMatchMailIds: IdTuple[] = []
const THROTTLE_MOVE_MAIL_SERVICE_REQUESTS_MS = 200
const DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS = 1000
async function sendMoveMailRequest(mailFacade: MailFacade): Promise<void> {
if (moveMailDataPerFolder.length) {
const moveToTargetFolder = assertNotNull(moveMailDataPerFolder.shift())
const mailChunks = splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, moveToTargetFolder.mails)
await promiseMap(mailChunks, (mailChunk) => {
moveToTargetFolder.mails = mailChunk
return mailFacade.moveMails(mailChunk, moveToTargetFolder.targetFolder, null, ClientClassifierType.CUSTOMER_INBOX_RULES)
})
.catch(
ofClass(LockedError, (e) => {
//LockedError should no longer be thrown!?!
console.log("moving mail failed", e, moveToTargetFolder)
}),
)
.catch(
ofClass(PreconditionFailedError, (e) => {
// move mail operation may have been locked by other process
console.log("moving mail failed", e, moveToTargetFolder)
}),
)
.finally(() => {
return processMatchingRules(mailFacade)
})
}
}
const processMatchingRules = throttleStart(THROTTLE_MOVE_MAIL_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => {
// Each target folder requires one request,
// We debounce the requests to a rate of THROTTLE_MOVE_MAIL_SERVICE_REQUESTS_MS
return sendMoveMailRequest(mailFacade)
})
const processNotMatchingRules = debounce(
DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS,
async (mailFacade: MailFacade, processingState: ProcessingState) => {
// Each update to ClientClassifierResultService (for mails that did not move) requires one request
// We debounce the requests to a rate of DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS
if (noRuleMatchMailIds.length) {
const mailIds = noRuleMatchMailIds
noRuleMatchMailIds = []
return mailFacade.updateMailPredictionState(mailIds, processingState)
}
},
)
export function getInboxRuleTypeNameMapping(): SelectorItemList<string> { export function getInboxRuleTypeNameMapping(): SelectorItemList<string> {
return [ return [
{ {
@ -112,10 +61,15 @@ export class InboxRuleHandler {
* Checks the mail for an existing inbox rule and moves the mail to the target folder of the rule. * Checks the mail for an existing inbox rule and moves the mail to the target folder of the rule.
* @returns true if a rule matches otherwise false * @returns true if a rule matches otherwise false
*/ */
async findAndApplyMatchingRule(mailboxDetail: MailboxDetail, mail: Readonly<Mail>, applyRulesOnServer: boolean): Promise<MailFolder | null> { async findAndApplyMatchingRule(
mailboxDetail: MailboxDetail,
mail: Readonly<Mail>,
applyRulesOnServer: boolean,
): Promise<Nullable<{ targetFolder: MailFolder; processInboxDatum: UnencryptedProcessInboxDatum }>> {
const shouldApply = const shouldApply =
mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED || (mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED ||
mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED_AND_DO_NOT_RUN_SPAM_PREDICTION mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED_AND_DO_NOT_RUN_SPAM_PREDICTION) &&
mail.processNeeded
if ( if (
mail._errors || mail._errors ||
@ -128,46 +82,30 @@ export class InboxRuleHandler {
} }
const inboxRule = await _findMatchingRule(this.mailFacade, mail, this.logins.getUserController().props.inboxRules) const inboxRule = await _findMatchingRule(this.mailFacade, mail, this.logins.getUserController().props.inboxRules)
const mailDetails = await this.mailFacade.loadMailDetailsBlob(mail)
if (inboxRule) { if (inboxRule) {
const folders = await this.mailModel.getMailboxFoldersForId(mailboxDetail.mailbox.folders._id) const folders = await this.mailModel.getMailboxFoldersForId(mailboxDetail.mailbox.folders._id)
const targetFolder = folders.getFolderById(elementIdPart(inboxRule.targetFolder)) const targetFolder = folders.getFolderById(elementIdPart(inboxRule.targetFolder))
if (targetFolder && targetFolder.folderType !== MailSetKind.INBOX) { if (targetFolder && targetFolder.folderType !== MailSetKind.INBOX) {
if (applyRulesOnServer) { if (applyRulesOnServer) {
let moveMailData = moveMailDataPerFolder.find((folderMoveMailData) => isSameId(folderMoveMailData.targetFolder, inboxRule.targetFolder)) const processInboxDatum: UnencryptedProcessInboxDatum = {
mailId: mail._id,
if (moveMailData) { targetMoveFolder: targetFolder._id,
moveMailData.mails.push(mail._id) classifierType: ClientClassifierType.CUSTOMER_INBOX_RULES,
} else { vector: await this.mailFacade.vectorizeAndCompressMails({ mail, mailDetails }),
moveMailData = createMoveMailData({
targetFolder: inboxRule.targetFolder,
mails: [mail._id],
excludeMailSet: null,
moveReason: ClientClassifierType.CUSTOMER_INBOX_RULES,
})
moveMailDataPerFolder.push(moveMailData)
} }
return { targetFolder, processInboxDatum }
} else {
// non leader client
return null
} }
processMatchingRules(this.mailFacade)
return targetFolder
} else { } else {
// target folder of inbox rule was deleted
return null return null
} }
} else { } else {
await this.logins.loadCustomizations() // no inbox rule applies to the mail
const isSpamClassificationFeatureEnabled = this.logins.isEnabled(FeatureType.SpamClientClassification)
// we set the processing state to a final state in case the feature is not enabled,
// to not re-classify when the feature gets enabled for the user
let processingState = isSpamClassificationFeatureEnabled
? ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING
: ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE
noRuleMatchMailIds.push(mail._id)
processNotMatchingRules(this.mailFacade, processingState)
return null return null
} }
} }

View file

@ -17,12 +17,12 @@ import { ListLoadingState, ListState } from "../../../common/gui/base/List"
import Stream from "mithril/stream" import Stream from "mithril/stream"
import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/api/common/utils/EntityUpdateUtils" import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/api/common/utils/EntityUpdateUtils"
import { OperationType } from "../../../common/api/common/TutanotaConstants" import { OperationType } from "../../../common/api/common/TutanotaConstants"
import { InboxRuleHandler } from "./InboxRuleHandler"
import { MailModel } from "./MailModel" import { MailModel } from "./MailModel"
import { ListFetchResult } from "../../../common/gui/base/ListUtils" import { ListFetchResult } from "../../../common/gui/base/ListUtils"
import { isOfflineError } from "../../../common/api/common/utils/ErrorUtils" import { isOfflineError } from "../../../common/api/common/utils/ErrorUtils"
import { ExposedCacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache" import { ExposedCacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache"
import { applyInboxRulesToEntries, LoadedMail, MailSetListModel, resolveMailSetEntries } from "./MailSetListModel" import { applyInboxRulesAndSpamPrediction, LoadedMail, MailSetListModel, resolveMailSetEntries } from "./MailSetListModel"
import { ProcessInboxHandler } from "./ProcessInboxHandler"
assertMainOrNode() assertMainOrNode()
@ -41,7 +41,7 @@ export class MailListModel implements MailSetListModel {
private readonly conversationPrefProvider: ConversationPrefProvider, private readonly conversationPrefProvider: ConversationPrefProvider,
private readonly entityClient: EntityClient, private readonly entityClient: EntityClient,
private readonly mailModel: MailModel, private readonly mailModel: MailModel,
private readonly inboxRuleHandler: InboxRuleHandler, private readonly processInboxHandler: ProcessInboxHandler,
private readonly cacheStorage: ExposedCacheStorage, private readonly cacheStorage: ExposedCacheStorage,
) { ) {
this.listModel = new ListModel({ this.listModel = new ListModel({
@ -304,7 +304,7 @@ export class MailListModel implements MailSetListModel {
complete = mailSetEntries.length < count complete = mailSetEntries.length < count
if (mailSetEntries.length > 0) { if (mailSetEntries.length > 0) {
items = await this.resolveMailSetEntries(mailSetEntries, this.defaultMailProvider) items = await this.resolveMailSetEntries(mailSetEntries, this.defaultMailProvider)
items = await this.applyInboxRulesToEntries(items) items = await this.applyInboxRulesAndSpamPrediction(items)
} }
} catch (e) { } catch (e) {
if (isOfflineError(e)) { if (isOfflineError(e)) {
@ -345,11 +345,8 @@ export class MailListModel implements MailSetListModel {
return await this.resolveMailSetEntries(mailSetEntries, (list, elements) => this.cacheStorage.provideMultiple(MailTypeRef, list, elements)) return await this.resolveMailSetEntries(mailSetEntries, (list, elements) => this.cacheStorage.provideMultiple(MailTypeRef, list, elements))
} }
/** private async applyInboxRulesAndSpamPrediction(entries: LoadedMail[]): Promise<LoadedMail[]> {
* Apply inbox rules to an array of mails, returning all mails that were not moved return applyInboxRulesAndSpamPrediction(entries, this.mailSet, this.mailModel, this.processInboxHandler)
*/
private async applyInboxRulesToEntries(entries: LoadedMail[]): Promise<LoadedMail[]> {
return applyInboxRulesToEntries(entries, this.mailSet, this.mailModel, this.inboxRuleHandler)
} }
private async loadSingleMail(id: IdTuple): Promise<LoadedMail> { private async loadSingleMail(id: IdTuple): Promise<LoadedMail> {

View file

@ -5,7 +5,6 @@ import { FolderSystem } from "../../../common/api/common/mail/FolderSystem.js"
import { import {
assertNotNull, assertNotNull,
collectToMap, collectToMap,
downcast,
getFirstOrThrow, getFirstOrThrow,
groupBy, groupBy,
groupByAndMap, groupByAndMap,
@ -34,7 +33,6 @@ import {
MailSetKind, MailSetKind,
MAX_NBR_OF_MAILS_SYNC_OPERATION, MAX_NBR_OF_MAILS_SYNC_OPERATION,
OperationType, OperationType,
ProcessingState,
ReportMovedMailsType, ReportMovedMailsType,
SimpleMoveMailTarget, SimpleMoveMailTarget,
SystemFolderType, SystemFolderType,
@ -49,16 +47,15 @@ import { ProgrammingError } from "../../../common/api/common/error/ProgrammingEr
import { NotAuthorizedError, NotFoundError, PreconditionFailedError } from "../../../common/api/common/error/RestError.js" import { NotAuthorizedError, NotFoundError, PreconditionFailedError } from "../../../common/api/common/error/RestError.js"
import { UserError } from "../../../common/api/main/UserError.js" import { UserError } from "../../../common/api/main/UserError.js"
import { EventController } from "../../../common/api/main/EventController.js" import { EventController } from "../../../common/api/main/EventController.js"
import { InboxRuleHandler } from "./InboxRuleHandler.js"
import { WebsocketConnectivityModel } from "../../../common/misc/WebsocketConnectivityModel.js" import { WebsocketConnectivityModel } from "../../../common/misc/WebsocketConnectivityModel.js"
import { EntityClient } from "../../../common/api/common/EntityClient.js" import { EntityClient } from "../../../common/api/common/EntityClient.js"
import { LoginController } from "../../../common/api/main/LoginController.js" import { LoginController } from "../../../common/api/main/LoginController.js"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js" import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js"
import { assertSystemFolderOfType } from "./MailUtils.js" import { assertSystemFolderOfType } from "./MailUtils.js"
import { TutanotaError } from "@tutao/tutanota-error" import { TutanotaError } from "@tutao/tutanota-error"
import { SpamClassificationHandler } from "./SpamClassificationHandler"
import { isWebClient } from "../../../common/api/common/Env"
import { isExpectedErrorForSynchronization } from "../../../common/api/common/utils/ErrorUtils" import { isExpectedErrorForSynchronization } from "../../../common/api/common/utils/ErrorUtils"
import { ProcessInboxHandler } from "./ProcessInboxHandler"
import { isWebClient } from "../../../common/api/common/Env"
interface MailboxSets { interface MailboxSets {
folders: FolderSystem folders: FolderSystem
@ -95,8 +92,7 @@ export class MailModel {
private readonly logins: LoginController, private readonly logins: LoginController,
private readonly mailFacade: MailFacade, private readonly mailFacade: MailFacade,
private readonly connectivityModel: WebsocketConnectivityModel | null, private readonly connectivityModel: WebsocketConnectivityModel | null,
private spamHandler: () => SpamClassificationHandler | null, private readonly processInboxHandler: () => ProcessInboxHandler,
private readonly inboxRuleHandler: () => InboxRuleHandler | null,
) {} ) {}
// only init listeners once // only init listeners once
@ -116,12 +112,6 @@ export class MailModel {
async init(): Promise<void> { async init(): Promise<void> {
this.initListeners() this.initListeners()
this.mailSets = await this.loadMailSets() this.mailSets = await this.loadMailSets()
await this.logins.loadCustomizations()
const isSpamClassificationFeatureEnabled = this.logins.isEnabled(FeatureType.SpamClientClassification)
if (!isSpamClassificationFeatureEnabled) {
this.spamHandler = () => null
}
} }
private async loadMailSets(): Promise<Map<Id, MailboxSets>> { private async loadMailSets(): Promise<Map<Id, MailboxSets>> {
@ -194,119 +184,57 @@ export class MailModel {
} }
// visibleForTesting // visibleForTesting
async entityEventsReceived(updates: ReadonlyArray<EntityUpdateData>): Promise<{ processingDone: Promise<void> }> { async entityEventsReceived(updates: ReadonlyArray<EntityUpdateData>): Promise<void> {
for (const update of updates) { for (const update of updates) {
if (isUpdateForTypeRef(MailFolderTypeRef, update)) { if (isUpdateForTypeRef(MailFolderTypeRef, update)) {
await this.init() await this.init()
m.redraw() m.redraw()
} else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.UPDATE) {
const mailId: IdTuple = [update.instanceListId, update.instanceId]
const mail = await this.loadMail(mailId)
if (mail == null) {
return { processingDone: Promise.resolve() }
}
const spamHandler = this.spamHandler()
await spamHandler?.updateSpamClassificationData(mail)
} else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.CREATE) { } else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.CREATE) {
const mailId: IdTuple = [update.instanceListId, update.instanceId] const mailId: IdTuple = [update.instanceListId, update.instanceId]
const mail = await this.loadMail(mailId) const mail = await this.loadMail(mailId)
if (mail == null) { if (mail == null) {
return { processingDone: Promise.resolve() } return
} }
// If an inbox rule has been applied or a spam prediction has been made if (!mail.processNeeded) {
// we can return, because those are the two final processing states return
if (
mail.processingState === ProcessingState.INBOX_RULE_APPLIED ||
mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE
) {
return { processingDone: Promise.resolve() }
}
// The webapp currently does not support spam prediction, and the inbox rule has been processed
if (isWebClient() && mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING) {
return { processingDone: Promise.resolve() }
} }
const sourceMailFolder = this.getMailFolderForMail(mail) const sourceMailFolder = this.getMailFolderForMail(mail)
if (sourceMailFolder == null) { if (sourceMailFolder == null) {
return { processingDone: Promise.resolve() } return
} }
const isLeaderClient = this.connectivityModel?.isLeader() ?? false const isLeaderClient = this.connectivityModel?.isLeader() ?? false
if (sourceMailFolder.folderType === MailSetKind.INBOX) { const mailboxDetail = await this.getMailboxDetailsForMail(mail)
const isInboxRuleTargetFolder = await this.getMailboxDetailsForMail(mail).then((mailboxDetail) => { const folderSystem = this.getFolderSystemByGroupId(assertNotNull(mail._ownerGroup))
// We only apply rules on server if we are the leader in case of incoming messages
return mailboxDetail && this.inboxRuleHandler()?.findAndApplyMatchingRule(mailboxDetail, mail, isLeaderClient)
})
if (isWebClient()) { let targetFolder = sourceMailFolder
// we only need to show notifications explicitly on the webapp const isInternalUser = this.logins.getUserController().isInternalUser()
this._showNotification(isInboxRuleTargetFolder ?? sourceMailFolder, mail) if (isLeaderClient && isInternalUser && mailboxDetail && folderSystem) {
} else if (this.spamHandler() != null) { targetFolder = await this.processInboxHandler().handleIncomingMail(mail, sourceMailFolder, mailboxDetail, folderSystem)
const mailDetails = await this.mailFacade.loadMailDetailsBlob(mail) }
this.spamHandler()?.storeTrainingDatum(mail, mailDetails) if (isWebClient()) {
this._showNotification(targetFolder, mail)
if (isInboxRuleTargetFolder) {
return { processingDone: Promise.resolve() }
} else if (
(isLeaderClient && mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED) ||
(mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING && mail.unread)
) {
const folderSystem = this.getFolderSystemByGroupId(assertNotNull(mail._ownerGroup))
if (sourceMailFolder && folderSystem) {
const predictPromise = this.spamHandler()?.predictSpamForNewMail(mail, mailDetails, sourceMailFolder, folderSystem)
return { processingDone: downcast(predictPromise) }
}
}
}
} else if (sourceMailFolder.folderType === MailSetKind.SPAM) {
const mailDetails = await this.mailFacade.loadMailDetailsBlob(mail)
this.spamHandler()?.storeTrainingDatum(mail, mailDetails)
if (
(isLeaderClient && mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED) ||
mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING
) {
const folderSystem = this.getFolderSystemByGroupId(assertNotNull(mail._ownerGroup))
if (sourceMailFolder && folderSystem) {
const predictPromise = this.spamHandler()?.predictSpamForNewMail(mail, mailDetails, sourceMailFolder, folderSystem)
return { processingDone: downcast(predictPromise) }
}
}
} }
} else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.DELETE) {
const mailId: IdTuple = [update.instanceListId, update.instanceId]
await this.spamHandler()?.dropClassificationData(mailId)
} }
} }
return { processingDone: Promise.resolve() }
} }
public async loadMail(mailId: IdTuple): Promise<Nullable<Mail>> { public async loadMail(mailId: IdTuple): Promise<Nullable<Mail>> {
return await this.entityClient.load(MailTypeRef, mailId).catch((e) => { return await this.entityClient.load(MailTypeRef, mailId).catch((e) => {
if (isExpectedErrorForSynchronization(e)) { if (isExpectedErrorForSynchronization(e)) {
console.log(`Could not find mail ${JSON.stringify(mailId)}`) console.log(`could not find mail ${JSON.stringify(mailId)}`)
return null return null
} }
throw e throw e
}) })
} }
async applyInboxRuleToMail(mail: Mail) {
const inboxRuleHandler = this.inboxRuleHandler()
if (inboxRuleHandler) {
const mailboxDetail = await this.getMailboxDetailsForMail(mail)
if (mailboxDetail) {
return inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)
}
}
}
async getMailboxDetailsForMail(mail: Mail): Promise<MailboxDetail | null> { async getMailboxDetailsForMail(mail: Mail): Promise<MailboxDetail | null> {
const detail = await this.mailboxModel.getMailboxDetailsForMailGroup(assertNotNull(mail._ownerGroup)) const detail = await this.mailboxModel.getMailboxDetailsForMailGroup(assertNotNull(mail._ownerGroup))
if (detail == null) { if (detail == null) {
console.warn("Mailbox detail for mail does not exist", mail) console.warn("mailboxDetail for mail does not exist", mail)
} }
return detail return detail
} }
@ -314,7 +242,7 @@ export class MailModel {
async getMailboxDetailsForMailFolder(mailFolder: MailFolder): Promise<MailboxDetail | null> { async getMailboxDetailsForMailFolder(mailFolder: MailFolder): Promise<MailboxDetail | null> {
const detail = await this.mailboxModel.getMailboxDetailsForMailGroup(assertNotNull(mailFolder._ownerGroup)) const detail = await this.mailboxModel.getMailboxDetailsForMailGroup(assertNotNull(mailFolder._ownerGroup))
if (detail == null) { if (detail == null) {
console.warn("Mailbox detail for mail folder does not exist", mailFolder) console.warn("mailbox detail for mail folder does not exist", mailFolder)
} }
return detail return detail
} }
@ -411,7 +339,7 @@ export class MailModel {
* @param targetMailFolderKind * @param targetMailFolderKind
*/ */
async simpleMoveMails(mails: readonly IdTuple[], targetMailFolderKind: SimpleMoveMailTarget): Promise<MovedMails[]> { async simpleMoveMails(mails: readonly IdTuple[], targetMailFolderKind: SimpleMoveMailTarget): Promise<MovedMails[]> {
return await this.mailFacade.simpleMoveMails(mails, targetMailFolderKind, null) return await this.mailFacade.simpleMoveMails(mails, targetMailFolderKind)
} }
getFolderExcludedFromMove(moveMode: MoveMode): SystemFolderType | null { getFolderExcludedFromMove(moveMode: MoveMode): SystemFolderType | null {

View file

@ -6,8 +6,8 @@ import Stream from "mithril/stream"
import { MailModel } from "./MailModel" import { MailModel } from "./MailModel"
import { elementIdPart, getElementId, listIdPart } from "../../../common/api/common/utils/EntityUtils" import { elementIdPart, getElementId, listIdPart } from "../../../common/api/common/utils/EntityUtils"
import { MailSetKind } from "../../../common/api/common/TutanotaConstants" import { MailSetKind } from "../../../common/api/common/TutanotaConstants"
import { groupByAndMap, promiseFilter } from "@tutao/tutanota-utils" import { groupByAndMap, isEmpty, promiseFilter } from "@tutao/tutanota-utils"
import { InboxRuleHandler } from "./InboxRuleHandler" import { ProcessInboxHandler } from "./ProcessInboxHandler"
/** /**
* Interface for retrieving and listing mails * Interface for retrieving and listing mails
@ -274,23 +274,30 @@ export async function provideAllMails(ids: IdTuple[], mailProvider: (listId: Id,
} }
/** /**
* Apply inbox rules to an array of mails, returning all mails that were not moved * Apply inbox rules and run spam prediction on an array of mails, returning all mails that were not moved
*/ */
export async function applyInboxRulesToEntries( export async function applyInboxRulesAndSpamPrediction(
entries: LoadedMail[], entries: LoadedMail[],
mailSet: MailFolder, sourceFolder: MailFolder,
mailModel: MailModel, mailModel: MailModel,
inboxRuleHandler: InboxRuleHandler, processInboxHandler: ProcessInboxHandler,
): Promise<LoadedMail[]> { ): Promise<LoadedMail[]> {
if (mailSet.folderType !== MailSetKind.INBOX || entries.length === 0) { if (isEmpty(entries)) {
return entries return entries
} }
const mailboxDetail = await mailModel.getMailboxDetailsForMailFolder(mailSet) if (!(sourceFolder.folderType === MailSetKind.SPAM || sourceFolder.folderType === MailSetKind.INBOX)) {
return entries
}
const mailboxDetail = await mailModel.getMailboxDetailsForMailFolder(sourceFolder)
if (!mailboxDetail) { if (!mailboxDetail) {
return entries return entries
} }
const folderSystem = mailModel.getFolderSystemByGroupId(mailboxDetail.mailGroup._id)
if (!folderSystem) {
return entries
}
return await promiseFilter(entries, async (entry) => { return await promiseFilter(entries, async (entry) => {
const ruleApplied = await inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, entry.mail, true) const targetFolder = await processInboxHandler.handleIncomingMail(entry.mail, sourceFolder, mailboxDetail, folderSystem)
return ruleApplied == null return sourceFolder.folderType === targetFolder.folderType
}) })
} }

View file

@ -0,0 +1,93 @@
import { SpamClassificationHandler } from "./SpamClassificationHandler"
import { InboxRuleHandler } from "./InboxRuleHandler"
import { Mail, MailFolder, ProcessInboxDatum } from "../../../common/api/entities/tutanota/TypeRefs"
import { FeatureType, MailSetKind } from "../../../common/api/common/TutanotaConstants"
import { assertNotNull, debounce, Nullable } from "@tutao/tutanota-utils"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade"
import { MailboxDetail } from "../../../common/mailFunctionality/MailboxModel"
import { FolderSystem } from "../../../common/api/common/mail/FolderSystem"
import { assertMainOrNode } from "../../../common/api/common/Env"
import { StrippedEntity } from "../../../common/api/common/utils/EntityUtils"
import { LoginController } from "../../../common/api/main/LoginController"
assertMainOrNode()
export type UnencryptedProcessInboxDatum = Omit<StrippedEntity<ProcessInboxDatum>, "encVector" | "ownerEncVectorSessionKey"> & {
vector: Uint8Array
}
const DEFAULT_DEBOUNCE_PROCESS_INBOX_SERVICE_REQUESTS_MS = 1000
export class ProcessInboxHandler {
sendProcessInboxServiceRequest: (mailFacade: MailFacade) => Promise<void>
constructor(
private readonly logins: LoginController,
private readonly mailFacade: MailFacade,
private spamHandler: () => SpamClassificationHandler,
private readonly inboxRuleHandler: () => InboxRuleHandler,
private processedMailsByMailGroup: Map<Id, UnencryptedProcessInboxDatum[]> = new Map(),
private readonly debounceTimeout: number = DEFAULT_DEBOUNCE_PROCESS_INBOX_SERVICE_REQUESTS_MS,
) {
this.sendProcessInboxServiceRequest = debounce(this.debounceTimeout, async (mailFacade: MailFacade) => {
// we debounce the requests to a rate of DEFAULT_DEBOUNCE_PROCESS_INBOX_SERVICE_REQUESTS_MS
if (this.processedMailsByMailGroup.size > 0) {
// copy map to prevent inserting into map while we await the server
const map = this.processedMailsByMailGroup
this.processedMailsByMailGroup = new Map()
for (const [mailGroup, processedMails] of map) {
// send request to server
await mailFacade.processNewMails(mailGroup, processedMails)
}
}
})
}
public async handleIncomingMail(mail: Mail, sourceFolder: MailFolder, mailboxDetail: MailboxDetail, folderSystem: FolderSystem): Promise<MailFolder> {
await this.logins.loadCustomizations()
const isSpamClassificationFeatureEnabled = this.logins.isEnabled(FeatureType.SpamClientClassification)
if (!mail.processNeeded) {
return sourceFolder
}
const mailDetails = await this.mailFacade.loadMailDetailsBlob(mail)
let finalProcessInboxDatum: Nullable<UnencryptedProcessInboxDatum> = null
let moveToFolder: MailFolder = sourceFolder
if (sourceFolder.folderType === MailSetKind.INBOX) {
const result = await this.inboxRuleHandler()?.findAndApplyMatchingRule(mailboxDetail, mail, true)
if (result) {
const { targetFolder, processInboxDatum } = result
finalProcessInboxDatum = processInboxDatum
moveToFolder = targetFolder
}
}
if (finalProcessInboxDatum === null) {
if (isSpamClassificationFeatureEnabled) {
const { targetFolder, processInboxDatum } = await this.spamHandler().predictSpamForNewMail(mail, mailDetails, sourceFolder, folderSystem)
moveToFolder = targetFolder
finalProcessInboxDatum = processInboxDatum
} else {
finalProcessInboxDatum = {
mailId: mail._id,
targetMoveFolder: moveToFolder._id,
classifierType: null,
vector: await this.mailFacade.vectorizeAndCompressMails({ mail, mailDetails }),
}
}
}
const mailGroupId = assertNotNull(mail._ownerGroup)
if (this.processedMailsByMailGroup.has(mailGroupId)) {
this.processedMailsByMailGroup.get(mailGroupId)?.push(finalProcessInboxDatum)
} else {
this.processedMailsByMailGroup.set(mailGroupId, [finalProcessInboxDatum])
}
// noinspection ES6MissingAwait
this.sendProcessInboxServiceRequest(this.mailFacade)
return moveToFolder
}
}

View file

@ -1,186 +1,42 @@
import { createMoveMailData, Mail, MailAddress, MailDetails, MailFolder, MoveMailData } from "../../../common/api/entities/tutanota/TypeRefs" import { Mail, MailDetails, MailFolder } from "../../../common/api/entities/tutanota/TypeRefs"
import { import { MailSetKind } from "../../../common/api/common/TutanotaConstants"
DEFAULT_IS_SPAM, import { SpamClassifier } from "../../workerUtils/spamClassification/SpamClassifier"
DEFAULT_IS_SPAM_CONFIDENCE, import { assertNotNull } from "@tutao/tutanota-utils"
getSpamConfidence,
MailAuthenticationStatus,
MailSetKind,
ProcessingState,
SpamDecision,
} from "../../../common/api/common/TutanotaConstants"
import { SpamClassifier, SpamPredMailDatum, SpamTrainMailDatum } from "../../workerUtils/spamClassification/SpamClassifier"
import { getMailBodyText } from "../../../common/api/common/CommonMailUtils"
import { assertNotNull, debounce, isNotNull, Nullable, ofClass } from "@tutao/tutanota-utils"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade"
import { ClientClassifierType } from "../../../common/api/common/ClientClassifierType"
import { FolderSystem } from "../../../common/api/common/mail/FolderSystem" import { FolderSystem } from "../../../common/api/common/mail/FolderSystem"
import { LockedError, PreconditionFailedError } from "../../../common/api/common/error/RestError" import { assertMainOrNode } from "../../../common/api/common/Env"
import { UnencryptedProcessInboxDatum } from "./ProcessInboxHandler"
import { ClientClassifierType } from "../../../common/api/common/ClientClassifierType"
import { createSpamMailDatum } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { isSameId } from "../../../common/api/common/utils/EntityUtils"
const DEBOUNCE_MOVE_MAIL_SERVICE_REQUESTS_MS = 500 assertMainOrNode()
const DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS = 1000
export class SpamClassificationHandler { export class SpamClassificationHandler {
public constructor( public constructor(private readonly spamClassifier: SpamClassifier) {}
private readonly mailFacade: MailFacade,
private readonly spamClassifier: Nullable<SpamClassifier>,
) {}
hamMoveMailData: MoveMailData | null = null public async predictSpamForNewMail(
spamMoveMailData: MoveMailData | null = null mail: Mail,
classifierResultServiceMailIds: IdTuple[] = [] mailDetails: MailDetails,
sourceFolder: MailFolder,
folderSystem: FolderSystem,
): Promise<{ targetFolder: MailFolder; processInboxDatum: UnencryptedProcessInboxDatum }> {
const spamMailDatum = createSpamMailDatum(mail, mailDetails)
sendClassifierResultServiceRequest = debounce(DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => { const vectorizedMail = await this.spamClassifier.vectorize(spamMailDatum)
// Each update to ClientClassifierResultService (for mails that did not move) requires one request const isSpam = (await this.spamClassifier.predict(vectorizedMail, spamMailDatum.ownerGroup)) ?? null
// We debounce the requests to a rate of DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS
if (this.classifierResultServiceMailIds.length) {
const mailIds = this.classifierResultServiceMailIds
this.classifierResultServiceMailIds = []
return mailFacade.updateMailPredictionState(mailIds, ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE)
}
})
sendMoveMailServiceRequest = debounce(DEBOUNCE_MOVE_MAIL_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => {
// Each update to MoveMailService (for ham or spam mails that did move) requires one request
// We debounce the requests to a rate of DEBOUNCE_MOVE_MAIL_SERVICE_REQUESTS_MS
if (this.hamMoveMailData) {
const moveMailData = this.hamMoveMailData
this.hamMoveMailData = null
await this.sendMoveMailRequest(mailFacade, moveMailData)
}
if (this.spamMoveMailData) {
const moveMailData = this.spamMoveMailData
this.spamMoveMailData = null
await this.sendMoveMailRequest(mailFacade, moveMailData)
}
})
async sendMoveMailRequest(mailFacade: MailFacade, moveMailData: MoveMailData): Promise<void> {
mailFacade
.moveMails(moveMailData.mails, moveMailData.targetFolder, null, ClientClassifierType.CLIENT_CLASSIFICATION)
.catch(
ofClass(LockedError, (e) => {
// LockedError should no longer be thrown!?!
console.log("moving mails failed", e, moveMailData.targetFolder)
}),
)
.catch(
ofClass(PreconditionFailedError, (e) => {
// move mail operation may have been locked by other process
console.log("moving mails failed", e, moveMailData.targetFolder)
}),
)
}
public async predictSpamForNewMail(mail: Mail, mailDetails: MailDetails, sourceFolder: MailFolder, folderSystem: FolderSystem): Promise<MailFolder> {
const spamPredMailDatum: SpamPredMailDatum = {
subject: mail.subject,
body: getMailBodyText(mailDetails.body),
ownerGroup: assertNotNull(mail._ownerGroup),
...extractSpamHeaderFeatures(mail, mailDetails),
}
const isSpam = (await this.spamClassifier?.predict(spamPredMailDatum)) ?? null
let targetFolder = sourceFolder
if (isSpam && sourceFolder.folderType === MailSetKind.INBOX) { if (isSpam && sourceFolder.folderType === MailSetKind.INBOX) {
const spamFolder = assertNotNull(folderSystem.getSystemFolderByType(MailSetKind.SPAM)) targetFolder = assertNotNull(folderSystem.getSystemFolderByType(MailSetKind.SPAM))
if (this.spamMoveMailData) {
this.spamMoveMailData.mails.push(mail._id)
} else {
this.spamMoveMailData = createMoveMailData({
targetFolder: spamFolder?._id,
mails: [mail._id],
excludeMailSet: null,
moveReason: ClientClassifierType.CLIENT_CLASSIFICATION,
})
}
await this.sendMoveMailServiceRequest(this.mailFacade)
return spamFolder
} else if (!isSpam && sourceFolder.folderType === MailSetKind.SPAM) { } else if (!isSpam && sourceFolder.folderType === MailSetKind.SPAM) {
const hamFolder = assertNotNull(folderSystem.getSystemFolderByType(MailSetKind.INBOX)) targetFolder = assertNotNull(folderSystem.getSystemFolderByType(MailSetKind.INBOX))
if (this.hamMoveMailData) {
this.hamMoveMailData.mails.push(mail._id)
} else {
this.hamMoveMailData = createMoveMailData({
targetFolder: hamFolder?._id,
mails: [mail._id],
excludeMailSet: null,
moveReason: ClientClassifierType.CLIENT_CLASSIFICATION,
})
}
await this.sendMoveMailServiceRequest(this.mailFacade)
return hamFolder
} else if (mail.processingState !== ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE) {
this.classifierResultServiceMailIds.push(mail._id)
await this.sendClassifierResultServiceRequest(this.mailFacade)
return sourceFolder
} else {
return sourceFolder
} }
} const processInboxDatum: UnencryptedProcessInboxDatum = {
public async updateSpamClassificationData(mail: Mail) {
if (this.spamClassifier == null || mail.clientSpamClassifierResult == null) {
return
}
const storedClassification = await this.spamClassifier.getSpamClassification(mail._id)
const isSpam = mail.clientSpamClassifierResult.spamDecision === SpamDecision.BLACKLIST
const isSpamConfidence = getSpamConfidence(mail)
if (isNotNull(storedClassification) && (isSpam !== storedClassification.isSpam || isSpamConfidence !== storedClassification.isSpamConfidence)) {
// the model has trained on the mail but the spamFlag was wrong so we refit with higher isSpamConfidence
await this.spamClassifier.updateSpamClassification(mail._id, isSpam, isSpamConfidence)
}
}
public async dropClassificationData(mailId: IdTuple) {
await this.spamClassifier?.deleteSpamClassification(mailId)
}
public async storeTrainingDatum(mail: Mail, mailDetails: MailDetails) {
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id, mailId: mail._id,
subject: mail.subject, targetMoveFolder: targetFolder._id,
body: getMailBodyText(mailDetails.body), classifierType: isSameId(targetFolder._id, sourceFolder._id) ? null : ClientClassifierType.CLIENT_CLASSIFICATION,
isSpam: DEFAULT_IS_SPAM, vector: await this.spamClassifier.vectorizeAndCompress(spamMailDatum),
isSpamConfidence: DEFAULT_IS_SPAM_CONFIDENCE,
ownerGroup: assertNotNull(mail._ownerGroup),
...extractSpamHeaderFeatures(mail, mailDetails),
} }
await this.spamClassifier?.storeSpamClassification(spamTrainMailDatum) return { targetFolder, processInboxDatum: processInboxDatum }
} }
} }
export function extractSpamHeaderFeatures(mail: Mail, mailDetails: MailDetails) {
const sender = joinNamesAndMailAddresses([mail?.sender])
const { toRecipients, ccRecipients, bccRecipients } = extractRecipients(mailDetails)
const authStatus = convertAuthStatusToSpamCategorizationToken(mail.authStatus)
return { sender, toRecipients, ccRecipients, bccRecipients, authStatus }
}
function extractRecipients({ recipients }: MailDetails) {
const toRecipients = joinNamesAndMailAddresses(recipients?.toRecipients)
const ccRecipients = joinNamesAndMailAddresses(recipients?.ccRecipients)
const bccRecipients = joinNamesAndMailAddresses(recipients?.bccRecipients)
return { toRecipients, ccRecipients, bccRecipients }
}
function joinNamesAndMailAddresses(recipients: MailAddress[] | null) {
return recipients?.map((recipient) => `${recipient?.name} ${recipient?.address}`).join(" ") || ""
}
function convertAuthStatusToSpamCategorizationToken(authStatus: string | null): string {
if (authStatus === MailAuthenticationStatus.AUTHENTICATED) {
return "TAUTHENTICATED"
} else if (authStatus === MailAuthenticationStatus.HARD_FAIL) {
return "THARDFAIL"
} else if (authStatus === MailAuthenticationStatus.SOFT_FAIL) {
return "TSOFTFAIL"
} else if (authStatus === MailAuthenticationStatus.INVALID_MAIL_FROM) {
return "TINVALIDMAILFROM"
} else if (authStatus === MailAuthenticationStatus.MISSING_MAIL_FROM) {
return "TMISSINGMAILFROM"
}
return ""
}

View file

@ -0,0 +1,33 @@
import { LoggedInEvent, PostLoginAction } from "../../../common/api/main/LoginController"
import { SpamClassifier } from "../../workerUtils/spamClassification/SpamClassifier"
import { FeatureType } from "../../../common/api/common/TutanotaConstants"
import { CustomerFacade } from "../../../common/api/worker/facades/lazy/CustomerFacade"
import { filterMailMemberships } from "../../../common/api/common/utils/IndexUtils"
import { assertNotNull } from "@tutao/tutanota-utils"
import { isInternalUser } from "../../../common/api/common/utils/UserUtils"
/**
* Initialize SpamClassifier if FeatureType.SpamClientClassification feature is enabled for the customer.
*/
export class SpamClassificationPostLoginAction implements PostLoginAction {
constructor(
private readonly spamClassifier: SpamClassifier,
private readonly customerFacade: CustomerFacade,
) {}
async onPartialLoginSuccess(_: LoggedInEvent): Promise<void> {}
async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {
await this.customerFacade.loadCustomizations()
const isSpamClassificationEnabled = await this.customerFacade.isEnabled(FeatureType.SpamClientClassification)
const user = assertNotNull(await this.customerFacade.getUser())
if (isSpamClassificationEnabled && isInternalUser(user) && this.spamClassifier) {
const ownerGroups = filterMailMemberships(user)
for (const ownerGroup of ownerGroups) {
this.spamClassifier.initialize(ownerGroup.group).catch((e) => {
console.log(`failed to initialize spam classification model for group: ${ownerGroup.group}`, e)
})
}
}
}
}

View file

@ -26,7 +26,6 @@ import { NotAuthorizedError, NotFoundError, PreconditionFailedError } from "../.
import { UserError } from "../../../common/api/main/UserError.js" import { UserError } from "../../../common/api/main/UserError.js"
import { ProgrammingError } from "../../../common/api/common/error/ProgrammingError.js" import { ProgrammingError } from "../../../common/api/common/error/ProgrammingError.js"
import Stream from "mithril/stream" import Stream from "mithril/stream"
import { InboxRuleHandler } from "../model/InboxRuleHandler.js"
import { Router } from "../../../common/gui/ScopedRouter.js" import { Router } from "../../../common/gui/ScopedRouter.js"
import { EntityUpdateData, isUpdateForTypeRef, PrefetchStatus } from "../../../common/api/common/utils/EntityUpdateUtils.js" import { EntityUpdateData, isUpdateForTypeRef, PrefetchStatus } from "../../../common/api/common/utils/EntityUpdateUtils.js"
import { EventController } from "../../../common/api/main/EventController.js" import { EventController } from "../../../common/api/main/EventController.js"
@ -40,6 +39,7 @@ import { MailSetListModel } from "../model/MailSetListModel"
import { ConversationListModel } from "../model/ConversationListModel" import { ConversationListModel } from "../model/ConversationListModel"
import { MailListDisplayMode } from "../../../common/misc/DeviceConfig" import { MailListDisplayMode } from "../../../common/misc/DeviceConfig"
import { client } from "../../../common/misc/ClientDetector" import { client } from "../../../common/misc/ClientDetector"
import { ProcessInboxHandler } from "../model/ProcessInboxHandler"
export interface MailOpenedListener { export interface MailOpenedListener {
onEmailOpened(mail: Mail): unknown onEmailOpened(mail: Mail): unknown
@ -98,7 +98,7 @@ export class MailViewModel {
private readonly conversationViewModelFactory: ConversationViewModelFactory, private readonly conversationViewModelFactory: ConversationViewModelFactory,
private readonly mailOpenedListener: MailOpenedListener, private readonly mailOpenedListener: MailOpenedListener,
private readonly conversationPrefProvider: ConversationPrefProvider, private readonly conversationPrefProvider: ConversationPrefProvider,
private readonly inboxRuleHandler: InboxRuleHandler, private readonly processInboxHandler: ProcessInboxHandler,
private readonly router: Router, private readonly router: Router,
private readonly updateUi: () => unknown, private readonly updateUi: () => unknown,
) {} ) {}
@ -258,8 +258,6 @@ export class MailViewModel {
return return
} }
if (cached) { if (cached) {
// Mails opened through the notification were not getting the inbox rule applied to them, so we apply it here
this.mailModel.applyInboxRuleToMail(cached)
console.log(TAG, "displaying cached mail", mailId) console.log(TAG, "displaying cached mail", mailId)
await this.displayExplicitMailTarget(cached) await this.displayExplicitMailTarget(cached)
} }
@ -526,7 +524,7 @@ export class MailViewModel {
this.conversationPrefProvider, this.conversationPrefProvider,
this.entityClient, this.entityClient,
this.mailModel, this.mailModel,
this.inboxRuleHandler, this.processInboxHandler,
this.cacheStorage, this.cacheStorage,
) )
} else { } else {
@ -535,7 +533,7 @@ export class MailViewModel {
this.conversationPrefProvider, this.conversationPrefProvider,
this.entityClient, this.entityClient,
this.mailModel, this.mailModel,
this.inboxRuleHandler, this.processInboxHandler,
this.cacheStorage, this.cacheStorage,
) )
} }

View file

@ -156,6 +156,7 @@ import { AutosaveFacade } from "../common/api/worker/facades/lazy/AutosaveFacade
import { lang } from "../common/misc/LanguageViewModel.js" import { lang } from "../common/misc/LanguageViewModel.js"
import { SpamClassificationHandler } from "./mail/model/SpamClassificationHandler" import { SpamClassificationHandler } from "./mail/model/SpamClassificationHandler"
import { SpamClassifier } from "./workerUtils/spamClassification/SpamClassifier" import { SpamClassifier } from "./workerUtils/spamClassification/SpamClassifier"
import { ProcessInboxHandler } from "./mail/model/ProcessInboxHandler"
import type { QuickActionsModel } from "../common/misc/quickactions/QuickActionsModel" import type { QuickActionsModel } from "../common/misc/quickactions/QuickActionsModel"
assertMainOrNode() assertMainOrNode()
@ -223,7 +224,7 @@ class MailLocator implements CommonLocator {
bulkMailLoader!: BulkMailLoader bulkMailLoader!: BulkMailLoader
mailExportFacade!: MailExportFacade mailExportFacade!: MailExportFacade
syncTracker!: SyncTracker syncTracker!: SyncTracker
spamClassifier: SpamClassifier | null = null spamClassifier!: SpamClassifier
whitelabelThemeGenerator!: WhitelabelThemeGenerator whitelabelThemeGenerator!: WhitelabelThemeGenerator
autosaveFacade!: AutosaveFacade autosaveFacade!: AutosaveFacade
@ -287,7 +288,7 @@ class MailLocator implements CommonLocator {
conversationViewModelFactory, conversationViewModelFactory,
this.mailOpenedListener, this.mailOpenedListener,
deviceConfig, deviceConfig,
this.inboxRuleHandler(), this.processInboxHandler(),
router, router,
await this.redraw(), await this.redraw(),
) )
@ -303,7 +304,11 @@ class MailLocator implements CommonLocator {
}) })
readonly spamClassificationHandler = lazyMemoized(() => { readonly spamClassificationHandler = lazyMemoized(() => {
return new SpamClassificationHandler(this.mailFacade, this.spamClassifier) return new SpamClassificationHandler(this.spamClassifier)
})
readonly processInboxHandler = lazyMemoized(() => {
return new ProcessInboxHandler(this.logins, this.mailFacade, this.spamClassificationHandler, this.inboxRuleHandler)
}) })
async searchViewModelFactory(): Promise<() => SearchViewModel> { async searchViewModelFactory(): Promise<() => SearchViewModel> {
@ -847,8 +852,7 @@ class MailLocator implements CommonLocator {
this.logins, this.logins,
this.mailFacade, this.mailFacade,
this.connectivityModel, this.connectivityModel,
this.spamClassificationHandler, this.processInboxHandler,
this.inboxRuleHandler,
) )
this.operationProgressTracker = new OperationProgressTracker() this.operationProgressTracker = new OperationProgressTracker()
this.infoMessageHandler = new InfoMessageHandler((state: SearchIndexStateInfo) => { this.infoMessageHandler = new InfoMessageHandler((state: SearchIndexStateInfo) => {
@ -879,6 +883,7 @@ class MailLocator implements CommonLocator {
this.usageTestController = new UsageTestController(this.usageTestModel) this.usageTestController = new UsageTestController(this.usageTestModel)
this.Const = Const this.Const = Const
this.whitelabelThemeGenerator = new WhitelabelThemeGenerator() this.whitelabelThemeGenerator = new WhitelabelThemeGenerator()
this.spamClassifier = spamClassifier
if (!isBrowser()) { if (!isBrowser()) {
const { WebDesktopFacade } = await import("../common/native/main/WebDesktopFacade") const { WebDesktopFacade } = await import("../common/native/main/WebDesktopFacade")
const { WebMobileFacade } = await import("../common/native/main/WebMobileFacade.js") const { WebMobileFacade } = await import("../common/native/main/WebMobileFacade.js")
@ -895,10 +900,9 @@ class MailLocator implements CommonLocator {
return await this.calendarEventModel(mode, getEventWithDefaultTimes(setNextHalfHour(new Date(date))), mailboxDetail, mailboxProperties, null) return await this.calendarEventModel(mode, getEventWithDefaultTimes(setNextHalfHour(new Date(date))), mailboxDetail, mailboxProperties, null)
}) })
const { OpenSettingsHandler } = await import("../common/native/main/OpenSettingsHandler.js") const { OpenSettingsHandler } = await import("../common/native/main/OpenSettingsHandler.js")
const openSettingsHandler = new OpenSettingsHandler(this.logins)
const openSettingsHandler = new OpenSettingsHandler(this.logins)
this.webMobileFacade = new WebMobileFacade(this.connectivityModel, MAIL_PREFIX) this.webMobileFacade = new WebMobileFacade(this.connectivityModel, MAIL_PREFIX)
this.spamClassifier = spamClassifier
this.nativeInterfaces = createNativeInterfaces( this.nativeInterfaces = createNativeInterfaces(
this.webMobileFacade, this.webMobileFacade,

View file

@ -1,47 +0,0 @@
import { LoggedInEvent, PostLoginAction } from "../../../common/api/main/LoginController"
import { OfflineStorageSettingsModel } from "../../../common/offline/OfflineStorageSettingsModel"
import { Indexer } from "../../workerUtils/index/Indexer"
import { SessionType } from "../../../common/api/common/SessionType"
import { SpamClassifier } from "../../workerUtils/spamClassification/SpamClassifier"
import { FeatureType } from "../../../common/api/common/TutanotaConstants"
import { CustomerFacade } from "../../../common/api/worker/facades/lazy/CustomerFacade"
import { filterMailMemberships } from "../../../common/api/common/utils/IndexUtils"
import { assertNotNull } from "@tutao/tutanota-utils"
/**
* The search range is tied to the offline storage settings.
* This updates the mail index on full login.
* And also initialize spamClassification if enabled
*/
export class MailIndexAndSpamClassificationPostLoginAction implements PostLoginAction {
constructor(
private readonly offlineStorageSettings: OfflineStorageSettingsModel,
private readonly indexer: Indexer,
private readonly spamClassifier: SpamClassifier | null,
private readonly customerFacade: CustomerFacade,
) {}
async onPartialLoginSuccess(event: LoggedInEvent): Promise<void> {
if (event.sessionType === SessionType.Persistent) {
await this.offlineStorageSettings.init()
// noinspection ES6MissingAwait
this.indexer.resizeMailIndex(this.offlineStorageSettings.getTimeRange().getTime()).then(async () => {
// spamClassification
// Wait until indexing is done, as its populate offlineDb
await this.customerFacade.loadCustomizations()
if (this.spamClassifier && (await this.customerFacade.isEnabled(FeatureType.SpamClientClassification))) {
const ownerGroups = filterMailMemberships(assertNotNull(await this.customerFacade.getUser()))
for (const ownerGroup of ownerGroups) {
this.spamClassifier.initialize(ownerGroup.group).catch((e) => {
console.log(`Failed to initialize spam classification model for group: ${ownerGroup._id}::${ownerGroup.group}. With reason:`)
console.log(e)
})
}
}
})
}
}
async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {}
}

View file

@ -0,0 +1,25 @@
import { LoggedInEvent, PostLoginAction } from "../../../common/api/main/LoginController"
import { OfflineStorageSettingsModel } from "../../../common/offline/OfflineStorageSettingsModel"
import { Indexer } from "../../workerUtils/index/Indexer"
import { SessionType } from "../../../common/api/common/SessionType"
/**
* The search range is tied to the offline storage settings.
* This updates the mail index on full login.
*/
export class MailIndexerPostLoginAction implements PostLoginAction {
constructor(
private readonly offlineStorageSettings: OfflineStorageSettingsModel,
private readonly indexer: Indexer,
) {}
async onPartialLoginSuccess(event: LoggedInEvent): Promise<void> {
if (event.sessionType === SessionType.Persistent) {
await this.offlineStorageSettings.init()
// noinspection ES6MissingAwait
this.indexer.resizeMailIndex(this.offlineStorageSettings.getTimeRange().getTime())
}
}
async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {}
}

View file

@ -10,8 +10,6 @@ import { htmlToText } from "../../../common/api/common/utils/IndexUtils"
import { getMailBodyText } from "../../../common/api/common/CommonMailUtils" import { getMailBodyText } from "../../../common/api/common/CommonMailUtils"
import { ListElementEntity } from "../../../common/api/common/EntityTypes" import { ListElementEntity } from "../../../common/api/common/EntityTypes"
import type { OfflineStorageTable } from "../../../common/api/worker/offline/OfflineStorage" import type { OfflineStorageTable } from "../../../common/api/worker/offline/OfflineStorage"
import { SpamClassificationModel, SpamTrainMailDatum } from "../spamClassification/SpamClassifier"
import { Nullable } from "@tutao/tutanota-utils/dist/Utils"
export const SearchTableDefinitions: Record<string, OfflineStorageTable> = Object.freeze({ export const SearchTableDefinitions: Record<string, OfflineStorageTable> = Object.freeze({
search_group_data: { search_group_data: {
@ -66,26 +64,6 @@ export const SearchTableDefinitions: Record<string, OfflineStorageTable> = Objec
}, },
}) })
export const SpamClassificationDefinitions: Record<string, OfflineStorageTable> = Object.freeze({
spam_classification_training_data: {
definition:
"CREATE TABLE IF NOT EXISTS spam_classification_training_data (listId TEXT NOT NULL, elementId TEXT NOT NULL," +
"ownerGroup TEXT NOT NULL, subject TEXT NOT NULL, body TEXT NOT NULL, isSpam NUMBER," +
"lastModified NUMBER NOT NULL, isSpamConfidence NUMBER NOT NULL, sender TEXT NOT NULL," +
"toRecipients TEXT NOT NULL, ccRecipients TEXT NOT NULL, bccRecipients TEXT NOT NULL," +
"authStatus TEXT NOT NULL, PRIMARY KEY (listId, elementId))",
purgedWithCache: true,
},
// TODO add test for new table
spam_classification_model: {
definition:
"CREATE TABLE IF NOT EXISTS spam_classification_model (version NUMBER NOT NULL, ownerGroup TEXT NOT NULL, modelTopology TEXT NOT NULL, weightSpecs TEXT NOT NULL, weightData BLOB NOT NULL, PRIMARY KEY(version, ownerGroup))",
purgedWithCache: true,
},
})
export interface IndexedGroupData { export interface IndexedGroupData {
groupId: Id groupId: Id
type: GroupType type: GroupType
@ -187,127 +165,6 @@ export class OfflineStoragePersistence {
} }
} }
async storeSpamClassification(spamTrainMailDatum: SpamTrainMailDatum): Promise<void> {
const { query, params } = sql`
INSERT
OR REPLACE INTO spam_classification_training_data(listId, elementId, ownerGroup, subject, body, isSpam,
lastModified, isSpamConfidence, sender, toRecipients, ccRecipients, bccRecipients, authStatus)
VALUES (
${listIdPart(spamTrainMailDatum.mailId)},
${elementIdPart(spamTrainMailDatum.mailId)},
${spamTrainMailDatum.ownerGroup},
${spamTrainMailDatum.subject},
${spamTrainMailDatum.body},
${spamTrainMailDatum.isSpam ? 1 : 0},
${Date.now()},
${spamTrainMailDatum.isSpamConfidence},
${spamTrainMailDatum.sender},
${spamTrainMailDatum.toRecipients},
${spamTrainMailDatum.ccRecipients},
${spamTrainMailDatum.bccRecipients},
${spamTrainMailDatum.authStatus}
)`
await this.sqlCipherFacade.run(query, params)
}
async deleteSpamClassification(mailId: IdTuple): Promise<void> {
const mailListId = listIdPart(mailId)
const mailElementId = elementIdPart(mailId)
const { query, params } = sql`
DELETE
FROM spam_classification_training_data
where listId = ${mailListId}
AND elementId = ${mailElementId}`
await this.sqlCipherFacade.run(query, params)
}
async deleteSpamClassificationTrainingDataBeforeCutoff(cutoffTimestamp: number, ownerGroupId: Id): Promise<void> {
const { query, params } = sql`DELETE
FROM spam_classification_training_data
WHERE lastModified < ${cutoffTimestamp}
AND ownerGroup = ${ownerGroupId}`
await this.sqlCipherFacade.run(query, params)
}
async updateSpamClassification(mailId: IdTuple, isSpam: boolean, isSpamConfidence: number): Promise<void> {
const { query, params } = sql`
UPDATE spam_classification_training_data
SET lastModified=${Date.now()},
isSpamConfidence=${isSpamConfidence},
isSpam=${isSpam ? 1 : 0}
WHERE listId = ${listIdPart(mailId)}
AND elementId = ${elementIdPart(mailId)}
`
await this.sqlCipherFacade.run(query, params)
}
async getSpamClassification(mailId: IdTuple): Promise<Nullable<{ isSpam: boolean; isSpamConfidence: number }>> {
const { query, params } = sql`
SELECT isSpam, isSpamConfidence
FROM spam_classification_training_data
where listId = ${listIdPart(mailId)}
AND elementId = ${elementIdPart(mailId)} `
const result = await this.sqlCipherFacade.get(query, params)
if (!result) {
return null
} else {
const isSpam = untagSqlObject(result).isSpam === 1
const isSpamConfidence = untagSqlObject(result).isSpamConfidence as number
return { isSpam, isSpamConfidence }
}
}
async getCertainSpamClassificationTrainingDataAfterCutoff(cutoffTimestamp: number, ownerGroupId: Id): Promise<SpamTrainMailDatum[]> {
const { query, params } = sql`SELECT listId,
elementId,
subject,
body,
isSpam,
isSpamConfidence,
sender,
toRecipients,
ccRecipients,
bccRecipients,
authStatus
FROM spam_classification_training_data
WHERE lastModified > ${cutoffTimestamp}
AND isSpamConfidence > 0
AND ownerGroup = ${ownerGroupId}`
const resultRows = await this.sqlCipherFacade.all(query, params)
return resultRows.map(untagSqlObject).map((row) => row as unknown as SpamTrainMailDatum)
}
async putSpamClassificationModel(model: SpamClassificationModel) {
const { query, params } = sql`INSERT
OR REPLACE INTO
spam_classification_model VALUES (
${1},
${model.ownerGroup},
${model.modelTopology},
${model.weightSpecs},
${model.weightData}
)`
await this.sqlCipherFacade.run(query, params)
}
async getSpamClassificationModel(ownerGroup: Id): Promise<Nullable<SpamClassificationModel>> {
const { query, params } = sql`SELECT modelTopology, weightSpecs, weightData, ownerGroup
FROM spam_classification_model
WHERE version = ${1}
AND ownerGroup = ${ownerGroup}`
const resultRows = await this.sqlCipherFacade.get(query, params)
if (resultRows !== null) {
const untaggedValue = untagSqlObject(resultRows)
return {
modelTopology: untaggedValue.modelTopology,
weightSpecs: untaggedValue.weightSpecs,
weightData: untaggedValue.weightData,
ownerGroup: untaggedValue.ownerGroup,
} as SpamClassificationModel
}
return null
}
async updateMailLocation(mail: Mail) { async updateMailLocation(mail: Mail) {
const rowid = await this.getRowid(MailTypeRef, mail._id) const rowid = await this.getRowid(MailTypeRef, mail._id)
if (rowid == null) { if (rowid == null) {

View file

@ -1,5 +1,6 @@
import { arrayHashUnsigned, downcast, promiseMap, stringToUtf8Uint8Array } from "@tutao/tutanota-utils" import { arrayHashUnsigned, downcast, promiseMap, stringToUtf8Uint8Array } from "@tutao/tutanota-utils"
import { stringToHashBucketFast, tensor1d } from "./tensorflow-custom" import { stringToHashBucketFast, tensor1d } from "./tensorflow-custom"
import { MAX_WORD_FREQUENCY } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
export class HashingVectorizer { export class HashingVectorizer {
private readonly hasher: (tokens: Array<string>) => Promise<Array<number>> = this.tensorHash private readonly hasher: (tokens: Array<string>) => Promise<Array<number>> = this.tensorHash
@ -11,7 +12,9 @@ export class HashingVectorizer {
const indexes = await this.hasher(downcast<Array<string>>(tokens)) const indexes = await this.hasher(downcast<Array<string>>(tokens))
for (const index of indexes) { for (const index of indexes) {
vector[index] += 1 if (vector[index] < MAX_WORD_FREQUENCY) {
vector[index] += 1
}
} }
return vector return vector

View file

@ -0,0 +1,241 @@
import { EntityClient } from "../../../common/api/common/EntityClient"
import { assertNotNull, isEmpty, isNotNull, last, lazyAsync, promiseMap } from "@tutao/tutanota-utils"
import {
ClientSpamTrainingDatum,
ClientSpamTrainingDatumIndexEntryTypeRef,
ClientSpamTrainingDatumTypeRef,
MailBag,
MailBox,
MailboxGroupRootTypeRef,
MailBoxTypeRef,
MailFolder,
MailFolderTypeRef,
MailTypeRef,
PopulateClientSpamTrainingDatum,
} from "../../../common/api/entities/tutanota/TypeRefs"
import { getMailSetKind, isFolder, MailSetKind, SpamDecision } from "../../../common/api/common/TutanotaConstants"
import { GENERATED_MIN_ID, getElementId, isSameId, StrippedEntity, timestampToGeneratedId } from "../../../common/api/common/utils/EntityUtils"
import { BulkMailLoader, MailWithMailDetails } from "../index/BulkMailLoader"
import { hasError } from "../../../common/api/common/utils/ErrorUtils"
import { getSpamConfidence } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade"
//Visible for testing
export const SINGLE_TRAIN_INTERVAL_TRAINING_DATA_LIMIT = 1000
const INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS = 90
const TRAINING_DATA_TIME_LIMIT: number = INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS * -1
export type TrainingDataset = {
trainingData: ClientSpamTrainingDatum[]
lastTrainingDataIndexId: Id
hamCount: number
spamCount: number
}
export type UnencryptedPopulateClientSpamTrainingDatum = Omit<StrippedEntity<PopulateClientSpamTrainingDatum>, "encVector" | "ownerEncVectorSessionKey"> & {
vector: Uint8Array
}
export class SpamClassificationDataDealer {
constructor(
private readonly entityClient: EntityClient,
private readonly bulkMailLoader: lazyAsync<BulkMailLoader>,
private readonly mailFacade: lazyAsync<MailFacade>,
) {}
public async fetchAllTrainingData(ownerGroup: Id): Promise<TrainingDataset> {
const mailboxGroupRoot = await this.entityClient.load(MailboxGroupRootTypeRef, ownerGroup)
const mailbox = await this.entityClient.load(MailBoxTypeRef, mailboxGroupRoot.mailbox)
const mailSets = await this.entityClient.loadAll(MailFolderTypeRef, assertNotNull(mailbox.folders).folders)
if (mailbox.clientSpamTrainingData == null || mailbox.modifiedClientSpamTrainingDataIndex == null) {
return { trainingData: [], lastTrainingDataIndexId: GENERATED_MIN_ID, hamCount: 0, spamCount: 0 }
}
// clientSpamTrainingData is NOT cached
let clientSpamTrainingData = await this.entityClient.loadAll(ClientSpamTrainingDatumTypeRef, mailbox.clientSpamTrainingData)
// if the training data is empty for this mailbox, we are aggregating
// the last INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS of mails and uploading the training data
if (isEmpty(clientSpamTrainingData)) {
console.log("building and uploading initial training data for mailbox: " + mailbox._id)
const mailsWithMailDetails = await this.fetchMailAndMailDetailsForMailbox(mailbox, mailSets)
console.log(`mailbox has ${mailsWithMailDetails.length} mails suitable for encrypted training vector data upload`)
console.log(`vectorizing, compressing and encrypting those ${mailsWithMailDetails.length} mails...`)
await this.uploadTrainingDataForMails(mailsWithMailDetails, mailbox, mailSets)
clientSpamTrainingData = await this.entityClient.loadAll(ClientSpamTrainingDatumTypeRef, mailbox.clientSpamTrainingData)
console.log(`clientSpamTrainingData list on the mailbox has ${clientSpamTrainingData.length} members.`)
}
const { subsampledTrainingData, hamCount, spamCount } = this.subsampleHamAndSpamMails(clientSpamTrainingData)
const modifiedClientSpamTrainingDataIndices = await this.entityClient.loadAll(
ClientSpamTrainingDatumIndexEntryTypeRef,
mailbox.modifiedClientSpamTrainingDataIndex,
)
const lastModifiedClientSpamTrainingDataIndexElementId = isEmpty(modifiedClientSpamTrainingDataIndices)
? GENERATED_MIN_ID
: getElementId(assertNotNull(last(modifiedClientSpamTrainingDataIndices)))
return {
trainingData: subsampledTrainingData,
lastTrainingDataIndexId: lastModifiedClientSpamTrainingDataIndexElementId,
hamCount,
spamCount,
}
}
async fetchPartialTrainingDataFromIndexStartId(indexStartId: Id, ownerGroup: Id): Promise<TrainingDataset> {
const mailboxGroupRoot = await this.entityClient.load(MailboxGroupRootTypeRef, ownerGroup)
const mailbox = await this.entityClient.load(MailBoxTypeRef, mailboxGroupRoot.mailbox)
const emptyResult = { trainingData: [], lastTrainingDataIndexId: indexStartId, hamCount: 0, spamCount: 0 }
if (mailbox.clientSpamTrainingData == null || mailbox.modifiedClientSpamTrainingDataIndex == null) {
return emptyResult
}
const modifiedClientSpamTrainingDataIndicesSinceStart = await this.entityClient.loadRange(
ClientSpamTrainingDatumIndexEntryTypeRef,
mailbox.modifiedClientSpamTrainingDataIndex,
indexStartId,
SINGLE_TRAIN_INTERVAL_TRAINING_DATA_LIMIT,
false,
)
if (isEmpty(modifiedClientSpamTrainingDataIndicesSinceStart)) {
return emptyResult
}
const clientSpamTrainingData = await this.entityClient.loadMultiple(
ClientSpamTrainingDatumTypeRef,
mailbox.clientSpamTrainingData,
modifiedClientSpamTrainingDataIndicesSinceStart.map((index) => index.clientSpamTrainingDatumElementId),
)
const { subsampledTrainingData, hamCount, spamCount } = this.subsampleHamAndSpamMails(clientSpamTrainingData)
return {
trainingData: subsampledTrainingData,
lastTrainingDataIndexId: getElementId(assertNotNull(last(modifiedClientSpamTrainingDataIndicesSinceStart))),
hamCount,
spamCount,
}
}
// Visible for testing
subsampleHamAndSpamMails(clientSpamTrainingData: ClientSpamTrainingDatum[]): {
subsampledTrainingData: ClientSpamTrainingDatum[]
hamCount: number
spamCount: number
} {
// we always want to include clientSpamTrainingData with high confidence (usually 4), because these mails have been moved explicitly by the user
const hamDataHighConfidence = clientSpamTrainingData.filter((d) => Number(d.confidence) > 1 && d.spamDecision === SpamDecision.WHITELIST)
const spamDataHighConfidence = clientSpamTrainingData.filter((d) => Number(d.confidence) > 1 && d.spamDecision === SpamDecision.BLACKLIST)
const hamDataLowConfidence = clientSpamTrainingData.filter((d) => Number(d.confidence) === 1 && d.spamDecision === SpamDecision.WHITELIST)
const spamDataLowConfidence = clientSpamTrainingData.filter((d) => Number(d.confidence) === 1 && d.spamDecision === SpamDecision.BLACKLIST)
const hamCount = hamDataHighConfidence.length + hamDataLowConfidence.length
const spamCount = spamDataHighConfidence.length + spamDataLowConfidence.length
if (hamCount === 0 || spamCount === 0) {
return { subsampledTrainingData: clientSpamTrainingData, hamCount, spamCount }
}
const ratio = hamCount / spamCount
const MAX_RATIO = 10
const MIN_RATIO = 1 / 10
let sampledHamLowConfidence = hamDataLowConfidence
let sampledSpamLowConfidence = spamDataLowConfidence
if (ratio > MAX_RATIO) {
const targetHamCount = Math.floor(spamCount * MAX_RATIO)
sampledHamLowConfidence = this.sampleEntriesFromArray(hamDataLowConfidence, targetHamCount)
} else if (ratio < MIN_RATIO) {
const targetSpamCount = Math.floor(hamCount * MAX_RATIO)
sampledSpamLowConfidence = this.sampleEntriesFromArray(spamDataLowConfidence, targetSpamCount)
}
const finalHam = hamDataHighConfidence.concat(sampledHamLowConfidence)
const finalSpam = spamDataHighConfidence.concat(sampledSpamLowConfidence)
const balanced = [...finalHam, ...finalSpam]
console.log(
`Subsampled training data to ${finalHam.length} ham (${hamDataHighConfidence.length} are confidence > 1) and ${finalSpam.length} spam (${spamDataHighConfidence.length} are confidence > 1) (ratio ${(finalHam.length / finalSpam.length).toFixed(2)}).`,
)
return { subsampledTrainingData: balanced, hamCount: finalHam.length, spamCount: finalSpam.length }
}
// Visible for testing
async fetchMailsByMailbagAfterDate(mailbag: MailBag, mailSets: MailFolder[], startDate: Date): Promise<Array<MailWithMailDetails>> {
const bulkMailLoader = await this.bulkMailLoader()
const mails = await this.entityClient.loadAll(MailTypeRef, mailbag.mails, timestampToGeneratedId(startDate.getTime()))
const filteredMails = mails.filter((mail) => {
const trashFolder = assertNotNull(mailSets.find((set) => getMailSetKind(set) === MailSetKind.TRASH))
const isMailTrashed = mail.sets.some((setId) => isSameId(setId, trashFolder._id))
return isNotNull(mail.mailDetails) && !hasError(mail) && mail.receivedDate > startDate && !isMailTrashed
})
const mailsWithMailDetails = await bulkMailLoader.loadMailDetails(filteredMails)
return mailsWithMailDetails ?? []
}
private async fetchMailAndMailDetailsForMailbox(mailbox: MailBox, mailSets: MailFolder[]): Promise<Array<MailWithMailDetails>> {
const downloadedMailClassificationData = new Array<MailWithMailDetails>()
const { LocalTimeDateProvider } = await import("../../../common/api/worker/DateProvider")
const startDate = new LocalTimeDateProvider().getStartOfDayShiftedBy(TRAINING_DATA_TIME_LIMIT)
// sorted from latest to oldest
const mailbagsToFetch = [assertNotNull(mailbox.currentMailBag), ...mailbox.archivedMailBags.reverse()]
for (let currentMailbag = mailbagsToFetch.shift(); isNotNull(currentMailbag); currentMailbag = mailbagsToFetch.shift()) {
const mailsOfThisMailbag = await this.fetchMailsByMailbagAfterDate(currentMailbag, mailSets, startDate)
if (isEmpty(mailsOfThisMailbag)) {
// the list is empty if none of the mails in the mailbag were recent enough,
// therefore, there is no point in requesting the remaining mailbags unnecessarily
break
}
downloadedMailClassificationData.push(...mailsOfThisMailbag)
}
return downloadedMailClassificationData
}
private async uploadTrainingDataForMails(mails: MailWithMailDetails[], mailBox: MailBox, mailSets: MailFolder[]): Promise<void> {
const clientSpamTrainingDataListId = mailBox.clientSpamTrainingData
if (clientSpamTrainingDataListId == null) {
return
}
const unencryptedPopulateClientSpamTrainingData: UnencryptedPopulateClientSpamTrainingDatum[] = await promiseMap(
mails,
async (mailWithDetail) => {
const { mail, mailDetails } = mailWithDetail
const allMailFolders = mailSets.filter((mailSet) => isFolder(mailSet)).map((mailFolder) => mailFolder._id)
const sourceMailFolderId = assertNotNull(mail.sets.find((setId) => allMailFolders.find((folderId) => isSameId(setId, folderId))))
const sourceMailFolder = assertNotNull(mailSets.find((set) => isSameId(set._id, sourceMailFolderId)))
const isSpam = getMailSetKind(sourceMailFolder) === MailSetKind.SPAM
const unencryptedPopulateClientSpamTrainingData: UnencryptedPopulateClientSpamTrainingDatum = {
mailId: mail._id,
isSpam,
confidence: getSpamConfidence(mail),
vector: await (await this.mailFacade()).vectorizeAndCompressMails({ mail, mailDetails }),
}
return unencryptedPopulateClientSpamTrainingData
},
{
concurrency: 5,
},
)
// we are uploading the initial spam training data using the PopulateClientSpamTrainingDataService
return (await this.mailFacade()).populateClientSpamTrainingData(assertNotNull(mailBox._ownerGroup), unencryptedPopulateClientSpamTrainingData)
}
private sampleEntriesFromArray<T>(arr: T[], numberOfEntries: number): T[] {
if (numberOfEntries >= arr.length) {
return arr
}
const shuffled = arr.slice().sort(() => Math.random() - 0.5)
return shuffled.slice(0, numberOfEntries)
}
}

View file

@ -1,114 +0,0 @@
import { EntityClient } from "../../../common/api/common/EntityClient"
import { assertNotNull, isNotNull, lazyAsync } from "@tutao/tutanota-utils"
import {
MailAddress,
MailBag,
MailboxGroupRootTypeRef,
MailBoxTypeRef,
MailDetails,
MailFolder,
MailFolderTypeRef,
MailTypeRef,
Recipients,
} from "../../../common/api/entities/tutanota/TypeRefs"
import { getMailSetKind, getSpamConfidence, MailSetKind } from "../../../common/api/common/TutanotaConstants"
import { elementIdPart, isSameId, listIdPart, timestampToGeneratedId } from "../../../common/api/common/utils/EntityUtils"
import { OfflineStoragePersistence } from "../index/OfflineStoragePersistence"
import { getMailBodyText } from "../../../common/api/common/CommonMailUtils"
import { BulkMailLoader, MailWithMailDetails } from "../index/BulkMailLoader"
import { hasError } from "../../../common/api/common/utils/ErrorUtils"
import { SpamTrainMailDatum } from "./SpamClassifier"
import { extractSpamHeaderFeatures } from "../../mail/model/SpamClassificationHandler"
const INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS = 28
export class SpamClassificationInitializer {
/*
* While downloading mails, we start from current mailbag, but it might be that current mailbag is too new,
* If there are less than this mail in current mailbag, we will also try to fetch previous one
*/
public readonly MIN_MAILS_COUNT: number = 300
public readonly TIME_LIMIT: number = INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS * -1
constructor(
private readonly entityClient: EntityClient,
private readonly offlineStorage: OfflineStoragePersistence,
private readonly bulkMailLoader: lazyAsync<BulkMailLoader>,
) {}
public async init(ownerGroup: Id): Promise<SpamTrainMailDatum[]> {
// populate the spam classification data with the last 28 days of mails if they are
// available in the current mail bag
const data = await this.downloadMailAndMailDetailsByGroupMembership(ownerGroup)
data.filter((datum) => datum.isSpamConfidence > 0)
let spamMailsCount = 0
let hamMailsCount = 0
for (const spamTrainMailDatum of data) {
await this.offlineStorage.storeSpamClassification(spamTrainMailDatum)
if (spamTrainMailDatum.isSpam) spamMailsCount += 1
else hamMailsCount += 1
}
console.log(
`Downloaded ${spamMailsCount} spam mails and ${hamMailsCount} ham mails for group: ${ownerGroup}. Spam:Ham ratio is: ${(spamMailsCount / hamMailsCount).toFixed(2)}`,
)
return data
}
private async downloadMailAndMailDetailsByGroupMembership(mailGroupId: Id): Promise<Array<SpamTrainMailDatum>> {
const mailboxGroupRoot = await this.entityClient.load(MailboxGroupRootTypeRef, mailGroupId)
const mailbox = await this.entityClient.load(MailBoxTypeRef, mailboxGroupRoot.mailbox)
const mailSets = await this.entityClient.loadAll(MailFolderTypeRef, assertNotNull(mailbox.folders).folders)
const spamFolder = mailSets.find((s) => getMailSetKind(s) === MailSetKind.SPAM)!
const downloadedMailClassificationDatas = new Array<SpamTrainMailDatum>()
const allMailbags = [assertNotNull(mailbox.currentMailBag), ...mailbox.archivedMailBags].reverse() // sorted from latest to oldest
for (
let currentMailbag = allMailbags.pop();
isNotNull(currentMailbag) && downloadedMailClassificationDatas.length < this.MIN_MAILS_COUNT;
currentMailbag = allMailbags.pop()
) {
const mailsOfThisMailbag = await this.downloadMailAndMailDetailsByMailbag(currentMailbag, spamFolder)
downloadedMailClassificationDatas.push(...mailsOfThisMailbag)
}
return downloadedMailClassificationDatas
}
private async downloadMailAndMailDetailsByMailbag(mailbag: MailBag, spamFolder: MailFolder): Promise<Array<SpamTrainMailDatum>> {
const { LocalTimeDateProvider } = await import("../../../common/api/worker/DateProvider.js")
const dateProvider = new LocalTimeDateProvider()
const startTime = dateProvider.getStartOfDayShiftedBy(this.TIME_LIMIT).getTime()
const bulkMailLoader = await this.bulkMailLoader()
return await this.entityClient
.loadAll(MailTypeRef, mailbag.mails, timestampToGeneratedId(startTime))
// Filter out draft mails and mails with error
.then((mails) => {
return mails.filter((m) => isNotNull(m.mailDetails) && !hasError(m))
})
// Download mail details
.then((mails) => bulkMailLoader.loadMailDetails(mails))
// Map to spam mail datum
.then((mails) => mails.map((m) => this.mailWithDetailsToMailDatum(spamFolder, m)))
}
private mailWithDetailsToMailDatum(spamFolder: MailFolder, { mail, mailDetails }: MailWithMailDetails): SpamTrainMailDatum {
const isSpam = mail.sets.some((folderId) => isSameId(folderId, spamFolder._id))
return {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(mailDetails.body),
isSpam: isSpam,
isSpamConfidence: getSpamConfidence(mail),
listId: listIdPart(mail._id),
elementId: elementIdPart(mail._id),
ownerGroup: assertNotNull(mail._ownerGroup),
...extractSpamHeaderFeatures(mail, mailDetails),
} as SpamTrainMailDatum
}
}

View file

@ -1,29 +1,10 @@
import { assertWorkerOrNode } from "../../../common/api/common/Env" import { assertWorkerOrNode } from "../../../common/api/common/Env"
import { assertNotNull, defer, groupByAndMap, isNotNull, Nullable, promiseMap, tokenize } from "@tutao/tutanota-utils" import { assertNotNull, groupByAndMap, isEmpty, Nullable, promiseMap } from "@tutao/tutanota-utils"
import { HashingVectorizer } from "./HashingVectorizer" import { SpamClassificationDataDealer, TrainingDataset } from "./SpamClassificationDataDealer"
import {
ML_BITCOIN_REGEX,
ML_BITCOIN_TOKEN,
ML_CREDIT_CARD_REGEX,
ML_CREDIT_CARD_TOKEN,
ML_DATE_REGEX,
ML_DATE_TOKEN,
ML_EMAIL_ADDR_REGEX,
ML_EMAIL_ADDR_TOKEN,
ML_NUMBER_SEQUENCE_REGEX,
ML_NUMBER_SEQUENCE_TOKEN,
ML_SPACE_BEFORE_NEW_LINE_REGEX,
ML_SPACE_BEFORE_NEW_LINE_TOKEN,
ML_SPECIAL_CHARACTER_REGEX,
ML_SPECIAL_CHARACTER_TOKEN,
ML_URL_REGEX,
ML_URL_TOKEN,
} from "./PreprocessPatterns"
import { SpamClassificationInitializer } from "./SpamClassificationInitializer"
import { CacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache" import { CacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache"
import { htmlToText } from "../../../common/api/common/utils/IndexUtils"
import { import {
dense, dense,
enableProdMode,
fromMemory, fromMemory,
glorotUniform, glorotUniform,
LayersModel, LayersModel,
@ -33,10 +14,13 @@ import {
tensor2d, tensor2d,
withSaveHandler, withSaveHandler,
} from "./tensorflow-custom" } from "./tensorflow-custom"
import type { Tensor } from "@tensorflow/tfjs-core"
import type { ModelArtifacts } from "@tensorflow/tfjs-core/dist/io/types" import type { ModelArtifacts } from "@tensorflow/tfjs-core/dist/io/types"
import type { ModelFitArgs } from "@tensorflow/tfjs-layers" import type { ModelFitArgs } from "@tensorflow/tfjs-layers"
import { OfflineStoragePersistence } from "../index/OfflineStoragePersistence" import type { Tensor } from "@tensorflow/tfjs-core"
import { DEFAULT_PREPROCESS_CONFIGURATION, SpamMailDatum, SpamMailProcessor } from "../../../common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { SparseVectorCompressor } from "../../../common/api/common/utils/spamClassificationUtils/SparseVectorCompressor"
import { SpamDecision } from "../../../common/api/common/TutanotaConstants"
import { HashingVectorizer } from "./HashingVectorizer"
assertWorkerOrNode() assertWorkerOrNode()
@ -45,222 +29,110 @@ export type SpamClassificationModel = {
weightSpecs: string weightSpecs: string
weightData: Uint8Array weightData: Uint8Array
ownerGroup: Id ownerGroup: Id
hamCount: number
spamCount: number
} }
export type SpamTrainMailDatum = { export const DEFAULT_PREDICTION_THRESHOLD = 0.55
mailId: IdTuple
subject: string
body: string
isSpam: boolean
isSpamConfidence: number
ownerGroup: Id
sender: string
toRecipients: string
ccRecipients: string
bccRecipients: string
authStatus: string
}
export type SpamPredMailDatum = {
subject: string
body: string
ownerGroup: Id
sender: string
toRecipients: string
ccRecipients: string
bccRecipients: string
authStatus: string
}
const PREDICTION_THRESHOLD = 0.55
export type PreprocessConfiguration = {
isPreprocessMails: boolean
isRemoveHTML: boolean
isReplaceDates: boolean
isReplaceUrls: boolean
isReplaceMailAddresses: boolean
isReplaceBitcoinAddress: boolean
isReplaceCreditCards: boolean
isReplaceNumbers: boolean
isReplaceSpecialCharacters: boolean
isRemoveSpaceBeforeNewLine: boolean
}
export const DEFAULT_PREPROCESS_CONFIGURATION: PreprocessConfiguration = {
isPreprocessMails: true,
isRemoveHTML: true,
isReplaceDates: true,
isReplaceUrls: true,
isReplaceMailAddresses: true,
isReplaceBitcoinAddress: true,
isReplaceCreditCards: true,
isReplaceNumbers: true,
isReplaceSpecialCharacters: true,
isRemoveSpaceBeforeNewLine: true,
}
const TRAINING_INTERVAL = 1000 * 60 * 10 // 10 minutes const TRAINING_INTERVAL = 1000 * 60 * 10 // 10 minutes
const FULL_RETRAINING_INTERVAL = 1000 * 60 * 60 * 24 * 7 // 1 week const FULL_RETRAINING_INTERVAL = 1000 * 60 * 60 * 24 * 7 // 1 week
type TrainingPerformance = { export type Classifier = {
trainingTime: number isEnabled: boolean
vectorizationTime: number layersModel: LayersModel
threshold: number
hamCount: number
spamCount: number
} }
export const spamClassifierTokenizer = (text: string): string[] => tokenize(text)
export class SpamClassifier { export class SpamClassifier {
private readonly classifier: Map<Id, { model: LayersModel; isEnabled: boolean }> // Visible for testing
readonly classifiers: Map<Id, Classifier>
sparseVectorCompressor: SparseVectorCompressor
spamMailProcessor: SpamMailProcessor
constructor( constructor(
private readonly offlineStorage: OfflineStoragePersistence, private readonly cacheStorage: CacheStorage,
private readonly offlineStorageCache: CacheStorage, private readonly initializer: SpamClassificationDataDealer,
private readonly initializer: SpamClassificationInitializer,
private readonly deterministic: boolean = false, private readonly deterministic: boolean = false,
private readonly preprocessConfiguration: PreprocessConfiguration = DEFAULT_PREPROCESS_CONFIGURATION,
private readonly vectorizer: HashingVectorizer = new HashingVectorizer(),
) { ) {
this.classifier = new Map() // enable tensorflow production mode
enableProdMode()
this.classifiers = new Map()
this.sparseVectorCompressor = new SparseVectorCompressor()
this.spamMailProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(), this.sparseVectorCompressor)
}
calculateThreshold(hamCount: number, spamCount: number) {
const hamToSpamRatio = hamCount / spamCount
let threshold = -0.1 * Math.log10(hamToSpamRatio * 10) + 0.65
if (threshold < DEFAULT_PREDICTION_THRESHOLD) {
threshold = DEFAULT_PREDICTION_THRESHOLD
} else if (threshold > 0.75) {
threshold = 0.75
}
return threshold
} }
public async initialize(ownerGroup: Id): Promise<void> { public async initialize(ownerGroup: Id): Promise<void> {
const loadedModel = await this.loadModel(ownerGroup) const classifier = await this.loadClassifier(ownerGroup)
const storage = assertNotNull(this.offlineStorageCache) if (classifier) {
setInterval(async () => { const timeSinceLastFullTraining = Date.now() - FULL_RETRAINING_INTERVAL
const cutoffDate = Date.now() - FULL_RETRAINING_INTERVAL const lastFullTrainingTime = await this.cacheStorage.getLastTrainedFromScratchTime()
const lastFullTrainingTime = await storage.getLastTrainedFromScratchTime() if (timeSinceLastFullTraining > lastFullTrainingTime) {
console.log(`Retraining from scratch as last train (${new Date(lastFullTrainingTime)}) was more than a week ago`)
if (cutoffDate > lastFullTrainingTime) { await this.trainFromScratch(this.cacheStorage, ownerGroup)
await this.retrainModelFromScratch(storage, ownerGroup, cutoffDate) } else {
console.log("loaded existing spam classification model from database")
this.classifiers.set(ownerGroup, classifier)
await this.updateAndSaveModel(this.cacheStorage, ownerGroup)
} }
}, FULL_RETRAINING_INTERVAL)
if (isNotNull(loadedModel)) {
console.log("Loaded existing spam classification model from database")
this.classifier.set(ownerGroup, { model: loadedModel, isEnabled: true })
await this.updateAndSaveModel(storage, ownerGroup)
setInterval(async () => { setInterval(async () => {
await this.updateAndSaveModel(storage, ownerGroup) await this.updateAndSaveModel(this.cacheStorage, ownerGroup)
}, TRAINING_INTERVAL)
} else {
console.log("no existing model found. Training from scratch ...")
await this.trainFromScratch(this.cacheStorage, ownerGroup)
setInterval(async () => {
await this.updateAndSaveModel(this.cacheStorage, ownerGroup)
}, TRAINING_INTERVAL) }, TRAINING_INTERVAL)
return
}
console.log("No existing model found. Training from scratch...")
await this.trainFromScratch(storage, ownerGroup)
setInterval(async () => {
await this.updateAndSaveModel(storage, ownerGroup)
}, TRAINING_INTERVAL)
}
private async trainFromScratch(storage: CacheStorage, ownerGroup: string) {
const data = await this.initializer.init(ownerGroup)
if (data.length === 0) {
console.log("No training data found. Training from scratch aborted.")
return
}
await this.initialTraining(data)
await this.saveModel(ownerGroup)
await storage.setLastTrainedFromScratchTime(Date.now())
await storage.setLastTrainedTime(Date.now())
}
// VisibleForTesting
public async updateAndSaveModel(storage: CacheStorage, ownerGroup: Id) {
const isModelUpdated = await this.updateModelFromCutoff(await storage.getLastTrainedTime(), ownerGroup)
if (isModelUpdated) {
await this.saveModel(ownerGroup)
await storage.setLastTrainedTime(Date.now())
} }
} }
// visibleForTesting // visibleForTesting
public preprocessMail(mail: SpamTrainMailDatum | SpamPredMailDatum): string { public async updateAndSaveModel(storage: CacheStorage, ownerGroup: Id) {
const mailText = this.concatSubjectAndBody(mail) const isModelUpdated = await this.updateModelFromIndexStartId(await storage.getLastTrainingDataIndexId(), ownerGroup)
if (isModelUpdated) {
if (!this.preprocessConfiguration.isPreprocessMails) { console.log(`Model updated successfully at ${Date.now()}`)
return mailText
} }
let preprocessedMail = mailText
// 1. Remove HTML code
if (this.preprocessConfiguration.isRemoveHTML) {
preprocessedMail = htmlToText(preprocessedMail)
}
// 2. Replace dates
if (this.preprocessConfiguration.isReplaceDates) {
for (const datePattern of ML_DATE_REGEX) {
preprocessedMail = preprocessedMail.replaceAll(datePattern, ML_DATE_TOKEN)
}
}
// 3. Replace urls
if (this.preprocessConfiguration.isReplaceUrls) {
preprocessedMail = preprocessedMail.replaceAll(ML_URL_REGEX, ML_URL_TOKEN)
}
// 4. Replace email addresses
if (this.preprocessConfiguration.isReplaceMailAddresses) {
preprocessedMail = preprocessedMail.replaceAll(ML_EMAIL_ADDR_REGEX, ML_EMAIL_ADDR_TOKEN)
}
// 5. Replace Bitcoin addresses
if (this.preprocessConfiguration.isReplaceBitcoinAddress) {
preprocessedMail = preprocessedMail.replaceAll(ML_BITCOIN_REGEX, ML_BITCOIN_TOKEN)
}
// 6. Replace credit card numbers
if (this.preprocessConfiguration.isReplaceCreditCards) {
preprocessedMail = preprocessedMail.replaceAll(ML_CREDIT_CARD_REGEX, ML_CREDIT_CARD_TOKEN)
}
// 7. Replace remaining numbers
if (this.preprocessConfiguration.isReplaceNumbers) {
preprocessedMail = preprocessedMail.replaceAll(ML_NUMBER_SEQUENCE_REGEX, ML_NUMBER_SEQUENCE_TOKEN)
}
// 8. Remove special characters
if (this.preprocessConfiguration.isReplaceSpecialCharacters) {
preprocessedMail = preprocessedMail.replaceAll(ML_SPECIAL_CHARACTER_REGEX, ML_SPECIAL_CHARACTER_TOKEN)
}
// 9. Remove spaces at end of lines
if (this.preprocessConfiguration.isRemoveSpaceBeforeNewLine) {
preprocessedMail = preprocessedMail.replaceAll(ML_SPACE_BEFORE_NEW_LINE_REGEX, ML_SPACE_BEFORE_NEW_LINE_TOKEN)
}
preprocessedMail += this.getHeaderFeatures(mail)
return preprocessedMail
} }
private getHeaderFeatures(mail: SpamTrainMailDatum | SpamPredMailDatum): string { public async initialTraining(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<void> {
const { sender, toRecipients, ccRecipients, bccRecipients, authStatus } = mail const { trainingData: clientSpamTrainingData, hamCount, spamCount } = trainingDataset
return `\n${sender}\n${toRecipients}\n${ccRecipients}\n${bccRecipients}\n${authStatus}` const trainingInput = await promiseMap(
} clientSpamTrainingData,
(d) => {
const vector = this.sparseVectorCompressor.binaryToVector(d.vector)
const label = d.spamDecision === SpamDecision.BLACKLIST ? 1 : 0
return { vector, label }
},
{
concurrency: 5,
},
)
const vectors = trainingInput.map((input) => input.vector)
const labels = trainingInput.map((input) => input.label)
public async initialTraining(mails: SpamTrainMailDatum[]): Promise<TrainingPerformance> { const xs = tensor2d(vectors, [trainingInput.length, this.sparseVectorCompressor.dimension], undefined)
const preprocessingStart = performance.now()
const tokenizedMails = await promiseMap(mails, (mail) => spamClassifierTokenizer(this.preprocessMail(mail)))
const preprocessingTime = performance.now() - preprocessingStart
const vectorizationStart = performance.now()
const vectors = await this.vectorizer.transform(tokenizedMails)
const labels = mails.map((mail) => (mail.isSpam ? 1 : 0))
const vectorizationTime = performance.now() - vectorizationStart
const xs = tensor2d(vectors, [vectors.length, this.vectorizer.dimension], undefined)
const ys = tensor1d(labels, undefined) const ys = tensor1d(labels, undefined)
const classifier = this.buildModel(this.vectorizer.dimension) const layersModel = this.buildModel(this.sparseVectorCompressor.dimension)
const trainingStart = performance.now() const trainingStart = performance.now()
await classifier.fit(xs, ys, { await layersModel.fit(xs, ys, {
epochs: 16, epochs: 16,
batchSize: 32, batchSize: 32,
shuffle: !this.deterministic, shuffle: !this.deterministic,
@ -271,80 +143,100 @@ export class SpamClassifier {
// } // }
// }, // },
// }, // },
yieldEvery: 15,
}) })
const trainingTime = performance.now() - trainingStart const trainingTime = performance.now() - trainingStart
// When using the webgl backend we need to manually dispose @tensorflow tensors // when using the webgl backend we need to manually dispose @tensorflow tensors
xs.dispose() xs.dispose()
ys.dispose() ys.dispose()
this.classifier.set(mails[0].ownerGroup, { model: classifier, isEnabled: true }) const threshold = this.calculateThreshold(trainingDataset.hamCount, trainingDataset.spamCount)
const classifier = {
layersModel: layersModel,
isEnabled: true,
hamCount,
spamCount,
threshold,
}
this.classifiers.set(ownerGroup, classifier)
console.log( console.log(
`### Finished Initial Training ### (total trained mails: ${mails.length}, preprocessing time: ${preprocessingTime}, vectorization time: ${vectorizationTime}ms, training time: ${trainingTime})`, `### Finished Initial Spam Classification Model Training ### (total trained mails: ${clientSpamTrainingData.length} (ham:spam ${hamCount}:${spamCount} => threshold:${threshold}), training time: ${trainingTime})`,
) )
return { vectorizationTime, trainingTime }
} }
public async updateModelFromCutoff(cutoffTimestamp: number, ownerGroup: Id): Promise<boolean> { public async updateModelFromIndexStartId(indexStartId: Id, ownerGroup: Id): Promise<boolean> {
try { try {
const modelNotEnabled = this.classifier.get(ownerGroup) === undefined || this.classifier.get(ownerGroup)?.isEnabled === false const modelNotEnabled = this.classifiers.get(ownerGroup) === undefined || this.classifiers.get(ownerGroup)?.isEnabled === false
if (modelNotEnabled) { if (modelNotEnabled) {
console.warn("Client spam classification is not enabled or there were errors during training") console.warn("client spam classification is not enabled or there were errors during training")
return false return false
} }
const newTrainingMails = await assertNotNull(this.offlineStorage).getCertainSpamClassificationTrainingDataAfterCutoff(cutoffTimestamp, ownerGroup) const trainingDataset = await this.initializer.fetchPartialTrainingDataFromIndexStartId(indexStartId, ownerGroup)
if (newTrainingMails.length === 0) { if (isEmpty(trainingDataset.trainingData)) {
console.log("No new training data since last update.") console.log("no new spam classification training data since last update")
return false return false
} }
console.log(`Retraining model with ${newTrainingMails.length} new mails (lastModified > ${new Date(cutoffTimestamp).toString()})`)
return await this.updateModel(ownerGroup, newTrainingMails) console.log(
`retraining spam classification model with ${trainingDataset.trainingData.length} new mails (ham:spam ${trainingDataset.hamCount}:${trainingDataset.spamCount}) (lastTrainingDataIndexId > ${indexStartId})`,
)
return await this.updateModel(ownerGroup, trainingDataset)
} catch (e) { } catch (e) {
console.error("Failed trying to update the model: ", e) console.error("failed to update the model", e)
return false return false
} }
} }
// VisibleForTesting // visibleForTesting
async updateModel(ownerGroup: Id, newTrainingMails: SpamTrainMailDatum[]) { async updateModel(ownerGroup: Id, trainingDataset: TrainingDataset): Promise<boolean> {
const retrainingStart = performance.now() const retrainingStart = performance.now()
const modelToUpdate = assertNotNull(this.classifier.get(ownerGroup)) if (isEmpty(trainingDataset.trainingData)) {
const tokenizedMailsArray = await promiseMap(newTrainingMails, async (mail) => { console.log("no new spam classification training data since last update")
const preprocessedMail = this.preprocessMail(mail) return false
const tokenizedMail = spamClassifierTokenizer(preprocessedMail) }
return { tokenizedMail, isSpamConfidence: mail.isSpamConfidence, isSpam: mail.isSpam ? 1 : 0 }
})
const tokenizedMailsByConfidence = groupByAndMap( const modelToUpdate = assertNotNull(this.classifiers.get(ownerGroup))
tokenizedMailsArray, const trainingInput = await promiseMap(
({ isSpamConfidence }) => isSpamConfidence, trainingDataset.trainingData,
({ isSpam, tokenizedMail }) => { (d) => {
return { isSpam, tokenizedMail } const vector = this.sparseVectorCompressor.binaryToVector(d.vector)
const label = d.spamDecision === SpamDecision.BLACKLIST ? 1 : 0
const isSpamConfidence = Number(d.confidence)
return { vector, label, isSpamConfidence }
},
{
concurrency: 5,
}, },
) )
const trainingInputByConfidence = groupByAndMap(
trainingInput,
({ isSpamConfidence }) => isSpamConfidence,
({ vector, label }) => {
return { vector, label }
},
)
modelToUpdate.isEnabled = false modelToUpdate.isEnabled = false
try { try {
for (const [isSpamConfidence, tokenizedMails] of tokenizedMailsByConfidence) { for (const [isSpamConfidence, trainingInput] of trainingInputByConfidence) {
const vectors = await this.vectorizer.transform(tokenizedMails.map(({ tokenizedMail }) => tokenizedMail)) const vectors = trainingInput.map((input) => input.vector)
const xs = tensor2d(vectors, [vectors.length, this.vectorizer.dimension], undefined) const labels = trainingInput.map((input) => input.label)
const ys = tensor1d(
tokenizedMails.map(({ isSpam }) => isSpam), const xs = tensor2d(vectors, [vectors.length, this.sparseVectorCompressor.dimension], "int32")
undefined, const ys = tensor1d(labels, "int32")
)
// We need a way to put weight on a specific mail, ideal way would be to pass sampleWeight to modelFitArgs, // We need a way to put weight on a specific mail, ideal way would be to pass sampleWeight to modelFitArgs,
// but is not yet implemented: https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/engine/training.ts#L1487 // but is not yet implemented: https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/engine/training.ts#L1487
// //
// work around: // For now, we use the following workaround:
// current: Re fit the vector of mail multiple times corresponding to `isSpamConfidence` // Re-fit the vector multiple times corresponding to `isSpamConfidence`
// tried approaches:
// 1) Increasing value in vectorizer by isSpamConfidence instead of 1
// 2) duplicating the emails with higher isSpamConfidence and calling .fit once
const modelFitArgs: ModelFitArgs = { const modelFitArgs: ModelFitArgs = {
epochs: 8, epochs: 8,
batchSize: 32, batchSize: 32,
@ -354,62 +246,51 @@ export class SpamClassifier {
// console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`) // console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`)
// }, // },
// }, // },
yieldEvery: 15,
} }
for (let i = 0; i <= isSpamConfidence; i++) { for (let i = 0; i <= isSpamConfidence; i++) {
await modelToUpdate.model.fit(xs, ys, modelFitArgs) await modelToUpdate.layersModel.fit(xs, ys, modelFitArgs)
} }
// When using the webgl backend we need to manually dispose @tensorflow tensors // when using the webgl backend we need to manually dispose @tensorflow tensors
xs.dispose() xs.dispose()
ys.dispose() ys.dispose()
} }
} finally { } finally {
modelToUpdate.hamCount += trainingDataset.hamCount
modelToUpdate.spamCount += trainingDataset.spamCount
modelToUpdate.threshold = this.calculateThreshold(modelToUpdate.hamCount, modelToUpdate.spamCount)
modelToUpdate.isEnabled = true modelToUpdate.isEnabled = true
} }
console.log(`Retraining finished. Took: ${performance.now() - retrainingStart}ms`) const trainingMetadata = `Total Ham: ${modelToUpdate.hamCount} Spam: ${modelToUpdate.spamCount} threshold: ${modelToUpdate.threshold}}`
console.log(`retraining spam classification model finished, took: ${performance.now() - retrainingStart}ms ${trainingMetadata}`)
await this.saveModel(ownerGroup)
await this.cacheStorage.setLastTrainingDataIndexId(trainingDataset.lastTrainingDataIndexId)
return true return true
} }
// visibleForTesting // visibleForTesting
public async predict(spamPredMailDatum: SpamPredMailDatum): Promise<Nullable<boolean>> { public async predict(vector: number[], ownerGroup: Id): Promise<Nullable<boolean>> {
const classifier = this.classifier.get(spamPredMailDatum.ownerGroup) const classifier = this.classifiers.get(ownerGroup)
if (classifier == null || !classifier.isEnabled) { if (classifier == null || !classifier.isEnabled) {
return null return null
} }
const preprocessedMail = this.preprocessMail(spamPredMailDatum) const vectors = [vector]
const tokenizedMail = spamClassifierTokenizer(preprocessedMail) const xs = tensor2d(vectors, [vectors.length, this.sparseVectorCompressor.dimension], "int32")
const vectors = await assertNotNull(this.vectorizer).transform([tokenizedMail])
const xs = tensor2d(vectors, [vectors.length, assertNotNull(this.vectorizer).dimension], undefined) const predictionTensor = classifier.layersModel.predict(xs) as Tensor
const predictionTensor = classifier.model.predict(xs) as Tensor
const predictionData = await predictionTensor.data() const predictionData = await predictionTensor.data()
const prediction = predictionData[0] const prediction = predictionData[0]
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${spamPredMailDatum.ownerGroup}`) console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${ownerGroup}`)
// When using the webgl backend we need to manually dispose @tensorflow tensors // when using the webgl backend we need to manually dispose @tensorflow tensors
xs.dispose() xs.dispose()
predictionTensor.dispose() predictionTensor.dispose()
return prediction > PREDICTION_THRESHOLD return prediction > classifier.threshold
}
public getSpamClassification(mailId: IdTuple) {
return this.offlineStorage.getSpamClassification(mailId)
}
public updateSpamClassification(mailId: IdTuple, isSpam: boolean, isSpamConfidence: number) {
return this.offlineStorage.updateSpamClassification(mailId, isSpam, isSpamConfidence)
}
public storeSpamClassification(spamTrainMailDatum: SpamTrainMailDatum) {
return this.offlineStorage.storeSpamClassification(spamTrainMailDatum)
}
public deleteSpamClassification(mailId: IdTuple) {
return this.offlineStorage.deleteSpamClassification(mailId)
} }
// visibleForTesting // visibleForTesting
@ -451,30 +332,126 @@ export class SpamClassifier {
} }
public async saveModel(ownerGroup: Id): Promise<void> { public async saveModel(ownerGroup: Id): Promise<void> {
const modelArtifacts = await this.getModelArtifacts(ownerGroup) const spamClassificationModel = await this.getSpamClassificationModel(ownerGroup)
if (modelArtifacts == null) { if (spamClassificationModel == null) {
throw new Error("Model is not available, and therefore can not be saved") throw new Error("spam classification model is not available, and therefore can not be saved")
}
await this.cacheStorage.setSpamClassificationModel(spamClassificationModel)
}
async vectorizeAndCompress(mailDatum: SpamMailDatum) {
return await this.spamMailProcessor.vectorizeAndCompress(mailDatum)
}
async vectorize(mailDatum: SpamMailDatum) {
return await this.spamMailProcessor.vectorize(mailDatum)
}
// visibleForTesting
public async loadClassifier(ownerGroup: Id): Promise<Nullable<Classifier>> {
const spamClassificationModel = await assertNotNull(this.cacheStorage).getSpamClassificationModel(ownerGroup)
if (spamClassificationModel) {
const modelTopology = JSON.parse(spamClassificationModel.modelTopology)
const weightSpecs = JSON.parse(spamClassificationModel.weightSpecs)
const weightData = spamClassificationModel.weightData.buffer.slice(
spamClassificationModel.weightData.byteOffset,
spamClassificationModel.weightData.byteOffset + spamClassificationModel.weightData.byteLength,
)
const modelArtifacts = { modelTopology, weightSpecs, weightData }
const layersModel = await loadLayersModelFromIOHandler(fromMemory(modelArtifacts), undefined, undefined)
layersModel.compile({
optimizer: "adam",
loss: "binaryCrossentropy",
metrics: ["accuracy"],
})
const threshold = this.calculateThreshold(spamClassificationModel.hamCount, spamClassificationModel.spamCount)
return {
isEnabled: true,
layersModel: layersModel,
threshold,
hamCount: spamClassificationModel.hamCount,
spamCount: spamClassificationModel.spamCount,
}
} else {
console.log("loading the spam classification spamClassificationModel from offline db failed ... ")
return null
}
}
// visibleForTesting
public async cloneClassifier(): Promise<SpamClassifier> {
const newClassifier = new SpamClassifier(this.cacheStorage, this.initializer, this.deterministic)
newClassifier.spamMailProcessor = this.spamMailProcessor
newClassifier.sparseVectorCompressor = this.sparseVectorCompressor
for (const [ownerGroup, { layersModel: _, isEnabled, threshold, hamCount, spamCount }] of this.classifiers) {
const modelArtifacts = assertNotNull(await this.getModelArtifacts(ownerGroup))
const newModel = await loadLayersModelFromIOHandler(fromMemory(modelArtifacts), undefined, undefined)
newModel.compile({
optimizer: "adam",
loss: "binaryCrossentropy",
metrics: ["accuracy"],
})
newClassifier.classifiers.set(ownerGroup, {
layersModel: newModel,
isEnabled,
threshold,
hamCount,
spamCount,
})
}
return newClassifier
}
// visibleForTesting
public addSpamClassifierForOwner(ownerGroup: Id, classifier: Classifier) {
this.classifiers.set(ownerGroup, classifier)
}
private async trainFromScratch(storage: CacheStorage, ownerGroup: string) {
const trainingDataset = await this.initializer.fetchAllTrainingData(ownerGroup)
const { trainingData, lastTrainingDataIndexId } = trainingDataset
if (isEmpty(trainingData)) {
console.log("No training trainingData found. Training from scratch aborted.")
return
}
await this.initialTraining(ownerGroup, trainingDataset)
await this.saveModel(ownerGroup)
await storage.setLastTrainedFromScratchTime(Date.now())
await storage.setLastTrainingDataIndexId(lastTrainingDataIndexId)
}
private async getSpamClassificationModel(ownerGroup: Id): Promise<SpamClassificationModel | null> {
const classifier = this.classifiers.get(ownerGroup)
if (!classifier) {
return null
}
const modelArtifacts = await this.getModelArtifacts(ownerGroup)
if (!modelArtifacts) {
return null
}
const modelTopology = JSON.stringify(modelArtifacts.modelTopology)
const weightSpecs = JSON.stringify(modelArtifacts.weightSpecs)
const weightData = new Uint8Array(modelArtifacts.weightData as ArrayBuffer)
return {
modelTopology,
weightSpecs,
weightData,
ownerGroup,
hamCount: classifier.hamCount,
spamCount: classifier.spamCount,
} }
await assertNotNull(this.offlineStorage).putSpamClassificationModel(modelArtifacts.spamClassificationModel)
} }
private async getModelArtifacts(ownerGroup: Id) { private async getModelArtifacts(ownerGroup: Id) {
const classifier = this.classifier.get(ownerGroup)?.model ?? null const classifier = this.classifiers.get(ownerGroup)
if (!classifier) return null if (!classifier) {
const spamClassificationModel = defer<SpamClassificationModel>() return null
const modelArtificats = new Promise<ModelArtifacts>((resolve) => { }
return await new Promise<ModelArtifacts>((resolve) => {
const saveInfo = withSaveHandler(async (artifacts: any) => { const saveInfo = withSaveHandler(async (artifacts: any) => {
resolve(artifacts) resolve(artifacts)
const modelTopology = JSON.stringify(artifacts.modelTopology)
const weightSpecs = JSON.stringify(artifacts.weightSpecs)
const weightData = new Uint8Array(artifacts.weightData as ArrayBuffer)
spamClassificationModel.resolve({
modelTopology,
weightSpecs,
weightData,
ownerGroup,
})
return { return {
modelArtifactsInfo: { modelArtifactsInfo: {
dateSaved: new Date(), dateSaved: new Date(),
@ -482,80 +459,7 @@ export class SpamClassifier {
}, },
} }
}) })
classifier.save(saveInfo, undefined) classifier.layersModel.save(saveInfo, undefined)
}) })
return {
modelArtifacts: await modelArtificats,
spamClassificationModel: await spamClassificationModel.promise,
}
}
// visibleForTesting
public async loadModel(ownerGroup: Id): Promise<Nullable<LayersModel>> {
const model = await assertNotNull(this.offlineStorage).getSpamClassificationModel(ownerGroup)
if (model) {
const modelTopology = JSON.parse(model.modelTopology)
const weightSpecs = JSON.parse(model.weightSpecs)
const weightData = model.weightData.buffer.slice(model.weightData.byteOffset, model.weightData.byteOffset + model.weightData.byteLength)
const classifier = await loadLayersModelFromIOHandler(fromMemory(modelTopology, weightSpecs, weightData, undefined), undefined)
classifier.compile({
optimizer: "adam",
loss: "binaryCrossentropy",
metrics: ["accuracy"],
})
return classifier
} else {
console.error("Loading the model from offline db failed")
return null
}
}
private concatSubjectAndBody(mail: SpamTrainMailDatum | SpamPredMailDatum) {
const subject = mail.subject || ""
const body = mail.body || ""
const concatenated = `${subject}\n${body}`.trim()
return concatenated.length > 0 ? concatenated : " "
}
private async retrainModelFromScratch(storage: CacheStorage, ownerGroup: Id, cutoffTimestamp: number) {
console.log("Model is being re-trained from scratch, deleting old data")
try {
await assertNotNull(this.offlineStorage).deleteSpamClassificationTrainingDataBeforeCutoff(cutoffTimestamp, ownerGroup)
} catch (e) {
console.error("Failed delete old training data: ", e)
return
}
await this.trainFromScratch(storage, ownerGroup)
}
// visibleForTesting
public async cloneClassifier(): Promise<SpamClassifier> {
const newClassifier = new SpamClassifier(
this.offlineStorage,
this.offlineStorageCache,
this.initializer,
this.deterministic,
this.preprocessConfiguration,
)
for (const [ownerGroup, { model: _, isEnabled }] of this.classifier) {
const { modelArtifacts } = assertNotNull(await this.getModelArtifacts(ownerGroup))
const newModel = await loadLayersModelFromIOHandler(fromMemory(modelArtifacts, undefined, undefined, undefined), undefined)
newModel.compile({
optimizer: "adam",
loss: "binaryCrossentropy",
metrics: ["accuracy"],
})
newClassifier.classifier.set(ownerGroup, { model: newModel, isEnabled })
}
return newClassifier
}
// visibleForTesting
public addSpamClassifierForOwner(ownerGroup: Id, model: LayersModel, isEnabled: boolean) {
this.classifier.set(ownerGroup, { model, isEnabled })
} }
} }

View file

@ -11,6 +11,7 @@ import { glorotUniform } from "@tensorflow/tfjs-layers/dist/exports_initializers
// Core tensor ops // Core tensor ops
import { tensor2d } from "@tensorflow/tfjs-core" import { tensor2d } from "@tensorflow/tfjs-core"
import { tensor1d } from "@tensorflow/tfjs-core" import { tensor1d } from "@tensorflow/tfjs-core"
import { enableProdMode } from "@tensorflow/tfjs-core"
import { stringToHashBucketFast } from "@tensorflow/tfjs-core/dist/ops/string/string_to_hash_bucket_fast" import { stringToHashBucketFast } from "@tensorflow/tfjs-core/dist/ops/string/string_to_hash_bucket_fast"
// IO handlers // IO handlers
@ -30,4 +31,5 @@ export {
withSaveHandler, withSaveHandler,
fromMemory, fromMemory,
stringToHashBucketFast, stringToHashBucketFast,
enableProdMode,
} }

View file

@ -93,7 +93,7 @@ export interface WorkerInterface {
readonly bulkMailLoader: BulkMailLoader readonly bulkMailLoader: BulkMailLoader
readonly applicationTypesFacade: ApplicationTypesFacade readonly applicationTypesFacade: ApplicationTypesFacade
readonly identityKeyCreator: IdentityKeyCreator readonly identityKeyCreator: IdentityKeyCreator
readonly spamClassifier: SpamClassifier | null readonly spamClassifier: SpamClassifier
readonly autosaveFacade: AutosaveFacade readonly autosaveFacade: AutosaveFacade
} }

View file

@ -112,9 +112,10 @@ import { PublicKeySignatureFacade } from "../../../common/api/worker/facades/Pub
import { AdminKeyLoaderFacade } from "../../../common/api/worker/facades/AdminKeyLoaderFacade" import { AdminKeyLoaderFacade } from "../../../common/api/worker/facades/AdminKeyLoaderFacade"
import { IdentityKeyCreator } from "../../../common/api/worker/facades/lazy/IdentityKeyCreator" import { IdentityKeyCreator } from "../../../common/api/worker/facades/lazy/IdentityKeyCreator"
import { PublicIdentityKeyProvider } from "../../../common/api/worker/facades/PublicIdentityKeyProvider" import { PublicIdentityKeyProvider } from "../../../common/api/worker/facades/PublicIdentityKeyProvider"
import type { SpamClassifier } from "../spamClassification/SpamClassifier" import { SpamClassifier } from "../spamClassification/SpamClassifier"
import { IdentityKeyTrustDatabase } from "../../../common/api/worker/facades/IdentityKeyTrustDatabase" import { IdentityKeyTrustDatabase } from "../../../common/api/worker/facades/IdentityKeyTrustDatabase"
import { AutosaveFacade } from "../../../common/api/worker/facades/lazy/AutosaveFacade" import { AutosaveFacade } from "../../../common/api/worker/facades/lazy/AutosaveFacade"
import { SpamClassificationDataDealer } from "../spamClassification/SpamClassificationDataDealer"
assertWorkerOrNode() assertWorkerOrNode()
@ -197,7 +198,7 @@ export type WorkerLocatorType = {
contactFacade: lazyAsync<ContactFacade> contactFacade: lazyAsync<ContactFacade>
//spam classification //spam classification
spamClassifier: SpamClassifier | null spamClassifier: SpamClassifier
} }
export const locator: WorkerLocatorType = {} as any export const locator: WorkerLocatorType = {} as any
@ -328,14 +329,8 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
if (isOfflineStorageAvailable() && !isAdminClient()) { if (isOfflineStorageAvailable() && !isAdminClient()) {
locator.sqlCipherFacade = new SqlCipherFacadeSendDispatcher(locator.native) locator.sqlCipherFacade = new SqlCipherFacadeSendDispatcher(locator.native)
offlineStorageProvider = async () => { offlineStorageProvider = async () => {
const { SpamClassifier } = await import("../spamClassification/SpamClassifier")
const { SpamClassificationInitializer } = await import("../spamClassification/SpamClassificationInitializer")
const offlineStorage = await offlineStorageIndexerPersistence()
const spamClassifierInitializer = new SpamClassificationInitializer(locator.cachingEntityClient, offlineStorage, locator.bulkMailLoader)
locator.spamClassifier = new SpamClassifier(offlineStorage, locator.cacheStorage, spamClassifierInitializer)
const { KeyVerificationTableDefinitions } = await import("../../../common/api/worker/facades/IdentityKeyTrustDatabase.js") const { KeyVerificationTableDefinitions } = await import("../../../common/api/worker/facades/IdentityKeyTrustDatabase.js")
const { SearchTableDefinitions, SpamClassificationDefinitions } = await import("../index/OfflineStoragePersistence.js") const { SearchTableDefinitions } = await import("../index/OfflineStoragePersistence.js")
const { AutosaveDraftsTableDefinitions } = await import("../../../common/api/worker/facades/lazy/OfflineStorageAutosaveFacade.js") const { AutosaveDraftsTableDefinitions } = await import("../../../common/api/worker/facades/lazy/OfflineStorageAutosaveFacade.js")
const customCacheHandler = new CustomCacheHandlerMap( const customCacheHandler = new CustomCacheHandlerMap(
@ -358,12 +353,11 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
locator.instancePipeline.modelMapper, locator.instancePipeline.modelMapper,
typeModelResolver, typeModelResolver,
customCacheHandler, customCacheHandler,
Object.assign({}, KeyVerificationTableDefinitions, SearchTableDefinitions, AutosaveDraftsTableDefinitions, SpamClassificationDefinitions), Object.assign({}, KeyVerificationTableDefinitions, SearchTableDefinitions, AutosaveDraftsTableDefinitions),
) )
} }
} else { } else {
offlineStorageProvider = async () => null offlineStorageProvider = async () => null
locator.spamClassifier = null
} }
const ephemeralStorageProvider = async () => { const ephemeralStorageProvider = async () => {
const customCacheHandler = new CustomCacheHandlerMap({ const customCacheHandler = new CustomCacheHandlerMap({
@ -385,19 +379,18 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
const { PdfWriter } = await import("../../../common/api/worker/pdf/PdfWriter.js") const { PdfWriter } = await import("../../../common/api/worker/pdf/PdfWriter.js")
return new PdfWriter(new TextEncoder(), undefined) return new PdfWriter(new TextEncoder(), undefined)
} }
locator.patchMerger = new PatchMerger(locator.cacheStorage, locator.instancePipeline, typeModelResolver, () => locator.crypto) locator.patchMerger = new PatchMerger(locator.cacheStorage, locator.instancePipeline, typeModelResolver, () => locator.crypto)
// We don't want to cache within the admin client // We don't want to cache within the admin client
let cache: DefaultEntityRestCache | null = null let cache: DefaultEntityRestCache | null = null
if (!isAdminClient()) { if (!isAdminClient()) {
cache = new DefaultEntityRestCache(entityRestClient, maybeUninitializedStorage, typeModelResolver, locator.patchMerger) cache = new DefaultEntityRestCache(entityRestClient, maybeUninitializedStorage, typeModelResolver, locator.patchMerger)
} }
locator.cache = cache ?? entityRestClient locator.cache = cache ?? entityRestClient
locator.cachingEntityClient = new EntityClient(locator.cache, typeModelResolver) locator.cachingEntityClient = new EntityClient(locator.cache, typeModelResolver)
const nonCachingEntityClient = new EntityClient(entityRestClient, typeModelResolver) const nonCachingEntityClient = new EntityClient(entityRestClient, typeModelResolver)
locator.cacheManagement = lazyMemoized(async () => { locator.cacheManagement = lazyMemoized(async () => {
const { CacheManagementFacade } = await import("../../../common/api/worker/facades/lazy/CacheManagementFacade.js") const { CacheManagementFacade } = await import("../../../common/api/worker/facades/lazy/CacheManagementFacade.js")
return new CacheManagementFacade(locator.user, locator.cachingEntityClient, assertNotNull(cache)) return new CacheManagementFacade(locator.user, locator.cachingEntityClient, assertNotNull(cache))
@ -607,7 +600,7 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
if (!isTest() && sessionType !== SessionType.Temporary && !isAdminClient()) { if (!isTest() && sessionType !== SessionType.Temporary && !isAdminClient()) {
// index new items in background // index new items in background
console.log("initIndexer and SpamClassifier after log in") console.log("initIndexer and SpamClassifier after log in")
const indexingDone = fullLoginIndexerInit(worker) fullLoginIndexerInit(worker)
} }
return mainInterface.loginListener.onFullLoginSuccess(sessionType, cacheInfo, credentials) return mainInterface.loginListener.onFullLoginSuccess(sessionType, cacheInfo, credentials)
@ -737,6 +730,7 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
locator.user, locator.user,
locator.cachingEntityClient, locator.cachingEntityClient,
locator.crypto, locator.crypto,
locator.cryptoWrapper,
locator.serviceExecutor, locator.serviceExecutor,
await locator.blob(), await locator.blob(),
fileApp, fileApp,
@ -745,6 +739,10 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
locator.publicEncryptionKeyProvider, locator.publicEncryptionKeyProvider,
) )
}) })
const spamClassificationDataDealer = new SpamClassificationDataDealer(locator.cachingEntityClient, locator.bulkMailLoader, locator.mail)
locator.spamClassifier = new SpamClassifier(locator.cacheStorage, spamClassificationDataDealer)
const nativePushFacade = new NativePushFacadeSendDispatcher(worker) const nativePushFacade = new NativePushFacadeSendDispatcher(worker)
locator.calendar = lazyMemoized(async () => { locator.calendar = lazyMemoized(async () => {
const { CalendarFacade } = await import("../../../common/api/worker/facades/lazy/CalendarFacade.js") const { CalendarFacade } = await import("../../../common/api/worker/facades/lazy/CalendarFacade.js")

View file

@ -83,6 +83,7 @@ import "./serviceworker/SwTest.js"
import "./api/worker/facades/KeyVerificationFacadeTest.js" import "./api/worker/facades/KeyVerificationFacadeTest.js"
import "./api/worker/utils/SleepDetectorTest.js" import "./api/worker/utils/SleepDetectorTest.js"
import "./api/worker/utils/spamClassification/HashingVectorizerTest.js" import "./api/worker/utils/spamClassification/HashingVectorizerTest.js"
import "./api/worker/utils/spamClassification/SpamClassificationDataDealerTest.js"
import "./api/worker/utils/spamClassification/PreprocessPatternsTest.js" import "./api/worker/utils/spamClassification/PreprocessPatternsTest.js"
import "./calendar/AlarmSchedulerTest.js" import "./calendar/AlarmSchedulerTest.js"
import "./calendar/CalendarAgendaViewTest.js" import "./calendar/CalendarAgendaViewTest.js"
@ -115,6 +116,7 @@ import "./gui/base/WizardDialogNTest.js"
import "./login/LoginViewModelTest.js" import "./login/LoginViewModelTest.js"
import "./login/PostLoginUtilsTest.js" import "./login/PostLoginUtilsTest.js"
import "./mail/InboxRuleHandlerTest.js" import "./mail/InboxRuleHandlerTest.js"
import "./mail/ProcessInboxHandlerTest.js"
import "./mail/KnowledgeBaseSearchFilterTest.js" import "./mail/KnowledgeBaseSearchFilterTest.js"
import "./mail/MailModelTest.js" import "./mail/MailModelTest.js"
import "./mail/MailUtilsSignatureTest.js" import "./mail/MailUtilsSignatureTest.js"
@ -211,6 +213,7 @@ async function setupSuite({ integration }: { integration?: boolean }) {
if (typeof process !== "undefined") { if (typeof process !== "undefined") {
// setup the Entropy for all testcases // setup the Entropy for all testcases
await random.addEntropy([{ data: 36, entropy: 256, source: "key" }]) await random.addEntropy([{ data: 36, entropy: 256, source: "key" }])
await import("./api/worker/utils/spamClassification/SparseVectorCompressorTest.js")
await import("./api/worker/utils/spamClassification/SpamClassifierTest.js") await import("./api/worker/utils/spamClassification/SpamClassifierTest.js")
await import("./api/worker/offline/OfflineStorageMigratorTest.js") await import("./api/worker/offline/OfflineStorageMigratorTest.js")
await import("./api/worker/offline/OfflineStorageTest.js") await import("./api/worker/offline/OfflineStorageTest.js")

View file

@ -578,7 +578,10 @@ o.spec("CryptoFacadeTest", function () {
senderIdentityKeyPair.publicKey, senderIdentityKeyPair.publicKey,
senderKeyVersion, senderKeyVersion,
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
const sessionKey = neverNull(await crypto.resolveSessionKey(mail)) const sessionKey = neverNull(await crypto.resolveSessionKey(mail))
@ -614,7 +617,10 @@ o.spec("CryptoFacadeTest", function () {
testData.senderIdentityKeyPair.publicKey, testData.senderIdentityKeyPair.publicKey,
anything(), anything(),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
await crypto.enforceSessionKeyUpdateIfNeeded(testData.mail, files) await crypto.enforceSessionKeyUpdateIfNeeded(testData.mail, files)
verify(ownerEncSessionKeysUpdateQueue.postUpdateSessionKeysService(anything()), { times: 1 }) verify(ownerEncSessionKeysUpdateQueue.postUpdateSessionKeysService(anything()), { times: 1 })
@ -962,7 +968,10 @@ o.spec("CryptoFacadeTest", function () {
testData.senderIdentityKeyPair.publicKey, testData.senderIdentityKeyPair.publicKey,
parseKeyVersion(senderKeyVersion), parseKeyVersion(senderKeyVersion),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
const sessionKey: AesKey = neverNull(await crypto.resolveSessionKey(testData.mail)) const sessionKey: AesKey = neverNull(await crypto.resolveSessionKey(testData.mail))
@ -995,7 +1004,10 @@ o.spec("CryptoFacadeTest", function () {
testData.senderIdentityKeyPair.publicKey, testData.senderIdentityKeyPair.publicKey,
parseKeyVersion(senderKeyVersion), parseKeyVersion(senderKeyVersion),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
const sessionKey: AesKey = neverNull(await crypto.resolveSessionKey(testData.mail)) const sessionKey: AesKey = neverNull(await crypto.resolveSessionKey(testData.mail))
@ -1029,7 +1041,10 @@ o.spec("CryptoFacadeTest", function () {
testData.senderIdentityKeyPair.publicKey, testData.senderIdentityKeyPair.publicKey,
parseKeyVersion(senderKeyVersion), parseKeyVersion(senderKeyVersion),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_FAILED, verificationState: PresentableKeyVerificationState.ALERT }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_FAILED,
verificationState: PresentableKeyVerificationState.ALERT,
})
const sessionKey = neverNull(await crypto.resolveSessionKey(testData.mail)) const sessionKey = neverNull(await crypto.resolveSessionKey(testData.mail))
@ -1258,7 +1273,10 @@ o.spec("CryptoFacadeTest", function () {
anything(), anything(),
anything(), anything(),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
const sessionKey = neverNull(await crypto.resolveSessionKey(testData.mail)) const sessionKey = neverNull(await crypto.resolveSessionKey(testData.mail))
@ -1281,7 +1299,10 @@ o.spec("CryptoFacadeTest", function () {
anything(), anything(),
anything(), anything(),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
// do not use testdouble here because it's hard to not break the function itself and then verify invocations // do not use testdouble here because it's hard to not break the function itself and then verify invocations
const decryptAndMapToInstance = (instancePipeline.cryptoMapper.decryptParsedInstance = spy(instancePipeline.cryptoMapper.decryptParsedInstance)) const decryptAndMapToInstance = (instancePipeline.cryptoMapper.decryptParsedInstance = spy(instancePipeline.cryptoMapper.decryptParsedInstance))
@ -1310,7 +1331,10 @@ o.spec("CryptoFacadeTest", function () {
anything(), anything(),
anything(), anything(),
), ),
).thenResolve({ authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED, verificationState: PresentableKeyVerificationState.SECURE }) ).thenResolve({
authStatus: EncryptionAuthStatus.TUTACRYPT_AUTHENTICATION_SUCCEEDED,
verificationState: PresentableKeyVerificationState.SECURE,
})
const mailSessionKey = neverNull(await crypto.resolveSessionKey(testData.mail)) const mailSessionKey = neverNull(await crypto.resolveSessionKey(testData.mail))
const bucketKey = assertNotNull(testData.mail.bucketKey) const bucketKey = assertNotNull(testData.mail.bucketKey)
@ -1860,6 +1884,7 @@ o.spec("CryptoFacadeTest", function () {
keyVerificationState: null, keyVerificationState: null,
processingState: ProcessingState.INBOX_RULE_APPLIED, processingState: ProcessingState.INBOX_RULE_APPLIED,
clientSpamClassifierResult: null, clientSpamClassifierResult: null,
processNeeded: false,
}) })
// casting here is fine, since we just want to mimic server response data // casting here is fine, since we just want to mimic server response data

View file

@ -40,7 +40,7 @@ import { UnreadMailStateService } from "../../../../../src/common/api/entities/t
import { BucketKeyTypeRef, InstanceSessionKey, InstanceSessionKeyTypeRef } from "../../../../../src/common/api/entities/sys/TypeRefs" import { BucketKeyTypeRef, InstanceSessionKey, InstanceSessionKeyTypeRef } from "../../../../../src/common/api/entities/sys/TypeRefs"
import { OwnerEncSessionKeyProvider } from "../../../../../src/common/api/worker/rest/EntityRestClient" import { OwnerEncSessionKeyProvider } from "../../../../../src/common/api/worker/rest/EntityRestClient"
import { elementIdPart, getElementId } from "../../../../../src/common/api/common/utils/EntityUtils" import { elementIdPart, getElementId } from "../../../../../src/common/api/common/utils/EntityUtils"
import { VersionedEncryptedKey } from "../../../../../src/common/api/worker/crypto/CryptoWrapper" import { CryptoWrapper, VersionedEncryptedKey } from "../../../../../src/common/api/worker/crypto/CryptoWrapper"
import { Recipient } from "../../../../../src/common/api/common/recipients/Recipient" import { Recipient } from "../../../../../src/common/api/common/recipients/Recipient"
import { AesKey } from "@tutao/tutanota-crypto" import { AesKey } from "@tutao/tutanota-crypto"
import { RecipientsNotFoundError } from "../../../../../src/common/api/common/error/RecipientsNotFoundError" import { RecipientsNotFoundError } from "../../../../../src/common/api/common/error/RecipientsNotFoundError"
@ -52,6 +52,7 @@ o.spec("MailFacade test", function () {
let facade: MailFacade let facade: MailFacade
let userFacade: UserFacade let userFacade: UserFacade
let cryptoFacade: CryptoFacade let cryptoFacade: CryptoFacade
let cryptoWrapper: CryptoWrapper
let serviceExecutor: IServiceExecutor let serviceExecutor: IServiceExecutor
let entityClient: EntityClient let entityClient: EntityClient
let blobFacade: BlobFacade let blobFacade: BlobFacade
@ -67,6 +68,7 @@ o.spec("MailFacade test", function () {
blobFacade = object() blobFacade = object()
entityClient = object() entityClient = object()
cryptoFacade = object() cryptoFacade = object()
cryptoWrapper = object()
serviceExecutor = object() serviceExecutor = object()
fileApp = object() fileApp = object()
loginFacade = object() loginFacade = object()
@ -76,6 +78,7 @@ o.spec("MailFacade test", function () {
userFacade, userFacade,
entityClient, entityClient,
cryptoFacade, cryptoFacade,
cryptoWrapper,
serviceExecutor, serviceExecutor,
blobFacade, blobFacade,
fileApp, fileApp,

View file

@ -1,7 +1,7 @@
import o from "@tutao/otest" import o from "@tutao/otest"
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer" import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
import { arrayEquals } from "@tutao/tutanota-utils" import { arrayEquals } from "@tutao/tutanota-utils"
import { spamClassifierTokenizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier" import { spamClassifierTokenizer } from "../../../../../../src/common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
o.spec("HashingVectorizer", () => { o.spec("HashingVectorizer", () => {
const rawDocuments = [ const rawDocuments = [

View file

@ -14,7 +14,7 @@ import {
ML_SPECIAL_CHARACTER_TOKEN, ML_SPECIAL_CHARACTER_TOKEN,
ML_URL_REGEX, ML_URL_REGEX,
ML_URL_TOKEN, ML_URL_TOKEN,
} from "../../../../../../src/mail-app/workerUtils/spamClassification/PreprocessPatterns" } from "../../../../../../src/common/api/common/utils/spamClassificationUtils/PreprocessPatterns"
import { isMailAddress } from "../../../../../../src/common/misc/FormatValidator" import { isMailAddress } from "../../../../../../src/common/misc/FormatValidator"
o.spec("PreprocessPatterns", () => { o.spec("PreprocessPatterns", () => {

View file

@ -0,0 +1,374 @@
import o from "@tutao/otest"
import {
SINGLE_TRAIN_INTERVAL_TRAINING_DATA_LIMIT,
SpamClassificationDataDealer,
UnencryptedPopulateClientSpamTrainingDatum,
} from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassificationDataDealer"
import {
ClientSpamTrainingDatum,
ClientSpamTrainingDatumIndexEntryTypeRef,
ClientSpamTrainingDatumTypeRef,
MailBagTypeRef,
MailBox,
MailboxGroupRoot,
MailboxGroupRootTypeRef,
MailBoxTypeRef,
MailDetails,
MailDetailsTypeRef,
MailFolderRefTypeRef,
MailFolderTypeRef,
MailTypeRef,
} from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
import { MailSetKind, SpamDecision } from "../../../../../../src/common/api/common/TutanotaConstants"
import { matchers, object, verify, when } from "testdouble"
import { EntityClient } from "../../../../../../src/common/api/common/EntityClient"
import { BulkMailLoader } from "../../../../../../src/mail-app/workerUtils/index/BulkMailLoader"
import { MailFacade } from "../../../../../../src/common/api/worker/facades/lazy/MailFacade"
import { createTestEntity } from "../../../../TestUtils"
import { GENERATED_MIN_ID, getElementId, isSameId } from "../../../../../../src/common/api/common/utils/EntityUtils"
import { DEFAULT_IS_SPAM_CONFIDENCE } from "../../../../../../src/common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { last } from "@tutao/tutanota-utils"
const { anything } = matchers
function createMailByFolderAndReceivedDate(mailId: IdTuple, mailSet: IdTuple, receivedDate: Date, mailDetailsId: Id) {
return createTestEntity(MailTypeRef, {
_id: mailId,
sets: [mailSet],
receivedDate: receivedDate,
mailDetails: ["detailsListId", mailDetailsId],
})
}
function createSpamTrainingDatumByConfidenceAndDecision(confidence: string, spamDecision: SpamDecision): ClientSpamTrainingDatum {
return createTestEntity(ClientSpamTrainingDatumTypeRef, {
_ownerGroup: "group",
confidence,
spamDecision,
vector: new Uint8Array(),
})
}
function createClientSpamTrainingDatumIndexEntryByClientSpamTrainingDatumElementId(clientSpamTrainingDatumElementId: Id) {
return createTestEntity(ClientSpamTrainingDatumIndexEntryTypeRef, { clientSpamTrainingDatumElementId })
}
o.spec("SpamClassificationDataDealer", () => {
const entityClientMock = object<EntityClient>()
const bulkMailLoaderMock = object<BulkMailLoader>()
const mailFacadeMock = object<MailFacade>()
let mailDetails: MailDetails
let spamClassificationDataDealer: SpamClassificationDataDealer
let mailboxGroupRoot: MailboxGroupRoot
let mailBox: MailBox
const inboxFolder = createTestEntity(MailFolderTypeRef, {
_id: ["folderListId", "inbox"],
_ownerGroup: "owner",
folderType: MailSetKind.INBOX,
})
const trashFolder = createTestEntity(MailFolderTypeRef, {
_id: ["folderListId", "trash"],
_ownerGroup: "owner",
folderType: MailSetKind.TRASH,
})
const spamFolder = createTestEntity(MailFolderTypeRef, {
_id: ["folderListId", "spam"],
_ownerGroup: "owner",
folderType: MailSetKind.SPAM,
})
o.beforeEach(function () {
mailboxGroupRoot = createTestEntity(MailboxGroupRootTypeRef, {
_ownerGroup: "owner",
mailbox: "mailbox",
})
mailBox = createTestEntity(MailBoxTypeRef, {
_id: "mailbox",
_ownerGroup: "owner",
folders: createTestEntity(MailFolderRefTypeRef, { folders: "folderListId" }),
currentMailBag: createTestEntity(MailBagTypeRef, { mails: "mailListId" }),
archivedMailBags: [createTestEntity(MailBagTypeRef, { mails: "oldMailListId" })],
clientSpamTrainingData: "clientSpamTrainingData",
modifiedClientSpamTrainingDataIndex: "modifiedClientSpamTrainingDataIndex",
})
mailDetails = createTestEntity(MailDetailsTypeRef, { _id: "mailDetail" })
when(mailFacadeMock.vectorizeAndCompressMails(anything())).thenResolve(new Uint8Array(1))
spamClassificationDataDealer = new SpamClassificationDataDealer(
entityClientMock,
() => Promise.resolve(bulkMailLoaderMock),
() => Promise.resolve(mailFacadeMock),
)
})
o.spec("subsampleHamAndSpamMails", () => {
o("does not subsample if ratio is balanced", () => {
const data = [
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST),
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST),
]
const { subsampledTrainingData, hamCount, spamCount } = spamClassificationDataDealer.subsampleHamAndSpamMails(data)
o(subsampledTrainingData.length).equals(2)
o(hamCount).equals(1)
o(spamCount).equals(1)
})
o("limits ham when ratio > MAX_RATIO", () => {
const hamData = Array.from({ length: 50 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST))
const spamData = Array.from({ length: 1 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST))
const { subsampledTrainingData, hamCount, spamCount } = spamClassificationDataDealer.subsampleHamAndSpamMails([...hamData, ...spamData])
o(hamCount).equals(10)
o(spamCount).equals(1)
o(subsampledTrainingData.length).equals(11)
})
o("limits spam when ratio < MIN_RATIO", () => {
const hamData = Array.from({ length: 1 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST))
const spamData = Array.from({ length: 50 }, () =>
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST),
)
const { subsampledTrainingData, hamCount, spamCount } = spamClassificationDataDealer.subsampleHamAndSpamMails([...hamData, ...spamData])
o(hamCount).equals(1)
o(spamCount).equals(10)
o(subsampledTrainingData.length).equals(11)
})
})
o.spec("fetchAllTrainingData", () => {
o("returns empty training data when index or training data is null", async () => {
mailBox.clientSpamTrainingData = null
mailBox.modifiedClientSpamTrainingDataIndex = null
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
const trainingDataset = await spamClassificationDataDealer.fetchAllTrainingData("owner")
o(trainingDataset.trainingData.length).equals(0)
o(trainingDataset.hamCount).equals(0)
o(trainingDataset.spamCount).equals(0)
o(trainingDataset.lastTrainingDataIndexId).equals(GENERATED_MIN_ID)
})
o("uploads training data when clientSpamTrainingData is empty", async () => {
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
const spamTrainingData = Array.from({ length: 10 }, () =>
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST),
).concat(Array.from({ length: 10 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST)))
const mails = Array.from({ length: 10 }, () =>
createMailByFolderAndReceivedDate([mailBox.currentMailBag!.mails, "inboxMailId"], inboxFolder._id, new Date(), mailDetails._id),
).concat(
Array.from({ length: 10 }, () =>
createMailByFolderAndReceivedDate([mailBox.currentMailBag!.mails, "spamMailId"], spamFolder._id, new Date(), mailDetails._id),
),
)
const modifiedIndicesSinceStart = spamTrainingData.map((data) =>
createClientSpamTrainingDatumIndexEntryByClientSpamTrainingDatumElementId(getElementId(data)),
)
when(entityClientMock.loadAll(ClientSpamTrainingDatumTypeRef, mailBox.clientSpamTrainingData!)).thenResolve([], spamTrainingData)
when(entityClientMock.loadAll(MailTypeRef, mailBox.currentMailBag!.mails, anything())).thenResolve(mails)
when(entityClientMock.loadAll(MailTypeRef, mailBox.archivedMailBags[0].mails, anything())).thenResolve([])
when(entityClientMock.loadAll(MailFolderTypeRef, mailBox.folders!.folders)).thenResolve([inboxFolder, spamFolder, trashFolder])
when(entityClientMock.loadAll(ClientSpamTrainingDatumIndexEntryTypeRef, mailBox.modifiedClientSpamTrainingDataIndex!)).thenResolve(
modifiedIndicesSinceStart,
)
when(bulkMailLoaderMock.loadMailDetails(mails)).thenResolve(
mails.map((mail) => {
return { mail, mailDetails }
}),
)
const trainingDataset = await spamClassificationDataDealer.fetchAllTrainingData("owner")
// first load: empty, second load: fetch uploaded data
verify(entityClientMock.loadAll(ClientSpamTrainingDatumTypeRef, mailBox.clientSpamTrainingData!), { times: 2 })
verify(entityClientMock.loadAll(ClientSpamTrainingDatumIndexEntryTypeRef, mailBox.modifiedClientSpamTrainingDataIndex!), { times: 1 })
const unencryptedPayload = mails.map((mail) => {
return {
mailId: mail._id,
isSpam: isSameId(mail.sets[0], spamFolder._id),
confidence: DEFAULT_IS_SPAM_CONFIDENCE,
vector: new Uint8Array(1),
} as UnencryptedPopulateClientSpamTrainingDatum
})
verify(mailFacadeMock.populateClientSpamTrainingData("owner", unencryptedPayload), { times: 1 })
o(trainingDataset).deepEquals({
trainingData: spamTrainingData,
lastTrainingDataIndexId: getElementId(last(modifiedIndicesSinceStart)!),
hamCount: 10,
spamCount: 10,
})
})
o("successfully returns training data with mixed ham/spam data", async () => {
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
const spamTrainingData = Array.from({ length: 10 }, () =>
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST),
).concat(Array.from({ length: 10 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST)))
const modifiedIndicesSinceStart = spamTrainingData.map((data) =>
createClientSpamTrainingDatumIndexEntryByClientSpamTrainingDatumElementId(getElementId(data)),
)
when(entityClientMock.loadAll(ClientSpamTrainingDatumTypeRef, mailBox.clientSpamTrainingData!)).thenResolve(spamTrainingData)
when(entityClientMock.loadAll(MailTypeRef, mailBox.archivedMailBags[0].mails, anything())).thenResolve([])
when(entityClientMock.loadAll(MailFolderTypeRef, mailBox.folders!.folders)).thenResolve([inboxFolder, spamFolder, trashFolder])
when(entityClientMock.loadAll(ClientSpamTrainingDatumIndexEntryTypeRef, mailBox.modifiedClientSpamTrainingDataIndex!)).thenResolve(
modifiedIndicesSinceStart,
)
const trainingDataset = await spamClassificationDataDealer.fetchAllTrainingData("owner")
// only one load as the list is already populated
verify(entityClientMock.loadAll(ClientSpamTrainingDatumTypeRef, mailBox.clientSpamTrainingData!), { times: 1 })
verify(entityClientMock.loadAll(ClientSpamTrainingDatumIndexEntryTypeRef, mailBox.modifiedClientSpamTrainingDataIndex!), { times: 1 })
o(trainingDataset).deepEquals({
trainingData: spamTrainingData,
lastTrainingDataIndexId: getElementId(last(modifiedIndicesSinceStart)!),
hamCount: 10,
spamCount: 10,
})
})
o("filters out training data with confidence=0 or spamDecision NONE", async () => {
const noneDecisionData = createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.NONE)
const zeroConfData = createSpamTrainingDatumByConfidenceAndDecision("0", SpamDecision.WHITELIST)
const validHamData = createSpamTrainingDatumByConfidenceAndDecision("1", SpamDecision.WHITELIST)
const validSpamData = createSpamTrainingDatumByConfidenceAndDecision("4", SpamDecision.BLACKLIST)
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
const spamTrainingData = [noneDecisionData, zeroConfData, validSpamData, validHamData]
const modifiedIndicesSinceStart = spamTrainingData.map((data) =>
createClientSpamTrainingDatumIndexEntryByClientSpamTrainingDatumElementId(getElementId(data)),
)
when(entityClientMock.loadAll(ClientSpamTrainingDatumTypeRef, mailBox.clientSpamTrainingData!)).thenResolve(spamTrainingData)
when(entityClientMock.loadAll(ClientSpamTrainingDatumIndexEntryTypeRef, mailBox.modifiedClientSpamTrainingDataIndex!)).thenResolve(
modifiedIndicesSinceStart,
)
when(entityClientMock.loadAll(MailFolderTypeRef, mailBox.folders!.folders)).thenResolve([inboxFolder, spamFolder, trashFolder])
const result = await spamClassificationDataDealer.fetchAllTrainingData("owner")
o(result.trainingData.length).equals(2)
o(result.spamCount).equals(1)
o(result.hamCount).equals(1)
o(new Set(result.trainingData)).deepEquals(new Set([validSpamData, validHamData]))
})
})
o.spec("fetchPartialTrainingDataFromIndexStartId", () => {
o("returns empty training data when index or training data is null", async () => {
mailBox.clientSpamTrainingData = null
mailBox.modifiedClientSpamTrainingDataIndex = null
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
const trainingDataset = await spamClassificationDataDealer.fetchPartialTrainingDataFromIndexStartId("startId", "owner")
o(trainingDataset.trainingData.length).equals(0)
o(trainingDataset.hamCount).equals(0)
o(trainingDataset.spamCount).equals(0)
o(trainingDataset.lastTrainingDataIndexId).equals("startId")
})
o("returns empty training data when modifiedClientSpamTrainingDataIndicesSinceStart are null", async () => {
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
when(
entityClientMock.loadRange(
ClientSpamTrainingDatumIndexEntryTypeRef,
mailBox.modifiedClientSpamTrainingDataIndex!,
"startId",
SINGLE_TRAIN_INTERVAL_TRAINING_DATA_LIMIT,
false,
),
).thenResolve([])
const trainingDataset = await spamClassificationDataDealer.fetchPartialTrainingDataFromIndexStartId("startId", "owner")
o(trainingDataset.trainingData.length).equals(0)
o(trainingDataset.hamCount).equals(0)
o(trainingDataset.spamCount).equals(0)
o(trainingDataset.lastTrainingDataIndexId).equals("startId")
})
o("returns new training data when index or training data is there", async () => {
when(entityClientMock.load(MailboxGroupRootTypeRef, "owner")).thenResolve(mailboxGroupRoot)
when(entityClientMock.load(MailBoxTypeRef, "mailbox")).thenResolve(mailBox)
const oldSpamTrainingData = Array.from({ length: 50 }, () =>
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST),
).concat(Array.from({ length: 50 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST)))
oldSpamTrainingData.map((data) => (data._id = [mailBox.clientSpamTrainingData!, GENERATED_MIN_ID]))
const newSpamTrainingData = Array.from({ length: 10 }, () =>
createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.WHITELIST),
).concat(Array.from({ length: 10 }, () => createSpamTrainingDatumByConfidenceAndDecision(DEFAULT_IS_SPAM_CONFIDENCE, SpamDecision.BLACKLIST)))
newSpamTrainingData.map((data) => (data._id = [mailBox.clientSpamTrainingData!, GENERATED_MIN_ID]))
const modifiedIndicesSinceStart = newSpamTrainingData.map((data) =>
createClientSpamTrainingDatumIndexEntryByClientSpamTrainingDatumElementId(getElementId(data)),
)
when(
entityClientMock.loadRange(
ClientSpamTrainingDatumIndexEntryTypeRef,
mailBox.modifiedClientSpamTrainingDataIndex!,
"startId",
anything(),
false,
),
).thenResolve(modifiedIndicesSinceStart)
when(
entityClientMock.loadMultiple(
ClientSpamTrainingDatumTypeRef,
mailBox.clientSpamTrainingData,
modifiedIndicesSinceStart.map((index) => index.clientSpamTrainingDatumElementId),
),
).thenResolve(newSpamTrainingData)
const trainingDataset = await spamClassificationDataDealer.fetchPartialTrainingDataFromIndexStartId("startId", "owner")
o(trainingDataset.trainingData.length).equals(20)
o(trainingDataset.hamCount).equals(10)
o(trainingDataset.spamCount).equals(10)
o(trainingDataset.lastTrainingDataIndexId).equals(getElementId(last(modifiedIndicesSinceStart)!))
})
})
o.spec("fetchMailsByMailbagAfterDate", () => {
o("correctly filters mails with received date greater than start date", async () => {
const startDate = new Date(2020, 11, 30)
const dayBeforeStart = new Date(2020, 11, 29)
const recentMails = Array.from({ length: 10 }, () =>
createMailByFolderAndReceivedDate([mailBox.currentMailBag!.mails, "inboxMailId"], inboxFolder._id, new Date(2025, 11, 17), mailDetails._id),
)
const oldMails = Array.from({ length: 10 }, () =>
createMailByFolderAndReceivedDate([mailBox.currentMailBag!.mails, "inboxMailId"], inboxFolder._id, dayBeforeStart, mailDetails._id),
)
const mails = recentMails.concat(oldMails)
when(entityClientMock.loadAll(MailTypeRef, mailBox.currentMailBag!.mails, anything())).thenResolve(mails)
when(bulkMailLoaderMock.loadMailDetails(recentMails)).thenResolve(
recentMails.map((mail) => {
return { mail, mailDetails }
}),
)
const result = await spamClassificationDataDealer.fetchMailsByMailbagAfterDate(
mailBox.currentMailBag!,
[inboxFolder, spamFolder, trashFolder],
startDate,
)
o(result.length).equals(10)
})
})
})

View file

@ -1,36 +1,41 @@
import o from "@tutao/otest" import o from "@tutao/otest"
import fs from "node:fs" import fs from "node:fs"
import { parseCsv } from "../../../../../../src/common/misc/parsing/CsvParser" import { parseCsv } from "../../../../../../src/common/misc/parsing/CsvParser"
import { import { Classifier, DEFAULT_PREDICTION_THRESHOLD, SpamClassifier } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
DEFAULT_PREPROCESS_CONFIGURATION,
SpamClassifier,
SpamTrainMailDatum,
} from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { OfflineStoragePersistence } from "../../../../../../src/mail-app/workerUtils/index/OfflineStoragePersistence"
import { matchers, object, when } from "testdouble" import { matchers, object, when } from "testdouble"
import { assertNotNull, promiseMap } from "@tutao/tutanota-utils" import { assertNotNull } from "@tutao/tutanota-utils"
import { SpamClassificationInitializer } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassificationInitializer" import { SpamClassificationDataDealer, TrainingDataset } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassificationDataDealer"
import { CacheStorage } from "../../../../../../src/common/api/worker/rest/DefaultEntityRestCache" import { CacheStorage } from "../../../../../../src/common/api/worker/rest/DefaultEntityRestCache"
import { mockAttribute } from "@tutao/tutanota-test-utils" import { mockAttribute } from "@tutao/tutanota-test-utils"
import "@tensorflow/tfjs-backend-cpu" import "@tensorflow/tfjs-backend-cpu"
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer" import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom" import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom"
import { createTestEntity } from "../../../../TestUtils" import { createTestEntity } from "../../../../TestUtils"
import { MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs" import { ClientSpamTrainingDatum, ClientSpamTrainingDatumTypeRef, MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
import { Sequential } from "@tensorflow/tfjs-layers" import { Sequential } from "@tensorflow/tfjs-layers"
import { SparseVectorCompressor } from "../../../../../../src/common/api/common/utils/spamClassificationUtils/SparseVectorCompressor"
import {
DEFAULT_IS_SPAM_CONFIDENCE,
DEFAULT_PREPROCESS_CONFIGURATION,
SpamMailDatum,
SpamMailProcessor,
} from "../../../../../../src/common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { SpamDecision } from "../../../../../../src/common/api/common/TutanotaConstants"
import { GENERATED_MIN_ID } from "../../../../../../src/common/api/common/utils/EntityUtils"
const { anything } = matchers const { anything } = matchers
export const DATASET_FILE_PATH: string = "./tests/api/worker/utils/spamClassification/spam_classification_test_mails.csv" export const DATASET_FILE_PATH: string = "./tests/api/worker/utils/spamClassification/spam_classification_test_mails.csv"
const TEST_OWNER_GROUP = "owner"
export async function readMailDataFromCSV(filePath: string): Promise<{ export async function readMailDataFromCSV(filePath: string): Promise<{
spamData: SpamTrainMailDatum[] spamData: SpamMailDatum[]
hamData: SpamTrainMailDatum[] hamData: SpamMailDatum[]
}> { }> {
const file = await fs.promises.readFile(filePath) const file = await fs.promises.readFile(filePath)
const csv = parseCsv(file.toString()) const csv = parseCsv(file.toString())
let spamData: SpamTrainMailDatum[] = [] let spamData: SpamMailDatum[] = []
let hamData: SpamTrainMailDatum[] = [] let hamData: SpamMailDatum[] = []
for (const row of csv.rows.slice(1, csv.rows.length - 1)) { for (const row of csv.rows.slice(1, csv.rows.length - 1)) {
const subject = row[8] const subject = row[8]
const body = row[10] const body = row[10]
@ -43,57 +48,77 @@ export async function readMailDataFromCSV(filePath: string): Promise<{
let isSpam = label === "spam" ? true : label === "ham" ? false : null let isSpam = label === "spam" ? true : label === "ham" ? false : null
isSpam = assertNotNull(isSpam, "Unknown label detected: " + label) isSpam = assertNotNull(isSpam, "Unknown label detected: " + label)
const targetData = isSpam ? spamData : hamData const spamMailDatum = {
targetData.push({
mailId: ["mailListId", "mailElementId"],
subject, subject,
body, body,
isSpam, ownerGroup: TEST_OWNER_GROUP,
isSpamConfidence: 1,
ownerGroup: "owner",
sender: from, sender: from,
toRecipients: to, toRecipients: to,
ccRecipients: cc, ccRecipients: cc,
bccRecipients: bcc, bccRecipients: bcc,
authStatus: authStatus, authStatus: authStatus,
} as SpamTrainMailDatum) } as SpamMailDatum
const targetData = isSpam ? spamData : hamData
targetData.push(spamMailDatum)
} }
return { spamData, hamData } return { spamData, hamData }
} }
async function convertToClientTrainingDatum(spamData: SpamMailDatum[], spamProcessor: SpamMailProcessor, isSpam: boolean): Promise<ClientSpamTrainingDatum[]> {
let result: ClientSpamTrainingDatum[] = []
for (const spamDatum of spamData) {
const clientSpamTrainingDatum = createTestEntity(ClientSpamTrainingDatumTypeRef, {
confidence: DEFAULT_IS_SPAM_CONFIDENCE.toString(),
spamDecision: isSpam ? SpamDecision.BLACKLIST : SpamDecision.WHITELIST,
vector: await spamProcessor.vectorizeAndCompress(spamDatum),
})
result.push(clientSpamTrainingDatum)
}
return result
}
function getTrainingDataset(trainSet: ClientSpamTrainingDatum[]) {
return {
trainingData: trainSet,
hamCount: trainSet.filter((item) => item.spamDecision === SpamDecision.WHITELIST).length,
spamCount: trainSet.filter((item) => item.spamDecision === SpamDecision.BLACKLIST).length,
lastTrainingDataIndexId: GENERATED_MIN_ID,
}
}
// Initial training (cutoff by day or amount) // Initial training (cutoff by day or amount)
o.spec("SpamClassifierTest", () => { o.spec("SpamClassifierTest", () => {
const mockOfflineStorageCache = object<CacheStorage>() const mockCacheStorage = object<CacheStorage>()
const mockOfflineStorage = object<OfflineStoragePersistence>() const mockSpamClassificationDataDealer = object<SpamClassificationDataDealer>()
const mockSpamClassificationInitializer = object<SpamClassificationInitializer>()
let nonEfficientSmallVectorizer: HashingVectorizer
let spamClassifier: SpamClassifier let spamClassifier: SpamClassifier
let spamProcessor: SpamMailProcessor
let compressor: SparseVectorCompressor
let spamData: SpamTrainMailDatum[] let spamData: ClientSpamTrainingDatum[]
let hamData: SpamTrainMailDatum[] let hamData: ClientSpamTrainingDatum[]
let dataSlice: SpamTrainMailDatum[] let dataSlice: ClientSpamTrainingDatum[]
o.beforeEach(async () => { o.beforeEach(async () => {
const spamHamData = await readMailDataFromCSV(DATASET_FILE_PATH) const spamHamData = await readMailDataFromCSV(DATASET_FILE_PATH)
spamData = spamHamData.spamData
hamData = spamHamData.hamData mockSpamClassificationDataDealer.fetchAllTrainingData = async () => {
return getTrainingDataset(dataSlice)
}
const vectorLength = 512
compressor = new SparseVectorCompressor(vectorLength)
spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(vectorLength), compressor)
spamClassifier = new SpamClassifier(mockCacheStorage, mockSpamClassificationDataDealer, true)
spamClassifier.spamMailProcessor = spamProcessor
spamClassifier.sparseVectorCompressor = compressor
spamData = await convertToClientTrainingDatum(spamHamData.spamData, spamProcessor, true)
hamData = await convertToClientTrainingDatum(spamHamData.hamData, spamProcessor, false)
dataSlice = spamData.concat(hamData) dataSlice = spamData.concat(hamData)
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
mockSpamClassificationInitializer.init = async () => {
return dataSlice
}
nonEfficientSmallVectorizer = new HashingVectorizer(512)
spamClassifier = new SpamClassifier(
mockOfflineStorage,
mockOfflineStorageCache,
mockSpamClassificationInitializer,
true,
DEFAULT_PREPROCESS_CONFIGURATION,
nonEfficientSmallVectorizer,
)
}) })
o("processSpam maintains server classification when client classification is not enabled", async function () { o("processSpam maintains server classification when client classification is not enabled", async function () {
@ -101,23 +126,27 @@ o.spec("SpamClassifierTest", () => {
_id: ["mailListId", "mailId"], _id: ["mailListId", "mailId"],
sets: [["folderList", "serverFolder"]], sets: [["folderList", "serverFolder"]],
}) })
const spamTrainMailDatum: SpamTrainMailDatum = { const spamMailDatum: SpamMailDatum = {
mailId: mail._id, ownerGroup: TEST_OWNER_GROUP,
subject: mail.subject, subject: mail.subject,
body: "some body", body: "some body",
isSpam: true, sender: "sender@tuta.com",
isSpamConfidence: 1, toRecipients: "recipient@tuta.com",
ownerGroup: "owner",
sender: "",
toRecipients: "",
ccRecipients: "", ccRecipients: "",
bccRecipients: "", bccRecipients: "",
authStatus: "", authStatus: "0",
} }
const layersModel = object<Sequential>()
spamClassifier.addSpamClassifierForOwner(spamTrainMailDatum.ownerGroup, layersModel, false)
const predictedSpam = await spamClassifier.predict(spamTrainMailDatum) // convert to vector
const layersModel = object<Sequential>()
const classifier = object<Classifier>()
classifier.layersModel = layersModel
classifier.isEnabled = false
classifier.threshold = DEFAULT_PREDICTION_THRESHOLD
spamClassifier.addSpamClassifierForOwner(spamMailDatum.ownerGroup, classifier)
const vector = await spamProcessor.vectorize(spamMailDatum)
const predictedSpam = await spamClassifier.predict(vector, spamMailDatum.ownerGroup)
o(predictedSpam).equals(null) o(predictedSpam).equals(null)
}) })
@ -126,37 +155,73 @@ o.spec("SpamClassifierTest", () => {
_id: ["mailListId", "mailId"], _id: ["mailListId", "mailId"],
sets: [["folderList", "serverFolder"]], sets: [["folderList", "serverFolder"]],
}) })
const spamTrainMailDatum: SpamTrainMailDatum = { const spamMailDatum: SpamMailDatum = {
mailId: mail._id, ownerGroup: TEST_OWNER_GROUP,
subject: mail.subject, subject: mail.subject,
body: "some body", body: "some body",
isSpam: false, sender: "sender@tuta.com",
isSpamConfidence: 0, toRecipients: "recipient@tuta.com",
ownerGroup: "owner",
sender: "",
toRecipients: "",
ccRecipients: "", ccRecipients: "",
bccRecipients: "", bccRecipients: "",
authStatus: "", authStatus: "0",
} }
const layersModel = object<Sequential>() const layersModel = object<Sequential>()
when(layersModel.predict(anything())).thenReturn(tensor1d([1])) when(layersModel.predict(anything())).thenReturn(tensor1d([1]))
spamClassifier.addSpamClassifierForOwner(spamTrainMailDatum.ownerGroup, layersModel, true) const classifier = object<Classifier>()
classifier.layersModel = layersModel
classifier.isEnabled = true
classifier.threshold = DEFAULT_PREDICTION_THRESHOLD
spamClassifier.addSpamClassifierForOwner(spamMailDatum.ownerGroup, classifier)
const predictedSpam = await spamClassifier.predict(spamTrainMailDatum) const vector = await spamProcessor.vectorize(spamMailDatum)
const predictedSpam = await spamClassifier.predict(vector, spamMailDatum.ownerGroup)
o(predictedSpam).equals(true) o(predictedSpam).equals(true)
}) })
o("processSpam respects the classifier threshold", async function () {
const mail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "mailId"],
sets: [["folderList", "serverFolder"]],
})
const spamMailDatum: SpamMailDatum = {
ownerGroup: TEST_OWNER_GROUP,
subject: mail.subject,
body: "some body",
sender: "sender@tuta.com",
toRecipients: "recipient@tuta.com",
ccRecipients: "",
bccRecipients: "",
authStatus: "0",
}
const layersModel = object<Sequential>()
when(layersModel.predict(anything())).thenReturn(tensor1d([0.7]))
const classifier = object<Classifier>()
classifier.layersModel = layersModel
classifier.isEnabled = true
classifier.threshold = 0.9
spamClassifier.addSpamClassifierForOwner(spamMailDatum.ownerGroup, classifier)
const vector = await spamProcessor.vectorize(spamMailDatum)
const predictedSpam = await spamClassifier.predict(vector, spamMailDatum.ownerGroup)
o(predictedSpam).equals(false)
})
o("Initial training only", async () => { o("Initial training only", async () => {
o.timeout(20_000) o.timeout(20_000)
const trainTestSplit = dataSlice.length * 0.8 const trainTestSplit = dataSlice.length * 0.8
const trainSet = dataSlice.slice(0, trainTestSplit) const trainSet = dataSlice.slice(0, trainTestSplit)
const testSet = dataSlice.slice(trainTestSplit) const testSet = dataSlice.slice(trainTestSplit)
const trainingDataset: TrainingDataset = getTrainingDataset(trainSet)
await spamClassifier.initialTraining(TEST_OWNER_GROUP, trainingDataset)
await testClassifier(spamClassifier, testSet, compressor)
await spamClassifier.initialTraining(trainSet) const classifier = spamClassifier.classifiers.get(TEST_OWNER_GROUP)
await testClassifier(spamClassifier, testSet) o(classifier?.hamCount).equals(trainingDataset.hamCount)
o(classifier?.spamCount).equals(trainingDataset.spamCount)
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(trainingDataset.hamCount, trainingDataset.spamCount))
}) })
o("Initial training and refitting in multi step", async () => { o("Initial training and refitting in multi step", async () => {
@ -170,18 +235,26 @@ o.spec("SpamClassifierTest", () => {
const trainSetSecondHalf = trainSet.slice(trainSet.length / 2, trainSet.length) const trainSetSecondHalf = trainSet.slice(trainSet.length / 2, trainSet.length)
dataSlice = trainSetFirstHalf dataSlice = trainSetFirstHalf
o(await mockSpamClassificationInitializer.init("owner")).deepEquals(trainSetFirstHalf) o((await mockSpamClassificationDataDealer.fetchAllTrainingData(TEST_OWNER_GROUP)).trainingData).deepEquals(dataSlice)
await spamClassifier.initialTraining(dataSlice) const initialTrainingDataset = getTrainingDataset(dataSlice)
await spamClassifier.initialTraining(TEST_OWNER_GROUP, initialTrainingDataset)
console.log(`==> Result when testing with mails in two steps (first step).`) console.log(`==> Result when testing with mails in two steps (first step).`)
await testClassifier(spamClassifier, testSet) await testClassifier(spamClassifier, testSet, compressor)
await spamClassifier.updateModel("owner", trainSetSecondHalf) const trainingDatasetSecondHalf = getTrainingDataset(trainSetSecondHalf)
await spamClassifier.updateModel(TEST_OWNER_GROUP, trainingDatasetSecondHalf)
console.log(`==> Result when testing with mails in two steps (second step).`) console.log(`==> Result when testing with mails in two steps (second step).`)
await testClassifier(spamClassifier, testSet) await testClassifier(spamClassifier, testSet, compressor)
const classifier = spamClassifier.classifiers.get(TEST_OWNER_GROUP)
const finalHamCount = initialTrainingDataset.hamCount + trainingDatasetSecondHalf.hamCount
const finalSpamCount = initialTrainingDataset.spamCount + trainingDatasetSecondHalf.spamCount
o(classifier?.hamCount).equals(finalHamCount)
o(classifier?.spamCount).equals(finalSpamCount)
o(classifier?.threshold).equals(spamClassifier.calculateThreshold(finalHamCount, finalSpamCount))
}) })
o("preprocessMail outputs expected tokens for mail content", async () => { o("preprocessMail outputs expected tokens for mail content", async () => {
const classifier = new SpamClassifier(object(), object(), object())
const mail = { const mail = {
subject: `Sample Tokens and values`, subject: `Sample Tokens and values`,
sender: "sender", sender: "sender",
@ -273,8 +346,8 @@ o.spec("SpamClassifierTest", () => {
<table cellpadding="0" cellspacing="0" border="0" role="presentation" width="100%"><tbody><tr><td align="center"><a href="https://mail.abc-web.de/optiext/optiextension.dll?ID=someid" rel="noopener noreferrer" target="_blank" style="text-decoration:none"><img id="OWATemporaryImageDivContainer1" src="https://mail.some-domain.de/images/SMC/grafik/image.png" alt="" border="0" class="" width="100%" style="max-width:100%;display:block;width:100%"></a></td></tr></tbody></table> <table cellpadding="0" cellspacing="0" border="0" role="presentation" width="100%"><tbody><tr><td align="center"><a href="https://mail.abc-web.de/optiext/optiextension.dll?ID=someid" rel="noopener noreferrer" target="_blank" style="text-decoration:none"><img id="OWATemporaryImageDivContainer1" src="https://mail.some-domain.de/images/SMC/grafik/image.png" alt="" border="0" class="" width="100%" style="max-width:100%;display:block;width:100%"></a></td></tr></tbody></table>
this text is shown this text is shown
`, `,
} as SpamTrainMailDatum } as SpamMailDatum
const preprocessedMail = classifier.preprocessMail(mail) const preprocessedMail = spamProcessor.preprocessMail(mail)
// prettier-ignore // prettier-ignore
const expectedOutput = `Sample Tokens and values const expectedOutput = `Sample Tokens and values
Hello TSPECIALCHAR these are my MAC Address Hello TSPECIALCHAR these are my MAC Address
@ -364,13 +437,17 @@ authStatus`
}) })
o("predict uses different models for different owner groups", async () => { o("predict uses different models for different owner groups", async () => {
const firstGroupModel = object<LayersModel>() const firstGroupClassifier = object<Classifier>()
const secondGroupModel = object<LayersModel>() firstGroupClassifier.layersModel = object<LayersModel>()
mockAttribute(spamClassifier, spamClassifier.loadModel, (ownerGroup) => { firstGroupClassifier.threshold = DEFAULT_PREDICTION_THRESHOLD
const secondGroupClassifier = object<Classifier>()
secondGroupClassifier.threshold = DEFAULT_PREDICTION_THRESHOLD
secondGroupClassifier.layersModel = object<LayersModel>()
mockAttribute(spamClassifier, spamClassifier.loadClassifier, (ownerGroup) => {
if (ownerGroup === "firstGroup") { if (ownerGroup === "firstGroup") {
return Promise.resolve(firstGroupModel) return Promise.resolve(firstGroupClassifier)
} else if (ownerGroup === "secondGroup") { } else if (ownerGroup === "secondGroup") {
return Promise.resolve(secondGroupModel) return Promise.resolve(secondGroupClassifier)
} }
return null return null
}) })
@ -380,9 +457,9 @@ authStatus`
}) })
const firstGroupReturnTensor = tensor1d([1.0], undefined) const firstGroupReturnTensor = tensor1d([1.0], undefined)
when(firstGroupModel.predict(matchers.anything())).thenReturn(firstGroupReturnTensor) when(firstGroupClassifier.layersModel.predict(matchers.anything())).thenReturn(firstGroupReturnTensor)
const secondGroupReturnTensor = tensor1d([0.0], undefined) const secondGroupReturnTensor = tensor1d([0.0], undefined)
when(secondGroupModel.predict(matchers.anything())).thenReturn(secondGroupReturnTensor) when(secondGroupClassifier.layersModel.predict(matchers.anything())).thenReturn(secondGroupReturnTensor)
await spamClassifier.initialize("firstGroup") await spamClassifier.initialize("firstGroup")
await spamClassifier.initialize("secondGroup") await spamClassifier.initialize("secondGroup")
@ -397,14 +474,16 @@ authStatus`
authStatus: "", authStatus: "",
} }
const isSpamFirstMail = await spamClassifier.predict({ const firstMailVector = await spamProcessor.vectorize({
ownerGroup: "firstGroup", ownerGroup: "firstGroup",
...commonSpamFields, ...commonSpamFields,
}) })
const isSpamSecondMail = await spamClassifier.predict({ const isSpamFirstMail = await spamClassifier.predict(firstMailVector, "firstGroup")
const secondMailVector = await spamProcessor.vectorize({
ownerGroup: "secondGroup", ownerGroup: "secondGroup",
...commonSpamFields, ...commonSpamFields,
}) })
const isSpamSecondMail = await spamClassifier.predict(secondMailVector, "secondGroup")
o(isSpamFirstMail).equals(true) o(isSpamFirstMail).equals(true)
o(isSpamSecondMail).equals(false) o(isSpamSecondMail).equals(false)
@ -419,39 +498,60 @@ authStatus`
// They run in loop hence do take more time to finish and is not necessary to include in CI test suite // They run in loop hence do take more time to finish and is not necessary to include in CI test suite
// //
// To enable running this, change following constant to true // To enable running this, change following constant to true
const DO_RUN_PERFORMANCE_ANALYSIS = false const DO_RUN_PERFORMANCE_ANALYSIS = true
if (DO_RUN_PERFORMANCE_ANALYSIS) { if (DO_RUN_PERFORMANCE_ANALYSIS) {
async function filterForMisclassifiedClientSpamTrainingData(
classifier: SpamClassifier,
compressor: SparseVectorCompressor,
dataSlice: ClientSpamTrainingDatum[],
desiredSlice: number,
) {
return dataSlice
.slice(desiredSlice)
.filter(async (datum) => {
const currentClassificationIsSpam = datum.spamDecision === SpamDecision.BLACKLIST
const actualPrediction = await classifier.predict(compressor.binaryToVector(datum.vector), datum._ownerGroup || TEST_OWNER_GROUP)
return currentClassificationIsSpam !== actualPrediction
})
.sort()
.slice(0, desiredSlice)
}
o.spec("SpamClassifier - Performance Analysis", () => { o.spec("SpamClassifier - Performance Analysis", () => {
const mockOfflineStorageCache = object<CacheStorage>() const mockOfflineStorageCache = object<CacheStorage>()
const mockOfflineStorage = object<OfflineStoragePersistence>() const compressor = new SparseVectorCompressor()
let spamClassifier = object<SpamClassifier>() let spamClassifier = object<SpamClassifier>()
let dataSlice: SpamTrainMailDatum[] let dataSlice: ClientSpamTrainingDatum[]
o.beforeEach(() => { let spamProcessor: SpamMailProcessor
const mockSpamClassificationInitializer = object<SpamClassificationInitializer>()
mockSpamClassificationInitializer.init = async () => { o.beforeEach(async () => {
return dataSlice const mockSpamClassificationDataDealer = object<SpamClassificationDataDealer>()
mockSpamClassificationDataDealer.fetchAllTrainingData = async () => {
return getTrainingDataset(dataSlice)
} }
spamClassifier = new SpamClassifier(mockOfflineStorage, mockOfflineStorageCache, mockSpamClassificationInitializer) spamProcessor = new SpamMailProcessor(DEFAULT_PREPROCESS_CONFIGURATION, new HashingVectorizer(), compressor)
spamClassifier = new SpamClassifier(mockOfflineStorageCache, mockSpamClassificationDataDealer, false)
spamClassifier.spamMailProcessor = spamProcessor
}) })
o("time to refit", async () => { o("time to refit", async () => {
o.timeout(20_000_000) o.timeout(20_000_000)
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH) const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
const hamSlice = hamData.slice(0, 1000) const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 1000), spamProcessor, false)
const spamSlice = spamData.slice(0, 400) const spamSlice = await convertToClientTrainingDatum(spamData.slice(0, 400), spamProcessor, true)
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
const start = performance.now() const start = performance.now()
await spamClassifier.initialTraining(dataSlice) await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice))
const initialTrainingDuration = performance.now() - start const initialTrainingDuration = performance.now() - start
console.log(`initial training time ${initialTrainingDuration}ms`) console.log(`initial training time ${initialTrainingDuration}ms`)
for (let i = 0; i < 20; i++) { for (let i = 0; i < 20; i++) {
const nowSpam = [hamSlice[0]] const nowSpam = [hamSlice[0]]
nowSpam.map((formerHam) => (formerHam.isSpam = true)) nowSpam.map((formerHam) => (formerHam.spamDecision = "1"))
const retrainingStart = performance.now() const retrainingStart = performance.now()
await spamClassifier.updateModel("owner", nowSpam) await spamClassifier.updateModel(TEST_OWNER_GROUP, getTrainingDataset(nowSpam))
const retrainingDuration = performance.now() - retrainingStart const retrainingDuration = performance.now() - retrainingStart
console.log(`retraining time ${retrainingDuration}ms`) console.log(`retraining time ${retrainingDuration}ms`)
} }
@ -460,17 +560,13 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
o("refit after moving a false negative classification multiple times", async () => { o("refit after moving a false negative classification multiple times", async () => {
o.timeout(20_000_000) o.timeout(20_000_000)
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH) const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
const hamSlice = hamData.slice(0, 100) const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 100), spamProcessor, false)
const spamSlice = spamData.slice(0, 10) const spamSlice = await convertToClientTrainingDatum(spamData.slice(0, 10), spamProcessor, true)
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
// seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
await spamClassifier.initialTraining(dataSlice) await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice))
const falseNegatives = spamData const falseNegatives = await filterForMisclassifiedClientSpamTrainingData(spamClassifier, compressor, spamSlice, 10)
.slice(10)
.filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
.sort()
.slice(0, 10)
let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0) let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0)
for (let i = 0; i < falseNegatives.length; i++) { for (let i = 0; i < falseNegatives.length; i++) {
@ -479,32 +575,39 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
let retrainCount = 0 let retrainCount = 0
let predictedSpam = false let predictedSpam = false
while (!predictedSpam && retrainCount++ <= 3) { while (!predictedSpam && retrainCount++ <= 10) {
await copiedClassifier.updateModel("owner", [{ ...sample, isSpam: true, isSpamConfidence: 1 }]) await copiedClassifier.updateModel(
predictedSpam = assertNotNull(await copiedClassifier.predict(sample)) TEST_OWNER_GROUP,
getTrainingDataset([
{
...sample,
spamDecision: SpamDecision.BLACKLIST,
confidence: "4",
},
]),
)
predictedSpam = assertNotNull(await copiedClassifier.predict(compressor.binaryToVector(sample.vector), TEST_OWNER_GROUP))
} }
retrainingNeeded[i] = retrainCount retrainingNeeded[i] = retrainCount
} }
console.log(retrainingNeeded) console.log(retrainingNeeded)
const maxRetrain = Math.max(...retrainingNeeded) const maxRetrain = Math.max(...retrainingNeeded)
o.check(retrainingNeeded.length >= 10).equals(true) o.check(retrainingNeeded.length >= 10).equals(false)
o.check(maxRetrain < 3).equals(true) o.check(maxRetrain < 3).equals(true)
}) })
o("refit after moving a false positive classification multiple times", async () => { o("refit after moving a false positive classification multiple times", async () => {
o.timeout(20_000_000) o.timeout(20_000_000)
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH) const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
const hamSlice = hamData.slice(0, 10) const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 10), spamProcessor, false)
const spamSlice = spamData.slice(0, 100) const spamSlice = await convertToClientTrainingDatum(spamData.slice(0, 100), spamProcessor, true)
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
// seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
await spamClassifier.initialTraining(dataSlice) await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice))
const falsePositive = hamData
.slice(10) const falsePositive = await filterForMisclassifiedClientSpamTrainingData(spamClassifier, compressor, hamSlice, 10)
.filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
.slice(0, 10)
let retrainingNeeded = new Array<number>(falsePositive.length).fill(0) let retrainingNeeded = new Array<number>(falsePositive.length).fill(0)
for (let i = 0; i < falsePositive.length; i++) { for (let i = 0; i < falsePositive.length; i++) {
const sample = falsePositive[i] const sample = falsePositive[i]
@ -513,32 +616,31 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
let retrainCount = 0 let retrainCount = 0
let predictedSpam = false let predictedSpam = false
while (!predictedSpam && retrainCount++ <= 10) { while (!predictedSpam && retrainCount++ <= 10) {
await copiedClassifier.updateModel("owner", [{ ...sample, isSpam: true }]) await copiedClassifier.updateModel(
await copiedClassifier.updateModel("owner", [{ ...sample, isSpam: false }]) TEST_OWNER_GROUP,
predictedSpam = assertNotNull(await copiedClassifier.predict(sample)) getTrainingDataset([{ ...sample, spamDecision: SpamDecision.WHITELIST, confidence: "4" }]),
)
predictedSpam = assertNotNull(await copiedClassifier.predict(compressor.binaryToVector(sample.vector), TEST_OWNER_GROUP))
} }
retrainingNeeded[i] = retrainCount retrainingNeeded[i] = retrainCount
} }
console.log(retrainingNeeded) console.log(retrainingNeeded)
const maxRetrain = Math.max(...retrainingNeeded) const maxRetrain = Math.max(...retrainingNeeded)
o.check(retrainingNeeded.length >= 10).equals(true) o.check(retrainingNeeded.length >= 10).equals(false)
o.check(maxRetrain < 3).equals(true) o.check(maxRetrain < 3).equals(true)
}) })
o("retrain after moving a false negative classification multiple times", async () => { o("retrain from scratch after moving a false negative classification multiple times", async () => {
o.timeout(20_000_000) o.timeout(20_000_000)
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH) const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
const hamSlice = hamData.slice(0, 100) const hamSlice = await convertToClientTrainingDatum(hamData.slice(0, 100), spamProcessor, false)
const spamSlice = spamData.slice(0, 10) const spamSlice = await convertToClientTrainingDatum(spamData.slice(0, 10), spamProcessor, true)
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
await spamClassifier.initialTraining(dataSlice) await spamClassifier.initialTraining(TEST_OWNER_GROUP, getTrainingDataset(dataSlice))
const falseNegatives = spamData const falseNegatives = await filterForMisclassifiedClientSpamTrainingData(spamClassifier, compressor, spamSlice, 10)
.slice(10)
.filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
.slice(0, 10)
let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0) let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0)
for (let i = 0; i < falseNegatives.length; i++) { for (let i = 0; i < falseNegatives.length; i++) {
@ -548,68 +650,30 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
let retrainCount = 0 let retrainCount = 0
let predictedSpam = false let predictedSpam = false
while (!predictedSpam && retrainCount++ <= 10) { while (!predictedSpam && retrainCount++ <= 10) {
await copiedClassifier.initialTraining([...dataSlice, sample]) await copiedClassifier.initialTraining(
predictedSpam = assertNotNull(await copiedClassifier.predict(sample)) TEST_OWNER_GROUP,
getTrainingDataset([...dataSlice, { ...sample, spamDecision: SpamDecision.BLACKLIST, confidence: "4" }]),
)
predictedSpam = assertNotNull(await copiedClassifier.predict(compressor.binaryToVector(sample.vector), TEST_OWNER_GROUP))
} }
retrainingNeeded[i] = retrainCount retrainingNeeded[i] = retrainCount
} }
console.log(retrainingNeeded) console.log(retrainingNeeded)
const maxRetrain = Math.max(...retrainingNeeded) const maxRetrain = Math.max(...retrainingNeeded)
o.check(retrainingNeeded.length >= 10).equals(true) o.check(retrainingNeeded.length >= 10).equals(false)
o.check(maxRetrain < 3).equals(true) o.check(maxRetrain < 3).equals(true)
}) })
o("Time spent in vectorization during initial training", async () => {
o.timeout(2_000_000)
const ITERATION_COUNT: number = 1
const { spamData, hamData } = await readMailDataFromCSV(DATASET_FILE_PATH)
dataSlice = spamData.concat(hamData)
let trainingTimes = new Array<number>()
let vectorizationTimes = new Array<number>()
let trainingWithoutVectorization = new Array<number>()
await promiseMap(
new Array<number>(ITERATION_COUNT).fill(0),
async () => {
const { vectorizationTime, trainingTime } = await spamClassifier.initialTraining(dataSlice)
trainingTimes.push(trainingTime)
vectorizationTimes.push(vectorizationTime)
trainingWithoutVectorization.push(trainingTime - vectorizationTime)
},
{ concurrency: ITERATION_COUNT },
)
trainingTimes = trainingTimes.sort()
vectorizationTimes = vectorizationTimes.sort()
trainingWithoutVectorization = trainingWithoutVectorization.sort()
const avgTrainingTime = trainingTimes.reduce((a, b) => a + b, 0) / trainingTimes.length
const avgVectorizationTime = vectorizationTimes.reduce((a, b) => a + b, 0) / vectorizationTimes.length
const avgTrainingWithoutVectorization = trainingWithoutVectorization.reduce((a, b) => a + b, 0) / trainingWithoutVectorization.length
console.log("For vectorization:")
console.log({ min: vectorizationTimes.at(0), max: vectorizationTimes.at(-1), avg: avgVectorizationTime })
console.log("For whole training:")
console.log({ min: trainingTimes.at(0), max: trainingTimes.at(-1), avg: avgTrainingTime })
console.log("For training without vectorization:")
console.log({
min: trainingWithoutVectorization.at(0),
max: trainingWithoutVectorization.at(-1),
avg: avgTrainingWithoutVectorization,
})
})
}) })
} }
async function testClassifier(classifier: SpamClassifier, mails: SpamTrainMailDatum[]): Promise<void> { async function testClassifier(classifier: SpamClassifier, mails: ClientSpamTrainingDatum[], compressor: SparseVectorCompressor): Promise<void> {
let predictionArray: number[] = [] let predictionArray: number[] = []
for (let mail of mails) { for (let mail of mails) {
const prediction = await classifier.predict(mail) const prediction = await classifier.predict(compressor.binaryToVector(mail.vector), TEST_OWNER_GROUP)
predictionArray.push(prediction ? 1 : 0) predictionArray.push(prediction ? 1 : 0)
} }
const ysArray = mails.map((mail) => mail.isSpam) const ysArray = mails.map((mail) => mail.spamDecision === SpamDecision.BLACKLIST)
let tp = 0, let tp = 0,
tn = 0, tn = 0,

View file

@ -0,0 +1,38 @@
import o from "@tutao/otest"
import { promiseMap } from "@tutao/tutanota-utils"
import { SparseVectorCompressor } from "../../../../../../src/common/api/common/utils/spamClassificationUtils/SparseVectorCompressor"
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
import { DATASET_FILE_PATH, readMailDataFromCSV } from "./SpamClassifierTest"
import { spamClassifierTokenizer, SpamMailProcessor } from "../../../../../../src/common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
o.spec("SparseVectorCompressorTest", () => {
o("sparse compress vectors", async () => {
o.timeout(20_000)
const spamHamData = await readMailDataFromCSV(DATASET_FILE_PATH)
const spamData = spamHamData.spamData
const hamData = spamHamData.hamData
const dataSlice = spamData.concat(hamData)
const tokenizedMails = await promiseMap(dataSlice, (mail) => spamClassifierTokenizer(new SpamMailProcessor().preprocessMail(mail)))
const vectorizer = new HashingVectorizer()
const vectors = (await vectorizer.transform(tokenizedMails)).slice(0, 1)
const compressor = new SparseVectorCompressor()
const BYTES_PER_NUMBER = 2
console.log("Byte size of a number: ", BYTES_PER_NUMBER)
const compressedVectors = vectors.map((v) => compressor.vectorToBinary(v))
const decompressedVectors = compressedVectors.map((v) => compressor.binaryToVector(v))
const decompressedVectorByteSizes: number[] = []
const compressedVectorByteSizes: number[] = []
for (let i = 0; i < compressedVectors.length; i++) {
compressedVectorByteSizes.push(compressedVectors[i].values.length + compressedVectors[i].length)
decompressedVectorByteSizes.push(decompressedVectors[i].length)
}
const averageCompressedVectorByteSize = compressedVectorByteSizes.reduce((a, b) => a + b, 0) / compressedVectorByteSizes.length
const averageDecompressedVectorByteSize = decompressedVectorByteSizes.reduce((a, b) => a + b, 0) / decompressedVectorByteSizes.length
console.log(`Average compressed vector byte size (Custom): ${averageCompressedVectorByteSize.toFixed(2)}B`)
console.log(`Average decompressed vector byte size (Custom): ${averageDecompressedVectorByteSize.toFixed(2)}B`)
o.check(decompressedVectors).deepEquals(vectors)
})
})

View file

@ -4,7 +4,6 @@ import { mock, Spy, spy, verify } from "@tutao/tutanota-test-utils"
import { MailSetKind, OperationType, ProcessingState } from "../../../src/common/api/common/TutanotaConstants.js" import { MailSetKind, OperationType, ProcessingState } from "../../../src/common/api/common/TutanotaConstants.js"
import { import {
BodyTypeRef, BodyTypeRef,
ClientSpamClassifierResultTypeRef,
Mail, Mail,
MailAddressTypeRef, MailAddressTypeRef,
MailDetails, MailDetails,
@ -24,17 +23,15 @@ import { UserController } from "../../../src/common/api/main/UserController.js"
import { createTestEntity } from "../TestUtils.js" import { createTestEntity } from "../TestUtils.js"
import { EntityUpdateData, PrefetchStatus } from "../../../src/common/api/common/utils/EntityUpdateUtils.js" import { EntityUpdateData, PrefetchStatus } from "../../../src/common/api/common/utils/EntityUpdateUtils.js"
import { MailboxDetail, MailboxModel } from "../../../src/common/mailFunctionality/MailboxModel.js" import { MailboxDetail, MailboxModel } from "../../../src/common/mailFunctionality/MailboxModel.js"
import { getElementId, getListId } from "../../../src/common/api/common/utils/EntityUtils.js"
import { MailModel } from "../../../src/mail-app/mail/model/MailModel.js" import { MailModel } from "../../../src/mail-app/mail/model/MailModel.js"
import { EventController } from "../../../src/common/api/main/EventController.js" import { EventController } from "../../../src/common/api/main/EventController.js"
import { MailFacade } from "../../../src/common/api/worker/facades/lazy/MailFacade.js" import { MailFacade } from "../../../src/common/api/worker/facades/lazy/MailFacade.js"
import { ClientModelInfo } from "../../../src/common/api/common/EntityFunctions" import { ClientModelInfo } from "../../../src/common/api/common/EntityFunctions"
import { InboxRuleHandler } from "../../../src/mail-app/mail/model/InboxRuleHandler" import { InboxRuleHandler } from "../../../src/mail-app/mail/model/InboxRuleHandler"
import { SpamClassificationHandler } from "../../../src/mail-app/mail/model/SpamClassificationHandler"
import { SpamClassifier, SpamTrainMailDatum } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { WebsocketConnectivityModel } from "../../../src/common/misc/WebsocketConnectivityModel" import { WebsocketConnectivityModel } from "../../../src/common/misc/WebsocketConnectivityModel"
import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem" import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem"
import { NotAuthorizedError, NotFoundError } from "../../../src/common/api/common/error/RestError" import { NotAuthorizedError } from "../../../src/common/api/common/error/RestError"
import { ProcessInboxHandler } from "../../../src/mail-app/mail/model/ProcessInboxHandler"
const { anything } = matchers const { anything } = matchers
@ -68,6 +65,7 @@ o.spec("MailModelTest", function () {
logins = object() logins = object()
let userController = object<UserController>() let userController = object<UserController>()
when(userController.isUpdateForLoggedInUserInstance(matchers.anything(), matchers.anything())).thenReturn(false) when(userController.isUpdateForLoggedInUserInstance(matchers.anything(), matchers.anything())).thenReturn(false)
when(userController.isInternalUser()).thenReturn(true)
when(logins.getUserController()).thenReturn(userController) when(logins.getUserController()).thenReturn(userController)
connectivityModel = object<WebsocketConnectivityModel>() connectivityModel = object<WebsocketConnectivityModel>()
@ -82,39 +80,9 @@ o.spec("MailModelTest", function () {
mailFacade, mailFacade,
connectivityModel, connectivityModel,
() => object(), () => object(),
() => null,
) )
}) })
o("doesn't send notification for another folder", async function () {
const mail = createTestEntity(MailTypeRef, { _id: ["mailBagListId", "mailId"], sets: [] })
restClient.addListInstances(mail)
await model.entityEventsReceived([
makeUpdate({
instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mail),
operation: OperationType.CREATE,
}),
])
o(showSpy.invocations.length).equals(0)
})
o("doesn't send notification for move operation", async function () {
const mail = createTestEntity(MailTypeRef, { _id: ["mailBagListId", "mailId"], sets: [] })
restClient.addListInstances(mail)
await model.entityEventsReceived([
makeUpdate({
instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mail),
operation: OperationType.DELETE,
}),
makeUpdate({
instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mail),
operation: OperationType.CREATE,
}),
])
o(showSpy.invocations.length).equals(0)
})
o("markMails", async function () { o("markMails", async function () {
const mailId1: IdTuple = ["mailbag id1", "mail id1"] const mailId1: IdTuple = ["mailbag id1", "mail id1"]
const mailId2: IdTuple = ["mailbag id2", "mail id2"] const mailId2: IdTuple = ["mailbag id2", "mail id2"]
@ -125,19 +93,15 @@ o.spec("MailModelTest", function () {
o.spec("Inbox rule processing and spam prediction", () => { o.spec("Inbox rule processing and spam prediction", () => {
let inboxRuleHandler: InboxRuleHandler let inboxRuleHandler: InboxRuleHandler
let spamClassificationHandler: SpamClassificationHandler
let spamClassifier: SpamClassifier
let mailboxModel: MailboxModel let mailboxModel: MailboxModel
let modelWithSpamAndInboxRule: MailModel let modelWithSpamAndInboxRule: MailModel
let mail: Mail let mail: Mail
let mailDetails: MailDetails let mailDetails: MailDetails
let processInboxHandler: ProcessInboxHandler = object<ProcessInboxHandler>()
o.beforeEach(async () => { o.beforeEach(async () => {
const entityClient = new EntityClient(restClient, ClientModelInfo.getNewInstanceForTestsOnly()) const entityClient = new EntityClient(restClient, ClientModelInfo.getNewInstanceForTestsOnly())
mailboxModel = instance(MailboxModel) mailboxModel = instance(MailboxModel)
inboxRuleHandler = object<InboxRuleHandler>() inboxRuleHandler = object<InboxRuleHandler>()
spamClassifier = object<SpamClassifier>()
spamClassificationHandler = new SpamClassificationHandler(mailFacade, spamClassifier)
mailDetails = createTestEntity(MailDetailsTypeRef, { mailDetails = createTestEntity(MailDetailsTypeRef, {
_id: "mailDetail", _id: "mailDetail",
@ -159,6 +123,7 @@ o.spec("MailModelTest", function () {
sets: [inboxFolder._id], sets: [inboxFolder._id],
sender: createTestEntity(MailAddressTypeRef, { name: "Sender", address: "sender@tuta.com" }), sender: createTestEntity(MailAddressTypeRef, { name: "Sender", address: "sender@tuta.com" }),
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED, processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED,
processNeeded: true,
authStatus: "0", authStatus: "0",
}) })
const mailDetailsBlob: MailDetailsBlob = createTestEntity(MailDetailsBlobTypeRef, { const mailDetailsBlob: MailDetailsBlob = createTestEntity(MailDetailsBlobTypeRef, {
@ -180,8 +145,7 @@ o.spec("MailModelTest", function () {
logins, logins,
mailFacade, mailFacade,
connectivityModel, connectivityModel,
() => spamClassificationHandler, () => processInboxHandler,
() => inboxRuleHandler,
), ),
(m: MailModel) => { (m: MailModel) => {
m.getFolderSystemByGroupId = (groupId) => { m.getFolderSystemByGroupId = (groupId) => {
@ -193,162 +157,51 @@ o.spec("MailModelTest", function () {
) )
}) })
o("does not re-apply inbox rules or re-classify mail if the mail is in a final processingState", async function () { o("invokes ProcessInboxHandler if the mail is not processed", async function () {
const alreadyClassifiedMail = createTestEntity(MailTypeRef, { const notProcessedMail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "maildIdWithFinalProcessingState"], _id: ["mailListId", "notProcessedMailId"],
_ownerGroup: "mailGroup", _ownerGroup: "mailGroup",
mailDetails: ["detailsList", mailDetails._id], mailDetails: ["detailsList", mailDetails._id],
sets: [inboxFolder._id], sets: [inboxFolder._id],
processingState: ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE, processNeeded: true,
clientSpamClassifierResult: createTestEntity(ClientSpamClassifierResultTypeRef),
}) })
restClient.addListInstances(alreadyClassifiedMail) restClient.addListInstances(notProcessedMail)
when(mailFacade.loadMailDetailsBlob(alreadyClassifiedMail)).thenResolve(mailDetails) when(mailFacade.loadMailDetailsBlob(notProcessedMail)).thenResolve(mailDetails)
const alreadyClassifiedMailCreateEvent = makeUpdate({ const alreadyClassifiedMailCreateEvent = makeUpdate({
instanceListId: "mailListId", instanceListId: "mailListId",
instanceId: "maildIdWithFinalProcessingState", instanceId: "notProcessedMailId",
operation: OperationType.CREATE, operation: OperationType.CREATE,
}) })
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([alreadyClassifiedMailCreateEvent]) await modelWithSpamAndInboxRule.entityEventsReceived([alreadyClassifiedMailCreateEvent])
await processingDone
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 }) verify(processInboxHandler.handleIncomingMail(anything(), anything(), anything(), anything()), { times: 1 })
verify(spamClassificationHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
verify(spamClassifier.storeSpamClassification(anything()), { times: 0 })
verify(spamClassifier.predict(anything()), { times: 0 })
}) })
o("don't classify mail if the mail is read and in INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING state", async function () { o("does not invoke ProcessInboxHandler if the mail is already processed", async function () {
const alreadyClassifiedMail = createTestEntity(MailTypeRef, { const alreadyProcessedMail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "maildIdWithFinalProcessingState"], _id: ["mailListId", "processedMailId"],
_ownerGroup: "mailGroup", _ownerGroup: "mailGroup",
mailDetails: ["detailsList", mailDetails._id], mailDetails: ["detailsList", mailDetails._id],
sets: [inboxFolder._id], sets: [inboxFolder._id],
processingState: ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING, processNeeded: false,
clientSpamClassifierResult: createTestEntity(ClientSpamClassifierResultTypeRef),
unread: false,
}) })
restClient.addListInstances(alreadyClassifiedMail) restClient.addListInstances(alreadyProcessedMail)
when(mailFacade.loadMailDetailsBlob(alreadyClassifiedMail)).thenResolve(mailDetails) when(mailFacade.loadMailDetailsBlob(alreadyProcessedMail)).thenResolve(mailDetails)
const alreadyClassifiedMailCreateEvent = makeUpdate({ const alreadyClassifiedMailCreateEvent = makeUpdate({
instanceListId: "mailListId", instanceListId: "mailListId",
instanceId: "mailIdDoNotRunPredictionState", instanceId: "processedMailId",
operation: OperationType.CREATE, operation: OperationType.CREATE,
}) })
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([alreadyClassifiedMailCreateEvent]) await modelWithSpamAndInboxRule.entityEventsReceived([alreadyClassifiedMailCreateEvent])
await processingDone
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 }) verify(processInboxHandler.handleIncomingMail(anything(), anything(), anything(), anything()), { times: 0 })
verify(spamClassificationHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
verify(spamClassifier.storeSpamClassification(anything()), { times: 0 })
verify(spamClassifier.predict(anything()), { times: 0 })
}) })
o("don't classify mail if the mail is in INBOX_RULE_NOT_PROCESSED_AND_DO_NOT_RUN_SPAM_PREDICTION state", async function () { o("does not invoke ProcessInboxHandler when downloading of mail fails on create mail event", async function () {
const alreadyClassifiedMail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "maildIdWithFinalProcessingState"],
_ownerGroup: "mailGroup",
mailDetails: ["detailsList", mailDetails._id],
sets: [inboxFolder._id],
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED_AND_DO_NOT_RUN_SPAM_PREDICTION,
clientSpamClassifierResult: createTestEntity(ClientSpamClassifierResultTypeRef),
})
restClient.addListInstances(alreadyClassifiedMail)
when(mailFacade.loadMailDetailsBlob(alreadyClassifiedMail)).thenResolve(mailDetails)
const alreadyClassifiedMailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailIdDoNotRunPredictionState",
operation: OperationType.CREATE,
})
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([alreadyClassifiedMailCreateEvent])
await processingDone
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 })
verify(spamClassificationHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
verify(spamClassifier.storeSpamClassification(anything()), { times: 0 })
verify(spamClassifier.predict(anything()), { times: 0 })
})
o("does not try to apply inbox rule when downloading of mail fails on create mail event", async function () {
restClient.setListElementException(mail._id, new NotFoundError("Mail not found"))
const mailCreateEvent = makeUpdate({
instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mail),
operation: OperationType.CREATE,
})
await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 })
})
o("spam prediction does not happen when inbox rule is applied", async () => {
when(spamClassifier.predict(anything())).thenResolve(false)
const mailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.CREATE,
})
// when inbox rule is applied
when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(inboxFolder)
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await processingDone
const expectedSpamTrainMailDatum: SpamTrainMailDatum = {
mailId: ["mailListId", "mailId"],
ownerGroup: "mailGroup",
body: "some text",
subject: "subject",
isSpam: false,
isSpamConfidence: 1,
sender: "Sender sender@tuta.com",
toRecipients: "Recipient recipient@tuta.com",
ccRecipients: "",
bccRecipients: "",
authStatus: "TAUTHENTICATED",
}
verify(spamClassifier.storeSpamClassification(expectedSpamTrainMailDatum), { times: 1 })
verify(spamClassifier.predict(anything()), { times: 0 })
})
o("spam prediction happens when inbox rule is not applied", async () => {
when(spamClassifier.predict(anything())).thenResolve(false)
const mailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.CREATE,
})
when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(null)
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await processingDone
const expectedSpamTrainMailDatum: SpamTrainMailDatum = {
mailId: ["mailListId", "mailId"],
ownerGroup: "mailGroup",
body: "some text",
subject: "subject",
isSpam: false,
isSpamConfidence: 1,
sender: "Sender sender@tuta.com",
toRecipients: "Recipient recipient@tuta.com",
ccRecipients: "",
bccRecipients: "",
authStatus: "TAUTHENTICATED",
}
verify(spamClassifier.storeSpamClassification(expectedSpamTrainMailDatum), { times: 1 })
verify(spamClassifier.predict(anything()), { times: 1 })
})
o("does not try to do spam classification when downloading of mail fails on create mail event", async function () {
when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(null) when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(null)
const mailCreateEvent = makeUpdate({ const mailCreateEvent = makeUpdate({
instanceListId: "mailListId", instanceListId: "mailListId",
@ -358,42 +211,8 @@ o.spec("MailModelTest", function () {
// mail not being there // mail not being there
restClient.setListElementException(mail._id, new NotAuthorizedError("blah")) restClient.setListElementException(mail._id, new NotAuthorizedError("blah"))
const { processingDone: inboxRuleProcessedMailNotThere } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent]) await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await inboxRuleProcessedMailNotThere verify(processInboxHandler.handleIncomingMail(anything(), anything(), anything(), anything()), { times: 0 })
verify(spamClassifier.storeSpamClassification(anything()), { times: 0 })
verify(spamClassifier.predict(anything()), { times: 0 })
// mail being there
restClient.addListInstances(mail)
const { processingDone: inboxRuleProcessedMailIsThere } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await inboxRuleProcessedMailIsThere
const expectedSpamTrainMailDatum: SpamTrainMailDatum = {
mailId: ["mailListId", "mailId"],
ownerGroup: "mailGroup",
body: "some text",
subject: "subject",
isSpam: false,
isSpamConfidence: 1,
sender: "Sender sender@tuta.com",
toRecipients: "Recipient recipient@tuta.com",
ccRecipients: "",
bccRecipients: "",
authStatus: "TAUTHENTICATED",
}
verify(spamClassifier.storeSpamClassification(expectedSpamTrainMailDatum), { times: 1 })
verify(spamClassifier.predict(anything()), { times: 1 })
})
o("deletes a training datum for deleted mail event", async () => {
const mailDeleteEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.DELETE,
})
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([mailDeleteEvent])
await processingDone
verify(spamClassifier.deleteSpamClassification(mail._id), { times: 1 })
}) })
}) })

View file

@ -0,0 +1,224 @@
import o from "@tutao/otest"
import { matchers, object, verify, when } from "testdouble"
import {
Body,
BodyTypeRef,
ClientSpamClassifierResultTypeRef,
Mail,
MailDetails,
MailDetailsTypeRef,
MailFolderTypeRef,
MailTypeRef,
} from "../../../src/common/api/entities/tutanota/TypeRefs"
import { FeatureType, MailSetKind, ProcessingState, SpamDecision } from "../../../src/common/api/common/TutanotaConstants"
import { ClientClassifierType } from "../../../src/common/api/common/ClientClassifierType"
import { assertNotNull, delay } from "@tutao/tutanota-utils"
import { MailFacade } from "../../../src/common/api/worker/facades/lazy/MailFacade"
import { createTestEntity } from "../TestUtils"
import { SpamClassificationHandler } from "../../../src/mail-app/mail/model/SpamClassificationHandler"
import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem"
import { isSameId } from "../../../src/common/api/common/utils/EntityUtils"
import { InboxRuleHandler } from "../../../src/mail-app/mail/model/InboxRuleHandler"
import { ProcessInboxHandler, UnencryptedProcessInboxDatum } from "../../../src/mail-app/mail/model/ProcessInboxHandler"
import { MailboxDetail } from "../../../src/common/mailFunctionality/MailboxModel"
import { createSpamMailDatum, SpamMailProcessor } from "../../../src/common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
import { LoginController } from "../../../src/common/api/main/LoginController"
const { anything } = matchers
o.spec("ProcessInboxHandlerTest", function () {
let mailFacade = object<MailFacade>()
let logins = object<LoginController>()
let body: Body
let mail: Mail
let spamHandler: SpamClassificationHandler
let folderSystem: FolderSystem
let mailboxDetail: MailboxDetail
let mailDetails: MailDetails
let inboxRuleHandler: InboxRuleHandler = object<InboxRuleHandler>()
let processInboxHandler: ProcessInboxHandler
const inboxFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "inbox"], folderType: MailSetKind.INBOX })
const trashFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "trash"], folderType: MailSetKind.TRASH })
const spamFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "spam"], folderType: MailSetKind.SPAM })
o.beforeEach(function () {
spamHandler = object<SpamClassificationHandler>()
inboxRuleHandler = object<InboxRuleHandler>()
body = createTestEntity(BodyTypeRef, { text: "Body Text" })
mailDetails = createTestEntity(MailDetailsTypeRef, { _id: "mailDetail", body })
mail = createTestEntity(MailTypeRef, {
_id: ["listId", "elementId"],
sets: [spamFolder._id],
subject: "subject",
_ownerGroup: "owner",
mailDetails: ["detailsList", mailDetails._id],
unread: true,
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED,
clientSpamClassifierResult: createTestEntity(ClientSpamClassifierResultTypeRef, { spamDecision: SpamDecision.NONE }),
processNeeded: true,
})
folderSystem = object<FolderSystem>()
mailboxDetail = object()
when(mailFacade.moveMails(anything(), anything(), anything())).thenResolve([])
when(
mailFacade.loadMailDetailsBlob(
matchers.argThat((requestedMails: Mail) => {
return isSameId(requestedMails._id, mail._id)
}),
),
).thenDo(async () => mailDetails)
processInboxHandler = new ProcessInboxHandler(
logins,
mailFacade,
() => spamHandler,
() => inboxRuleHandler,
new Map(),
0,
)
when(logins.isEnabled(FeatureType.SpamClientClassification)).thenReturn(true)
})
o("handleIncomingMail does move mail if it has been processed already", async function () {
mail.sets = [inboxFolder._id]
mail.processNeeded = false
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 })
verify(spamHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(inboxFolder)
verify(mailFacade.processNewMails(anything(), anything()), { times: 0 })
})
o("handleIncomingMail does move mail from inbox to other folder if inbox rule applies", async function () {
mail.sets = [inboxFolder._id]
const processInboxDatum: UnencryptedProcessInboxDatum = {
classifierType: ClientClassifierType.CUSTOMER_INBOX_RULES,
mailId: mail._id,
targetMoveFolder: trashFolder._id,
vector: new Uint8Array(),
}
when(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)).thenResolve({
targetFolder: trashFolder,
processInboxDatum,
})
verify(spamHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(trashFolder)
await delay(0)
verify(mailFacade.processNewMails(assertNotNull(mail._ownerGroup), [processInboxDatum]))
})
o("handleIncomingMail does move mail from inbox to spam folder if mail is spam", async function () {
mail.sets = [inboxFolder._id]
when(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)).thenResolve(null)
const processInboxDatum: UnencryptedProcessInboxDatum = {
classifierType: ClientClassifierType.CLIENT_CLASSIFICATION,
mailId: mail._id,
targetMoveFolder: spamFolder._id,
vector: new Uint8Array(),
}
when(spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)).thenResolve({
targetFolder: spamFolder,
processInboxDatum,
})
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(spamFolder)
await delay(0)
verify(mailFacade.processNewMails(assertNotNull(mail._ownerGroup), [processInboxDatum]))
})
o("handleIncomingMail does NOT move mail from inbox to spam folder if mail is ham", async function () {
mail.sets = [inboxFolder._id]
when(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)).thenResolve(null)
const processInboxDatum: UnencryptedProcessInboxDatum = {
classifierType: null,
mailId: mail._id,
targetMoveFolder: inboxFolder._id,
vector: new Uint8Array(),
}
when(spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)).thenResolve({
targetFolder: inboxFolder,
processInboxDatum,
})
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(inboxFolder)
await delay(0)
verify(mailFacade.processNewMails(assertNotNull(mail._ownerGroup), [processInboxDatum]))
})
o("handleIncomingMail does NOT move mail from spam to inbox folder if mail is spam", async function () {
mail.sets = [spamFolder._id]
when(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)).thenResolve(null)
const processInboxDatum: UnencryptedProcessInboxDatum = {
classifierType: ClientClassifierType.CLIENT_CLASSIFICATION,
mailId: mail._id,
targetMoveFolder: spamFolder._id,
vector: new Uint8Array(),
}
when(spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)).thenResolve({
targetFolder: spamFolder,
processInboxDatum,
})
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(spamFolder)
await delay(0)
verify(mailFacade.processNewMails(assertNotNull(mail._ownerGroup), [processInboxDatum]))
})
o("handleIncomingMail moves mail from spam to inbox folder if mail is ham", async function () {
mail.sets = [spamFolder._id]
when(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)).thenResolve(null)
const processInboxDatum: UnencryptedProcessInboxDatum = {
classifierType: ClientClassifierType.CLIENT_CLASSIFICATION,
mailId: mail._id,
targetMoveFolder: inboxFolder._id,
vector: new Uint8Array(),
}
when(spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)).thenResolve({
targetFolder: inboxFolder,
processInboxDatum,
})
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(inboxFolder)
await delay(0)
verify(mailFacade.processNewMails(assertNotNull(mail._ownerGroup), [processInboxDatum]))
})
o("handleIncomingMail does NOT move mail from inbox to spam folder if spam classification is disabled", async function () {
when(logins.isEnabled(FeatureType.SpamClientClassification)).thenReturn(false)
mail.sets = [inboxFolder._id]
const compressedVector = new Uint8Array([2, 4, 8, 16])
const datum = createSpamMailDatum(mail, mailDetails)
when(mailFacade.vectorizeAndCompressMails({ mail, mailDetails })).thenResolve(compressedVector)
processInboxHandler = new ProcessInboxHandler(
logins,
mailFacade,
() => spamHandler,
() => inboxRuleHandler,
new Map(),
0,
)
when(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)).thenResolve(null)
const processedMail: UnencryptedProcessInboxDatum = {
classifierType: null,
mailId: mail._id,
targetMoveFolder: inboxFolder._id,
vector: compressedVector,
}
verify(spamHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
const targetFolder = await processInboxHandler.handleIncomingMail(mail, inboxFolder, mailboxDetail, folderSystem)
o(targetFolder).deepEquals(inboxFolder)
await delay(0)
verify(mailFacade.processNewMails(assertNotNull(mail._ownerGroup), [processedMail]))
})
})

View file

@ -1,5 +1,5 @@
import o from "@tutao/otest" import o from "@tutao/otest"
import { matchers, object, verify, when } from "testdouble" import { matchers, object, when } from "testdouble"
import { import {
Body, Body,
BodyTypeRef, BodyTypeRef,
@ -10,8 +10,7 @@ import {
MailFolderTypeRef, MailFolderTypeRef,
MailTypeRef, MailTypeRef,
} from "../../../src/common/api/entities/tutanota/TypeRefs" } from "../../../src/common/api/entities/tutanota/TypeRefs"
import { SpamClassifier, SpamTrainMailDatum } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassifier" import { SpamClassifier } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { getMailBodyText } from "../../../src/common/api/common/CommonMailUtils"
import { MailSetKind, ProcessingState, SpamDecision } from "../../../src/common/api/common/TutanotaConstants" import { MailSetKind, ProcessingState, SpamDecision } from "../../../src/common/api/common/TutanotaConstants"
import { ClientClassifierType } from "../../../src/common/api/common/ClientClassifierType" import { ClientClassifierType } from "../../../src/common/api/common/ClientClassifierType"
import { assert, assertNotNull } from "@tutao/tutanota-utils" import { assert, assertNotNull } from "@tutao/tutanota-utils"
@ -20,7 +19,8 @@ import { createTestEntity } from "../TestUtils"
import { SpamClassificationHandler } from "../../../src/mail-app/mail/model/SpamClassificationHandler" import { SpamClassificationHandler } from "../../../src/mail-app/mail/model/SpamClassificationHandler"
import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem" import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem"
import { isSameId } from "../../../src/common/api/common/utils/EntityUtils" import { isSameId } from "../../../src/common/api/common/utils/EntityUtils"
import { any } from "@tensorflow/tfjs-core" import { UnencryptedProcessInboxDatum } from "../../../src/mail-app/mail/model/ProcessInboxHandler"
import { createSpamMailDatum, SpamMailProcessor } from "../../../src/common/api/common/utils/spamClassificationUtils/SpamMailProcessor"
const { anything } = matchers const { anything } = matchers
@ -30,6 +30,7 @@ o.spec("SpamClassificationHandlerTest", function () {
let mail: Mail let mail: Mail
let spamClassifier: SpamClassifier let spamClassifier: SpamClassifier
let spamHandler: SpamClassificationHandler let spamHandler: SpamClassificationHandler
let spamMailProcessor: SpamMailProcessor = new SpamMailProcessor()
let folderSystem: FolderSystem let folderSystem: FolderSystem
let mailDetails: MailDetails let mailDetails: MailDetails
@ -37,7 +38,7 @@ o.spec("SpamClassificationHandlerTest", function () {
const trashFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "trash"], folderType: MailSetKind.TRASH }) const trashFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "trash"], folderType: MailSetKind.TRASH })
const spamFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "spam"], folderType: MailSetKind.SPAM }) const spamFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "spam"], folderType: MailSetKind.SPAM })
o.beforeEach(function () { o.beforeEach(async function () {
spamClassifier = object<SpamClassifier>() spamClassifier = object<SpamClassifier>()
body = createTestEntity(BodyTypeRef, { text: "Body Text" }) body = createTestEntity(BodyTypeRef, { text: "Body Text" })
@ -54,7 +55,7 @@ o.spec("SpamClassificationHandlerTest", function () {
}) })
folderSystem = object<FolderSystem>() folderSystem = object<FolderSystem>()
when(mailFacade.moveMails(anything(), anything(), anything(), ClientClassifierType.CLIENT_CLASSIFICATION)).thenResolve([]) when(mailFacade.moveMails(anything(), anything(), anything())).thenResolve([])
when(folderSystem.getSystemFolderByType(MailSetKind.SPAM)).thenReturn(spamFolder) when(folderSystem.getSystemFolderByType(MailSetKind.SPAM)).thenReturn(spamFolder)
when(folderSystem.getSystemFolderByType(MailSetKind.INBOX)).thenReturn(inboxFolder) when(folderSystem.getSystemFolderByType(MailSetKind.INBOX)).thenReturn(inboxFolder)
when(folderSystem.getSystemFolderByType(MailSetKind.TRASH)).thenReturn(trashFolder) when(folderSystem.getSystemFolderByType(MailSetKind.TRASH)).thenReturn(trashFolder)
@ -75,87 +76,77 @@ o.spec("SpamClassificationHandlerTest", function () {
), ),
anything(), anything(),
).thenDo(async () => [{ mail, mailDetails }]) ).thenDo(async () => [{ mail, mailDetails }])
spamHandler = new SpamClassificationHandler(mailFacade, spamClassifier) when(spamClassifier.vectorizeAndCompress(createSpamMailDatum(mail, mailDetails))).thenResolve(
await spamMailProcessor.vectorizeAndCompress(createSpamMailDatum(mail, mailDetails)),
)
spamHandler = new SpamClassificationHandler(spamClassifier)
}) })
o("predictSpamForNewMail does move mail from inbox to spam folder if mail is spam", async function () { o("predictSpamForNewMail does move mail from inbox to spam folder if mail is spam", async function () {
mail.sets = [inboxFolder._id] mail.sets = [inboxFolder._id]
when(spamClassifier.predict(anything())).thenResolve(true) when(spamClassifier.predict(anything(), anything())).thenResolve(true)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem) const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData?.mails).deepEquals([mail._id]) const expectedProcessInboxDatum: UnencryptedProcessInboxDatum = {
o(spamHandler.classifierResultServiceMailIds).deepEquals([]) mailId: mail._id,
o(finalResult).deepEquals(spamFolder) targetMoveFolder: spamFolder._id,
classifierType: ClientClassifierType.CLIENT_CLASSIFICATION,
vector: await spamMailProcessor.vectorizeAndCompress(createSpamMailDatum(mail, mailDetails)),
}
o(finalResult.targetFolder).deepEquals(spamFolder)
o(finalResult.processInboxDatum).deepEquals(expectedProcessInboxDatum)
}) })
o("predictSpamForNewMail does NOT move mail from inbox to spam folder if mail is ham", async function () { o("predictSpamForNewMail does NOT move mail from inbox to spam folder if mail is ham", async function () {
mail.sets = [inboxFolder._id] mail.sets = [inboxFolder._id]
when(spamClassifier.predict(anything())).thenResolve(false) when(spamClassifier.predict(anything(), anything())).thenResolve(false)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem) const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null) const expectedProcessInboxDatum: UnencryptedProcessInboxDatum = {
o(spamHandler.classifierResultServiceMailIds).deepEquals([mail._id]) mailId: mail._id,
o(finalResult).deepEquals(inboxFolder) targetMoveFolder: inboxFolder._id,
classifierType: null,
vector: await spamMailProcessor.vectorizeAndCompress(createSpamMailDatum(mail, mailDetails)),
}
o(finalResult.targetFolder).deepEquals(inboxFolder)
o(finalResult.processInboxDatum).deepEquals(expectedProcessInboxDatum)
}) })
o("predictSpamForNewMail does NOT move mail from spam to inbox folder if mail is spam", async function () { o("predictSpamForNewMail does NOT move mail from spam to inbox folder if mail is spam", async function () {
mail.sets = [spamFolder._id] mail.sets = [spamFolder._id]
when(spamClassifier.predict(anything())).thenResolve(true) when(spamClassifier.predict(anything(), anything())).thenResolve(true)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem) const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null) const expectedProcessInboxDatum: UnencryptedProcessInboxDatum = {
o(spamHandler.classifierResultServiceMailIds).deepEquals([mail._id]) mailId: mail._id,
o(finalResult).deepEquals(spamFolder) targetMoveFolder: spamFolder._id,
classifierType: null,
vector: await spamMailProcessor.vectorizeAndCompress(createSpamMailDatum(mail, mailDetails)),
}
o(finalResult.targetFolder).deepEquals(spamFolder)
o(finalResult.processInboxDatum).deepEquals(expectedProcessInboxDatum)
}) })
o("predictSpamForNewMail moves mail from spam to inbox folder if mail is ham", async function () { o("predictSpamForNewMail moves mail from spam to inbox folder if mail is ham", async function () {
mail.sets = [spamFolder._id] mail.sets = [spamFolder._id]
when(spamClassifier.predict(anything())).thenResolve(false) when(spamClassifier.predict(anything(), anything())).thenResolve(false)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem) const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem)
o(spamHandler.hamMoveMailData?.mails).deepEquals([mail._id])
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([])
o(finalResult).deepEquals(inboxFolder)
})
o("predictSpamForNewMail does NOT move mail from spam to spam folder if mail is spam", async function () { const expectedProcessInboxDatum: UnencryptedProcessInboxDatum = {
mail.sets = [spamFolder._id] mailId: mail._id,
when(spamClassifier.predict(anything())).thenResolve(true) targetMoveFolder: inboxFolder._id,
classifierType: ClientClassifierType.CLIENT_CLASSIFICATION,
vector: await spamMailProcessor.vectorizeAndCompress(createSpamMailDatum(mail, mailDetails)),
}
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem) o(finalResult.targetFolder).deepEquals(inboxFolder)
o(spamHandler.hamMoveMailData).deepEquals(null) o(finalResult.processInboxDatum).deepEquals(expectedProcessInboxDatum)
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([mail._id])
o(finalResult).deepEquals(spamFolder)
})
o(
"predictSpamForNewMail does NOT send classifierResultService request if processingState is INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE",
async function () {
mail.sets = [inboxFolder._id]
mail.processingState = ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE
when(spamClassifier.predict(anything())).thenResolve(false)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([])
o(finalResult).deepEquals(inboxFolder)
},
)
o("update spam classification data on every mail update", async function () {
when(spamClassifier.getSpamClassification(anything())).thenResolve({ isSpam: false, isSpamConfidence: 0 })
mail.clientSpamClassifierResult = createTestEntity(ClientSpamClassifierResultTypeRef, {
spamDecision: SpamDecision.BLACKLIST,
confidence: "1",
})
await spamHandler.updateSpamClassificationData(mail)
verify(spamClassifier.updateSpamClassification(["listId", "elementId"], true, 1), { times: 1 })
}) })
}) })

View file

@ -41,6 +41,8 @@ import { ConversationListModel } from "../../../../src/mail-app/mail/model/Conve
import { theme } from "../../../../src/common/gui/theme.js" import { theme } from "../../../../src/common/gui/theme.js"
import { ListLoadingState } from "../../../../src/common/gui/base/List" import { ListLoadingState } from "../../../../src/common/gui/base/List"
import { getMailFilterForType, MailFilterType } from "../../../../src/mail-app/mail/view/MailViewerUtils" import { getMailFilterForType, MailFilterType } from "../../../../src/mail-app/mail/view/MailViewerUtils"
import { ProcessInboxHandler } from "../../../../src/mail-app/mail/model/ProcessInboxHandler"
import { FolderSystem } from "../../../../src/common/api/common/mail/FolderSystem"
o.spec("ConversationListModel", () => { o.spec("ConversationListModel", () => {
let model: ConversationListModel let model: ConversationListModel
@ -80,7 +82,7 @@ o.spec("ConversationListModel", () => {
let conversationPrefProvider: ConversationPrefProvider let conversationPrefProvider: ConversationPrefProvider
let entityClient: EntityClient let entityClient: EntityClient
let mailModel: MailModel let mailModel: MailModel
let inboxRuleHandler: InboxRuleHandler let processInboxHandler: ProcessInboxHandler
let cacheStorage: ExposedCacheStorage let cacheStorage: ExposedCacheStorage
o.beforeEach(() => { o.beforeEach(() => {
@ -95,10 +97,12 @@ o.spec("ConversationListModel", () => {
conversationPrefProvider = object() conversationPrefProvider = object()
entityClient = object() entityClient = object()
mailModel = object() mailModel = object()
inboxRuleHandler = object() processInboxHandler = object()
cacheStorage = object() cacheStorage = object()
model = new ConversationListModel(mailSet, conversationPrefProvider, entityClient, mailModel, inboxRuleHandler, cacheStorage) model = new ConversationListModel(mailSet, conversationPrefProvider, entityClient, mailModel, processInboxHandler, cacheStorage)
when(mailModel.getMailboxDetailsForMailFolder(mailSet)).thenResolve(mailboxDetail) when(mailModel.getMailboxDetailsForMailFolder(mailSet)).thenResolve(mailboxDetail)
const folderSystem: FolderSystem = object()
when(mailModel.getFolderSystemByGroupId(matchers.anything())).thenReturn(folderSystem)
}) })
// Care has to be ensured for generating mail set entry IDs as we depend on real mail set ID decoding, thus we have // Care has to be ensured for generating mail set entry IDs as we depend on real mail set ID decoding, thus we have
@ -209,7 +213,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), { verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), mailboxDetail, matchers.anything()), {
times: 0, times: 0,
}) })
}) })
@ -227,7 +231,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), { verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), mailboxDetail, matchers.anything()), {
times: 0, times: 0,
}) })
}) })
@ -246,7 +250,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), { verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), mailboxDetail, matchers.anything()), {
times: 0, times: 0,
}) })
o.check(model.loadingStatus).equals(ListLoadingState.Idle) o.check(model.loadingStatus).equals(ListLoadingState.Idle)
@ -262,15 +266,27 @@ o.spec("ConversationListModel", () => {
// make one item have a rule // make one item have a rule
when( when(
inboxRuleHandler.findAndApplyMatchingRule( processInboxHandler.handleIncomingMail(
mailboxDetail,
matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))), matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))),
true, matchers.anything(),
matchers.anything(),
matchers.anything(),
), ),
).thenResolve({}) ).thenResolve({ folderType: MailSetKind.SPAM })
when(
processInboxHandler.handleIncomingMail(
matchers.argThat((mail: Mail) => !isSameId(mail._id, makeMailId(25))),
matchers.anything(),
matchers.anything(),
matchers.anything(),
),
).thenResolve({ folderType: MailSetKind.INBOX })
await setUpTestData(PageSize, labels, false, 1) await setUpTestData(PageSize, labels, false, 1)
await model.loadInitial() await model.loadInitial()
o.check(model.mails.length).equals(PageSize - 1) o.check(model.mails.length).equals(PageSize - 1)
for (const mail of model.mails) { for (const mail of model.mails) {
o.check(model.getLabelsForMail(mail)).deepEquals(labels) o.check(model.getLabelsForMail(mail)).deepEquals(labels)
@ -281,7 +297,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 1, times: 1,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), { verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), mailboxDetail, matchers.anything()), {
times: 100, times: 100,
}) })
}) })

View file

@ -15,7 +15,6 @@ import { matchers, object, verify, when } from "testdouble"
import { ConversationPrefProvider } from "../../../../src/mail-app/mail/view/ConversationViewModel" import { ConversationPrefProvider } from "../../../../src/mail-app/mail/view/ConversationViewModel"
import { EntityClient } from "../../../../src/common/api/common/EntityClient" import { EntityClient } from "../../../../src/common/api/common/EntityClient"
import { MailModel } from "../../../../src/mail-app/mail/model/MailModel" import { MailModel } from "../../../../src/mail-app/mail/model/MailModel"
import { InboxRuleHandler } from "../../../../src/mail-app/mail/model/InboxRuleHandler"
import { ExposedCacheStorage } from "../../../../src/common/api/worker/rest/DefaultEntityRestCache" import { ExposedCacheStorage } from "../../../../src/common/api/worker/rest/DefaultEntityRestCache"
import { MailSetKind, OperationType } from "../../../../src/common/api/common/TutanotaConstants" import { MailSetKind, OperationType } from "../../../../src/common/api/common/TutanotaConstants"
import { import {
@ -39,6 +38,8 @@ import { clamp, pad } from "@tutao/tutanota-utils"
import { LoadedMail } from "../../../../src/mail-app/mail/model/MailSetListModel" import { LoadedMail } from "../../../../src/mail-app/mail/model/MailSetListModel"
import { getMailFilterForType, MailFilterType } from "../../../../src/mail-app/mail/view/MailViewerUtils" import { getMailFilterForType, MailFilterType } from "../../../../src/mail-app/mail/view/MailViewerUtils"
import { theme } from "../../../../src/common/gui/theme.js" import { theme } from "../../../../src/common/gui/theme.js"
import { ProcessInboxHandler } from "../../../../src/mail-app/mail/model/ProcessInboxHandler"
import { FolderSystem } from "../../../../src/common/api/common/mail/FolderSystem"
o.spec("MailListModel", () => { o.spec("MailListModel", () => {
let model: MailListModel let model: MailListModel
@ -78,7 +79,7 @@ o.spec("MailListModel", () => {
let conversationPrefProvider: ConversationPrefProvider let conversationPrefProvider: ConversationPrefProvider
let entityClient: EntityClient let entityClient: EntityClient
let mailModel: MailModel let mailModel: MailModel
let inboxRuleHandler: InboxRuleHandler let processInboxHandler: ProcessInboxHandler
let cacheStorage: ExposedCacheStorage let cacheStorage: ExposedCacheStorage
o.beforeEach(() => { o.beforeEach(() => {
@ -93,10 +94,12 @@ o.spec("MailListModel", () => {
conversationPrefProvider = object() conversationPrefProvider = object()
entityClient = object() entityClient = object()
mailModel = object() mailModel = object()
inboxRuleHandler = object() processInboxHandler = object()
cacheStorage = object() cacheStorage = object()
model = new MailListModel(mailSet, conversationPrefProvider, entityClient, mailModel, inboxRuleHandler, cacheStorage) model = new MailListModel(mailSet, conversationPrefProvider, entityClient, mailModel, processInboxHandler, cacheStorage)
when(mailModel.getMailboxDetailsForMailFolder(mailSet)).thenResolve(mailboxDetail) when(mailModel.getMailboxDetailsForMailFolder(mailSet)).thenResolve(mailboxDetail)
const folderSystem: FolderSystem = object()
when(mailModel.getFolderSystemByGroupId(matchers.anything())).thenReturn(folderSystem)
}) })
// Care has to be ensured for generating mail set entry IDs as we depend on real mail set ID decoding, thus we have // Care has to be ensured for generating mail set entry IDs as we depend on real mail set ID decoding, thus we have
@ -204,7 +207,8 @@ o.spec("MailListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), matchers.anything(), matchers.anything()), {
times: 0, times: 0,
}) })
}) })
@ -222,7 +226,7 @@ o.spec("MailListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), { verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), matchers.anything(), matchers.anything()), {
times: 0, times: 0,
}) })
}) })
@ -230,14 +234,23 @@ o.spec("MailListModel", () => {
o.test("applies inbox rules if inbox", async () => { o.test("applies inbox rules if inbox", async () => {
mailSet.folderType = MailSetKind.INBOX mailSet.folderType = MailSetKind.INBOX
// make one item have a rule
when( when(
inboxRuleHandler.findAndApplyMatchingRule( processInboxHandler.handleIncomingMail(
mailboxDetail,
matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))), matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))),
true, matchers.anything(),
matchers.anything(),
matchers.anything(),
), ),
).thenResolve({}) ).thenResolve({ folderType: MailSetKind.SPAM })
when(
processInboxHandler.handleIncomingMail(
matchers.argThat((mail: Mail) => !isSameId(mail._id, makeMailId(25))),
matchers.anything(),
matchers.anything(),
matchers.anything(),
),
).thenResolve({ folderType: MailSetKind.INBOX })
await setUpTestData(PageSize, labels, false) await setUpTestData(PageSize, labels, false)
await model.loadInitial() await model.loadInitial()
@ -251,7 +264,8 @@ o.spec("MailListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 1, times: 1,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
verify(processInboxHandler.handleIncomingMail(matchers.anything(), matchers.anything(), matchers.anything(), matchers.anything()), {
times: 100, times: 100,
}) })
}) })

View file

@ -1821,6 +1821,9 @@ mod tests {
"1729"=> JsonElement::Array( "1729"=> JsonElement::Array(
vec![], vec![],
), ),
"1769"=> JsonElement::String(
"0".to_string()
)
} }
} }

View file

@ -401,6 +401,8 @@ pub struct Mail {
pub keyVerificationState: Option<i64>, pub keyVerificationState: Option<i64>,
#[serde(rename = "1728")] #[serde(rename = "1728")]
pub processingState: i64, pub processingState: i64,
#[serde(rename = "1769")]
pub processNeeded: bool,
#[serde(rename = "111")] #[serde(rename = "111")]
pub sender: MailAddress, pub sender: MailAddress,
#[serde(rename = "115")] #[serde(rename = "115")]
@ -473,6 +475,10 @@ pub struct MailBox {
pub mailImportStates: GeneratedId, pub mailImportStates: GeneratedId,
#[serde(rename = "1710")] #[serde(rename = "1710")]
pub extractedFeatures: Option<GeneratedId>, pub extractedFeatures: Option<GeneratedId>,
#[serde(rename = "1754")]
pub clientSpamTrainingData: Option<GeneratedId>,
#[serde(rename = "1755")]
pub modifiedClientSpamTrainingDataIndex: Option<GeneratedId>,
#[serde(default)] #[serde(default)]
pub _errors: Errors, pub _errors: Errors,
@ -4146,3 +4152,166 @@ impl Entity for ClientClassifierResultPostIn {
} }
} }
} }
#[derive(uniffi::Record, Clone, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "testing"), derive(PartialEq, Debug))]
pub struct ClientSpamTrainingDatum {
#[serde(rename = "1738")]
pub _id: Option<IdTupleGenerated>,
#[serde(rename = "1739")]
pub _permissions: GeneratedId,
#[serde(rename = "1740")]
pub _format: i64,
#[serde(rename = "1741")]
pub _ownerGroup: Option<GeneratedId>,
#[serde(rename = "1742")]
#[serde(with = "serde_bytes")]
pub _ownerEncSessionKey: Option<Vec<u8>>,
#[serde(rename = "1743")]
pub _ownerKeyVersion: Option<i64>,
#[serde(rename = "1744")]
pub confidence: i64,
#[serde(rename = "1745")]
pub spamDecision: i64,
#[serde(rename = "1746")]
#[serde(with = "serde_bytes")]
pub vector: Vec<u8>,
#[serde(default)]
pub _errors: Errors,
#[serde(default)]
pub _finalIvs: HashMap<String, Option<FinalIv>>,
}
impl Entity for ClientSpamTrainingDatum {
fn type_ref() -> TypeRef {
TypeRef {
app: AppName::Tutanota,
type_id: TypeId::from(1736),
}
}
}
#[derive(uniffi::Record, Clone, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "testing"), derive(PartialEq, Debug))]
pub struct ClientSpamTrainingDatumIndexEntry {
#[serde(rename = "1749")]
pub _id: Option<IdTupleGenerated>,
#[serde(rename = "1750")]
pub _permissions: GeneratedId,
#[serde(rename = "1751")]
pub _format: i64,
#[serde(rename = "1752")]
pub _ownerGroup: Option<GeneratedId>,
#[serde(rename = "1753")]
pub clientSpamTrainingDatumElementId: GeneratedId,
}
impl Entity for ClientSpamTrainingDatumIndexEntry {
fn type_ref() -> TypeRef {
TypeRef {
app: AppName::Tutanota,
type_id: TypeId::from(1747),
}
}
}
#[derive(uniffi::Record, Clone, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "testing"), derive(PartialEq, Debug))]
pub struct ProcessInboxDatum {
#[serde(rename = "1757")]
pub _id: Option<CustomId>,
#[serde(rename = "1758")]
#[serde(with = "serde_bytes")]
pub ownerEncVectorSessionKey: Vec<u8>,
#[serde(rename = "1759")]
pub ownerKeyVersion: i64,
#[serde(rename = "1762")]
pub classifierType: Option<i64>,
#[serde(rename = "1763")]
#[serde(with = "serde_bytes")]
pub encVector: Vec<u8>,
#[serde(rename = "1760")]
pub mailId: IdTupleGenerated,
#[serde(rename = "1761")]
pub targetMoveFolder: IdTupleGenerated,
}
impl Entity for ProcessInboxDatum {
fn type_ref() -> TypeRef {
TypeRef {
app: AppName::Tutanota,
type_id: TypeId::from(1756),
}
}
}
#[derive(uniffi::Record, Clone, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "testing"), derive(PartialEq, Debug))]
pub struct ProcessInboxPostIn {
#[serde(rename = "1765")]
pub _format: i64,
#[serde(rename = "1766")]
pub mailOwnerGroup: GeneratedId,
#[serde(rename = "1767")]
pub processInboxDatum: Vec<ProcessInboxDatum>,
}
impl Entity for ProcessInboxPostIn {
fn type_ref() -> TypeRef {
TypeRef {
app: AppName::Tutanota,
type_id: TypeId::from(1764),
}
}
}
#[derive(uniffi::Record, Clone, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "testing"), derive(PartialEq, Debug))]
pub struct PopulateClientSpamTrainingDatum {
#[serde(rename = "1771")]
pub _id: Option<CustomId>,
#[serde(rename = "1772")]
#[serde(with = "serde_bytes")]
pub ownerEncVectorSessionKey: Vec<u8>,
#[serde(rename = "1773")]
pub ownerKeyVersion: i64,
#[serde(rename = "1775")]
pub isSpam: bool,
#[serde(rename = "1776")]
pub confidence: i64,
#[serde(rename = "1777")]
#[serde(with = "serde_bytes")]
pub encVector: Vec<u8>,
#[serde(rename = "1774")]
pub mailId: IdTupleGenerated,
}
impl Entity for PopulateClientSpamTrainingDatum {
fn type_ref() -> TypeRef {
TypeRef {
app: AppName::Tutanota,
type_id: TypeId::from(1770),
}
}
}
#[derive(uniffi::Record, Clone, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "testing"), derive(PartialEq, Debug))]
pub struct PopulateClientSpamTrainingDataPostIn {
#[serde(rename = "1779")]
pub _format: i64,
#[serde(rename = "1780")]
pub mailOwnerGroup: GeneratedId,
#[serde(rename = "1781")]
pub populateClientSpamTrainingDatum: Vec<PopulateClientSpamTrainingDatum>,
}
impl Entity for PopulateClientSpamTrainingDataPostIn {
fn type_ref() -> TypeRef {
TypeRef {
app: AppName::Tutanota,
type_id: TypeId::from(1778),
}
}
}

View file

@ -44,6 +44,8 @@ use crate::entities::generated::tutanota::MoveMailData;
use crate::entities::generated::tutanota::MoveMailPostOut; use crate::entities::generated::tutanota::MoveMailPostOut;
use crate::entities::generated::tutanota::NewsIn; use crate::entities::generated::tutanota::NewsIn;
use crate::entities::generated::tutanota::NewsOut; use crate::entities::generated::tutanota::NewsOut;
use crate::entities::generated::tutanota::PopulateClientSpamTrainingDataPostIn;
use crate::entities::generated::tutanota::ProcessInboxPostIn;
use crate::entities::generated::tutanota::ReceiveInfoServiceData; use crate::entities::generated::tutanota::ReceiveInfoServiceData;
use crate::entities::generated::tutanota::ReceiveInfoServicePostOut; use crate::entities::generated::tutanota::ReceiveInfoServicePostOut;
use crate::entities::generated::tutanota::ReportMailPostData; use crate::entities::generated::tutanota::ReportMailPostData;
@ -59,70 +61,70 @@ use crate::entities::generated::tutanota::UserAccountCreateData;
use crate::entities::generated::tutanota::UserAccountPostOut; use crate::entities::generated::tutanota::UserAccountPostOut;
pub struct ApplyLabelService; pub struct ApplyLabelService;
crate::service_impl!(declare, ApplyLabelService, "tutanota/applylabelservice", 97); crate::service_impl!(declare, ApplyLabelService, "tutanota/applylabelservice", 98);
crate::service_impl!(POST, ApplyLabelService, ApplyLabelServicePostIn, ()); crate::service_impl!(POST, ApplyLabelService, ApplyLabelServicePostIn, ());
pub struct CalendarService; pub struct CalendarService;
crate::service_impl!(declare, CalendarService, "tutanota/calendarservice", 97); crate::service_impl!(declare, CalendarService, "tutanota/calendarservice", 98);
crate::service_impl!(POST, CalendarService, UserAreaGroupPostData, CreateGroupPostReturn); crate::service_impl!(POST, CalendarService, UserAreaGroupPostData, CreateGroupPostReturn);
crate::service_impl!(DELETE, CalendarService, CalendarDeleteData, ()); crate::service_impl!(DELETE, CalendarService, CalendarDeleteData, ());
pub struct ChangePrimaryAddressService; pub struct ChangePrimaryAddressService;
crate::service_impl!(declare, ChangePrimaryAddressService, "tutanota/changeprimaryaddressservice", 97); crate::service_impl!(declare, ChangePrimaryAddressService, "tutanota/changeprimaryaddressservice", 98);
crate::service_impl!(PUT, ChangePrimaryAddressService, ChangePrimaryAddressServicePutIn, ()); crate::service_impl!(PUT, ChangePrimaryAddressService, ChangePrimaryAddressServicePutIn, ());
pub struct ClientClassifierResultService; pub struct ClientClassifierResultService;
crate::service_impl!(declare, ClientClassifierResultService, "tutanota/clientclassifierresultservice", 97); crate::service_impl!(declare, ClientClassifierResultService, "tutanota/clientclassifierresultservice", 98);
crate::service_impl!(POST, ClientClassifierResultService, ClientClassifierResultPostIn, ()); crate::service_impl!(POST, ClientClassifierResultService, ClientClassifierResultPostIn, ());
pub struct ContactListGroupService; pub struct ContactListGroupService;
crate::service_impl!(declare, ContactListGroupService, "tutanota/contactlistgroupservice", 97); crate::service_impl!(declare, ContactListGroupService, "tutanota/contactlistgroupservice", 98);
crate::service_impl!(POST, ContactListGroupService, UserAreaGroupPostData, CreateGroupPostReturn); crate::service_impl!(POST, ContactListGroupService, UserAreaGroupPostData, CreateGroupPostReturn);
crate::service_impl!(DELETE, ContactListGroupService, UserAreaGroupDeleteData, ()); crate::service_impl!(DELETE, ContactListGroupService, UserAreaGroupDeleteData, ());
pub struct CustomerAccountService; pub struct CustomerAccountService;
crate::service_impl!(declare, CustomerAccountService, "tutanota/customeraccountservice", 97); crate::service_impl!(declare, CustomerAccountService, "tutanota/customeraccountservice", 98);
crate::service_impl!(POST, CustomerAccountService, CustomerAccountCreateData, ()); crate::service_impl!(POST, CustomerAccountService, CustomerAccountCreateData, ());
pub struct DraftService; pub struct DraftService;
crate::service_impl!(declare, DraftService, "tutanota/draftservice", 97); crate::service_impl!(declare, DraftService, "tutanota/draftservice", 98);
crate::service_impl!(POST, DraftService, DraftCreateData, DraftCreateReturn); crate::service_impl!(POST, DraftService, DraftCreateData, DraftCreateReturn);
crate::service_impl!(PUT, DraftService, DraftUpdateData, DraftUpdateReturn); crate::service_impl!(PUT, DraftService, DraftUpdateData, DraftUpdateReturn);
pub struct EncryptTutanotaPropertiesService; pub struct EncryptTutanotaPropertiesService;
crate::service_impl!(declare, EncryptTutanotaPropertiesService, "tutanota/encrypttutanotapropertiesservice", 97); crate::service_impl!(declare, EncryptTutanotaPropertiesService, "tutanota/encrypttutanotapropertiesservice", 98);
crate::service_impl!(POST, EncryptTutanotaPropertiesService, EncryptTutanotaPropertiesData, ()); crate::service_impl!(POST, EncryptTutanotaPropertiesService, EncryptTutanotaPropertiesData, ());
pub struct EntropyService; pub struct EntropyService;
crate::service_impl!(declare, EntropyService, "tutanota/entropyservice", 97); crate::service_impl!(declare, EntropyService, "tutanota/entropyservice", 98);
crate::service_impl!(PUT, EntropyService, EntropyData, ()); crate::service_impl!(PUT, EntropyService, EntropyData, ());
pub struct ExternalUserService; pub struct ExternalUserService;
crate::service_impl!(declare, ExternalUserService, "tutanota/externaluserservice", 97); crate::service_impl!(declare, ExternalUserService, "tutanota/externaluserservice", 98);
crate::service_impl!(POST, ExternalUserService, ExternalUserData, ()); crate::service_impl!(POST, ExternalUserService, ExternalUserData, ());
pub struct GroupInvitationService; pub struct GroupInvitationService;
crate::service_impl!(declare, GroupInvitationService, "tutanota/groupinvitationservice", 97); crate::service_impl!(declare, GroupInvitationService, "tutanota/groupinvitationservice", 98);
crate::service_impl!(POST, GroupInvitationService, GroupInvitationPostData, GroupInvitationPostReturn); crate::service_impl!(POST, GroupInvitationService, GroupInvitationPostData, GroupInvitationPostReturn);
crate::service_impl!(PUT, GroupInvitationService, GroupInvitationPutData, ()); crate::service_impl!(PUT, GroupInvitationService, GroupInvitationPutData, ());
crate::service_impl!(DELETE, GroupInvitationService, GroupInvitationDeleteData, ()); crate::service_impl!(DELETE, GroupInvitationService, GroupInvitationDeleteData, ());
@ -130,26 +132,26 @@ crate::service_impl!(DELETE, GroupInvitationService, GroupInvitationDeleteData,
pub struct ImportMailService; pub struct ImportMailService;
crate::service_impl!(declare, ImportMailService, "tutanota/importmailservice", 97); crate::service_impl!(declare, ImportMailService, "tutanota/importmailservice", 98);
crate::service_impl!(POST, ImportMailService, ImportMailPostIn, ImportMailPostOut); crate::service_impl!(POST, ImportMailService, ImportMailPostIn, ImportMailPostOut);
crate::service_impl!(GET, ImportMailService, ImportMailGetIn, ImportMailGetOut); crate::service_impl!(GET, ImportMailService, ImportMailGetIn, ImportMailGetOut);
pub struct ListUnsubscribeService; pub struct ListUnsubscribeService;
crate::service_impl!(declare, ListUnsubscribeService, "tutanota/listunsubscribeservice", 97); crate::service_impl!(declare, ListUnsubscribeService, "tutanota/listunsubscribeservice", 98);
crate::service_impl!(POST, ListUnsubscribeService, ListUnsubscribeData, ()); crate::service_impl!(POST, ListUnsubscribeService, ListUnsubscribeData, ());
pub struct MailExportTokenService; pub struct MailExportTokenService;
crate::service_impl!(declare, MailExportTokenService, "tutanota/mailexporttokenservice", 97); crate::service_impl!(declare, MailExportTokenService, "tutanota/mailexporttokenservice", 98);
crate::service_impl!(POST, MailExportTokenService, (), MailExportTokenServicePostOut); crate::service_impl!(POST, MailExportTokenService, (), MailExportTokenServicePostOut);
pub struct MailFolderService; pub struct MailFolderService;
crate::service_impl!(declare, MailFolderService, "tutanota/mailfolderservice", 97); crate::service_impl!(declare, MailFolderService, "tutanota/mailfolderservice", 98);
crate::service_impl!(POST, MailFolderService, CreateMailFolderData, CreateMailFolderReturn); crate::service_impl!(POST, MailFolderService, CreateMailFolderData, CreateMailFolderReturn);
crate::service_impl!(PUT, MailFolderService, UpdateMailFolderData, ()); crate::service_impl!(PUT, MailFolderService, UpdateMailFolderData, ());
crate::service_impl!(DELETE, MailFolderService, DeleteMailFolderData, ()); crate::service_impl!(DELETE, MailFolderService, DeleteMailFolderData, ());
@ -157,87 +159,99 @@ crate::service_impl!(DELETE, MailFolderService, DeleteMailFolderData, ());
pub struct MailGroupService; pub struct MailGroupService;
crate::service_impl!(declare, MailGroupService, "tutanota/mailgroupservice", 97); crate::service_impl!(declare, MailGroupService, "tutanota/mailgroupservice", 98);
crate::service_impl!(POST, MailGroupService, CreateMailGroupData, MailGroupPostOut); crate::service_impl!(POST, MailGroupService, CreateMailGroupData, MailGroupPostOut);
crate::service_impl!(DELETE, MailGroupService, DeleteGroupData, ()); crate::service_impl!(DELETE, MailGroupService, DeleteGroupData, ());
pub struct MailService; pub struct MailService;
crate::service_impl!(declare, MailService, "tutanota/mailservice", 97); crate::service_impl!(declare, MailService, "tutanota/mailservice", 98);
crate::service_impl!(DELETE, MailService, DeleteMailData, ()); crate::service_impl!(DELETE, MailService, DeleteMailData, ());
pub struct ManageLabelService; pub struct ManageLabelService;
crate::service_impl!(declare, ManageLabelService, "tutanota/managelabelservice", 97); crate::service_impl!(declare, ManageLabelService, "tutanota/managelabelservice", 98);
crate::service_impl!(POST, ManageLabelService, ManageLabelServicePostIn, ()); crate::service_impl!(POST, ManageLabelService, ManageLabelServicePostIn, ());
crate::service_impl!(DELETE, ManageLabelService, ManageLabelServiceDeleteIn, ()); crate::service_impl!(DELETE, ManageLabelService, ManageLabelServiceDeleteIn, ());
pub struct MoveMailService; pub struct MoveMailService;
crate::service_impl!(declare, MoveMailService, "tutanota/movemailservice", 97); crate::service_impl!(declare, MoveMailService, "tutanota/movemailservice", 98);
crate::service_impl!(POST, MoveMailService, MoveMailData, MoveMailPostOut); crate::service_impl!(POST, MoveMailService, MoveMailData, MoveMailPostOut);
pub struct NewsService; pub struct NewsService;
crate::service_impl!(declare, NewsService, "tutanota/newsservice", 97); crate::service_impl!(declare, NewsService, "tutanota/newsservice", 98);
crate::service_impl!(POST, NewsService, NewsIn, ()); crate::service_impl!(POST, NewsService, NewsIn, ());
crate::service_impl!(GET, NewsService, (), NewsOut); crate::service_impl!(GET, NewsService, (), NewsOut);
pub struct PopulateClientSpamTrainingDataService;
crate::service_impl!(declare, PopulateClientSpamTrainingDataService, "tutanota/populateclientspamtrainingdataservice", 98);
crate::service_impl!(POST, PopulateClientSpamTrainingDataService, PopulateClientSpamTrainingDataPostIn, ());
pub struct ProcessInboxService;
crate::service_impl!(declare, ProcessInboxService, "tutanota/processinboxservice", 98);
crate::service_impl!(POST, ProcessInboxService, ProcessInboxPostIn, ());
pub struct ReceiveInfoService; pub struct ReceiveInfoService;
crate::service_impl!(declare, ReceiveInfoService, "tutanota/receiveinfoservice", 97); crate::service_impl!(declare, ReceiveInfoService, "tutanota/receiveinfoservice", 98);
crate::service_impl!(POST, ReceiveInfoService, ReceiveInfoServiceData, ReceiveInfoServicePostOut); crate::service_impl!(POST, ReceiveInfoService, ReceiveInfoServiceData, ReceiveInfoServicePostOut);
pub struct ReportMailService; pub struct ReportMailService;
crate::service_impl!(declare, ReportMailService, "tutanota/reportmailservice", 97); crate::service_impl!(declare, ReportMailService, "tutanota/reportmailservice", 98);
crate::service_impl!(POST, ReportMailService, ReportMailPostData, ()); crate::service_impl!(POST, ReportMailService, ReportMailPostData, ());
pub struct ResolveConversationsService; pub struct ResolveConversationsService;
crate::service_impl!(declare, ResolveConversationsService, "tutanota/resolveconversationsservice", 97); crate::service_impl!(declare, ResolveConversationsService, "tutanota/resolveconversationsservice", 98);
crate::service_impl!(GET, ResolveConversationsService, ResolveConversationsServiceGetIn, ResolveConversationsServiceGetOut); crate::service_impl!(GET, ResolveConversationsService, ResolveConversationsServiceGetIn, ResolveConversationsServiceGetOut);
pub struct SendDraftService; pub struct SendDraftService;
crate::service_impl!(declare, SendDraftService, "tutanota/senddraftservice", 97); crate::service_impl!(declare, SendDraftService, "tutanota/senddraftservice", 98);
crate::service_impl!(POST, SendDraftService, SendDraftData, SendDraftReturn); crate::service_impl!(POST, SendDraftService, SendDraftData, SendDraftReturn);
pub struct SimpleMoveMailService; pub struct SimpleMoveMailService;
crate::service_impl!(declare, SimpleMoveMailService, "tutanota/simplemovemailservice", 97); crate::service_impl!(declare, SimpleMoveMailService, "tutanota/simplemovemailservice", 98);
crate::service_impl!(POST, SimpleMoveMailService, SimpleMoveMailPostIn, MoveMailPostOut); crate::service_impl!(POST, SimpleMoveMailService, SimpleMoveMailPostIn, MoveMailPostOut);
pub struct TemplateGroupService; pub struct TemplateGroupService;
crate::service_impl!(declare, TemplateGroupService, "tutanota/templategroupservice", 97); crate::service_impl!(declare, TemplateGroupService, "tutanota/templategroupservice", 98);
crate::service_impl!(POST, TemplateGroupService, UserAreaGroupPostData, CreateGroupPostReturn); crate::service_impl!(POST, TemplateGroupService, UserAreaGroupPostData, CreateGroupPostReturn);
crate::service_impl!(DELETE, TemplateGroupService, UserAreaGroupDeleteData, ()); crate::service_impl!(DELETE, TemplateGroupService, UserAreaGroupDeleteData, ());
pub struct TranslationService; pub struct TranslationService;
crate::service_impl!(declare, TranslationService, "tutanota/translationservice", 97); crate::service_impl!(declare, TranslationService, "tutanota/translationservice", 98);
crate::service_impl!(GET, TranslationService, TranslationGetIn, TranslationGetOut); crate::service_impl!(GET, TranslationService, TranslationGetIn, TranslationGetOut);
pub struct UnreadMailStateService; pub struct UnreadMailStateService;
crate::service_impl!(declare, UnreadMailStateService, "tutanota/unreadmailstateservice", 97); crate::service_impl!(declare, UnreadMailStateService, "tutanota/unreadmailstateservice", 98);
crate::service_impl!(POST, UnreadMailStateService, UnreadMailStatePostIn, ()); crate::service_impl!(POST, UnreadMailStateService, UnreadMailStatePostIn, ());
pub struct UserAccountService; pub struct UserAccountService;
crate::service_impl!(declare, UserAccountService, "tutanota/useraccountservice", 97); crate::service_impl!(declare, UserAccountService, "tutanota/useraccountservice", 98);
crate::service_impl!(POST, UserAccountService, UserAccountCreateData, UserAccountPostOut); crate::service_impl!(POST, UserAccountService, UserAccountCreateData, UserAccountPostOut);

File diff suppressed because it is too large Load diff

View file

@ -43,5 +43,6 @@
"1465": [], "1465": [],
"1677": null, "1677": null,
"1728": "1", "1728": "1",
"1729": [] "1729": [],
"1769": "0"
} }

View file

@ -43,5 +43,6 @@
"896": "1723113273034", "896": "1723113273034",
"1677": null, "1677": null,
"1728": "1", "1728": "1",
"1729": [] "1729": [],
"1769": "0"
} }

View file

@ -43,5 +43,6 @@
"1465": [], "1465": [],
"1677": null, "1677": null,
"1728": "1", "1728": "1",
"1729": [] "1729": [],
"1769": "0"
} }

View file

@ -43,5 +43,6 @@
"466": "", "466": "",
"1677": null, "1677": null,
"1728": "1", "1728": "1",
"1729": [] "1729": [],
"1769": "1"
} }