improve inbox rule handling and run spam prediction after inbox rules

Instead of applying inbox rules based on the unread mail state in the
inbox folder, we introduce the new ProcessingState enum on
the mail type. If a mail has been processed by the leader client, which
is checking for matching inbox rules, the ProcessingState is
updated. If there is a matching rule the flag is updated through the
MoveMailService, if there is no matching rule, the flag is updated
using the ClientClassifierResultService. Both requests are
throttled / debounced. After processing inbox rules, spam prediction
is conducted for mails that have not yet been moved by an inbox rule.
The ProcessingState for not matching ham mails is also updated using
the ClientClassifierResultService.

This new inbox rule handing solves the following two problems:
 - when clicking on a notification it could still happen,
   that sometimes the inbox rules where not applied
 - when the inbox folder had a lot of unread mails, the loading time did
   massively increase, since inbox rules were re-applied on every load

Co-authored-by: amm <amm@tutao.de>
Co-authored-by: Nick <nif@tutao.de>
Co-authored-by: das <das@tutao.de>
Co-authored-by: abp <abp@tutao.de>
Co-authored-by: jhm <17314077+jomapp@users.noreply.github.com>
Co-authored-by: map <mpfau@users.noreply.github.com>
Co-authored-by: Kinan <104761667+kibibytium@users.noreply.github.com>
This commit is contained in:
sug 2025-10-14 12:11:22 +02:00 committed by abp
parent 030bea4fe6
commit f11e59672e
No known key found for this signature in database
GPG key ID: 791D4EC38A7AA7C2
53 changed files with 1269 additions and 1010 deletions

View file

@ -35,7 +35,7 @@ export const allowedImports = {
boot: ["polyfill-helpers", "common-min"], boot: ["polyfill-helpers", "common-min"],
common: ["polyfill-helpers", "common-min"], common: ["polyfill-helpers", "common-min"],
"gui-base": ["polyfill-helpers", "common-min", "common", "boot"], "gui-base": ["polyfill-helpers", "common-min", "common", "boot"],
main: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "date"], main: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "date", "spam-classifier"],
sanitizer: ["polyfill-helpers", "common-min", "common", "boot", "gui-base"], sanitizer: ["polyfill-helpers", "common-min", "common", "boot", "gui-base"],
date: ["polyfill-helpers", "common-min", "common"], date: ["polyfill-helpers", "common-min", "common"],
"date-gui": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "sharing", "date", "contacts", "ui-extra"], "date-gui": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "sharing", "date", "contacts", "ui-extra"],
@ -47,7 +47,7 @@ export const allowedImports = {
"calendar-view": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "date", "date-gui", "sharing", "contacts"], "calendar-view": ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main", "date", "date-gui", "sharing", "contacts"],
login: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main"], login: ["polyfill-helpers", "common-min", "common", "boot", "gui-base", "main"],
"spam-classifier": ["polyfill-helpers", "common", "common-min"], "spam-classifier": ["polyfill-helpers", "common", "common-min"],
worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback", "spam-classifier"], worker: ["polyfill-helpers", "common-min", "common", "native-common", "native-worker", "wasm", "wasm-fallback"],
"pow-worker": [], "pow-worker": [],
settings: [ settings: [
"polyfill-helpers", "polyfill-helpers",

View file

@ -172,7 +172,7 @@ async function rollupTensorFlow(src, target, banner) {
}, },
], ],
}), }),
logResolvePlugin, // logResolvePlugin,
nodeResolve(), nodeResolve(),
commonjs(), commonjs(),
], ],

View file

@ -112,6 +112,7 @@ import("../mail-app/translations/en.js")
if (isApp()) { if (isApp()) {
calendarLocator.fileApp.clearFileData().catch((e) => console.log("Failed to clean file data", e)) calendarLocator.fileApp.clearFileData().catch((e) => console.log("Failed to clean file data", e))
} }
return { asyncAction: Promise.resolve() }
}, },
async onFullLoginSuccess() {}, async onFullLoginSuccess() {},
} }
@ -333,7 +334,13 @@ import("../mail-app/translations/en.js")
}, },
calendarLocator.logins, calendarLocator.logins,
), ),
webauthnmobile: makeViewResolver<MobileWebauthnAttrs, MobileWebauthnView, { browserWebauthn: BrowserWebauthn }>( webauthnmobile: makeViewResolver<
MobileWebauthnAttrs,
MobileWebauthnView,
{
browserWebauthn: BrowserWebauthn
}
>(
{ {
prepareRoute: async () => { prepareRoute: async () => {
const { MobileWebauthnView } = await import("../common/login/MobileWebauthnView.js") const { MobileWebauthnView } = await import("../common/login/MobileWebauthnView.js")

View file

@ -521,7 +521,6 @@ export async function initLocator(worker: CalendarWorkerImpl, browserData: Brows
locator.login, locator.login,
locator.keyLoader, locator.keyLoader,
locator.publicEncryptionKeyProvider, locator.publicEncryptionKeyProvider,
null,
) )
}) })
const nativePushFacade = new NativePushFacadeSendDispatcher(worker) const nativePushFacade = new NativePushFacadeSendDispatcher(worker)

View file

@ -1,4 +1,5 @@
// Keep in sync with server: ClassifierType // Keep in sync with server: ClassifierType
export enum ClientClassifierType { export enum ClientClassifierType {
CLIENT_CLASSIFICATION = "20", CLIENT_CLASSIFICATION = "20",
CUSTOMER_INBOX_RULES = "21",
} }

View file

@ -3,13 +3,13 @@
import { DAY_IN_MILLIS, downcast } from "@tutao/tutanota-utils" import { DAY_IN_MILLIS, downcast } from "@tutao/tutanota-utils"
import type { CertificateInfo, CreditCard, EmailSenderListElement, GroupMembership } from "../entities/sys/TypeRefs.js" import type { CertificateInfo, CreditCard, EmailSenderListElement, GroupMembership } from "../entities/sys/TypeRefs.js"
import { AccountingInfo, Customer } from "../entities/sys/TypeRefs.js" import { AccountingInfo, Customer } from "../entities/sys/TypeRefs.js"
import type { CalendarEventAttendee, ContactCustomDate, ContactRelationship, UserSettingsGroupRoot } from "../entities/tutanota/TypeRefs.js" import type { CalendarEventAttendee, ContactCustomDate, ContactRelationship, Mail, UserSettingsGroupRoot } from "../entities/tutanota/TypeRefs.js"
import { ContactSocialId, MailFolder } from "../entities/tutanota/TypeRefs.js" import { ContactSocialId, MailFolder } from "../entities/tutanota/TypeRefs.js"
import { isApp, isElectronClient, isIOSApp } from "./Env" import { isApp, isElectronClient, isIOSApp } from "./Env"
import type { Country } from "./CountryList" import type { Country } from "./CountryList"
import { ProgrammingError } from "./error/ProgrammingError" import { ProgrammingError } from "./error/ProgrammingError"
export const MAX_NBR_MOVE_DELETE_MAIL_SERVICE = 50 export const MAX_NBR_OF_MAILS_SYNC_OPERATION = 50
export const MAX_NBR_OF_CONVERSATIONS = 50 export const MAX_NBR_OF_CONVERSATIONS = 50
// visible for testing // visible for testing
@ -1386,4 +1386,22 @@ export enum DeactivationReason {
MassSignup, MassSignup,
} }
export enum SpamDecision {
NONE = "0",
WHITELIST = "1",
BLACKLIST = "2",
DISCARD = "3",
}
export enum ProcessingState {
INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE = "0",
INBOX_RULE_NOT_PROCESSED = "1",
INBOX_RULE_APPLIED = "2",
INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING = "3",
}
export function getProcessingState(mail: Mail): ProcessingState {
return mail.processingState as ProcessingState
}
export const PLAN_SELECTOR_SELECTED_BOX_SCALE = "1.03" export const PLAN_SELECTOR_SELECTED_BOX_SCALE = "1.03"

View file

@ -199,3 +199,11 @@ export function objToError(o: Record<string, any>): Error {
e.data = o.data e.data = o.data
return e return e
} }
/**
* Returns whether the error is expected for the cases where our local state might not be up-to-date with the server yet. E.g. we might be processing an update
* for the instance that was already deleted. Normally this would be optimized away but it might still happen due to timing.
*/
export function isExpectedErrorForSynchronization(e: Error): boolean {
return e instanceof NotFoundError || e instanceof NotAuthorizedError
}

View file

@ -87,6 +87,8 @@ export function _createNewIndexUpdate(typeInfo: TypeInfo): IndexUpdate {
} }
} }
// Removes most html tags from a text.
// NOTE: This function is not covering all edge-cases.
export function htmlToText(html: string | null): string { export function htmlToText(html: string | null): string {
if (html == null) return "" if (html == null) return ""
let text = html.replace(/<[^>]*>?/gm, " ") let text = html.replace(/<[^>]*>?/gm, " ")

View file

@ -20,7 +20,7 @@ assertMainOrNodeBoot()
export interface PostLoginAction { export interface PostLoginAction {
/** Partial login is achieved with getting the user, can happen offline. The login will wait for the returned promise. */ /** Partial login is achieved with getting the user, can happen offline. The login will wait for the returned promise. */
onPartialLoginSuccess(loggedInEvent: LoggedInEvent): Promise<void> onPartialLoginSuccess(loggedInEvent: LoggedInEvent): Promise<{ asyncAction: Promise<void> }>
/** Full login is achieved with getting group keys. Can do service calls from this point on. */ /** Full login is achieved with getting group keys. Can do service calls from this point on. */
onFullLoginSuccess(loggedInEvent: LoggedInEvent): Promise<void> onFullLoginSuccess(loggedInEvent: LoggedInEvent): Promise<void>

View file

@ -5,10 +5,10 @@ import { assertNotNull, getTypeString, groupBy, isNotNull, isSameTypeRef, parseT
import { parseKeyVersion } from "./facades/KeyLoaderFacade" import { parseKeyVersion } from "./facades/KeyLoaderFacade"
import { VersionedEncryptedKey } from "./crypto/CryptoWrapper" import { VersionedEncryptedKey } from "./crypto/CryptoWrapper"
import { OperationType } from "../common/TutanotaConstants" import { OperationType } from "../common/TutanotaConstants"
import { NotAuthorizedError, NotFoundError } from "../common/error/RestError"
import { ElementEntity, ListElementEntity, SomeEntity } from "../common/EntityTypes" import { ElementEntity, ListElementEntity, SomeEntity } from "../common/EntityTypes"
import { CacheMode, type EntityRestInterface } from "./rest/EntityRestClient" import { CacheMode, type EntityRestInterface } from "./rest/EntityRestClient"
import { ProgressMonitorDelegate } from "./ProgressMonitorDelegate" import { ProgressMonitorDelegate } from "./ProgressMonitorDelegate"
import { isExpectedErrorForSynchronization } from "../common/utils/ErrorUtils"
export class EventInstancePrefetcher { export class EventInstancePrefetcher {
constructor(private readonly entityCache: EntityRestInterface) {} constructor(private readonly entityCache: EntityRestInterface) {}
@ -191,11 +191,3 @@ export class EventInstancePrefetcher {
} }
} }
} }
/**
* Returns whether the error is expected for the cases where our local state might not be up-to-date with the server yet. E.g. we might be processing an update
* for the instance that was already deleted. Normally this would be optimized away but it might still happen due to timing.
*/
function isExpectedErrorForSynchronization(e: Error): boolean {
return e instanceof NotFoundError || e instanceof NotAuthorizedError
}

View file

@ -1,6 +1,7 @@
import type { CryptoFacade } from "../../crypto/CryptoFacade.js" import type { CryptoFacade } from "../../crypto/CryptoFacade.js"
import { import {
ApplyLabelService, ApplyLabelService,
ClientClassifierResultService,
DraftService, DraftService,
ExternalUserService, ExternalUserService,
ListUnsubscribeService, ListUnsubscribeService,
@ -27,10 +28,11 @@ import {
MailMethod, MailMethod,
MailReportType, MailReportType,
MailSetKind, MailSetKind,
MAX_NBR_MOVE_DELETE_MAIL_SERVICE,
MAX_NBR_OF_CONVERSATIONS, MAX_NBR_OF_CONVERSATIONS,
MAX_NBR_OF_MAILS_SYNC_OPERATION,
OperationType, OperationType,
PhishingMarkerStatus, PhishingMarkerStatus,
ProcessingState,
PublicKeyIdentifierType, PublicKeyIdentifierType,
ReportedMailFieldType, ReportedMailFieldType,
SimpleMoveMailTarget, SimpleMoveMailTarget,
@ -40,6 +42,7 @@ import {
Contact, Contact,
createApplyLabelServicePostIn, createApplyLabelServicePostIn,
createAttachmentKeyData, createAttachmentKeyData,
createClientClassifierResultPostIn,
createCreateExternalUserGroupData, createCreateExternalUserGroupData,
createCreateMailFolderData, createCreateMailFolderData,
createDeleteMailData, createDeleteMailData,
@ -155,11 +158,8 @@ import { EntityUpdateData, isUpdateForTypeRef } from "../../../common/utils/Enti
import { Entity } from "../../../common/EntityTypes" import { Entity } from "../../../common/EntityTypes"
import { KeyVerificationMismatchError } from "../../../common/error/KeyVerificationMismatchError" import { KeyVerificationMismatchError } from "../../../common/error/KeyVerificationMismatchError"
import { VerifiedPublicEncryptionKey } from "./KeyVerificationFacade" import { VerifiedPublicEncryptionKey } from "./KeyVerificationFacade"
import { SpamClassifier, SpamPredMailDatum } from "../../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
import { isDraft } from "../../../../../mail-app/mail/model/MailChecks"
import { Nullable } from "@tutao/tutanota-utils/dist/Utils" import { Nullable } from "@tutao/tutanota-utils/dist/Utils"
import { ClientClassifierType } from "../../../common/ClientClassifierType" import { ClientClassifierType } from "../../../common/ClientClassifierType"
import { getMailBodyText } from "../../../common/CommonMailUtils"
assertWorkerOrNode() assertWorkerOrNode()
type Attachments = ReadonlyArray<TutanotaFile | DataFile | FileReference> type Attachments = ReadonlyArray<TutanotaFile | DataFile | FileReference>
@ -208,7 +208,6 @@ export class MailFacade {
private readonly loginFacade: LoginFacade, private readonly loginFacade: LoginFacade,
private readonly keyLoaderFacade: KeyLoaderFacade, private readonly keyLoaderFacade: KeyLoaderFacade,
private readonly publicEncryptionKeyProvider: PublicEncryptionKeyProvider, private readonly publicEncryptionKeyProvider: PublicEncryptionKeyProvider,
private readonly spamClassifier: SpamClassifier | null,
) {} ) {}
async createMailFolder(name: string, parent: IdTuple | null, ownerGroupId: Id): Promise<void> { async createMailFolder(name: string, parent: IdTuple | null, ownerGroupId: Id): Promise<void> {
@ -392,7 +391,12 @@ export class MailFacade {
/** /**
* Move mails from {@param targetFolder} except those that are in {@param excludeMailSet}. * Move mails from {@param targetFolder} except those that are in {@param excludeMailSet}.
*/ */
async moveMails(mails: readonly IdTuple[], targetFolder: IdTuple, excludeMailSet: IdTuple | null): Promise<MovedMails[]> { async moveMails(
mails: readonly IdTuple[],
targetFolder: IdTuple,
excludeMailSet: IdTuple | null,
moveReason: ClientClassifierType | null = null,
): Promise<MovedMails[]> {
if (isEmpty(mails)) { if (isEmpty(mails)) {
return [] return []
} }
@ -401,7 +405,7 @@ export class MailFacade {
const mailsPerList = groupBy(mails, (mailId) => listIdPart(mailId)) const mailsPerList = groupBy(mails, (mailId) => listIdPart(mailId))
const movedMails: MovedMails[] = [] const movedMails: MovedMails[] = []
for (const [_, mailsInList] of mailsPerList) { for (const [_, mailsInList] of mailsPerList) {
const mailChunks = splitInChunks(MAX_NBR_MOVE_DELETE_MAIL_SERVICE, mailsInList) const mailChunks = splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, mailsInList)
for (const mails of mailChunks) { for (const mails of mailChunks) {
const moveMailPostOut = await this.serviceExecutor.post( const moveMailPostOut = await this.serviceExecutor.post(
MoveMailService, MoveMailService,
@ -409,7 +413,7 @@ export class MailFacade {
mails, mails,
excludeMailSet, excludeMailSet,
targetFolder, targetFolder,
moveReason: null, moveReason,
}), }),
) )
movedMails.push(...moveMailPostOut.movedMails) movedMails.push(...moveMailPostOut.movedMails)
@ -421,13 +425,13 @@ export class MailFacade {
async simpleMoveMails( async simpleMoveMails(
mails: readonly IdTuple[], mails: readonly IdTuple[],
targetFolderKind: SimpleMoveMailTarget, targetFolderKind: SimpleMoveMailTarget,
clientSpamClassifier: Nullable<ClientClassifierType>, moveReason: Nullable<ClientClassifierType>,
): Promise<MovedMails[]> { ): Promise<MovedMails[]> {
if (isEmpty(mails)) { if (isEmpty(mails)) {
return [] return []
} }
const mailChunks = splitInChunks(MAX_NBR_MOVE_DELETE_MAIL_SERVICE, mails) const mailChunks = splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, mails)
const movedMails: MovedMails[] = [] const movedMails: MovedMails[] = []
for (const mails of mailChunks) { for (const mails of mailChunks) {
const simpleMove = await this.serviceExecutor.post( const simpleMove = await this.serviceExecutor.post(
@ -435,7 +439,7 @@ export class MailFacade {
createSimpleMoveMailPostIn({ createSimpleMoveMailPostIn({
mails, mails,
destinationSetType: targetFolderKind, destinationSetType: targetFolderKind,
moveReason: clientSpamClassifier, moveReason,
}), }),
) )
movedMails.push(...simpleMove.movedMails) movedMails.push(...simpleMove.movedMails)
@ -453,28 +457,6 @@ export class MailFacade {
await this.serviceExecutor.post(ReportMailService, postData) await this.serviceExecutor.post(ReportMailService, postData)
} }
public isSpamClassificationEnabled(ownerGroup: Id): boolean {
return this.spamClassifier != null && this.spamClassifier.getEnabledSpamClassifierForOwnerGroup(ownerGroup) != null
}
async predictSpamResult(mail: Mail): Promise<Nullable<boolean>> {
if (isDraft(mail)) {
return null
} else {
const spamClassifier = this.spamClassifier?.getEnabledSpamClassifierForOwnerGroup(assertNotNull(mail._ownerGroup)) ?? null
if (isNotNull(spamClassifier)) {
const mailDetails = await this.loadMailDetailsBlob(mail)
const spamPredMailDatum: SpamPredMailDatum = {
subject: mail.subject,
body: getMailBodyText(mailDetails.body),
ownerGroup: assertNotNull(mail._ownerGroup),
}
return await assertNotNull(this.spamClassifier).predict(spamPredMailDatum)
}
return null
}
}
async deleteMails(mails: readonly IdTuple[], filterMailSet: IdTuple | null): Promise<void> { async deleteMails(mails: readonly IdTuple[], filterMailSet: IdTuple | null): Promise<void> {
if (isEmpty(mails)) { if (isEmpty(mails)) {
return return
@ -483,7 +465,7 @@ export class MailFacade {
// Must be split by list (mailbag) // Must be split by list (mailbag)
const mailsGrouped = groupBy(mails, listIdPart) const mailsGrouped = groupBy(mails, listIdPart)
for (const [_, mails] of mailsGrouped) { for (const [_, mails] of mailsGrouped) {
const mailChunks = splitInChunks(MAX_NBR_MOVE_DELETE_MAIL_SERVICE, mails) const mailChunks = splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, mails)
for (const mailChunk of mailChunks) { for (const mailChunk of mailChunks) {
const deleteMailData = createDeleteMailData({ const deleteMailData = createDeleteMailData({
mails: mailChunk, mails: mailChunk,
@ -1197,7 +1179,7 @@ export class MailFacade {
*/ */
async markMails(mails: readonly IdTuple[], unread: boolean) { async markMails(mails: readonly IdTuple[], unread: boolean) {
await promiseMap( await promiseMap(
splitInChunks(MAX_NBR_MOVE_DELETE_MAIL_SERVICE, mails), splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, mails),
async (mails) => async (mails) =>
this.serviceExecutor.post( this.serviceExecutor.post(
UnreadMailStateService, UnreadMailStateService,
@ -1210,6 +1192,27 @@ export class MailFacade {
) )
} }
/**
* Mark the given mails as read/unread
* @param mails mail ids to mark as unread
* @param processingState
*/
async updateMailPredictionState(mails: readonly IdTuple[], processingState: ProcessingState) {
const isPredictionMade = processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE
await promiseMap(
splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, mails),
async (mails) =>
this.serviceExecutor.post(
ClientClassifierResultService,
createClientClassifierResultPostIn({
mails,
isPredictionMade: isPredictionMade,
}),
),
{ concurrency: 5 },
)
}
/** Resolve conversation list ids to the IDs of mails in those conversations. */ /** Resolve conversation list ids to the IDs of mails in those conversations. */
async resolveConversations(conversationListIds: readonly Id[]): Promise<IdTuple[]> { async resolveConversations(conversationListIds: readonly Id[]): Promise<IdTuple[]> {
const result = await promiseMap( const result = await promiseMap(

View file

@ -47,6 +47,8 @@ import { TypeModelResolver } from "../../common/EntityFunctions"
import { collapseId, expandId } from "../rest/RestClientIdUtils" import { collapseId, expandId } from "../rest/RestClientIdUtils"
import { Category, syncMetrics } from "../utils/SyncMetrics" import { Category, syncMetrics } from "../utils/SyncMetrics"
import { hasError } from "../../common/utils/ErrorUtils" import { hasError } from "../../common/utils/ErrorUtils"
import { SpamClassificationModel, SpamTrainMailDatum } from "../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
import { Mail } from "../../entities/tutanota/TypeRefs"
/** /**
* this is the value of SQLITE_MAX_VARIABLE_NUMBER in sqlite3.c * this is the value of SQLITE_MAX_VARIABLE_NUMBER in sqlite3.c
@ -102,6 +104,7 @@ export interface OfflineDbMeta {
// offline db schema version // offline db schema version
"offline-version": number "offline-version": number
lastTrainedTime: number lastTrainedTime: number
lastTrainedFromScratchTime: number
} }
export const TableDefinitions = Object.freeze({ export const TableDefinitions = Object.freeze({
@ -719,6 +722,14 @@ export class OfflineStorage implements CacheStorage {
await this.putMetadata("lastTrainedTime", ms) await this.putMetadata("lastTrainedTime", ms)
} }
async getLastTrainedFromScratchTime(): Promise<number> {
return (await this.getMetadata("lastTrainedFromScratchTime")) ?? Date.now()
}
async setLastTrainedFromScratchTime(ms: number): Promise<void> {
await this.putMetadata("lastTrainedFromScratchTime", ms)
}
async purgeStorage(): Promise<void> { async purgeStorage(): Promise<void> {
if (this.userId == null || this.databaseKey == null) { if (this.userId == null || this.databaseKey == null) {
console.warn("not purging storage since we don't have an open db") console.warn("not purging storage since we don't have an open db")

View file

@ -1,11 +1,10 @@
import { CacheStorage, LastUpdateTime, Range } from "./DefaultEntityRestCache.js" import { CacheStorage, LastUpdateTime, Range } from "./DefaultEntityRestCache.js"
import { ProgrammingError } from "../../common/error/ProgrammingError" import { ProgrammingError } from "../../common/error/ProgrammingError"
import { Entity, ListElementEntity, ServerModelParsedInstance, SomeEntity } from "../../common/EntityTypes" import { Entity, ListElementEntity, ServerModelParsedInstance, SomeEntity } from "../../common/EntityTypes"
import { TypeRef } from "@tutao/tutanota-utils" import { Nullable, TypeRef } from "@tutao/tutanota-utils"
import { OfflineStorage, OfflineStorageInitArgs } from "../offline/OfflineStorage.js" import { OfflineStorage, OfflineStorageInitArgs } from "../offline/OfflineStorage.js"
import { EphemeralCacheStorage, EphemeralStorageInitArgs } from "./EphemeralCacheStorage" import { EphemeralCacheStorage, EphemeralStorageInitArgs } from "./EphemeralCacheStorage"
import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js" import { CustomCacheHandlerMap } from "./cacheHandler/CustomCacheHandler.js"
import { Nullable } from "@tutao/tutanota-utils"
export interface EphemeralStorageArgs extends EphemeralStorageInitArgs { export interface EphemeralStorageArgs extends EphemeralStorageInitArgs {
type: "ephemeral" type: "ephemeral"
@ -194,6 +193,14 @@ export class LateInitializedCacheStorageImpl implements CacheStorageLateInitiali
return this.inner.getLastTrainedTime() return this.inner.getLastTrainedTime()
} }
setLastTrainedFromScratchTime(ms: number): Promise<void> {
return this.inner.setLastTrainedFromScratchTime(ms)
}
getLastTrainedFromScratchTime(): Promise<number> {
return this.inner.getLastTrainedFromScratchTime() ?? Date.now()
}
setLowerRangeForList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id, id: Id): Promise<void> { setLowerRangeForList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id, id: Id): Promise<void> {
return this.inner.setLowerRangeForList(typeRef, listId, id) return this.inner.setLowerRangeForList(typeRef, listId, id)
} }

View file

@ -8,7 +8,7 @@ import {
OwnerEncSessionKeyProvider, OwnerEncSessionKeyProvider,
} from "./EntityRestClient" } from "./EntityRestClient"
import { OperationType } from "../../common/TutanotaConstants" import { OperationType } from "../../common/TutanotaConstants"
import { assertNotNull, downcast, getFirstOrThrow, getTypeString, isNotEmpty, isSameTypeRef, lastThrow, TypeRef } from "@tutao/tutanota-utils" import { assertNotNull, downcast, getFirstOrThrow, getTypeString, isNotEmpty, isSameTypeRef, lastThrow, Nullable, TypeRef } from "@tutao/tutanota-utils"
import { import {
AuditLogEntryTypeRef, AuditLogEntryTypeRef,
BucketPermissionTypeRef, BucketPermissionTypeRef,
@ -25,7 +25,7 @@ import {
UserGroupRootTypeRef, UserGroupRootTypeRef,
} from "../../entities/sys/TypeRefs.js" } from "../../entities/sys/TypeRefs.js"
import { ValueType } from "../../common/EntityConstants.js" import { ValueType } from "../../common/EntityConstants.js"
import { Body, CalendarEventUidIndexTypeRef, Mail, MailDetailsBlobTypeRef, MailSetEntryTypeRef, MailTypeRef } from "../../entities/tutanota/TypeRefs.js" import { CalendarEventUidIndexTypeRef, MailDetailsBlobTypeRef, MailSetEntryTypeRef, MailTypeRef } from "../../entities/tutanota/TypeRefs.js"
import { import {
CUSTOM_MAX_ID, CUSTOM_MAX_ID,
CUSTOM_MIN_ID, CUSTOM_MIN_ID,
@ -47,9 +47,7 @@ import { TypeModelResolver } from "../../common/EntityFunctions"
import { AttributeModel } from "../../common/AttributeModel" import { AttributeModel } from "../../common/AttributeModel"
import { collapseId, expandId } from "./RestClientIdUtils" import { collapseId, expandId } from "./RestClientIdUtils"
import { PatchMerger } from "../offline/PatchMerger" import { PatchMerger } from "../offline/PatchMerger"
import { NotAuthorizedError, NotFoundError } from "../../common/error/RestError" import { hasError, isExpectedErrorForSynchronization } from "../../common/utils/ErrorUtils"
import { Nullable } from "@tutao/tutanota-utils"
import { hasError } from "../../common/utils/ErrorUtils"
assertWorkerOrNode() assertWorkerOrNode()
@ -259,6 +257,10 @@ export interface CacheStorage extends ExposedCacheStorage {
setLastTrainedTime(value: number): Promise<void> setLastTrainedTime(value: number): Promise<void>
getLastTrainedFromScratchTime(): Promise<number>
setLastTrainedFromScratchTime(value: number): Promise<void>
getUserId(): Id getUserId(): Id
deleteAllOwnedBy(owner: Id): Promise<void> deleteAllOwnedBy(owner: Id): Promise<void>
@ -954,14 +956,6 @@ export class DefaultEntityRestCache implements EntityRestCache {
} }
} }
/**
* Returns whether the error is expected for the cases where our local state might not be up-to-date with the server yet. E.g. we might be processing an update
* for the instance that was already deleted. Normally this would be optimized away but it might still happen due to timing.
*/
function isExpectedErrorForSynchronization(e: Error): boolean {
return e instanceof NotFoundError || e instanceof NotAuthorizedError
}
/** /**
* Check if a range request begins inside an existing range * Check if a range request begins inside an existing range
*/ */

View file

@ -43,6 +43,7 @@ export class EphemeralCacheStorage implements CacheStorage {
private readonly blobEntities: Map<string, BlobElementTypeCache> = new Map() private readonly blobEntities: Map<string, BlobElementTypeCache> = new Map()
private lastUpdateTime: number | null = null private lastUpdateTime: number | null = null
private lastTrainedTime: number | null = null private lastTrainedTime: number | null = null
private lastTrainedFromScratchTime: number | null = null
private userId: Id | null = null private userId: Id | null = null
private lastBatchIdPerGroup = new Map<Id, Id>() private lastBatchIdPerGroup = new Map<Id, Id>()
@ -426,6 +427,14 @@ export class EphemeralCacheStorage implements CacheStorage {
this.lastTrainedTime = value this.lastTrainedTime = value
} }
async getLastTrainedFromScratchTime(): Promise<number> {
return this.lastTrainedFromScratchTime ?? Date.now()
}
async setLastTrainedFromScratchTime(ms: number): Promise<void> {
this.lastTrainedFromScratchTime = ms
}
async getWholeList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id): Promise<Array<T>> { async getWholeList<T extends ListElementEntity>(typeRef: TypeRef<T>, listId: Id): Promise<Array<T>> {
const parsedInstances = await this.getWholeListParsed(typeRef, listId) const parsedInstances = await this.getWholeListParsed(typeRef, listId)
return await this.modelMapper.mapToInstances(typeRef, parsedInstances) return await this.modelMapper.mapToInstances(typeRef, parsedInstances)

View file

@ -1,27 +1,14 @@
import { Mail, MailFolder, MailFolderTypeRef, MailSetEntryTypeRef, MailTypeRef } from "../../../entities/tutanota/TypeRefs" import { Mail, MailDetailsBlobTypeRef } from "../../../entities/tutanota/TypeRefs"
import { assertNotNull, isSameTypeRef, lazy, lazyAsync } from "@tutao/tutanota-utils" import { assertNotNull, lazy, lazyAsync } from "@tutao/tutanota-utils"
import { MailIndexer } from "../../../../../mail-app/workerUtils/index/MailIndexer" import { MailIndexer } from "../../../../../mail-app/workerUtils/index/MailIndexer"
import { CustomCacheHandler } from "./CustomCacheHandler" import { CustomCacheHandler } from "./CustomCacheHandler"
import { EntityUpdateData } from "../../../common/utils/EntityUpdateUtils"
import { MailFacade } from "../../facades/lazy/MailFacade"
import { OfflineStoragePersistence } from "../../../../../mail-app/workerUtils/index/OfflineStoragePersistence" import { OfflineStoragePersistence } from "../../../../../mail-app/workerUtils/index/OfflineStoragePersistence"
import { MailSetKind } from "../../../common/TutanotaConstants"
import { CacheStorage } from "../DefaultEntityRestCache"
import { elementIdPart, isSameId, listIdPart } from "../../../common/utils/EntityUtils"
import { ClientClassifierType } from "../../../common/ClientClassifierType"
import { MailWithDetailsAndAttachments } from "../../../../../mail-app/workerUtils/index/MailIndexerBackend"
import { getMailBodyText } from "../../../common/CommonMailUtils"
import { SpamTrainMailDatum } from "../../../../../mail-app/workerUtils/spamClassification/SpamClassifier"
/** /**
* Handles telling the indexer to index or un-index mail data on updates. * Handles telling the indexer to index or un-index mail data on updates.
*/ */
export class CustomMailEventCacheHandler implements CustomCacheHandler<Mail> { export class CustomMailEventCacheHandler implements CustomCacheHandler<Mail> {
constructor( constructor(private readonly indexer: lazyAsync<MailIndexer>) {}
private readonly indexerAndMailFacade: lazyAsync<{ mailIndexer: MailIndexer; mailFacade: MailFacade }>,
private readonly offlineStoragePersistence: lazy<Promise<OfflineStoragePersistence>>,
private readonly storage: CacheStorage,
) {}
shouldLoadOnCreateEvent(): boolean { shouldLoadOnCreateEvent(): boolean {
// New emails should be pre-cached. // New emails should be pre-cached.
@ -33,127 +20,17 @@ export class CustomMailEventCacheHandler implements CustomCacheHandler<Mail> {
} }
async onBeforeCacheDeletion(id: IdTuple): Promise<void> { async onBeforeCacheDeletion(id: IdTuple): Promise<void> {
const { mailIndexer } = await this.indexerAndMailFacade() const indexer = await this.indexer()
return mailIndexer.beforeMailDeleted(id) return indexer.beforeMailDeleted(id)
} }
async onEntityEventCreate(id: IdTuple, events: EntityUpdateData[]) { async onEntityEventCreate(id: IdTuple) {
const { mailIndexer, mailFacade } = await this.indexerAndMailFacade() const indexer = await this.indexer()
// At this point, the mail entity, itself, is cached, so when we go to download it again, it will come from cache return indexer.afterMailCreated(id)
const newMailData = await mailIndexer.downloadNewMailData(id)
await mailIndexer.afterMailCreated(id, newMailData)
await this.processSpam(newMailData, mailFacade, id)
} }
async onEntityEventUpdate(id: IdTuple, events: EntityUpdateData[]) { async onEntityEventUpdate(id: IdTuple) {
const { mailIndexer } = await this.indexerAndMailFacade() const indexer = await this.indexer()
await mailIndexer.afterMailUpdated(id) return indexer.afterMailUpdated(id)
await this.updateSpamClassificationData(events, id)
}
private async processSpam(newMailData: MailWithDetailsAndAttachments | null, mailFacade: MailFacade, id: readonly [string, string]) {
const usedClientSpamClassifier = ClientClassifierType.CLIENT_CLASSIFICATION
if (newMailData == null) {
return
}
// update spam classification table
const mail = newMailData.mail
const allFolders = await this.storage.getWholeList(MailFolderTypeRef, listIdPart(mail.sets[0]))
const spamFolder = allFolders.find((folder) => folder.folderType === MailSetKind.SPAM)!
const isStoredInSpamFolder = mail.sets.some((folderId) => isSameId(folderId, spamFolder._id))
const { isStoredInTrashFolder, confidence } = this.getSpamConfidence(allFolders, mail)
// isStoredInSpamFolder is true
// this might be run multiple times for a single user if they use multiple devices
const predictedSpam = isStoredInTrashFolder ? null : await mailFacade.predictSpamResult(mail)
// use the server classification for initial training, mixed with data from when user moves mails in and out of spam
const isSpam = predictedSpam ?? isStoredInSpamFolder
const offlineStoragePersistence = await this.offlineStoragePersistence()
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(newMailData.mailDetails.body),
isSpam,
isSpamConfidence: confidence,
ownerGroup: assertNotNull(mail._ownerGroup),
}
let moveServiceCall
if (!isStoredInTrashFolder && isSpam && !isStoredInSpamFolder) {
spamTrainMailDatum.isSpamConfidence = 1
moveServiceCall = mailFacade.simpleMoveMails([id], MailSetKind.SPAM, usedClientSpamClassifier)
} else if (!isStoredInTrashFolder && !isSpam && isStoredInSpamFolder) {
spamTrainMailDatum.isSpamConfidence = 0
moveServiceCall = mailFacade.simpleMoveMails([id], MailSetKind.INBOX, usedClientSpamClassifier)
}
await offlineStoragePersistence.storeSpamClassification(spamTrainMailDatum)
await moveServiceCall
}
private async updateSpamClassificationData(events: EntityUpdateData[], id: readonly [string, string]) {
const mail = assertNotNull(await this.storage.get(MailTypeRef, listIdPart(id), elementIdPart(id)))
const changedMailSetEntry = events.some((ev) => isSameTypeRef(ev.typeRef, MailSetEntryTypeRef))
const mailHasBeenRead = !mail.unread
if (!mailHasBeenRead && !changedMailSetEntry) {
return
}
const allFolders = await this.storage.getWholeList(MailFolderTypeRef, listIdPart(mail.sets[0]))
const spamFolder = allFolders.find((folder) => folder.folderType === MailSetKind.SPAM)!
const isSpam = mail.sets.some((folderId) => isSameId(folderId, spamFolder._id))
let { confidence: isSpamConfidence, isStoredInTrashFolder } = this.getSpamConfidence(allFolders, mail)
const offlineStoragePersistence = await this.offlineStoragePersistence()
const storedClassification = await offlineStoragePersistence.getStoredClassification(mail)
if (storedClassification != null) {
// email is in classification data
const wasDeletedFromSpamFolder = isStoredInTrashFolder && storedClassification.isSpam
if (wasDeletedFromSpamFolder) {
// This is the case if we delete from spam Folder, in that case we do not need any change in storedClassification
} else if (isSpam !== storedClassification.isSpam || isSpamConfidence !== storedClassification.isSpamConfidence) {
// the model has trained on the mail but the spamFlag was wrong so we refit with higher isSpamConfidence
await offlineStoragePersistence.updateSpamClassificationData(id, isSpam, isSpamConfidence)
}
} else {
const { mailIndexer } = await this.indexerAndMailFacade()
// At this point, the mail entity, itself, is cached, so when we go to download it again, it will come from cache
const newMailData = await mailIndexer.downloadNewMailData(id)
if (newMailData) {
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(newMailData.mailDetails.body),
isSpam,
isSpamConfidence,
ownerGroup: assertNotNull(mail._ownerGroup),
}
await offlineStoragePersistence.storeSpamClassification(spamTrainMailDatum)
} else {
// race: mail deleted in meantime
}
}
}
// visible for testing
public getSpamConfidence(allFolders: Array<MailFolder>, mail: Mail): { confidence: number; isStoredInTrashFolder: boolean } {
const spamFolder = allFolders.find((folder) => folder.folderType === MailSetKind.SPAM)!
const trashFolder = allFolders.find((folder) => folder.folderType === MailSetKind.TRASH)!
const isStoredInSpamFolder = mail.sets.some((folderId) => isSameId(folderId, spamFolder._id))
const isStoredInTrashFolder = mail.sets.some((folderId) => isSameId(folderId, trashFolder._id))
const isReadAndNotInSpamAndNotInTrash = !mail.unread && !isStoredInSpamFolder && !isStoredInTrashFolder
if (isStoredInSpamFolder || isReadAndNotInSpamAndNotInTrash) {
return { confidence: 1, isStoredInTrashFolder }
} else {
return { confidence: 0, isStoredInTrashFolder }
}
} }
} }

View file

@ -11,9 +11,10 @@ export class DesktopPostLoginActions implements PostLoginAction {
private readonly windowId: number, private readonly windowId: number,
) {} ) {}
async onPartialLoginSuccess({ userId }: LoggedInEvent): Promise<void> { async onPartialLoginSuccess({ userId }: LoggedInEvent): Promise<{ asyncAction: Promise<void> }> {
this.wm.get(this.windowId)?.setUserId(userId) this.wm.get(this.windowId)?.setUserId(userId)
await this.notifier.clearUserNotifications(userId) await this.notifier.clearUserNotifications(userId)
return { asyncAction: Promise.resolve() }
} }
async onFullLoginSuccess({ userId }: LoggedInEvent): Promise<void> { async onFullLoginSuccess({ userId }: LoggedInEvent): Promise<void> {

View file

@ -57,7 +57,7 @@ export class PostLoginActions implements PostLoginAction {
private readonly updateClient: () => unknown, private readonly updateClient: () => unknown,
) {} ) {}
async onPartialLoginSuccess(loggedInEvent: LoggedInEvent): Promise<void> { async onPartialLoginSuccess(loggedInEvent: LoggedInEvent): Promise<{ asyncAction: Promise<void> }> {
// We establish websocket connection even for temporary sessions because we need to get updates e.g. during signup // We establish websocket connection even for temporary sessions because we need to get updates e.g. during signup
windowFacade.addOnlineListener(() => { windowFacade.addOnlineListener(() => {
console.log(new Date().toISOString(), "online - try reconnect") console.log(new Date().toISOString(), "online - try reconnect")
@ -80,7 +80,7 @@ export class PostLoginActions implements PostLoginAction {
document.title = "Tuta Mail" document.title = "Tuta Mail"
} }
return return { asyncAction: Promise.resolve() }
} else { } else {
let postLoginTitle = document.title === LOGIN_TITLE ? "Tuta Mail" : document.title let postLoginTitle = document.title === LOGIN_TITLE ? "Tuta Mail" : document.title
document.title = neverNull(this.logins.getUserController().userGroupInfo.mailAddress) + " - " + postLoginTitle document.title = neverNull(this.logins.getUserController().userGroupInfo.mailAddress) + " - " + postLoginTitle
@ -106,6 +106,7 @@ export class PostLoginActions implements PostLoginAction {
if (isApp() || isDesktop()) { if (isApp() || isDesktop()) {
await this.storeNewCustomThemes() await this.storeNewCustomThemes()
} }
return { asyncAction: Promise.resolve() }
} }
async onFullLoginSuccess(loggedInEvent: LoggedInEvent): Promise<void> { async onFullLoginSuccess(loggedInEvent: LoggedInEvent): Promise<void> {

View file

@ -41,12 +41,13 @@ export class CachePostLoginAction implements PostLoginAction {
progressMonitor.completed() progressMonitor.completed()
} }
async onPartialLoginSuccess(event: LoggedInEvent): Promise<void> { async onPartialLoginSuccess(event: LoggedInEvent): Promise<{ asyncAction: Promise<void> }> {
if (event.sessionType === SessionType.Persistent && this.offlineStorageSettings != null) { if (event.sessionType === SessionType.Persistent && this.offlineStorageSettings != null) {
await this.offlineStorageSettings.init() await this.offlineStorageSettings.init()
// Clear the excluded data (i.e. trash and spam lists, old data) in the offline storage. // Clear the excluded data (i.e. trash and spam lists, old data) in the offline storage.
await this.cacheStorage.clearExcludedData(this.offlineStorageSettings.getTimeRange()) await this.cacheStorage.clearExcludedData(this.offlineStorageSettings.getTimeRange())
} }
return { asyncAction: Promise.resolve() }
} }
} }

View file

@ -10,6 +10,7 @@ import { ToggleButton } from "../../gui/base/buttons/ToggleButton.js"
import { isApp, isDesktop } from "../../api/common/Env.js" import { isApp, isDesktop } from "../../api/common/Env.js"
import { LoginButton } from "../../gui/base/buttons/LoginButton.js" import { LoginButton } from "../../gui/base/buttons/LoginButton.js"
import { lang } from "../../misc/LanguageViewModel.js" import { lang } from "../../misc/LanguageViewModel.js"
import { ProcessingState } from "../../api/common/TutanotaConstants"
export const BUTTON_WIDTH = 270 export const BUTTON_WIDTH = 270
@ -94,6 +95,8 @@ export class CustomColorEditorPreview implements Component {
phishingStatus: "0", phishingStatus: "0",
recipientCount: "0", recipientCount: "0",
sets: [], sets: [],
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED,
clientSpamClassifierResult: null,
} satisfies Partial<Mail> } satisfies Partial<Mail>
const mail = createMail({ const mail = createMail({
sender: createMailAddress({ sender: createMailAddress({

View file

@ -155,13 +155,14 @@ import("./translations/en.js")
const canSync = await syncManager.canSync() const canSync = await syncManager.canSync()
if (!canSync) { if (!canSync) {
await syncManager.disableSync() await syncManager.disableSync()
return return { asyncAction: Promise.resolve() }
} }
} }
syncManager.syncContacts() syncManager.syncContacts()
} }
await mailLocator.mailboxModel.init() await mailLocator.mailboxModel.init()
await mailLocator.mailModel.init() await mailLocator.mailModel.init()
return { asyncAction: Promise.resolve() }
}, },
async onFullLoginSuccess() { async onFullLoginSuccess() {
// We might have outdated Customer features, force reload the customer to make sure the customizations are up-to-date // We might have outdated Customer features, force reload the customer to make sure the customizations are up-to-date
@ -235,7 +236,9 @@ import("./translations/en.js")
if (isDesktop()) { if (isDesktop()) {
mailLocator.logins.addPostLoginAction(async () => { mailLocator.logins.addPostLoginAction(async () => {
return { return {
onPartialLoginSuccess: async () => {}, onPartialLoginSuccess: async () => {
return { asyncAction: Promise.resolve() }
},
onFullLoginSuccess: async (event) => { onFullLoginSuccess: async (event) => {
// not a temporary aka signup login // not a temporary aka signup login
if (event.sessionType === SessionType.Persistent) { if (event.sessionType === SessionType.Persistent) {

View file

@ -30,9 +30,10 @@ export class OpenLocallySavedDraftAction implements PostLoginAction {
async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {} async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {}
async onPartialLoginSuccess(_: LoggedInEvent): Promise<void> { async onPartialLoginSuccess(_: LoggedInEvent): Promise<{ asyncAction: Promise<void> }> {
// fire and forget; this might take some time // fire and forget; this might take some time
this._loadAutosavedDraft() const asyncAction = this._loadAutosavedDraft()
return { asyncAction }
} }
/** /**

View file

@ -1,31 +1,34 @@
import type { InboxRule, Mail, MailFolder, MoveMailData } from "../../../common/api/entities/tutanota/TypeRefs.js" import { createMoveMailData, InboxRule, Mail, MailFolder, MoveMailData } from "../../../common/api/entities/tutanota/TypeRefs.js"
import { createMoveMailData } from "../../../common/api/entities/tutanota/TypeRefs.js" import { InboxRuleType, MailSetKind, MAX_NBR_OF_MAILS_SYNC_OPERATION, ProcessingState } from "../../../common/api/common/TutanotaConstants"
import { InboxRuleType, MailSetKind, MAX_NBR_MOVE_DELETE_MAIL_SERVICE } from "../../../common/api/common/TutanotaConstants"
import { isDomainName, isRegularExpression } from "../../../common/misc/FormatValidator" import { isDomainName, isRegularExpression } from "../../../common/misc/FormatValidator"
import { assertNotNull, asyncFind, ofClass, promiseMap, splitInChunks, throttleStart } from "@tutao/tutanota-utils" import { assertNotNull, asyncFind, debounce, ofClass, promiseMap, splitInChunks, throttleStart } from "@tutao/tutanota-utils"
import { lang } from "../../../common/misc/LanguageViewModel" import { lang } from "../../../common/misc/LanguageViewModel"
import type { MailboxDetail } from "../../../common/mailFunctionality/MailboxModel.js" import type { MailboxDetail } from "../../../common/mailFunctionality/MailboxModel.js"
import { LockedError, PreconditionFailedError } from "../../../common/api/common/error/RestError" import { LockedError, PreconditionFailedError } from "../../../common/api/common/error/RestError"
import type { SelectorItemList } from "../../../common/gui/base/DropDownSelector.js" import type { SelectorItemList } from "../../../common/gui/base/DropDownSelector.js"
import { elementIdPart, isSameId } from "../../../common/api/common/utils/EntityUtils" import { elementIdPart, isSameId } from "../../../common/api/common/utils/EntityUtils"
import { assertMainOrNode } from "../../../common/api/common/Env" import { assertMainOrNode, isWebClient } from "../../../common/api/common/Env"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js" import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js"
import { LoginController } from "../../../common/api/main/LoginController.js" import { LoginController } from "../../../common/api/main/LoginController.js"
import { getMailHeaders } from "./MailUtils.js" import { getMailHeaders } from "./MailUtils.js"
import { MailModel } from "./MailModel" import { MailModel } from "./MailModel"
import { ClientClassifierType } from "../../../common/api/common/ClientClassifierType"
assertMainOrNode() assertMainOrNode()
const moveMailDataPerFolder: MoveMailData[] = [] const moveMailDataPerFolder: MoveMailData[] = []
const DEBOUNCE_FIRST_MOVE_MAIL_REQUEST_MS = 200 let noRuleMatchMailIds: IdTuple[] = []
let applyingRules = false // used to avoid concurrent application of rules (-> requests to locked service)
const THROTTLE_MOVE_MAIL_SERVICE_REQUESTS_MS = 200
const DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS = 1000
async function sendMoveMailRequest(mailFacade: MailFacade): Promise<void> { async function sendMoveMailRequest(mailFacade: MailFacade): Promise<void> {
if (moveMailDataPerFolder.length) { if (moveMailDataPerFolder.length) {
const moveToTargetFolder = assertNotNull(moveMailDataPerFolder.shift()) const moveToTargetFolder = assertNotNull(moveMailDataPerFolder.shift())
const mailChunks = splitInChunks(MAX_NBR_MOVE_DELETE_MAIL_SERVICE, moveToTargetFolder.mails) const mailChunks = splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, moveToTargetFolder.mails)
await promiseMap(mailChunks, (mailChunk) => { await promiseMap(mailChunks, (mailChunk) => {
moveToTargetFolder.mails = mailChunk moveToTargetFolder.mails = mailChunk
return mailFacade.moveMails(mailChunk, moveToTargetFolder.targetFolder, null) return mailFacade.moveMails(mailChunk, moveToTargetFolder.targetFolder, null, ClientClassifierType.CUSTOMER_INBOX_RULES)
}) })
.catch( .catch(
ofClass(LockedError, (e) => { ofClass(LockedError, (e) => {
@ -40,20 +43,25 @@ async function sendMoveMailRequest(mailFacade: MailFacade): Promise<void> {
}), }),
) )
.finally(() => { .finally(() => {
return sendMoveMailRequest(mailFacade) return processMatchingRules(mailFacade)
}) })
} //We are done and unlock for future requests }
} }
// We throttle the moveMail requests to a rate of 200ms const processMatchingRules = throttleStart(THROTTLE_MOVE_MAIL_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => {
// Each target folder requires one request // Each target folder requires one request,
const applyMatchingRules = throttleStart(DEBOUNCE_FIRST_MOVE_MAIL_REQUEST_MS, async (mailFacade: MailFacade) => { // We debounce the requests to a rate of THROTTLE_MOVE_MAIL_SERVICE_REQUESTS_MS
if (applyingRules) return return sendMoveMailRequest(mailFacade)
// We lock to avoid concurrent requests })
applyingRules = true
sendMoveMailRequest(mailFacade).finally(() => { const processNotMatchingRules = debounce(DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => {
applyingRules = false // Each update to ClientClassifierResultService (for mails that did not move) requires one request
}) // We debounce the requests to a rate of DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS
if (noRuleMatchMailIds.length) {
const mailIds = noRuleMatchMailIds
noRuleMatchMailIds = []
return mailFacade.updateMailPredictionState(mailIds, ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING)
}
}) })
export function getInboxRuleTypeNameMapping(): SelectorItemList<string> { export function getInboxRuleTypeNameMapping(): SelectorItemList<string> {
@ -101,16 +109,8 @@ export class InboxRuleHandler {
* Checks the mail for an existing inbox rule and moves the mail to the target folder of the rule. * Checks the mail for an existing inbox rule and moves the mail to the target folder of the rule.
* @returns true if a rule matches otherwise false * @returns true if a rule matches otherwise false
*/ */
async findAndApplyMatchingRule( async findAndApplyMatchingRule(mailboxDetail: MailboxDetail, mail: Readonly<Mail>, applyRulesOnServer: boolean): Promise<MailFolder | null> {
mailboxDetail: MailboxDetail, const shouldApply = mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED
mail: Mail,
applyRulesOnServer: boolean,
applyIfRead: boolean,
): Promise<{
folder: MailFolder
mail: Mail
} | null> {
const shouldApply = applyIfRead || mail.unread
if ( if (
mail._errors || mail._errors ||
@ -138,19 +138,25 @@ export class InboxRuleHandler {
targetFolder: inboxRule.targetFolder, targetFolder: inboxRule.targetFolder,
mails: [mail._id], mails: [mail._id],
excludeMailSet: null, excludeMailSet: null,
moveReason: null, moveReason: ClientClassifierType.CUSTOMER_INBOX_RULES,
}) })
moveMailDataPerFolder.push(moveMailData) moveMailDataPerFolder.push(moveMailData)
} }
applyMatchingRules(this.mailFacade)
} }
return { folder: targetFolder, mail } processMatchingRules(this.mailFacade)
return targetFolder
} else { } else {
return null return null
} }
} else { } else {
// if we are not on the webapp this is handled in SpamClassificationHandler
if (isWebClient()) {
noRuleMatchMailIds.push(mail._id)
processNotMatchingRules(this.mailFacade)
}
return null return null
} }
} }

View file

@ -5,12 +5,13 @@ import { FolderSystem } from "../../../common/api/common/mail/FolderSystem.js"
import { import {
assertNotNull, assertNotNull,
collectToMap, collectToMap,
downcast,
getFirstOrThrow, getFirstOrThrow,
groupBy, groupBy,
groupByAndMap, groupByAndMap,
isNotNull, isNotNull,
lazyMemoized, lazyMemoized,
noOp, Nullable,
ofClass, ofClass,
partition, partition,
promiseMap, promiseMap,
@ -31,8 +32,9 @@ import {
isLabel, isLabel,
MailReportType, MailReportType,
MailSetKind, MailSetKind,
MAX_NBR_MOVE_DELETE_MAIL_SERVICE, MAX_NBR_OF_MAILS_SYNC_OPERATION,
OperationType, OperationType,
ProcessingState,
ReportMovedMailsType, ReportMovedMailsType,
SimpleMoveMailTarget, SimpleMoveMailTarget,
SystemFolderType, SystemFolderType,
@ -54,6 +56,9 @@ import { LoginController } from "../../../common/api/main/LoginController.js"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js" import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade.js"
import { assertSystemFolderOfType } from "./MailUtils.js" import { assertSystemFolderOfType } from "./MailUtils.js"
import { TutanotaError } from "@tutao/tutanota-error" import { TutanotaError } from "@tutao/tutanota-error"
import { SpamClassificationHandler } from "../../workerUtils/spamClassification/SpamClassificationHandler"
import { isWebClient } from "../../../common/api/common/Env"
import { isExpectedErrorForSynchronization } from "../../../common/api/common/utils/ErrorUtils"
interface MailboxSets { interface MailboxSets {
folders: FolderSystem folders: FolderSystem
@ -90,6 +95,7 @@ export class MailModel {
private readonly logins: LoginController, private readonly logins: LoginController,
private readonly mailFacade: MailFacade, private readonly mailFacade: MailFacade,
private readonly connectivityModel: WebsocketConnectivityModel | null, private readonly connectivityModel: WebsocketConnectivityModel | null,
private readonly spamHandler: () => SpamClassificationHandler,
private readonly inboxRuleHandler: () => InboxRuleHandler | null, private readonly inboxRuleHandler: () => InboxRuleHandler | null,
) {} ) {}
@ -182,53 +188,103 @@ export class MailModel {
} }
// visibleForTesting // visibleForTesting
async entityEventsReceived(updates: ReadonlyArray<EntityUpdateData>): Promise<void> { async entityEventsReceived(updates: ReadonlyArray<EntityUpdateData>): Promise<{ processingDone: Promise<void> }> {
for (const update of updates) { for (const update of updates) {
if (isUpdateForTypeRef(MailFolderTypeRef, update)) { if (isUpdateForTypeRef(MailFolderTypeRef, update)) {
await this.init() await this.init()
m.redraw() m.redraw()
} else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.CREATE) { } else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.UPDATE) {
if (this.inboxRuleHandler && this.connectivityModel) {
const mailId: IdTuple = [update.instanceListId, update.instanceId] const mailId: IdTuple = [update.instanceListId, update.instanceId]
try { const mail = await this.loadMail(mailId)
const mail = await this.entityClient.load(MailTypeRef, mailId) if (mail == null) {
const folder = this.getMailFolderForMail(mail) return { processingDone: Promise.resolve() }
}
const spamHandler = this.spamHandler()
await spamHandler.updateSpamClassificationData(mail)
} else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.CREATE) {
const mailId: IdTuple = [update.instanceListId, update.instanceId]
const mail = await this.loadMail(mailId)
if (mail == null) {
return { processingDone: Promise.resolve() }
}
if (folder && folder.folderType === MailSetKind.INBOX) { // If an inbox rule has been applied or a spam prediction has been made
// If we don't find another delete operation on this email in the batch, then it should be a create operation, // we can return, because those are the two final processing states
// otherwise it's a move if (
await this.getMailboxDetailsForMail(mail) mail.processingState === ProcessingState.INBOX_RULE_APPLIED ||
.then((mailboxDetail) => { mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE
) {
return { processingDone: Promise.resolve() }
}
// The webapp currently does not support spam prediction, and the inbox rule has been processed
if (isWebClient() && mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING) {
return { processingDone: Promise.resolve() }
}
const sourceMailFolder = this.getMailFolderForMail(mail)
if (sourceMailFolder == null) {
return { processingDone: Promise.resolve() }
}
const isLeaderClient = this.connectivityModel?.isLeader() ?? false
if (sourceMailFolder.folderType === MailSetKind.INBOX) {
const isInboxRuleTargetFolder = await this.getMailboxDetailsForMail(mail).then((mailboxDetail) => {
// We only apply rules on server if we are the leader in case of incoming messages // We only apply rules on server if we are the leader in case of incoming messages
return ( return mailboxDetail && this.inboxRuleHandler()?.findAndApplyMatchingRule(mailboxDetail, mail, isLeaderClient)
mailboxDetail &&
this.inboxRuleHandler()?.findAndApplyMatchingRule(
mailboxDetail,
mail,
this.connectivityModel ? this.connectivityModel.isLeader() : false,
false,
)
)
}) })
.then((newFolderAndMail) => {
if (newFolderAndMail) { if (isWebClient()) {
this._showNotification(newFolderAndMail.folder, newFolderAndMail.mail) // we only need to show notifications explicitly on the webapp
this._showNotification(isInboxRuleTargetFolder ?? sourceMailFolder, mail)
} else { } else {
this._showNotification(folder, mail) const mailDetails = await this.mailFacade.loadMailDetailsBlob(mail)
this.spamHandler().storeTrainingDatum(mail, mailDetails)
if (isInboxRuleTargetFolder) {
return { processingDone: Promise.resolve() }
} else if (
(isLeaderClient && mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED) ||
mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING
) {
const folderSystem = this.getFolderSystemByGroupId(assertNotNull(mail._ownerGroup))
if (sourceMailFolder && folderSystem) {
const predictPromise = this.spamHandler().predictSpamForNewMail(mail, mailDetails, sourceMailFolder, folderSystem)
return { processingDone: downcast(predictPromise) }
} }
})
.catch(noOp)
} }
} catch (e) { }
if (e instanceof NotFoundError) { } else if (sourceMailFolder.folderType === MailSetKind.SPAM) {
console.log(`Could not find updated mail ${JSON.stringify(mailId)}`) const mailDetails = await this.mailFacade.loadMailDetailsBlob(mail)
} else { this.spamHandler().storeTrainingDatum(mail, mailDetails)
if (
(isLeaderClient && mail.processingState === ProcessingState.INBOX_RULE_NOT_PROCESSED) ||
mail.processingState === ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_PENDING
) {
const folderSystem = this.getFolderSystemByGroupId(assertNotNull(mail._ownerGroup))
if (sourceMailFolder && folderSystem) {
const predictPromise = this.spamHandler().predictSpamForNewMail(mail, mailDetails, sourceMailFolder, folderSystem)
return { processingDone: downcast(predictPromise) }
}
}
}
} else if (isUpdateForTypeRef(MailTypeRef, update) && update.operation === OperationType.DELETE) {
const mailId: IdTuple = [update.instanceListId, update.instanceId]
await this.spamHandler().dropClassificationData(mailId)
}
}
return { processingDone: Promise.resolve() }
}
public async loadMail(mailId: IdTuple): Promise<Nullable<Mail>> {
return await this.entityClient.load(MailTypeRef, mailId).catch((e) => {
if (isExpectedErrorForSynchronization(e)) {
console.log(`Could not find mail ${JSON.stringify(mailId)}`)
return null
}
throw e throw e
} })
}
}
}
}
} }
async applyInboxRuleToMail(mail: Mail) { async applyInboxRuleToMail(mail: Mail) {
@ -236,7 +292,7 @@ export class MailModel {
if (inboxRuleHandler) { if (inboxRuleHandler) {
const mailboxDetail = await this.getMailboxDetailsForMail(mail) const mailboxDetail = await this.getMailboxDetailsForMail(mail)
if (mailboxDetail) { if (mailboxDetail) {
inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true, true) return inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, mail, true)
} }
} }
} }
@ -370,10 +426,6 @@ export class MailModel {
return await this.mailFacade.moveMails(mails, targetFolder._id, excludeFolder) return await this.mailFacade.moveMails(mails, targetFolder._id, excludeFolder)
} }
async trashMails(mails: readonly IdTuple[]): Promise<MovedMails[]> {
return await this.mailFacade.simpleMoveMails(mails, MailSetKind.TRASH, null)
}
/** /**
* Finally deletes all given mails. Caller must ensure that all mails are in folders that allows final delete operation. * Finally deletes all given mails. Caller must ensure that all mails are in folders that allows final delete operation.
* @param mailIds mails to delete * @param mailIds mails to delete
@ -436,7 +488,7 @@ export class MailModel {
async applyLabels(mails: readonly IdTuple[], addedLabels: readonly MailFolder[], removedLabels: readonly MailFolder[]): Promise<void> { async applyLabels(mails: readonly IdTuple[], addedLabels: readonly MailFolder[], removedLabels: readonly MailFolder[]): Promise<void> {
const groupedByListIds = groupBy(mails, (mailId) => listIdPart(mailId)) const groupedByListIds = groupBy(mails, (mailId) => listIdPart(mailId))
for (const [_, groupedMails] of groupedByListIds) { for (const [_, groupedMails] of groupedByListIds) {
const mailChunks = splitInChunks(MAX_NBR_MOVE_DELETE_MAIL_SERVICE, groupedMails) const mailChunks = splitInChunks(MAX_NBR_OF_MAILS_SYNC_OPERATION, groupedMails)
for (const mailChunk of mailChunks) { for (const mailChunk of mailChunks) {
await this.mailFacade.applyLabels(mailChunk, addedLabels, removedLabels) await this.mailFacade.applyLabels(mailChunk, addedLabels, removedLabels)
} }

View file

@ -290,7 +290,7 @@ export async function applyInboxRulesToEntries(
return entries return entries
} }
return await promiseFilter(entries, async (entry) => { return await promiseFilter(entries, async (entry) => {
const ruleApplied = await inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, entry.mail, true, false) const ruleApplied = await inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, entry.mail, true)
return ruleApplied == null return ruleApplied == null
}) })
} }

View file

@ -151,10 +151,11 @@ import { IdentityKeyCreator } from "../common/api/worker/facades/lazy/IdentityKe
import { PublicIdentityKeyProvider } from "../common/api/worker/facades/PublicIdentityKeyProvider" import { PublicIdentityKeyProvider } from "../common/api/worker/facades/PublicIdentityKeyProvider"
import { WhitelabelThemeGenerator } from "../common/gui/WhitelabelThemeGenerator" import { WhitelabelThemeGenerator } from "../common/gui/WhitelabelThemeGenerator"
import { UndoModel } from "./UndoModel" import { UndoModel } from "./UndoModel"
import { SpamClassifier } from "./workerUtils/spamClassification/SpamClassifier"
import { GroupSettingsModel } from "../common/sharing/model/GroupSettingsModel" import { GroupSettingsModel } from "../common/sharing/model/GroupSettingsModel"
import { AutosaveFacade } from "../common/api/worker/facades/lazy/AutosaveFacade" import { AutosaveFacade } from "../common/api/worker/facades/lazy/AutosaveFacade"
import { lang } from "../common/misc/LanguageViewModel.js" import { lang } from "../common/misc/LanguageViewModel.js"
import { SpamClassificationHandler } from "./workerUtils/spamClassification/SpamClassificationHandler"
import { SpamClassifier } from "./workerUtils/spamClassification/SpamClassifier"
assertMainOrNode() assertMainOrNode()
@ -300,6 +301,10 @@ class MailLocator implements CommonLocator {
return new InboxRuleHandler(this.mailFacade, this.logins, this.mailModel) return new InboxRuleHandler(this.mailFacade, this.logins, this.mailModel)
}) })
readonly spamClassificationHandler = lazyMemoized(() => {
return new SpamClassificationHandler(this.mailFacade, this.spamClassifier)
})
async searchViewModelFactory(): Promise<() => SearchViewModel> { async searchViewModelFactory(): Promise<() => SearchViewModel> {
const { SearchViewModel } = await import("../mail-app/search/view/SearchViewModel.js") const { SearchViewModel } = await import("../mail-app/search/view/SearchViewModel.js")
const conversationViewModelFactory = await this.conversationViewModelFactory() const conversationViewModelFactory = await this.conversationViewModelFactory()
@ -836,6 +841,7 @@ class MailLocator implements CommonLocator {
this.logins, this.logins,
this.mailFacade, this.mailFacade,
this.connectivityModel, this.connectivityModel,
this.spamClassificationHandler,
this.inboxRuleHandler, this.inboxRuleHandler,
) )
this.operationProgressTracker = new OperationProgressTracker() this.operationProgressTracker = new OperationProgressTracker()
@ -886,6 +892,7 @@ class MailLocator implements CommonLocator {
const openSettingsHandler = new OpenSettingsHandler(this.logins) const openSettingsHandler = new OpenSettingsHandler(this.logins)
this.webMobileFacade = new WebMobileFacade(this.connectivityModel, MAIL_PREFIX) this.webMobileFacade = new WebMobileFacade(this.connectivityModel, MAIL_PREFIX)
this.spamClassifier = spamClassifier
this.nativeInterfaces = createNativeInterfaces( this.nativeInterfaces = createNativeInterfaces(
this.webMobileFacade, this.webMobileFacade,
@ -1032,7 +1039,6 @@ class MailLocator implements CommonLocator {
if (selectedThemeFacade instanceof WebThemeFacade) { if (selectedThemeFacade instanceof WebThemeFacade) {
selectedThemeFacade.addDarkListener(() => mailLocator.themeController.reloadTheme()) selectedThemeFacade.addDarkListener(() => mailLocator.themeController.reloadTheme())
} }
this.spamClassifier = spamClassifier
} }
readonly calendarModel: () => Promise<CalendarModel> = lazyMemoized(async () => { readonly calendarModel: () => Promise<CalendarModel> = lazyMemoized(async () => {

View file

@ -21,11 +21,11 @@ export class MailIndexAndSpamClassificationPostLoginAction implements PostLoginA
private readonly customerFacade: CustomerFacade, private readonly customerFacade: CustomerFacade,
) {} ) {}
async onPartialLoginSuccess(event: LoggedInEvent): Promise<void> { async onPartialLoginSuccess(event: LoggedInEvent): Promise<{ asyncAction: Promise<void> }> {
if (event.sessionType === SessionType.Persistent) { if (event.sessionType === SessionType.Persistent) {
await this.offlineStorageSettings.init() await this.offlineStorageSettings.init()
// noinspection ES6MissingAwait // noinspection ES6MissingAwait
this.indexer.resizeMailIndex(this.offlineStorageSettings.getTimeRange().getTime()).then(async () => { const resizeMailIndex = this.indexer.resizeMailIndex(this.offlineStorageSettings.getTimeRange().getTime()).then(async () => {
// spamClassification // spamClassification
// Wait until indexing is done, as its populate offlineDb // Wait until indexing is done, as its populate offlineDb
@ -40,7 +40,9 @@ export class MailIndexAndSpamClassificationPostLoginAction implements PostLoginA
} }
} }
}) })
return { asyncAction: resizeMailIndex }
} }
return { asyncAction: Promise.resolve() }
} }
async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {} async onFullLoginSuccess(_: LoggedInEvent): Promise<void> {}

View file

@ -642,7 +642,7 @@ export class IndexedDbIndexer implements Indexer {
await this.mailIndexer.afterMailUpdated(mailId) await this.mailIndexer.afterMailUpdated(mailId)
break break
case OperationType.CREATE: case OperationType.CREATE:
await this.mailIndexer.afterMailCreated(mailId, await this.mailIndexer.downloadNewMailData(mailId)) await this.mailIndexer.afterMailCreated(mailId)
break break
} }
} catch (e) { } catch (e) {

View file

@ -606,14 +606,17 @@ export class MailIndexer {
* @throws NotAuthorizedError if the mail cannot be accessed (and has not been cached) * @throws NotAuthorizedError if the mail cannot be accessed (and has not been cached)
* @throws NotFoundError if the mail no longer exists (and has not been cached) * @throws NotFoundError if the mail no longer exists (and has not been cached)
*/ */
async afterMailCreated(mailId: IdTuple, newMailData: MailWithDetailsAndAttachments | null) { async afterMailCreated(mailId: IdTuple) {
await this.initialized.promise await this.initialized.promise
if (!this._mailIndexingEnabled) return if (!this._mailIndexingEnabled) return
const mail = newMailData?.mail const newMail = await this.entityClient.load(MailTypeRef, mailId)
if (mail == null || !this.canIndexMail(mail)) { if (!this.canIndexMail(newMail)) {
return return
} }
// At this point, the mail entity, itself, is cached, so when we go to download it again, it will come from cache
const newMailData = await this.downloadNewMailData(mailId)
if (newMailData) { if (newMailData) {
await this.backend.onMailCreated(newMailData) await this.backend.onMailCreated(newMailData)
} }

View file

@ -70,7 +70,9 @@ export const SpamClassificationDefinitions: Record<string, OfflineStorageTable>
// Spam classification training data // Spam classification training data
spam_classification_training_data: { spam_classification_training_data: {
definition: definition:
"CREATE TABLE IF NOT EXISTS spam_classification_training_data (listId TEXT NOT NULL, elementId TEXT NOT NULL, ownerGroup TEXT NOT NULL, subject TEXT NOT NULL, body TEXT NOT NULL, isSpam NUMBER, lastModified NUMBER NOT NULL, isSpamConfidence NUMBER NOT NULL, PRIMARY KEY (listId, elementId))", "CREATE TABLE IF NOT EXISTS spam_classification_training_data (listId TEXT NOT NULL, elementId TEXT NOT NULL," +
" ownerGroup TEXT NOT NULL, subject TEXT NOT NULL, body TEXT NOT NULL, isSpam NUMBER, " +
"lastModified NUMBER NOT NULL, isSpamConfidence NUMBER NOT NULL, PRIMARY KEY (listId, elementId))",
purgedWithCache: true, purgedWithCache: true,
}, },
@ -200,24 +202,43 @@ export class OfflineStoragePersistence {
await this.sqlCipherFacade.run(query, params) await this.sqlCipherFacade.run(query, params)
} }
async updateSpamClassificationData(id: IdTuple, isSpam: boolean, isSpamConfidence: number): Promise<void> { async deleteSpamClassification(mailId: IdTuple): Promise<void> {
const mailListId = listIdPart(mailId)
const mailElementId = elementIdPart(mailId)
const { query, params } = sql`
DELETE
FROM spam_classification_training_data
where listId = ${mailListId}
AND elementId = ${mailElementId}`
await this.sqlCipherFacade.run(query, params)
}
async deleteSpamClassificationTrainingDataBeforeCutoff(cutoffTimestamp: number, ownerGroupId: Id): Promise<void> {
const { query, params } = sql`DELETE
FROM spam_classification_training_data
WHERE lastModified < ${cutoffTimestamp}
AND ownerGroup = ${ownerGroupId}`
await this.sqlCipherFacade.run(query, params)
}
async updateSpamClassification(mailId: IdTuple, isSpam: boolean, isSpamConfidence: number): Promise<void> {
const { query, params } = sql` const { query, params } = sql`
UPDATE spam_classification_training_data UPDATE spam_classification_training_data
SET lastModified=${Date.now()}, SET lastModified=${Date.now()},
isSpamConfidence=${isSpamConfidence}, isSpamConfidence=${isSpamConfidence},
isSpam=${isSpam ? 1 : 0} isSpam=${isSpam ? 1 : 0}
WHERE listId = ${listIdPart(id)} WHERE listId = ${listIdPart(mailId)}
AND elementId = ${elementIdPart(id)} AND elementId = ${elementIdPart(mailId)}
` `
await this.sqlCipherFacade.run(query, params) await this.sqlCipherFacade.run(query, params)
} }
async getStoredClassification(mail: Mail): Promise<Nullable<{ isSpam: boolean; isSpamConfidence: number }>> { async getSpamClassification(mailId: IdTuple): Promise<Nullable<{ isSpam: boolean; isSpamConfidence: number }>> {
const { query, params } = sql` const { query, params } = sql`
SELECT isSpam, isSpamConfidence SELECT isSpam, isSpamConfidence
FROM spam_classification_training_data FROM spam_classification_training_data
where listId = ${listIdPart(mail._id)} where listId = ${listIdPart(mailId)}
AND elementId = ${elementIdPart(mail._id)} ` AND elementId = ${elementIdPart(mailId)} `
const result = await this.sqlCipherFacade.get(query, params) const result = await this.sqlCipherFacade.get(query, params)
if (!result) { if (!result) {
return null return null
@ -369,10 +390,6 @@ export class OfflineStoragePersistence {
} }
return untagSqlObject(rowIdResult).rowid return untagSqlObject(rowIdResult).rowid
} }
public async tokenize(text: string): Promise<ReadonlyArray<string>> {
return this.sqlCipherFacade.tokenize(text)
}
} }
function serializeMailAddresses(recipients: readonly MailAddress[]): string { function serializeMailAddresses(recipients: readonly MailAddress[]): string {

View file

@ -0,0 +1,147 @@
import { createMoveMailData, Mail, MailDetails, MailFolder, MoveMailData } from "../../../common/api/entities/tutanota/TypeRefs"
import { MailSetKind, ProcessingState, SpamDecision } from "../../../common/api/common/TutanotaConstants"
import { SpamClassifier, SpamPredMailDatum, SpamTrainMailDatum } from "./SpamClassifier"
import { getMailBodyText } from "../../../common/api/common/CommonMailUtils"
import { assertNotNull, debounce, isNotNull, Nullable, ofClass } from "@tutao/tutanota-utils"
import { MailFacade } from "../../../common/api/worker/facades/lazy/MailFacade"
import { ClientClassifierType } from "../../../common/api/common/ClientClassifierType"
import { FolderSystem } from "../../../common/api/common/mail/FolderSystem"
import { LockedError, PreconditionFailedError } from "../../../common/api/common/error/RestError"
const DEBOUNCE_MOVE_MAIL_SERVICE_REQUESTS_MS = 500
const DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS = 1000
const DEFAULT_IS_SPAM_CONFIDENCE = 1
const DEFAULT_IS_SPAM = false
export class SpamClassificationHandler {
public constructor(
private readonly mailFacade: MailFacade,
private readonly spamClassifier: Nullable<SpamClassifier>,
) {}
hamMoveMailData: MoveMailData | null = null
spamMoveMailData: MoveMailData | null = null
classifierResultServiceMailIds: IdTuple[] = []
sendClassifierResultServiceRequest = debounce(DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => {
// Each update to ClientClassifierResultService (for mails that did not move) requires one request
// We debounce the requests to a rate of DEBOUNCE_CLIENT_CLASSIFIER_RESULT_SERVICE_REQUESTS_MS
if (this.classifierResultServiceMailIds.length) {
const mailIds = this.classifierResultServiceMailIds
this.classifierResultServiceMailIds = []
return mailFacade.updateMailPredictionState(mailIds, ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE)
}
})
sendMoveMailServiceRequest = debounce(DEBOUNCE_MOVE_MAIL_SERVICE_REQUESTS_MS, async (mailFacade: MailFacade) => {
// Each update to MoveMailService (for ham or spam mails that did move) requires one request
// We debounce the requests to a rate of DEBOUNCE_MOVE_MAIL_SERVICE_REQUESTS_MS
if (this.hamMoveMailData) {
const moveMailData = this.hamMoveMailData
this.hamMoveMailData = null
await this.sendMoveMailRequest(mailFacade, moveMailData)
}
if (this.spamMoveMailData) {
const moveMailData = this.spamMoveMailData
this.spamMoveMailData = null
await this.sendMoveMailRequest(mailFacade, moveMailData)
}
})
async sendMoveMailRequest(mailFacade: MailFacade, moveMailData: MoveMailData): Promise<void> {
mailFacade
.moveMails(moveMailData.mails, moveMailData.targetFolder, null, ClientClassifierType.CLIENT_CLASSIFICATION)
.catch(
ofClass(LockedError, (e) => {
// LockedError should no longer be thrown!?!
console.log("moving mails failed", e, moveMailData.targetFolder)
}),
)
.catch(
ofClass(PreconditionFailedError, (e) => {
// move mail operation may have been locked by other process
console.log("moving mails failed", e, moveMailData.targetFolder)
}),
)
}
public async predictSpamForNewMail(mail: Mail, mailDetails: MailDetails, sourceFolder: MailFolder, folderSystem: FolderSystem): Promise<MailFolder> {
const spamPredMailDatum: SpamPredMailDatum = {
subject: mail.subject,
body: getMailBodyText(mailDetails.body),
ownerGroup: assertNotNull(mail._ownerGroup),
}
const isSpam = (await this.spamClassifier?.predict(spamPredMailDatum)) ?? null
if (isSpam && sourceFolder.folderType === MailSetKind.INBOX) {
const spamFolder = assertNotNull(folderSystem.getSystemFolderByType(MailSetKind.SPAM))
if (this.spamMoveMailData) {
this.spamMoveMailData.mails.push(mail._id)
} else {
this.spamMoveMailData = createMoveMailData({
targetFolder: spamFolder?._id,
mails: [mail._id],
excludeMailSet: null,
moveReason: ClientClassifierType.CLIENT_CLASSIFICATION,
})
}
await this.sendMoveMailServiceRequest(this.mailFacade)
return spamFolder
} else if (!isSpam && sourceFolder.folderType === MailSetKind.SPAM) {
const hamFolder = assertNotNull(folderSystem.getSystemFolderByType(MailSetKind.INBOX))
if (this.hamMoveMailData) {
this.hamMoveMailData.mails.push(mail._id)
} else {
this.hamMoveMailData = createMoveMailData({
targetFolder: hamFolder?._id,
mails: [mail._id],
excludeMailSet: null,
moveReason: ClientClassifierType.CLIENT_CLASSIFICATION,
})
}
await this.sendMoveMailServiceRequest(this.mailFacade)
return hamFolder
} else if (mail.processingState !== ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE) {
this.classifierResultServiceMailIds.push(mail._id)
await this.sendClassifierResultServiceRequest(this.mailFacade)
return sourceFolder
} else {
return sourceFolder
}
}
public async updateSpamClassificationData(mail: Mail) {
if (this.spamClassifier == null || mail.clientSpamClassifierResult == null) {
return
}
const storedClassification = await this.spamClassifier.getSpamClassification(mail._id)
const isSpam = mail.clientSpamClassifierResult.spamDecision === SpamDecision.BLACKLIST
const isSpamConfidence = getSpamConfidence(mail)
if (isNotNull(storedClassification) && (isSpam !== storedClassification.isSpam || isSpamConfidence !== storedClassification.isSpamConfidence)) {
// the model has trained on the mail but the spamFlag was wrong so we refit with higher isSpamConfidence
await this.spamClassifier.updateSpamClassification(mail._id, isSpam, isSpamConfidence)
}
}
public async dropClassificationData(mailId: IdTuple) {
await this.spamClassifier?.deleteSpamClassification(mailId)
}
public async storeTrainingDatum(mail: Mail, mailDetails: MailDetails) {
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(mailDetails.body),
isSpam: DEFAULT_IS_SPAM,
isSpamConfidence: DEFAULT_IS_SPAM_CONFIDENCE,
ownerGroup: assertNotNull(mail._ownerGroup),
}
await this.spamClassifier?.storeSpamClassification(spamTrainMailDatum)
}
}
export function getSpamConfidence(mail: Mail): number {
return Number(mail.clientSpamClassifierResult?.confidence ?? DEFAULT_IS_SPAM_CONFIDENCE)
}

View file

@ -1,5 +1,4 @@
import { EntityClient } from "../../../common/api/common/EntityClient" import { EntityClient } from "../../../common/api/common/EntityClient"
import { UserFacade } from "../../../common/api/worker/facades/UserFacade"
import { assertNotNull, isNotNull, lazyAsync } from "@tutao/tutanota-utils" import { assertNotNull, isNotNull, lazyAsync } from "@tutao/tutanota-utils"
import { MailBag, MailboxGroupRootTypeRef, MailBoxTypeRef, MailFolder, MailFolderTypeRef, MailTypeRef } from "../../../common/api/entities/tutanota/TypeRefs" import { MailBag, MailboxGroupRootTypeRef, MailBoxTypeRef, MailFolder, MailFolderTypeRef, MailTypeRef } from "../../../common/api/entities/tutanota/TypeRefs"
import { getMailSetKind, MailSetKind } from "../../../common/api/common/TutanotaConstants" import { getMailSetKind, MailSetKind } from "../../../common/api/common/TutanotaConstants"
@ -9,6 +8,7 @@ import { getMailBodyText } from "../../../common/api/common/CommonMailUtils"
import { BulkMailLoader, MailWithMailDetails } from "../index/BulkMailLoader" import { BulkMailLoader, MailWithMailDetails } from "../index/BulkMailLoader"
import { hasError } from "../../../common/api/common/utils/ErrorUtils" import { hasError } from "../../../common/api/common/utils/ErrorUtils"
import { SpamTrainMailDatum } from "./SpamClassifier" import { SpamTrainMailDatum } from "./SpamClassifier"
import { getSpamConfidence } from "./SpamClassificationHandler"
const INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS = 28 const INITIAL_SPAM_CLASSIFICATION_INDEX_INTERVAL_DAYS = 28
@ -22,7 +22,6 @@ export class SpamClassificationInitializer {
constructor( constructor(
private readonly entityClient: EntityClient, private readonly entityClient: EntityClient,
private readonly userFacade: UserFacade,
private readonly offlineStorage: OfflineStoragePersistence, private readonly offlineStorage: OfflineStoragePersistence,
private readonly bulkMailLoader: lazyAsync<BulkMailLoader>, private readonly bulkMailLoader: lazyAsync<BulkMailLoader>,
) {} ) {}
@ -91,16 +90,14 @@ export class SpamClassificationInitializer {
private mailWithDetailsToMailDatum(spamFolder: MailFolder, inboxFolder: MailFolder, { mail, mailDetails }: MailWithMailDetails): SpamTrainMailDatum { private mailWithDetailsToMailDatum(spamFolder: MailFolder, inboxFolder: MailFolder, { mail, mailDetails }: MailWithMailDetails): SpamTrainMailDatum {
const isSpam = mail.sets.some((folderId) => isSameId(folderId, spamFolder._id)) const isSpam = mail.sets.some((folderId) => isSameId(folderId, spamFolder._id))
const isCertain = !mail.unread || !mail.sets.some((folderId) => isSameId(folderId, inboxFolder._id))
return { return {
mailId: mail._id, mailId: mail._id,
subject: mail.subject, subject: mail.subject,
body: getMailBodyText(mailDetails.body), body: getMailBodyText(mailDetails.body),
isSpam: isSpam, isSpam: isSpam,
isSpamConfidence: isCertain ? 1 : 0, isSpamConfidence: getSpamConfidence(mail),
listId: listIdPart(mail._id), listId: listIdPart(mail._id),
elementId: elementIdPart(mail._id), elementId: elementIdPart(mail._id),
// todo: when owner group is null?
ownerGroup: assertNotNull(mail._ownerGroup), ownerGroup: assertNotNull(mail._ownerGroup),
} as SpamTrainMailDatum } as SpamTrainMailDatum
} }

View file

@ -1,5 +1,5 @@
import { assertWorkerOrNode } from "../../../common/api/common/Env" import { assertWorkerOrNode } from "../../../common/api/common/Env"
import { assertNotNull, defer, groupByAndMap, isNotNull, Nullable, promiseMap } from "@tutao/tutanota-utils" import { assertNotNull, defer, groupByAndMap, isNotNull, Nullable, promiseMap, tokenize } from "@tutao/tutanota-utils"
import { DynamicTfVectorizer } from "./DynamicTfVectorizer" import { DynamicTfVectorizer } from "./DynamicTfVectorizer"
import { HashingVectorizer } from "./HashingVectorizer" import { HashingVectorizer } from "./HashingVectorizer"
import { import {
@ -22,11 +22,9 @@ import {
} from "./PreprocessPatterns" } from "./PreprocessPatterns"
import { SpamClassificationInitializer } from "./SpamClassificationInitializer" import { SpamClassificationInitializer } from "./SpamClassificationInitializer"
import { CacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache" import { CacheStorage } from "../../../common/api/worker/rest/DefaultEntityRestCache"
import { OfflineStoragePersistence } from "../index/OfflineStoragePersistence" import { htmlToText } from "../../../common/api/common/utils/IndexUtils"
import { filterMailMemberships, htmlToText } from "../../../common/api/common/utils/IndexUtils"
import { import {
dense, dense,
dropout,
fromMemory, fromMemory,
glorotUniform, glorotUniform,
LayersModel, LayersModel,
@ -39,6 +37,7 @@ import {
import type { Tensor } from "@tensorflow/tfjs-core" import type { Tensor } from "@tensorflow/tfjs-core"
import type { ModelArtifacts } from "@tensorflow/tfjs-core/dist/io/types" import type { ModelArtifacts } from "@tensorflow/tfjs-core/dist/io/types"
import type { ModelFitArgs } from "@tensorflow/tfjs-layers" import type { ModelFitArgs } from "@tensorflow/tfjs-layers"
import { OfflineStoragePersistence } from "../index/OfflineStoragePersistence"
assertWorkerOrNode() assertWorkerOrNode()
@ -64,7 +63,7 @@ export type SpamPredMailDatum = {
ownerGroup: Id ownerGroup: Id
} }
const PREDICTION_THRESHOLD = 0.5 const PREDICTION_THRESHOLD = 0.55
export type PreprocessConfiguration = { export type PreprocessConfiguration = {
isPreprocessMails: boolean isPreprocessMails: boolean
@ -92,18 +91,21 @@ export const DEFAULT_PREPROCESS_CONFIGURATION: PreprocessConfiguration = {
isRemoveSpaceBeforeNewLine: true, isRemoveSpaceBeforeNewLine: true,
} }
const TRAINING_INTERVAL = 1000 * 60 * 10 const TRAINING_INTERVAL = 1000 * 60 * 10 // 10 minutes
const FULL_RETRAINING_INTERVAL = 1000 * 60 * 60 * 24 * 7 // 1 week
type TrainingPerformance = { type TrainingPerformance = {
trainingTime: number trainingTime: number
vectorizationTime: number vectorizationTime: number
} }
export const spamClassifierTokenizer = (text: string): string[] => tokenize(text)
export class SpamClassifier { export class SpamClassifier {
private readonly classifier: Map<Id, { model: LayersModel; isEnabled: boolean }> private readonly classifier: Map<Id, { model: LayersModel; isEnabled: boolean }>
constructor( constructor(
private readonly offlineStorage: OfflineStoragePersistence | null, private readonly offlineStorage: OfflineStoragePersistence,
private readonly offlineStorageCache: CacheStorage, private readonly offlineStorageCache: CacheStorage,
private readonly initializer: SpamClassificationInitializer, private readonly initializer: SpamClassificationInitializer,
private readonly deterministic: boolean = false, private readonly deterministic: boolean = false,
@ -113,18 +115,18 @@ export class SpamClassifier {
this.classifier = new Map() this.classifier = new Map()
} }
public getEnabledSpamClassifierForOwnerGroup(ownerGroup: Id): Nullable<LayersModel> {
const classifier = this.classifier.get(ownerGroup) ?? null
if (classifier && classifier.isEnabled) {
return classifier.model
}
return null
}
public async initialize(ownerGroup: Id): Promise<void> { public async initialize(ownerGroup: Id): Promise<void> {
const loadedModel = await this.loadModel(ownerGroup) const loadedModel = await this.loadModel(ownerGroup)
const storage = assertNotNull(this.offlineStorageCache) const storage = assertNotNull(this.offlineStorageCache)
setInterval(async () => {
const cutoffDate = Date.now() - FULL_RETRAINING_INTERVAL
const lastFullTrainingTime = await storage.getLastTrainedFromScratchTime()
if (cutoffDate > lastFullTrainingTime) {
await this.retrainModelFromScratch(storage, ownerGroup, cutoffDate)
}
}, FULL_RETRAINING_INTERVAL)
if (isNotNull(loadedModel)) { if (isNotNull(loadedModel)) {
console.log("Loaded existing spam classification model from database") console.log("Loaded existing spam classification model from database")
@ -138,14 +140,19 @@ export class SpamClassifier {
} }
console.log("No existing model found. Training from scratch...") console.log("No existing model found. Training from scratch...")
const data = await this.initializer.init(ownerGroup) await this.trainFromScratch(storage, ownerGroup)
await this.initialTraining(data)
await this.saveModel(ownerGroup)
setInterval(async () => { setInterval(async () => {
await this.updateAndSaveModel(storage, ownerGroup) await this.updateAndSaveModel(storage, ownerGroup)
}, TRAINING_INTERVAL) }, TRAINING_INTERVAL)
} }
private async trainFromScratch(storage: CacheStorage, ownerGroup: string) {
const data = await this.initializer.init(ownerGroup)
await this.initialTraining(data)
await this.saveModel(ownerGroup)
await storage.setLastTrainedFromScratchTime(Date.now())
}
// VisibleForTesting // VisibleForTesting
public async updateAndSaveModel(storage: CacheStorage, ownerGroup: Id) { public async updateAndSaveModel(storage: CacheStorage, ownerGroup: Id) {
const isModelUpdated = await this.updateModelFromCutoff(await storage.getLastTrainedTime(), ownerGroup) const isModelUpdated = await this.updateModelFromCutoff(await storage.getLastTrainedTime(), ownerGroup)
@ -216,9 +223,11 @@ export class SpamClassifier {
} }
public async initialTraining(mails: SpamTrainMailDatum[]): Promise<TrainingPerformance> { public async initialTraining(mails: SpamTrainMailDatum[]): Promise<TrainingPerformance> {
const vectorizationStart = performance.now() const preprocessingStart = performance.now()
const tokenizedMails = await promiseMap(mails, (mail) => spamClassifierTokenizer(this.preprocessMail(mail)))
const preprocessingTime = performance.now() - preprocessingStart
const tokenizedMails = await promiseMap(mails, (mail) => assertNotNull(this.offlineStorage).tokenize(this.preprocessMail(mail))) const vectorizationStart = performance.now()
if (this.vectorizer instanceof DynamicTfVectorizer) { if (this.vectorizer instanceof DynamicTfVectorizer) {
this.vectorizer.buildInitialTokenVocabulary(tokenizedMails) this.vectorizer.buildInitialTokenVocabulary(tokenizedMails)
} }
@ -237,11 +246,13 @@ export class SpamClassifier {
epochs: 16, epochs: 16,
batchSize: 32, batchSize: 32,
shuffle: !this.deterministic, shuffle: !this.deterministic,
callbacks: { // callbacks: {
onEpochEnd: async (epoch, logs) => { // onEpochEnd: async (epoch, logs) => {
console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`) // if (logs) {
}, // console.log(`Epoch ${epoch + 1} - Loss: ${logs.loss.toFixed(4)}`)
}, // }
// },
// },
}) })
const trainingTime = performance.now() - trainingStart const trainingTime = performance.now() - trainingStart
@ -251,7 +262,9 @@ export class SpamClassifier {
this.classifier.set(mails[0].ownerGroup, { model: classifier, isEnabled: true }) this.classifier.set(mails[0].ownerGroup, { model: classifier, isEnabled: true })
console.log(`### Finished Initial Training ### (total trained mails: ${mails.length})`) console.log(
`### Finished Initial Training ### (total trained mails: ${mails.length}, preprocessing time: ${preprocessingTime}, vectorization time: ${vectorizationTime}ms, training time: ${trainingTime})`,
)
return { vectorizationTime, trainingTime } return { vectorizationTime, trainingTime }
} }
@ -283,10 +296,9 @@ export class SpamClassifier {
const retrainingStart = performance.now() const retrainingStart = performance.now()
const modelToUpdate = assertNotNull(this.classifier.get(ownerGroup)) const modelToUpdate = assertNotNull(this.classifier.get(ownerGroup))
const offlineStorage = assertNotNull(this.offlineStorage)
const tokenizedMailsArray = await promiseMap(newTrainingMails, async (mail) => { const tokenizedMailsArray = await promiseMap(newTrainingMails, async (mail) => {
const preprocessedMail = this.preprocessMail(mail) const preprocessedMail = this.preprocessMail(mail)
const tokenizedMail = await offlineStorage.tokenize(preprocessedMail) const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
return { tokenizedMail, isSpamConfidence: mail.isSpamConfidence, isSpam: mail.isSpam ? 1 : 0 } return { tokenizedMail, isSpamConfidence: mail.isSpamConfidence, isSpam: mail.isSpam ? 1 : 0 }
}) })
@ -319,11 +331,11 @@ export class SpamClassifier {
epochs: 8, epochs: 8,
batchSize: 32, batchSize: 32,
shuffle: !this.deterministic, shuffle: !this.deterministic,
callbacks: { // callbacks: {
onEpochEnd: async (epoch, logs) => { // onEpochEnd: async (epoch, logs) => {
console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`) // console.log(`Epoch ${epoch + 1} - Loss: ${logs!.loss.toFixed(4)}`)
}, // },
}, // },
} }
for (let i = 0; i <= isSpamConfidence; i++) { for (let i = 0; i <= isSpamConfidence; i++) {
await modelToUpdate.model.fit(xs, ys, modelFitArgs) await modelToUpdate.model.fit(xs, ys, modelFitArgs)
@ -349,7 +361,7 @@ export class SpamClassifier {
} }
const preprocessedMail = this.preprocessMail(spamPredMailDatum) const preprocessedMail = this.preprocessMail(spamPredMailDatum)
const tokenizedMail = await assertNotNull(this.offlineStorage).tokenize(preprocessedMail) const tokenizedMail = spamClassifierTokenizer(preprocessedMail)
const vectors = await assertNotNull(this.vectorizer).transform([tokenizedMail]) const vectors = await assertNotNull(this.vectorizer).transform([tokenizedMail])
const xs = tensor2d(vectors, [vectors.length, assertNotNull(this.vectorizer).dimension], undefined) const xs = tensor2d(vectors, [vectors.length, assertNotNull(this.vectorizer).dimension], undefined)
@ -357,7 +369,7 @@ export class SpamClassifier {
const predictionData = await predictionTensor.data() const predictionData = await predictionTensor.data()
const prediction = predictionData[0] const prediction = predictionData[0]
console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${spamPredMailDatum.ownerGroup}`) // console.log(`predicted new mail to be with probability ${prediction.toFixed(2)} spam. Owner Group: ${spamPredMailDatum.ownerGroup}`)
// When using the webgl backend we need to manually dispose @tensorflow tensors // When using the webgl backend we need to manually dispose @tensorflow tensors
xs.dispose() xs.dispose()
@ -366,97 +378,20 @@ export class SpamClassifier {
return prediction > PREDICTION_THRESHOLD return prediction > PREDICTION_THRESHOLD
} }
/* public getSpamClassification(mailId: IdTuple) {
* TODO: Only for internal release return this.offlineStorage.getSpamClassification(mailId)
*
* Allows to check the accuracy of your currently trained classifier against the content of mailbox itself
* How-to:
* 1) Open console and switch context to worker-bootstrap.js
* 2) Execute this method in console: `locator.spamClassifier.getSpamMetricsForCurrentMailBox()`
* 3) Let execution continue from breakpoint
*
* Since we change constant of this.initializer,
* it's better to restart the client to not have unexpected effect
*/
public async getSpamMetricsForCurrentMailBox(ownerGroup?: Id): Promise<void> {
const { LocalTimeDateProvider } = await import("../../../common/api/worker/DateProvider.js")
const dateProvider = new LocalTimeDateProvider()
const getIdOfClassificationMail = (classificationData: any) => {
return ((classificationData.listId as Id) + "/" + classificationData.elementId) as Id
}
const user = assertNotNull((this.initializer as any).userFacade.getUser())
const firstOwnerGroup = ownerGroup ?? filterMailMemberships(user)[0]._id
console.log(`Testing with ownergroup: ${firstOwnerGroup}`)
const readingAllSpamStart = performance.now()
const trainedMails = await assertNotNull(this.offlineStorage)
.getCertainSpamClassificationTrainingDataAfterCutoff(0, firstOwnerGroup)
.then((mails) => new Set(mails.map(getIdOfClassificationMail)))
console.log(`Done reading ${trainedMails.size} certain training mail data in: ${performance.now() - readingAllSpamStart}ms`)
// since we train with last -28 days, we can test with last -90
;(this.initializer as any).TIME_LIMIT = dateProvider.getStartOfDayShiftedBy(-90)
// if exists, try to test with at 5xleast same number of mails as in training sample
;(this.initializer as any).MIN_MAILS_COUNT = trainedMails.size * 5
// to avoid putting stuff into offline storage
;(this.initializer as any).offlineStorage.storeSpamClassification = async () => {
console.log("not putting classification datum into offline storage")
} }
const downloadingExtraMailsStart = performance.now() public updateSpamClassification(mailId: IdTuple, isSpam: boolean, isSpamConfidence: number) {
const testingMails = (await this.initializer.init(firstOwnerGroup)) return this.offlineStorage.updateSpamClassification(mailId, isSpam, isSpamConfidence)
// do not test with the same mails that was used to train
.filter((classificationData) => !trainedMails.has(getIdOfClassificationMail(classificationData)))
console.log(`Done downloading extra ${testingMails.length} of last 90 days mail data in: ${performance.now() - downloadingExtraMailsStart}ms`)
const testingAllSamplesStart = performance.now()
await this.test(testingMails)
console.log(`Done testing all extra mails sample in: ${performance.now() - testingAllSamplesStart}ms`)
} }
public async test(mails: SpamTrainMailDatum[]): Promise<void> { public storeSpamClassification(spamTrainMailDatum: SpamTrainMailDatum) {
if (!this.classifier) { return this.offlineStorage.storeSpamClassification(spamTrainMailDatum)
throw new Error("Model has not been loaded")
} }
let predictionArray: number[] = [] public deleteSpamClassification(mailId: IdTuple) {
for (let mail of mails) { return this.offlineStorage.deleteSpamClassification(mailId)
const prediction = await this.predict(mail)
predictionArray.push(prediction ? 1 : 0)
}
const ysArray = mails.map((mail) => mail.isSpam)
let tp = 0,
tn = 0,
fp = 0,
fn = 0
for (let i = 0; i < predictionArray.length; i++) {
const predictedSpam = predictionArray[i] > 0.5
const isActuallyASpam = ysArray[i]
if (predictedSpam && isActuallyASpam) tp++
else if (!predictedSpam && !isActuallyASpam) tn++
else if (predictedSpam && !isActuallyASpam) fp++
else if (!predictedSpam && isActuallyASpam) fn++
}
const total = tp + tn + fp + fn
const accuracy = (tp + tn) / total
const precision = tp / (tp + fp + 1e-7)
const recall = tp / (tp + fn + 1e-7)
const f1 = 2 * ((precision * recall) / (precision + recall + 1e-7))
console.log("\n--- Evaluation Metrics ---")
console.log(`Accuracy: \t${(accuracy * 100).toFixed(2)}%`)
console.log(`Precision:\t${(precision * 100).toFixed(2)}%`)
console.log(`Recall: \t${(recall * 100).toFixed(2)}%`)
console.log(`F1 Score: \t${(f1 * 100).toFixed(2)}%`)
console.log("\nConfusion Matrix:")
console.log({
Predicted_Spam: { True_Positive: tp, False_Positive: fp },
Predicted_Ham: { False_Negative: fn, True_Negative: tn },
})
} }
// visibleForTesting // visibleForTesting
@ -538,7 +473,7 @@ export class SpamClassifier {
} }
} }
// VisibleForTesting // visibleForTesting
public async loadModel(ownerGroup: Id): Promise<Nullable<LayersModel>> { public async loadModel(ownerGroup: Id): Promise<Nullable<LayersModel>> {
const model = await assertNotNull(this.offlineStorage).getSpamClassificationModel(ownerGroup) const model = await assertNotNull(this.offlineStorage).getSpamClassificationModel(ownerGroup)
if (model) { if (model) {
@ -565,7 +500,19 @@ export class SpamClassifier {
return concatenated.length > 0 ? concatenated : " " return concatenated.length > 0 ? concatenated : " "
} }
// === Testing methods private async retrainModelFromScratch(storage: CacheStorage, ownerGroup: Id, cutoffTimestamp: number) {
console.log("Model is being re-trained from scratch, deleting old data")
try {
await assertNotNull(this.offlineStorage).deleteSpamClassificationTrainingDataBeforeCutoff(cutoffTimestamp, ownerGroup)
} catch (e) {
console.error("Failed delete old training data: ", e)
return
}
await this.trainFromScratch(storage, ownerGroup)
}
// visibleForTesting
public async cloneClassifier(): Promise<SpamClassifier> { public async cloneClassifier(): Promise<SpamClassifier> {
const newClassifier = new SpamClassifier( const newClassifier = new SpamClassifier(
this.offlineStorage, this.offlineStorage,
@ -587,4 +534,9 @@ export class SpamClassifier {
return newClassifier return newClassifier
} }
// visibleForTesting
public addSpamClassifierForOwner(ownerGroup: Id, model: LayersModel, isEnabled: boolean) {
this.classifier.set(ownerGroup, { model, isEnabled })
}
} }

View file

@ -20,6 +20,7 @@ import {
isIOSApp, isIOSApp,
isOfflineStorageAvailable, isOfflineStorageAvailable,
isTest, isTest,
isWebClient,
} from "../../../common/api/common/Env.js" } from "../../../common/api/common/Env.js"
import { Const } from "../../../common/api/common/TutanotaConstants.js" import { Const } from "../../../common/api/common/TutanotaConstants.js"
import type { BrowserData } from "../../../common/misc/ClientConstants.js" import type { BrowserData } from "../../../common/misc/ClientConstants.js"
@ -269,7 +270,7 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
return new IndexerCore(await db(), browserData) return new IndexerCore(await db(), browserData)
}) })
const mailIndexerAndMailFacade = lazyMemoized(async () => { const mailIndexer = lazyMemoized(async () => {
const { IndexedDbMailIndexerBackend } = await import("../index/IndexedDbMailIndexerBackend") const { IndexedDbMailIndexerBackend } = await import("../index/IndexedDbMailIndexerBackend")
const { OfflineStorageMailIndexerBackend } = await import("../index/OfflineStorageMailIndexerBackend") const { OfflineStorageMailIndexerBackend } = await import("../index/OfflineStorageMailIndexerBackend")
const { MailIndexer } = await import("../index/MailIndexer.js") const { MailIndexer } = await import("../index/MailIndexer.js")
@ -278,30 +279,24 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
const mailFacade = await locator.mail() const mailFacade = await locator.mail()
if (isOfflineStorageAvailable()) { if (isOfflineStorageAvailable()) {
const persistence = await offlineStorageIndexerPersistence() const persistence = await offlineStorageIndexerPersistence()
return { return new MailIndexer(
mailIndexer: new MailIndexer(
mainInterface.infoMessageHandler, mainInterface.infoMessageHandler,
bulkLoaderFactory, bulkLoaderFactory,
locator.cachingEntityClient, locator.cachingEntityClient,
dateProvider, dateProvider,
mailFacade, mailFacade,
() => new OfflineStorageMailIndexerBackend(persistence), () => new OfflineStorageMailIndexerBackend(persistence),
), )
mailFacade: mailFacade,
}
} else { } else {
const core = await indexerCore() const core = await indexerCore()
return { return new MailIndexer(
mailIndexer: new MailIndexer(
mainInterface.infoMessageHandler, mainInterface.infoMessageHandler,
bulkLoaderFactory, bulkLoaderFactory,
locator.cachingEntityClient, locator.cachingEntityClient,
dateProvider, dateProvider,
mailFacade, mailFacade,
(userId) => new IndexedDbMailIndexerBackend(core, userId, typeModelResolver), (userId) => new IndexedDbMailIndexerBackend(core, userId, typeModelResolver),
), )
mailFacade: mailFacade,
}
} }
}) })
@ -335,18 +330,12 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
if (isOfflineStorageAvailable() && !isAdminClient()) { if (isOfflineStorageAvailable() && !isAdminClient()) {
locator.sqlCipherFacade = new SqlCipherFacadeSendDispatcher(locator.native) locator.sqlCipherFacade = new SqlCipherFacadeSendDispatcher(locator.native)
offlineStorageProvider = async () => { offlineStorageProvider = async () => {
if (isDesktop()) {
const { SpamClassifier } = await import("../spamClassification/SpamClassifier") const { SpamClassifier } = await import("../spamClassification/SpamClassifier")
const { SpamClassificationInitializer } = await import("../spamClassification/SpamClassificationInitializer") const { SpamClassificationInitializer } = await import("../spamClassification/SpamClassificationInitializer")
const offlineStorage = await offlineStorageIndexerPersistence() const offlineStorage = await offlineStorageIndexerPersistence()
const spamClassifierInitializer = new SpamClassificationInitializer( const spamClassifierInitializer = new SpamClassificationInitializer(locator.cachingEntityClient, offlineStorage, locator.bulkMailLoader)
locator.cachingEntityClient,
locator.user,
offlineStorage,
locator.bulkMailLoader,
)
locator.spamClassifier = new SpamClassifier(offlineStorage, locator.cacheStorage, spamClassifierInitializer) locator.spamClassifier = new SpamClassifier(offlineStorage, locator.cacheStorage, spamClassifierInitializer)
}
const { KeyVerificationTableDefinitions } = await import("../../../common/api/worker/facades/IdentityKeyTrustDatabase.js") const { KeyVerificationTableDefinitions } = await import("../../../common/api/worker/facades/IdentityKeyTrustDatabase.js")
const { SearchTableDefinitions, SpamClassificationDefinitions } = await import("../index/OfflineStoragePersistence.js") const { SearchTableDefinitions, SpamClassificationDefinitions } = await import("../index/OfflineStoragePersistence.js")
const { AutosaveDraftsTableDefinitions } = await import("../../../common/api/worker/facades/lazy/OfflineStorageAutosaveFacade.js") const { AutosaveDraftsTableDefinitions } = await import("../../../common/api/worker/facades/lazy/OfflineStorageAutosaveFacade.js")
@ -358,7 +347,7 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
}, },
{ {
ref: MailTypeRef, ref: MailTypeRef,
handler: new CustomMailEventCacheHandler(mailIndexerAndMailFacade, offlineStorageIndexerPersistence, locator.cacheStorage), handler: new CustomMailEventCacheHandler(mailIndexer),
}, },
{ ref: UserTypeRef, handler: new CustomUserCacheHandler(locator.cacheStorage) }, { ref: UserTypeRef, handler: new CustomUserCacheHandler(locator.cacheStorage) },
) )
@ -452,12 +441,11 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
locator.indexer = lazyMemoized(async () => { locator.indexer = lazyMemoized(async () => {
const contact = await contactIndexer() const contact = await contactIndexer()
const { mailIndexer, mailFacade } = await mailIndexerAndMailFacade()
if (isOfflineStorageAvailable()) { if (isOfflineStorageAvailable()) {
const { OfflineStorageIndexer } = await import("../index/OfflineStorageIndexer.js") const { OfflineStorageIndexer } = await import("../index/OfflineStorageIndexer.js")
const persistence = await offlineStorageIndexerPersistence() const persistence = await offlineStorageIndexerPersistence()
return new OfflineStorageIndexer(locator.user, persistence, mailIndexer, mainInterface.infoMessageHandler, contact) return new OfflineStorageIndexer(locator.user, persistence, await mailIndexer(), mainInterface.infoMessageHandler, contact)
} else { } else {
const { IndexedDbIndexer } = await import("../index/IndexedDbIndexer.js") const { IndexedDbIndexer } = await import("../index/IndexedDbIndexer.js")
const core = await indexerCore() const core = await indexerCore()
@ -467,7 +455,7 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
core, core,
mainInterface.infoMessageHandler, mainInterface.infoMessageHandler,
locator.cachingEntityClient, locator.cachingEntityClient,
mailIndexer, await mailIndexer(),
contact, contact,
typeModelResolver, typeModelResolver,
locator.keyLoader, locator.keyLoader,
@ -670,16 +658,15 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
) )
locator.search = lazyMemoized(async () => { locator.search = lazyMemoized(async () => {
const { mailIndexer } = await mailIndexerAndMailFacade()
if (isOfflineStorageAvailable()) { if (isOfflineStorageAvailable()) {
const { OfflineStorageSearchFacade } = await import("../index/OfflineStorageSearchFacade.js") const { OfflineStorageSearchFacade } = await import("../index/OfflineStorageSearchFacade.js")
return new OfflineStorageSearchFacade(locator.sqlCipherFacade, mailIndexer, await contactIndexer()) return new OfflineStorageSearchFacade(locator.sqlCipherFacade, await mailIndexer(), await contactIndexer())
} else { } else {
const { IndexedDbSearchFacade } = await import("../index/IndexedDbSearchFacade.js") const { IndexedDbSearchFacade } = await import("../index/IndexedDbSearchFacade.js")
return new IndexedDbSearchFacade( return new IndexedDbSearchFacade(
locator.user, locator.user,
await db(), await db(),
mailIndexer, await mailIndexer(),
await contactSuggestionFacade(), await contactSuggestionFacade(),
browserData, browserData,
locator.cachingEntityClient, locator.cachingEntityClient,
@ -744,6 +731,7 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
}) })
locator.mail = lazyMemoized(async () => { locator.mail = lazyMemoized(async () => {
const { MailFacade } = await import("../../../common/api/worker/facades/lazy/MailFacade.js") const { MailFacade } = await import("../../../common/api/worker/facades/lazy/MailFacade.js")
return new MailFacade( return new MailFacade(
locator.user, locator.user,
locator.cachingEntityClient, locator.cachingEntityClient,
@ -754,7 +742,6 @@ export async function initLocator(worker: WorkerImpl, browserData: BrowserData)
locator.login, locator.login,
locator.keyLoader, locator.keyLoader,
locator.publicEncryptionKeyProvider, locator.publicEncryptionKeyProvider,
locator.spamClassifier,
) )
}) })
const nativePushFacade = new NativePushFacadeSendDispatcher(worker) const nativePushFacade = new NativePushFacadeSendDispatcher(worker)

View file

@ -68,7 +68,6 @@ import "./api/worker/rest/EntityRestClientTest.js"
import "./api/worker/rest/EphemeralCacheStorageTest.js" import "./api/worker/rest/EphemeralCacheStorageTest.js"
import "./api/worker/rest/PatchGeneratorTest.js" import "./api/worker/rest/PatchGeneratorTest.js"
import "./api/worker/rest/ServiceExecutorTest.js" import "./api/worker/rest/ServiceExecutorTest.js"
import "./api/worker/rest/cacheHandler/CustomMailEventCacheHandlerTest.js"
import "./api/worker/search/BulkMailLoaderTest.js" import "./api/worker/search/BulkMailLoaderTest.js"
import "./api/worker/search/ContactIndexerTest.js" import "./api/worker/search/ContactIndexerTest.js"
import "./api/worker/search/EventQueueTest.js" import "./api/worker/search/EventQueueTest.js"
@ -85,7 +84,6 @@ import "./api/worker/facades/KeyVerificationFacadeTest.js"
import "./api/worker/utils/SleepDetectorTest.js" import "./api/worker/utils/SleepDetectorTest.js"
import "./api/worker/utils/spamClassification/TfIdfVectorizerTest.js" import "./api/worker/utils/spamClassification/TfIdfVectorizerTest.js"
import "./api/worker/utils/spamClassification/HashingVectorizerTest.js" import "./api/worker/utils/spamClassification/HashingVectorizerTest.js"
import "./api/worker/utils/spamClassification/SpamClassifierTest.js"
import "./api/worker/utils/spamClassification/PreprocessPatternsTest.js" import "./api/worker/utils/spamClassification/PreprocessPatternsTest.js"
import "./calendar/AlarmSchedulerTest.js" import "./calendar/AlarmSchedulerTest.js"
import "./calendar/CalendarAgendaViewTest.js" import "./calendar/CalendarAgendaViewTest.js"
@ -182,6 +180,7 @@ import "./misc/parsing/ParserCombinatorTest.js"
import "./sharing/GroupSettingsModelTest.js" import "./sharing/GroupSettingsModelTest.js"
import "./mail/editor/OpenLocallySavedDraftActionTest.js" import "./mail/editor/OpenLocallySavedDraftActionTest.js"
import "./mail/search/MailIndexAndSpamClassificationPostLoginActionTest.js" import "./mail/search/MailIndexAndSpamClassificationPostLoginActionTest.js"
import "./mail/SpamClassificationHandlerTest.js"
import * as td from "testdouble" import * as td from "testdouble"
import { random } from "@tutao/tutanota-crypto" import { random } from "@tutao/tutanota-crypto"
@ -213,6 +212,7 @@ async function setupSuite({ integration }: { integration?: boolean }) {
if (typeof process !== "undefined") { if (typeof process !== "undefined") {
// setup the Entropy for all testcases // setup the Entropy for all testcases
await random.addEntropy([{ data: 36, entropy: 256, source: "key" }]) await random.addEntropy([{ data: 36, entropy: 256, source: "key" }])
await import("./api/worker/utils/spamClassification/SpamClassifierTest.js")
await import("./api/worker/offline/OfflineStorageMigratorTest.js") await import("./api/worker/offline/OfflineStorageMigratorTest.js")
await import("./api/worker/offline/OfflineStorageTest.js") await import("./api/worker/offline/OfflineStorageTest.js")
await import("./api/worker/rest/RestClientTest.js") await import("./api/worker/rest/RestClientTest.js")

View file

@ -10,6 +10,7 @@ import {
GroupType, GroupType,
PermissionType, PermissionType,
PresentableKeyVerificationState, PresentableKeyVerificationState,
ProcessingState,
PublicKeyIdentifierType, PublicKeyIdentifierType,
} from "../../../../../src/common/api/common/TutanotaConstants.js" } from "../../../../../src/common/api/common/TutanotaConstants.js"
import { import {
@ -1854,6 +1855,8 @@ o.spec("CryptoFacadeTest", function () {
mailDetailsDraft: null, mailDetailsDraft: null,
sets: [], sets: [],
keyVerificationState: null, keyVerificationState: null,
processingState: ProcessingState.INBOX_RULE_APPLIED,
clientSpamClassifierResult: null,
}) })
// casting here is fine, since we just want to mimic server response data // casting here is fine, since we just want to mimic server response data

View file

@ -18,7 +18,7 @@ import {
import { import {
CryptoProtocolVersion, CryptoProtocolVersion,
MailAuthenticationStatus, MailAuthenticationStatus,
MAX_NBR_MOVE_DELETE_MAIL_SERVICE, MAX_NBR_OF_MAILS_SYNC_OPERATION,
ReportedMailFieldType, ReportedMailFieldType,
} from "../../../../../src/common/api/common/TutanotaConstants.js" } from "../../../../../src/common/api/common/TutanotaConstants.js"
import { matchers, object, when } from "testdouble" import { matchers, object, when } from "testdouble"
@ -72,7 +72,6 @@ o.spec("MailFacade test", function () {
loginFacade = object() loginFacade = object()
keyLoaderFacade = object() keyLoaderFacade = object()
publicEncryptionKeyProvider = object() publicEncryptionKeyProvider = object()
spamClassifier = object()
facade = new MailFacade( facade = new MailFacade(
userFacade, userFacade,
entityClient, entityClient,
@ -83,7 +82,6 @@ o.spec("MailFacade test", function () {
loginFacade, loginFacade,
keyLoaderFacade, keyLoaderFacade,
publicEncryptionKeyProvider, publicEncryptionKeyProvider,
spamClassifier,
) )
}) })
@ -491,7 +489,7 @@ o.spec("MailFacade test", function () {
o.test("batches large amounts of mails", async () => { o.test("batches large amounts of mails", async () => {
const expectedBatches = 4 const expectedBatches = 4
const testIds: IdTuple[] = [] const testIds: IdTuple[] = []
for (let i = 0; i < MAX_NBR_MOVE_DELETE_MAIL_SERVICE * expectedBatches; i++) { for (let i = 0; i < MAX_NBR_OF_MAILS_SYNC_OPERATION * expectedBatches; i++) {
testIds.push([`${i}`, `${i}`]) testIds.push([`${i}`, `${i}`])
} }
await facade.markMails(testIds, true) await facade.markMails(testIds, true)
@ -500,7 +498,7 @@ o.spec("MailFacade test", function () {
serviceExecutor.post( serviceExecutor.post(
UnreadMailStateService, UnreadMailStateService,
matchers.contains({ matchers.contains({
mails: testIds.slice(i * MAX_NBR_MOVE_DELETE_MAIL_SERVICE, (i + 1) * MAX_NBR_MOVE_DELETE_MAIL_SERVICE), mails: testIds.slice(i * MAX_NBR_OF_MAILS_SYNC_OPERATION, (i + 1) * MAX_NBR_OF_MAILS_SYNC_OPERATION),
unread: true, unread: true,
}), }),
), ),

View file

@ -1,320 +0,0 @@
import o from "@tutao/otest"
import { func, matchers, object, verify, when } from "testdouble"
import { lazy, lazyAsync } from "@tutao/tutanota-utils"
import { MailIndexer } from "../../../../../../src/mail-app/workerUtils/index/MailIndexer"
import { MailFacade } from "../../../../../../src/common/api/worker/facades/lazy/MailFacade"
import { OfflineStoragePersistence } from "../../../../../../src/mail-app/workerUtils/index/OfflineStoragePersistence"
import { CacheStorage } from "../../../../../../src/common/api/worker/rest/DefaultEntityRestCache"
import { CustomMailEventCacheHandler } from "../../../../../../src/common/api/worker/rest/cacheHandler/CustomMailEventCacheHandler"
import { Body, Mail, MailDetails, MailFolderTypeRef, MailSetEntryTypeRef, MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
import { MailSetKind } from "../../../../../../src/common/api/common/TutanotaConstants"
import { ClientClassifierType } from "../../../../../../src/common/api/common/ClientClassifierType"
import { EntityUpdateData } from "../../../../../../src/common/api/common/utils/EntityUpdateUtils"
import { SpamTrainMailDatum } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { getMailBodyText } from "../../../../../../src/common/api/common/CommonMailUtils"
import { createTestEntity } from "../../../../TestUtils"
/**
* These tests should verify that the following are obeyed:
* - All Mails in Spam have isSpamConfidence of 1 (during create)
* - Moved Mails have isSpamConfidence of 1 (event update)
* - Read Mails have isSpamConfidence of 1 (event update)
* - Mails in Inbox have isSpamConfidence of 0.
*/
o.spec("CustomMailEventCacheHandler", function () {
let cacheStorageMock: CacheStorage
let offlineStorageMock: lazy<Promise<OfflineStoragePersistence>>
let indexerAndMailFacadeMock: lazyAsync<{ mailIndexer: MailIndexer; mailFacade: MailFacade }>
const inboxFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "inbox"], folderType: MailSetKind.INBOX })
const trashFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "trash"], folderType: MailSetKind.TRASH })
const spamFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "spam"], folderType: MailSetKind.SPAM })
const allFolders = [inboxFolder, trashFolder, spamFolder]
o.beforeEach(function () {
cacheStorageMock = object() as CacheStorage
offlineStorageMock = func() as lazy<Promise<OfflineStoragePersistence>>
indexerAndMailFacadeMock = func() as lazyAsync<{ mailIndexer: MailIndexer; mailFacade: MailFacade }>
when(cacheStorageMock.getWholeList(MailFolderTypeRef, matchers.anything())).thenResolve(allFolders)
})
o.spec("onEntityEventCreate", function () {
let mailIndexer = object() as MailIndexer
let mailFacade = object() as MailFacade
let body: Body
let mailDetails: MailDetails
let mail: Mail
o.beforeEach(function () {
when(indexerAndMailFacadeMock()).thenResolve({ mailIndexer, mailFacade })
body = object({ text: "Body Text" }) as Body
mailDetails = object({ body }) as MailDetails
mail = createTestEntity(MailTypeRef, {
_id: ["listId", "elementId"],
sets: [spamFolder._id],
subject: "subject",
_ownerGroup: "owner",
unread: false,
})
when(mailIndexer.downloadNewMailData(matchers.anything())).thenResolve({
mail,
mailDetails,
})
})
o("does not process spam e-mails when it fails to download new mail", async function () {
when(mailIndexer.downloadNewMailData(matchers.anything())).thenResolve(null)
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
verify(cacheStorageMock.getWholeList(MailFolderTypeRef, matchers.anything()), { times: 0 })
})
o("processSpam maintains server classification when client classification is not enabled", async function () {
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
when(mailFacade.predictSpamResult(mail)).thenResolve(null)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(body),
isSpam: true,
isSpamConfidence: 1,
ownerGroup: "owner",
}
verify(offlineStorage.storeSpamClassification(spamTrainMailDatum), { times: 1 })
})
o("processSpam uses client classification when enabled", async function () {
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
when(mailFacade.predictSpamResult(mail)).thenResolve(false)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(body),
isSpam: false,
isSpamConfidence: 0,
ownerGroup: "owner",
}
verify(offlineStorage.storeSpamClassification(spamTrainMailDatum), { times: 1 })
})
o("processSpam correctly verifies if email is stored in spam folder", async function () {
mail.sets = [spamFolder._id]
mail.unread = true
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
when(mailFacade.predictSpamResult(mail)).thenResolve(false)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(body),
isSpam: false,
ownerGroup: "owner",
isSpamConfidence: 0,
}
verify(offlineStorage.storeSpamClassification(spamTrainMailDatum), { times: 1 })
})
o("getSpamConfidence is 0 for mail in trash folder ", async function () {
mail.unread = false
mail.sets = [["listId", "trash"]]
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
o(cacheHandler.getSpamConfidence(allFolders, mail).confidence).equals(0)
})
o("getSpamConfidence is 1 for mail in spam folder ", async function () {
mail.unread = true
mail.sets = [spamFolder._id]
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
o(cacheHandler.getSpamConfidence(allFolders, mail).confidence).equals(1)
})
o("getSpamConfidence for inbox folder depends on read status", async function () {
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
mail.sets = [inboxFolder._id]
mail.unread = true
o(cacheHandler.getSpamConfidence(allFolders, mail).confidence).equals(0)
mail.unread = false
o(cacheHandler.getSpamConfidence(allFolders, mail).confidence).equals(1)
})
o("processSpam moves mail to spam when detected as such and its not already in spam", async function () {
mail.sets = [inboxFolder._id]
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
when(mailFacade.predictSpamResult(mail)).thenResolve(true)
when(mailFacade.isSpamClassificationEnabled("owner")).thenReturn(true)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(body),
isSpam: true,
isSpamConfidence: 1,
ownerGroup: "owner",
}
verify(offlineStorage.storeSpamClassification(spamTrainMailDatum), { times: 1 })
verify(mailFacade.simpleMoveMails([["listId", "elementId"]], MailSetKind.SPAM, ClientClassifierType.CLIENT_CLASSIFICATION))
})
o("processSpam moves mail to inbox when detected as such and its not already in inbox", async function () {
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
mail.sets = [spamFolder._id] // the mail is in spam folder
when(mailFacade.predictSpamResult(mail)).thenResolve(false)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(body),
isSpam: false,
isSpamConfidence: 0,
ownerGroup: "owner",
}
verify(offlineStorage.storeSpamClassification(spamTrainMailDatum), { times: 1 })
verify(mailFacade.simpleMoveMails([["listId", "elementId"]], MailSetKind.INBOX, ClientClassifierType.CLIENT_CLASSIFICATION))
})
})
o.spec("onEntityEventUpdate", function () {
let mailIndexer = object() as MailIndexer
let mailFacade = object() as MailFacade
let mail: Mail
let body: Body
let mailDetails: MailDetails
o.beforeEach(function () {
when(indexerAndMailFacadeMock()).thenResolve({ mailIndexer, mailFacade })
body = object({ text: "Body Text" }) as Body
mailDetails = object({ body }) as MailDetails
mail = createTestEntity(MailTypeRef, {
_id: ["listId", "elementId"],
subject: "subject",
sets: [inboxFolder._id],
_ownerGroup: "owner",
})
when(mailIndexer.downloadNewMailData(matchers.anything())).thenResolve({
mail,
mailDetails,
})
})
o("does nothing if mail has not been read and not moved or had label applied.", async function () {
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
mail.unread = true
when(cacheStorageMock.get(MailTypeRef, "listId", "elementId")).thenResolve(mail)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventUpdate(["listId", "elementId"], [])
verify(offlineStorage.updateSpamClassificationData(matchers.anything(), matchers.anything(), matchers.anything()), { times: 0 })
})
o("does nothing if we delete a mail from spam folder", async function () {
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorageMock()).thenResolve(offlineStorage)
when(cacheStorageMock.get(MailTypeRef, "listId", "elementId")).thenResolve(mail)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
mail.sets = [spamFolder._id]
await cacheHandler.onEntityEventCreate(["listId", "elementId"], [])
verify(offlineStorage.storeSpamClassification(matchers.anything()), { times: 1 })
mail.sets = [trashFolder._id]
await cacheHandler.onEntityEventUpdate(["listId", "elementId"], [])
verify(offlineStorage.updateSpamClassificationData(matchers.anything(), matchers.anything(), matchers.anything()), { times: 0 })
})
o("does update spam classification data if mail has been read in inbox and not moved", async function () {
mail.sets = [inboxFolder._id]
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorage.getStoredClassification(mail)).thenResolve({ isSpam: false, isSpamConfidence: 0 })
when(offlineStorageMock()).thenResolve(offlineStorage)
mail.unread = false
when(cacheStorageMock.get(MailTypeRef, "listId", "elementId")).thenResolve(mail)
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventUpdate(["listId", "elementId"], [])
verify(offlineStorage.updateSpamClassificationData(["listId", "elementId"], false, 1), { times: 1 })
verify(mailFacade.predictSpamResult(mail), { times: 0 })
})
o("does update spam classification data if mail has not been read but moved", async function () {
mail.sets = [spamFolder._id]
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorage.getStoredClassification(mail)).thenResolve({ isSpam: false, isSpamConfidence: 0 })
when(offlineStorageMock()).thenResolve(offlineStorage)
mail.unread = true
when(cacheStorageMock.get(MailTypeRef, "listId", "elementId")).thenResolve(mail)
const event = object({ typeRef: MailSetEntryTypeRef }) as unknown as EntityUpdateData
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventUpdate(["listId", "elementId"], [event])
verify(offlineStorage.updateSpamClassificationData(["listId", "elementId"], true, 1), { times: 1 })
})
o("does update spam classification data if mail was not previously included", async function () {
mail.sets = [inboxFolder._id]
const offlineStorage = object() as OfflineStoragePersistence
when(offlineStorage.getStoredClassification(mail)).thenResolve(null)
when(offlineStorageMock()).thenResolve(offlineStorage)
mail.unread = true
when(cacheStorageMock.get(MailTypeRef, "listId", "elementId")).thenResolve(mail)
const event = object({ typeRef: MailSetEntryTypeRef }) as unknown as EntityUpdateData
const cacheHandler = new CustomMailEventCacheHandler(indexerAndMailFacadeMock, offlineStorageMock, cacheStorageMock)
await cacheHandler.onEntityEventUpdate(["listId", "elementId"], [event])
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: getMailBodyText(body),
isSpam: false,
isSpamConfidence: 0,
ownerGroup: "owner",
}
verify(offlineStorage.storeSpamClassification(spamTrainMailDatum), { times: 1 })
})
})
})

View file

@ -220,7 +220,7 @@ o.spec("Index Utils", () => {
"this string has <i>html</i> code <!-- ignore comments-->i want to <b>remove</b><br>Link Number 1 -><a href='http://www.bbc.co.uk'>BBC</a> Link Number 1<br><p>Now back to normal text and stuff</p>" "this string has <i>html</i> code <!-- ignore comments-->i want to <b>remove</b><br>Link Number 1 -><a href='http://www.bbc.co.uk'>BBC</a> Link Number 1<br><p>Now back to normal text and stuff</p>"
let plain = "this string has html code i want to remove Link Number 1 -> BBC Link Number 1 Now back to normal text and stuff " let plain = "this string has html code i want to remove Link Number 1 -> BBC Link Number 1 Now back to normal text and stuff "
o(htmlToText(html)).equals(plain) o(htmlToText(html)).equals(plain)
o(htmlToText("<img src='>' >")).equals(" ' >") // TODO handle this case o(htmlToText("<img src='>' >")).equals(" ' >")
o(htmlToText("&nbsp;&amp;&lt;&gt;")).equals(" &<>") o(htmlToText("&nbsp;&amp;&lt;&gt;")).equals(" &<>")
o(htmlToText("&ouml;")).equals("ö") o(htmlToText("&ouml;")).equals("ö")

View file

@ -1186,9 +1186,9 @@ o.spec("IndexedDbIndexer", () => {
o.test("create", async () => { o.test("create", async () => {
await indexer._processEntityEvents(testBatch) await indexer._processEntityEvents(testBatch)
verify(mailIndexer.afterMailCreated(["create", "id-1"], matchers.anything())) verify(mailIndexer.afterMailCreated(["create", "id-1"]))
verify(mailIndexer.afterMailCreated(["create", "id-3"], matchers.anything())) verify(mailIndexer.afterMailCreated(["create", "id-3"]))
verify(mailIndexer.afterMailCreated(matchers.anything(), matchers.anything()), { times: 2 }) verify(mailIndexer.afterMailCreated(matchers.anything()), { times: 2 })
verify(core.writeGroupDataBatchId(testBatch.groupId, testBatch.batchId)) verify(core.writeGroupDataBatchId(testBatch.groupId, testBatch.batchId))
}) })
o.test("update", async () => { o.test("update", async () => {
@ -1207,13 +1207,13 @@ o.spec("IndexedDbIndexer", () => {
}) })
o.test("gracefully handles not found errors", async () => { o.test("gracefully handles not found errors", async () => {
when(mailIndexer.afterMailCreated(["create", "id-1"], matchers.anything())).thenReject(new NotFoundError("Not found :(")) when(mailIndexer.afterMailCreated(["create", "id-1"])).thenReject(new NotFoundError("Not found :("))
when(mailIndexer.afterMailCreated(["update", "id-4"], matchers.anything())).thenReject(new NotFoundError("Not found :(")) when(mailIndexer.afterMailCreated(["update", "id-4"])).thenReject(new NotFoundError("Not found :("))
await indexer._processEntityEvents(testBatch) await indexer._processEntityEvents(testBatch)
verify(mailIndexer.afterMailCreated(["create", "id-1"], matchers.anything())) verify(mailIndexer.afterMailCreated(["create", "id-1"]))
verify(mailIndexer.afterMailCreated(["create", "id-3"], matchers.anything())) verify(mailIndexer.afterMailCreated(["create", "id-3"]))
verify(mailIndexer.afterMailCreated(matchers.anything(), matchers.anything()), { times: 2 }) verify(mailIndexer.afterMailCreated(matchers.anything()), { times: 2 })
verify(mailIndexer.afterMailUpdated(["update", "id-4"])) verify(mailIndexer.afterMailUpdated(["update", "id-4"]))
verify(mailIndexer.afterMailUpdated(["update", "id-6"])) verify(mailIndexer.afterMailUpdated(["update", "id-6"]))
@ -1227,13 +1227,13 @@ o.spec("IndexedDbIndexer", () => {
}) })
o.test("gracefully handles not authorized errors", async () => { o.test("gracefully handles not authorized errors", async () => {
when(mailIndexer.afterMailCreated(["create", "id-1"], matchers.anything())).thenReject(new NotAuthorizedError("You shall not pass :(")) when(mailIndexer.afterMailCreated(["create", "id-1"])).thenReject(new NotAuthorizedError("You shall not pass :("))
when(mailIndexer.afterMailCreated(["update", "id-4"], matchers.anything())).thenReject(new NotAuthorizedError("You shall not pass :(")) when(mailIndexer.afterMailCreated(["update", "id-4"])).thenReject(new NotAuthorizedError("You shall not pass :("))
await indexer._processEntityEvents(testBatch) await indexer._processEntityEvents(testBatch)
verify(mailIndexer.afterMailCreated(["create", "id-1"], matchers.anything())) verify(mailIndexer.afterMailCreated(["create", "id-1"]))
verify(mailIndexer.afterMailCreated(["create", "id-3"], matchers.anything())) verify(mailIndexer.afterMailCreated(["create", "id-3"]))
verify(mailIndexer.afterMailCreated(matchers.anything(), matchers.anything()), { times: 2 }) verify(mailIndexer.afterMailCreated(matchers.anything()), { times: 2 })
verify(mailIndexer.afterMailUpdated(["update", "id-4"])) verify(mailIndexer.afterMailUpdated(["update", "id-4"]))
verify(mailIndexer.afterMailUpdated(["update", "id-6"])) verify(mailIndexer.afterMailUpdated(["update", "id-6"]))

View file

@ -446,21 +446,21 @@ o.spec("MailIndexer", () => {
o.spec("afterMailCreated", () => { o.spec("afterMailCreated", () => {
o.test("no-op if mailIndexing is disabled", async () => { o.test("no-op if mailIndexing is disabled", async () => {
await initWithEnabled(false) await initWithEnabled(false)
await indexer.afterMailCreated(mailIdTuple, null) await indexer.afterMailCreated(mailIdTuple)
verify(backend.onMailCreated(matchers.anything()), { times: 0 }) verify(backend.onMailCreated(matchers.anything()), { times: 0 })
}) })
o.test("no-op if new email is out of index range", async () => { o.test("no-op if new email is out of index range", async () => {
addEntities() addEntities()
setCurrentIndexTimestamp(now + 1) setCurrentIndexTimestamp(now + 1)
await initWithEnabled(true) await initWithEnabled(true)
await indexer.afterMailCreated(mailIdTuple, null) await indexer.afterMailCreated(mailIdTuple)
verify(backend.onMailCreated(matchers.anything()), { times: 0 }) verify(backend.onMailCreated(matchers.anything()), { times: 0 })
}) })
o.test("creates if mailIndexing is enabled", async () => { o.test("creates if mailIndexing is enabled", async () => {
const entities = addEntities() const entities = addEntities()
setCurrentIndexTimestamp(now) setCurrentIndexTimestamp(now)
await initWithEnabled(true) await initWithEnabled(true)
await indexer.afterMailCreated(mailIdTuple, entities) await indexer.afterMailCreated(mailIdTuple)
verify(backend.onMailCreated(entities)) verify(backend.onMailCreated(entities))
}) })
o.test("no-op if draft details fail to download", async () => { o.test("no-op if draft details fail to download", async () => {

View file

@ -1,12 +1,7 @@
import o from "@tutao/otest" import o from "@tutao/otest"
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer" import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
import { arrayEquals } from "@tutao/tutanota-utils" import { arrayEquals } from "@tutao/tutanota-utils"
import { spamClassifierTokenizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
export const tokenize = (text: string): string[] =>
text
.toLowerCase()
.split(/\s+/)
.filter((t) => t.length > 1)
o.spec("HashingVectorizer", () => { o.spec("HashingVectorizer", () => {
const rawDocuments = [ const rawDocuments = [
@ -17,7 +12,7 @@ o.spec("HashingVectorizer", () => {
"Millions of people choose Tuta to protect their personal and professional communication.", "Millions of people choose Tuta to protect their personal and professional communication.",
] ]
const tokenizedDocuments = rawDocuments.map(tokenize) const tokenizedDocuments = rawDocuments.map(spamClassifierTokenizer)
o("vectorize creates same vector for same tokens", async () => { o("vectorize creates same vector for same tokens", async () => {
const vectorizer = new HashingVectorizer() const vectorizer = new HashingVectorizer()

View file

@ -4,9 +4,9 @@ import { parseCsv } from "../../../../../../src/common/misc/parsing/CsvParser"
import { import {
DEFAULT_PREPROCESS_CONFIGURATION, DEFAULT_PREPROCESS_CONFIGURATION,
SpamClassifier, SpamClassifier,
spamClassifierTokenizer as testTokenize,
SpamTrainMailDatum, SpamTrainMailDatum,
} from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier" } from "../../../../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { tokenize as testTokenize } from "./HashingVectorizerTest"
import { OfflineStoragePersistence } from "../../../../../../src/mail-app/workerUtils/index/OfflineStoragePersistence" import { OfflineStoragePersistence } from "../../../../../../src/mail-app/workerUtils/index/OfflineStoragePersistence"
import { matchers, object, when } from "testdouble" import { matchers, object, when } from "testdouble"
import { assertNotNull, promiseMap } from "@tutao/tutanota-utils" import { assertNotNull, promiseMap } from "@tutao/tutanota-utils"
@ -16,7 +16,11 @@ import { mockAttribute } from "@tutao/tutanota-test-utils"
import "@tensorflow/tfjs-backend-cpu" import "@tensorflow/tfjs-backend-cpu"
import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer" import { HashingVectorizer } from "../../../../../../src/mail-app/workerUtils/spamClassification/HashingVectorizer"
import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom" import { LayersModel, tensor1d } from "../../../../../../src/mail-app/workerUtils/spamClassification/tensorflow-custom"
import { createTestEntity } from "../../../../TestUtils"
import { MailTypeRef } from "../../../../../../src/common/api/entities/tutanota/TypeRefs"
import { Sequential } from "@tensorflow/tfjs-layers"
const { anything } = matchers
export const DATASET_FILE_PATH: string = "./tests/api/worker/utils/spamClassification/spam_classification_test_mails.csv" export const DATASET_FILE_PATH: string = "./tests/api/worker/utils/spamClassification/spam_classification_test_mails.csv"
export async function readMailDataFromCSV(filePath: string): Promise<{ export async function readMailDataFromCSV(filePath: string): Promise<{
@ -50,7 +54,7 @@ export async function readMailDataFromCSV(filePath: string): Promise<{
} }
// Initial training (cutoff by day or amount) // Initial training (cutoff by day or amount)
o.spec("SpamClassifier", () => { o.spec("SpamClassifierTest", () => {
const mockOfflineStorageCache = object<CacheStorage>() const mockOfflineStorageCache = object<CacheStorage>()
const mockOfflineStorage = object<OfflineStoragePersistence>() const mockOfflineStorage = object<OfflineStoragePersistence>()
const mockSpamClassificationInitializer = object<SpamClassificationInitializer>() const mockSpamClassificationInitializer = object<SpamClassificationInitializer>()
@ -68,9 +72,6 @@ o.spec("SpamClassifier", () => {
dataSlice = spamData.concat(hamData) dataSlice = spamData.concat(hamData)
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
mockOfflineStorage.tokenize = async (text) => {
return testTokenize(text)
}
mockSpamClassificationInitializer.init = async () => { mockSpamClassificationInitializer.init = async () => {
return dataSlice return dataSlice
} }
@ -86,6 +87,48 @@ o.spec("SpamClassifier", () => {
) )
}) })
o("processSpam maintains server classification when client classification is not enabled", async function () {
const mail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "mailId"],
sets: [["folderList", "serverFolder"]],
})
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: "some body",
isSpam: true,
isSpamConfidence: 1,
ownerGroup: "owner",
}
const layersModel = object<Sequential>()
spamClassifier.addSpamClassifierForOwner(spamTrainMailDatum.ownerGroup, layersModel, false)
const predictedSpam = await spamClassifier.predict(spamTrainMailDatum)
o(predictedSpam).equals(null)
})
o("processSpam uses client classification when enabled", async function () {
const mail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "mailId"],
sets: [["folderList", "serverFolder"]],
})
const spamTrainMailDatum: SpamTrainMailDatum = {
mailId: mail._id,
subject: mail.subject,
body: "some body",
isSpam: false,
isSpamConfidence: 0,
ownerGroup: "owner",
}
const layersModel = object<Sequential>()
when(layersModel.predict(anything())).thenReturn(tensor1d([1]))
spamClassifier.addSpamClassifierForOwner(spamTrainMailDatum.ownerGroup, layersModel, true)
const predictedSpam = await spamClassifier.predict(spamTrainMailDatum)
o(predictedSpam).equals(true)
})
o("Initial training only", async () => { o("Initial training only", async () => {
o.timeout(20_000) o.timeout(20_000)
@ -94,13 +137,12 @@ o.spec("SpamClassifier", () => {
const testSet = dataSlice.slice(trainTestSplit) const testSet = dataSlice.slice(trainTestSplit)
await spamClassifier.initialTraining(trainSet) await spamClassifier.initialTraining(trainSet)
await spamClassifier.test(testSet) await testClassifier(spamClassifier, testSet)
}) })
o("Initial training and refitting in multi step", async () => { o("Initial training and refitting in multi step", async () => {
o.timeout(20_000) o.timeout(20_000)
const testStart = Date.now()
const trainTestSplit = dataSlice.length * 0.8 const trainTestSplit = dataSlice.length * 0.8
const trainSet = dataSlice.slice(0, trainTestSplit) const trainSet = dataSlice.slice(0, trainTestSplit)
const testSet = dataSlice.slice(trainTestSplit) const testSet = dataSlice.slice(trainTestSplit)
@ -112,15 +154,15 @@ o.spec("SpamClassifier", () => {
o(await mockSpamClassificationInitializer.init("owner")).deepEquals(trainSetFirstHalf) o(await mockSpamClassificationInitializer.init("owner")).deepEquals(trainSetFirstHalf)
await spamClassifier.initialTraining(dataSlice) await spamClassifier.initialTraining(dataSlice)
console.log(`==> Result when testing with mails in two steps (first step).`) console.log(`==> Result when testing with mails in two steps (first step).`)
await spamClassifier.test(testSet) await testClassifier(spamClassifier, testSet)
await spamClassifier.updateModel("owner", trainSetSecondHalf) await spamClassifier.updateModel("owner", trainSetSecondHalf)
console.log(`==> Result when testing with mails in two steps (second step).`) console.log(`==> Result when testing with mails in two steps (second step).`)
await spamClassifier.test(testSet) await testClassifier(spamClassifier, testSet)
}) })
o("preprocessMail outputs expected tokens for mail content", async () => { o("preprocessMail outputs expected tokens for mail content", async () => {
const classifier = new SpamClassifier(null, object(), object()) const classifier = new SpamClassifier(object(), object(), object())
const mail = { const mail = {
subject: `Sample Tokens and values`, subject: `Sample Tokens and values`,
// prettier-ignore // prettier-ignore
@ -336,18 +378,14 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
o.spec("SpamClassifier - Performance Analysis", () => { o.spec("SpamClassifier - Performance Analysis", () => {
const mockOfflineStorageCache = object<CacheStorage>() const mockOfflineStorageCache = object<CacheStorage>()
const mockOfflineStorage = object<OfflineStoragePersistence>() const mockOfflineStorage = object<OfflineStoragePersistence>()
let classifier = object<SpamClassifier>() let spamClassifier = object<SpamClassifier>()
let dataSlice: SpamTrainMailDatum[] let dataSlice: SpamTrainMailDatum[]
o.beforeEach(() => { o.beforeEach(() => {
mockOfflineStorage.tokenize = async (text) => {
return testTokenize(text)
}
const mockSpamClassificationInitializer = object<SpamClassificationInitializer>() const mockSpamClassificationInitializer = object<SpamClassificationInitializer>()
mockSpamClassificationInitializer.init = async () => { mockSpamClassificationInitializer.init = async () => {
return dataSlice return dataSlice
} }
classifier = new SpamClassifier(mockOfflineStorage, mockOfflineStorageCache, mockSpamClassificationInitializer) spamClassifier = new SpamClassifier(mockOfflineStorage, mockOfflineStorageCache, mockSpamClassificationInitializer)
}) })
o("time to refit", async () => { o("time to refit", async () => {
@ -359,7 +397,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
const start = performance.now() const start = performance.now()
await classifier.initialTraining(dataSlice) await spamClassifier.initialTraining(dataSlice)
const initialTrainingDuration = performance.now() - start const initialTrainingDuration = performance.now() - start
console.log(`initial training time ${initialTrainingDuration}ms`) console.log(`initial training time ${initialTrainingDuration}ms`)
@ -367,7 +405,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
const nowSpam = [hamSlice[0]] const nowSpam = [hamSlice[0]]
nowSpam.map((formerHam) => (formerHam.isSpam = true)) nowSpam.map((formerHam) => (formerHam.isSpam = true))
const retrainingStart = performance.now() const retrainingStart = performance.now()
await classifier.updateModel("owner", nowSpam) await spamClassifier.updateModel("owner", nowSpam)
const retrainingDuration = performance.now() - retrainingStart const retrainingDuration = performance.now() - retrainingStart
console.log(`retraining time ${retrainingDuration}ms`) console.log(`retraining time ${retrainingDuration}ms`)
} }
@ -381,17 +419,17 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
// seededShuffle(dataSlice, 42) // seededShuffle(dataSlice, 42)
await classifier.initialTraining(dataSlice) await spamClassifier.initialTraining(dataSlice)
const falseNegatives = spamData const falseNegatives = spamData
.slice(10) .slice(10)
.filter(async (mailDatum) => mailDatum.isSpam !== (await classifier.predict(mailDatum))) .filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
.sort() .sort()
.slice(0, 10) .slice(0, 10)
let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0) let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0)
for (let i = 0; i < falseNegatives.length; i++) { for (let i = 0; i < falseNegatives.length; i++) {
const sample = falseNegatives[i] const sample = falseNegatives[i]
const copiedClassifier = await classifier.cloneClassifier() const copiedClassifier = await spamClassifier.cloneClassifier()
let retrainCount = 0 let retrainCount = 0
let predictedSpam = false let predictedSpam = false
@ -458,15 +496,15 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
// seededShuffle(dataSlice, 42) // seededShuffle(dataSlice, 42)
await classifier.initialTraining(dataSlice) await spamClassifier.initialTraining(dataSlice)
const falsePositive = hamData const falsePositive = hamData
.slice(10) .slice(10)
.filter(async (mailDatum) => mailDatum.isSpam !== (await classifier.predict(mailDatum))) .filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
.slice(0, 10) .slice(0, 10)
let retrainingNeeded = new Array<number>(falsePositive.length).fill(0) let retrainingNeeded = new Array<number>(falsePositive.length).fill(0)
for (let i = 0; i < falsePositive.length; i++) { for (let i = 0; i < falsePositive.length; i++) {
const sample = falsePositive[i] const sample = falsePositive[i]
const copiedClassifier = await classifier.cloneClassifier() const copiedClassifier = await spamClassifier.cloneClassifier()
let retrainCount = 0 let retrainCount = 0
let predictedSpam = false let predictedSpam = false
@ -492,16 +530,16 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
dataSlice = hamSlice.concat(spamSlice) dataSlice = hamSlice.concat(spamSlice)
seededShuffle(dataSlice, 42) seededShuffle(dataSlice, 42)
await classifier.initialTraining(dataSlice) await spamClassifier.initialTraining(dataSlice)
const falseNegatives = spamData const falseNegatives = spamData
.slice(10) .slice(10)
.filter(async (mailDatum) => mailDatum.isSpam !== (await classifier.predict(mailDatum))) .filter(async (mailDatum) => mailDatum.isSpam !== (await spamClassifier.predict(mailDatum)))
.slice(0, 10) .slice(0, 10)
let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0) let retrainingNeeded = new Array<number>(falseNegatives.length).fill(0)
for (let i = 0; i < falseNegatives.length; i++) { for (let i = 0; i < falseNegatives.length; i++) {
const sample = falseNegatives[i] const sample = falseNegatives[i]
const copiedClassifier = await classifier.cloneClassifier() const copiedClassifier = await spamClassifier.cloneClassifier()
let retrainCount = 0 let retrainCount = 0
let predictedSpam = false let predictedSpam = false
@ -532,7 +570,7 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
await promiseMap( await promiseMap(
new Array<number>(ITERATION_COUNT).fill(0), new Array<number>(ITERATION_COUNT).fill(0),
async () => { async () => {
const { vectorizationTime, trainingTime } = await classifier.initialTraining(dataSlice) const { vectorizationTime, trainingTime } = await spamClassifier.initialTraining(dataSlice)
trainingTimes.push(trainingTime) trainingTimes.push(trainingTime)
vectorizationTimes.push(vectorizationTime) vectorizationTimes.push(vectorizationTime)
trainingWithoutVectorization.push(trainingTime - vectorizationTime) trainingWithoutVectorization.push(trainingTime - vectorizationTime)
@ -560,6 +598,47 @@ if (DO_RUN_PERFORMANCE_ANALYSIS) {
}) })
}) })
} }
async function testClassifier(classifier: SpamClassifier, mails: SpamTrainMailDatum[]): Promise<void> {
let predictionArray: number[] = []
for (let mail of mails) {
const prediction = await classifier.predict(mail)
predictionArray.push(prediction ? 1 : 0)
}
const ysArray = mails.map((mail) => mail.isSpam)
let tp = 0,
tn = 0,
fp = 0,
fn = 0
for (let i = 0; i < predictionArray.length; i++) {
const predictedSpam = predictionArray[i] > 0.5
const isActuallyASpam = ysArray[i]
if (predictedSpam && isActuallyASpam) tp++
else if (!predictedSpam && !isActuallyASpam) tn++
else if (predictedSpam && !isActuallyASpam) fp++
else if (!predictedSpam && isActuallyASpam) fn++
}
const total = tp + tn + fp + fn
const accuracy = (tp + tn) / total
const precision = tp / (tp + fp + 1e-7)
const recall = tp / (tp + fn + 1e-7)
const f1 = 2 * ((precision * recall) / (precision + recall + 1e-7))
console.log("\n--- Evaluation Metrics ---")
console.log(`Accuracy: \t${(accuracy * 100).toFixed(2)}%`)
console.log(`Precision:\t${(precision * 100).toFixed(2)}%`)
console.log(`Recall: \t${(recall * 100).toFixed(2)}%`)
console.log(`F1 Score: \t${(f1 * 100).toFixed(2)}%`)
console.log("\nConfusion Matrix:")
console.log({
Predicted_Spam: { True_Positive: tp, False_Positive: fp },
Predicted_Ham: { False_Negative: fn, True_Negative: tn },
})
}
// For testing, we need deterministic shuffling which is not provided by tf.util.shuffle(dataSlice) // For testing, we need deterministic shuffling which is not provided by tf.util.shuffle(dataSlice)
// Seeded Fisher-Yates shuffle // Seeded Fisher-Yates shuffle
function seededShuffle<T>(array: T[], seed: number): void { function seededShuffle<T>(array: T[], seed: number): void {

View file

@ -1,8 +1,18 @@
import o from "@tutao/otest" import o from "@tutao/otest"
import { Notifications } from "../../../src/common/gui/Notifications.js" import { Notifications } from "../../../src/common/gui/Notifications.js"
import { Spy, spy, verify } from "@tutao/tutanota-test-utils" import { mock, Spy, spy, verify } from "@tutao/tutanota-test-utils"
import { MailSetKind, OperationType } from "../../../src/common/api/common/TutanotaConstants.js" import { MailSetKind, OperationType, ProcessingState } from "../../../src/common/api/common/TutanotaConstants.js"
import { Mail, MailFolderTypeRef, MailSetEntryTypeRef, MailTypeRef } from "../../../src/common/api/entities/tutanota/TypeRefs.js" import {
BodyTypeRef,
ClientSpamClassifierResultTypeRef,
Mail,
MailDetails,
MailDetailsBlob,
MailDetailsBlobTypeRef,
MailDetailsTypeRef,
MailFolderTypeRef,
MailTypeRef,
} from "../../../src/common/api/entities/tutanota/TypeRefs.js"
import { EntityClient } from "../../../src/common/api/common/EntityClient.js" import { EntityClient } from "../../../src/common/api/common/EntityClient.js"
import { EntityRestClientMock } from "../api/worker/rest/EntityRestClientMock.js" import { EntityRestClientMock } from "../api/worker/rest/EntityRestClientMock.js"
import { downcast } from "@tutao/tutanota-utils" import { downcast } from "@tutao/tutanota-utils"
@ -11,23 +21,40 @@ import { instance, matchers, object, when } from "testdouble"
import { UserController } from "../../../src/common/api/main/UserController.js" import { UserController } from "../../../src/common/api/main/UserController.js"
import { createTestEntity } from "../TestUtils.js" import { createTestEntity } from "../TestUtils.js"
import { EntityUpdateData, PrefetchStatus } from "../../../src/common/api/common/utils/EntityUpdateUtils.js" import { EntityUpdateData, PrefetchStatus } from "../../../src/common/api/common/utils/EntityUpdateUtils.js"
import { MailboxModel } from "../../../src/common/mailFunctionality/MailboxModel.js" import { MailboxDetail, MailboxModel } from "../../../src/common/mailFunctionality/MailboxModel.js"
import { getElementId, getListId } from "../../../src/common/api/common/utils/EntityUtils.js" import { getElementId, getListId } from "../../../src/common/api/common/utils/EntityUtils.js"
import { MailModel } from "../../../src/mail-app/mail/model/MailModel.js" import { MailModel } from "../../../src/mail-app/mail/model/MailModel.js"
import { EventController } from "../../../src/common/api/main/EventController.js" import { EventController } from "../../../src/common/api/main/EventController.js"
import { MailFacade } from "../../../src/common/api/worker/facades/lazy/MailFacade.js" import { MailFacade } from "../../../src/common/api/worker/facades/lazy/MailFacade.js"
import { ClientModelInfo } from "../../../src/common/api/common/EntityFunctions" import { ClientModelInfo } from "../../../src/common/api/common/EntityFunctions"
import { InboxRuleHandler } from "../../../src/mail-app/mail/model/InboxRuleHandler"
import { SpamClassificationHandler } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassificationHandler"
import { SpamClassifier, SpamTrainMailDatum } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { WebsocketConnectivityModel } from "../../../src/common/misc/WebsocketConnectivityModel"
import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem"
import { NotAuthorizedError, NotFoundError } from "../../../src/common/api/common/error/RestError"
const { anything } = matchers
o.spec("MailModelTest", function () { o.spec("MailModelTest", function () {
let notifications: Partial<Notifications> let notifications: Partial<Notifications>
let showSpy: Spy let showSpy: Spy
let model: MailModel let model: MailModel
const inboxFolder = createTestEntity(MailFolderTypeRef, { _id: ["folderListId", "inboxId"] }) const inboxFolder = createTestEntity(MailFolderTypeRef, {
inboxFolder.folderType = MailSetKind.INBOX _id: ["folderListId", "inboxId"],
const anotherFolder = createTestEntity(MailFolderTypeRef, { _id: ["folderListId", "archiveId"] }) folderType: MailSetKind.INBOX,
anotherFolder.folderType = MailSetKind.ARCHIVE })
const spamFolder = createTestEntity(MailFolderTypeRef, {
_id: ["folderListId", "spamId"],
folderType: MailSetKind.SPAM,
})
const anotherFolder = createTestEntity(MailFolderTypeRef, {
_id: ["folderListId", "archiveId"],
folderType: MailSetKind.ARCHIVE,
})
let logins: LoginController let logins: LoginController
let mailFacade: MailFacade let mailFacade: MailFacade
let connectivityModel: WebsocketConnectivityModel
const restClient: EntityRestClientMock = new EntityRestClientMock() const restClient: EntityRestClientMock = new EntityRestClientMock()
o.beforeEach(function () { o.beforeEach(function () {
@ -41,6 +68,9 @@ o.spec("MailModelTest", function () {
when(userController.isUpdateForLoggedInUserInstance(matchers.anything(), matchers.anything())).thenReturn(false) when(userController.isUpdateForLoggedInUserInstance(matchers.anything(), matchers.anything())).thenReturn(false)
when(logins.getUserController()).thenReturn(userController) when(logins.getUserController()).thenReturn(userController)
connectivityModel = object<WebsocketConnectivityModel>()
when(connectivityModel.isLeader()).thenReturn(true)
model = new MailModel( model = new MailModel(
downcast({}), downcast({}),
mailboxModel, mailboxModel,
@ -48,40 +78,41 @@ o.spec("MailModelTest", function () {
new EntityClient(restClient, ClientModelInfo.getNewInstanceForTestsOnly()), new EntityClient(restClient, ClientModelInfo.getNewInstanceForTestsOnly()),
logins, logins,
mailFacade, mailFacade,
null, connectivityModel,
() => object(),
() => null, () => null,
) )
}) })
o("doesn't send notification for another folder", async function () { o("doesn't send notification for another folder", async function () {
const mailSetEntry = createTestEntity(MailSetEntryTypeRef, { _id: [anotherFolder.entries, "mailSetEntryId"] }) const mail = createTestEntity(MailTypeRef, { _id: ["mailBagListId", "mailId"], sets: [] })
restClient.addListInstances(mailSetEntry) restClient.addListInstances(mail)
await model.entityEventsReceived([ await model.entityEventsReceived([
makeUpdate({ makeUpdate({
instanceListId: getListId(mailSetEntry) as NonEmptyString, instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mailSetEntry), instanceId: getElementId(mail),
operation: OperationType.CREATE, operation: OperationType.CREATE,
}), }),
]) ])
o(showSpy.invocations.length).equals(0) o(showSpy.invocations.length).equals(0)
}) })
o("doesn't send notification for move operation", async function () { o("doesn't send notification for move operation", async function () {
const mailSetEntry = createTestEntity(MailSetEntryTypeRef, { _id: [inboxFolder.entries, "mailSetEntryId"] }) const mail = createTestEntity(MailTypeRef, { _id: ["mailBagListId", "mailId"], sets: [] })
restClient.addListInstances(mailSetEntry) restClient.addListInstances(mail)
await model.entityEventsReceived([ await model.entityEventsReceived([
makeUpdate({ makeUpdate({
instanceListId: getListId(mailSetEntry) as NonEmptyString, instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mailSetEntry), instanceId: getElementId(mail),
operation: OperationType.DELETE, operation: OperationType.DELETE,
}), }),
makeUpdate({ makeUpdate({
instanceListId: getListId(mailSetEntry) as NonEmptyString, instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mailSetEntry), instanceId: getElementId(mail),
operation: OperationType.CREATE, operation: OperationType.CREATE,
}), }),
]) ])
o(showSpy.invocations.length).equals(0) o(showSpy.invocations.length).equals(0)
}) })
o("markMails", async function () { o("markMails", async function () {
const mailId1: IdTuple = ["mailbag id1", "mail id1"] const mailId1: IdTuple = ["mailbag id1", "mail id1"]
const mailId2: IdTuple = ["mailbag id2", "mail id2"] const mailId2: IdTuple = ["mailbag id2", "mail id2"]
@ -90,6 +121,200 @@ o.spec("MailModelTest", function () {
verify(mailFacade.markMails([mailId1, mailId2, mailId3], true)) verify(mailFacade.markMails([mailId1, mailId2, mailId3], true))
}) })
o.spec("Inbox rule processing and spam prediction", () => {
let inboxRuleHandler: InboxRuleHandler
let spamClassificationHandler: SpamClassificationHandler
let spamClassifier: SpamClassifier
let mailboxModel: MailboxModel
let modelWithSpamAndInboxRule: MailModel
let mail: Mail
let mailDetails: MailDetails
o.beforeEach(async () => {
const entityClient = new EntityClient(restClient, ClientModelInfo.getNewInstanceForTestsOnly())
mailboxModel = instance(MailboxModel)
inboxRuleHandler = object<InboxRuleHandler>()
spamClassifier = object<SpamClassifier>()
spamClassificationHandler = new SpamClassificationHandler(mailFacade, spamClassifier)
mailDetails = createTestEntity(MailDetailsTypeRef, {
_id: "mailDetail",
body: createTestEntity(BodyTypeRef, { text: "some text" }),
})
mail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "mailId"],
_ownerGroup: "mailGroup",
mailDetails: ["detailsList", mailDetails._id],
subject: "subject",
sets: [inboxFolder._id],
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED,
})
const mailDetailsBlob: MailDetailsBlob = createTestEntity(MailDetailsBlobTypeRef, {
_id: mail.mailDetails!,
details: mailDetails,
})
restClient.addListInstances(mail)
restClient.addBlobInstances(mailDetailsBlob)
when(mailFacade.loadMailDetailsBlob(mail)).thenResolve(mailDetails)
modelWithSpamAndInboxRule = mock(
new MailModel(
downcast({}),
mailboxModel,
instance(EventController),
entityClient,
logins,
mailFacade,
connectivityModel,
() => spamClassificationHandler,
() => inboxRuleHandler,
),
(m: MailModel) => {
m.getFolderSystemByGroupId = (groupId) => {
o(groupId).equals("mailGroup")
return new FolderSystem([inboxFolder, spamFolder, anotherFolder])
}
m.getMailboxDetailsForMail = async (_: Mail) => object<MailboxDetail>()
},
)
})
o("does not re-apply inbox rules or re-classify mail if the mail is in a final processingState", async function () {
const alreadyClassifiedMail = createTestEntity(MailTypeRef, {
_id: ["mailListId", "maildIdWithFinalProcessingState"],
_ownerGroup: "mailGroup",
mailDetails: ["detailsList", mailDetails._id],
sets: [inboxFolder._id],
processingState: ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE,
clientSpamClassifierResult: createTestEntity(ClientSpamClassifierResultTypeRef),
})
restClient.addListInstances(alreadyClassifiedMail)
when(mailFacade.loadMailDetailsBlob(alreadyClassifiedMail)).thenResolve(mailDetails)
const alreadyClassifiedMailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "maildIdWithFinalProcessingState",
operation: OperationType.CREATE,
})
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([alreadyClassifiedMailCreateEvent])
await processingDone
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 })
verify(spamClassificationHandler.predictSpamForNewMail(anything(), anything(), anything(), anything()), { times: 0 })
verify(spamClassifier.storeSpamClassification(anything()), { times: 0 })
verify(spamClassifier.predict(anything()), { times: 0 })
})
o("does not try to apply inbox rule when downloading of mail fails on create mail event", async function () {
restClient.setListElementException(mail._id, new NotFoundError("Mail not found"))
const mailCreateEvent = makeUpdate({
instanceListId: getListId(mail) as NonEmptyString,
instanceId: getElementId(mail),
operation: OperationType.CREATE,
})
await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
verify(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything()), { times: 0 })
})
o("spam prediction does not happen when inbox rule is applied", async () => {
when(spamClassifier.predict(anything())).thenResolve(false)
const mailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.CREATE,
})
// when inbox rule is applied
when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(inboxFolder)
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await processingDone
const expectedSpamTrainMailDatum: SpamTrainMailDatum = {
mailId: ["mailListId", "mailId"],
ownerGroup: "mailGroup",
body: "some text",
subject: "subject",
isSpam: false,
isSpamConfidence: 1,
}
verify(spamClassifier.storeSpamClassification(expectedSpamTrainMailDatum), { times: 1 })
verify(spamClassifier.predict(anything()), { times: 0 })
})
o("spam prediction happens when inbox rule is not applied", async () => {
when(spamClassifier.predict(anything())).thenResolve(false)
const mailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.CREATE,
})
when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(null)
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await processingDone
const expectedSpamTrainMailDatum: SpamTrainMailDatum = {
mailId: ["mailListId", "mailId"],
ownerGroup: "mailGroup",
body: "some text",
subject: "subject",
isSpam: false,
isSpamConfidence: 1,
}
verify(spamClassifier.storeSpamClassification(expectedSpamTrainMailDatum), { times: 1 })
verify(spamClassifier.predict(anything()), { times: 1 })
})
o("does not try to do spam classification when downloading of mail fails on create mail event", async function () {
when(inboxRuleHandler.findAndApplyMatchingRule(anything(), anything(), anything())).thenResolve(null)
const mailCreateEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.CREATE,
})
// mail not being there
restClient.setListElementException(mail._id, new NotAuthorizedError("blah"))
const { processingDone: inboxRuleProcessedMailNotThere } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await inboxRuleProcessedMailNotThere
verify(spamClassifier.storeSpamClassification(anything()), { times: 0 })
verify(spamClassifier.predict(anything()), { times: 0 })
// mail being there
restClient.addListInstances(mail)
const { processingDone: inboxRuleProcessedMailIsThere } = await modelWithSpamAndInboxRule.entityEventsReceived([mailCreateEvent])
await inboxRuleProcessedMailIsThere
const expectedSpamTrainMailDatum: SpamTrainMailDatum = {
mailId: ["mailListId", "mailId"],
ownerGroup: "mailGroup",
body: "some text",
subject: "subject",
isSpam: false,
isSpamConfidence: 1,
}
verify(spamClassifier.storeSpamClassification(expectedSpamTrainMailDatum), { times: 1 })
verify(spamClassifier.predict(anything()), { times: 1 })
})
o("deletes a training datum for deleted mail event", async () => {
const mailDeleteEvent = makeUpdate({
instanceListId: "mailListId",
instanceId: "mailId",
operation: OperationType.DELETE,
})
const { processingDone } = await modelWithSpamAndInboxRule.entityEventsReceived([mailDeleteEvent])
await processingDone
verify(spamClassifier.deleteSpamClassification(mail._id), { times: 1 })
})
})
function makeUpdate({ function makeUpdate({
instanceId, instanceId,
instanceListId, instanceListId,

View file

@ -0,0 +1,161 @@
import o from "@tutao/otest"
import { matchers, object, verify, when } from "testdouble"
import {
Body,
BodyTypeRef,
ClientSpamClassifierResultTypeRef,
Mail,
MailDetails,
MailDetailsTypeRef,
MailFolderTypeRef,
MailTypeRef,
} from "../../../src/common/api/entities/tutanota/TypeRefs"
import { SpamClassifier, SpamTrainMailDatum } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassifier"
import { getMailBodyText } from "../../../src/common/api/common/CommonMailUtils"
import { MailSetKind, ProcessingState, SpamDecision } from "../../../src/common/api/common/TutanotaConstants"
import { ClientClassifierType } from "../../../src/common/api/common/ClientClassifierType"
import { assert, assertNotNull } from "@tutao/tutanota-utils"
import { MailFacade } from "../../../src/common/api/worker/facades/lazy/MailFacade"
import { createTestEntity } from "../TestUtils"
import { SpamClassificationHandler } from "../../../src/mail-app/workerUtils/spamClassification/SpamClassificationHandler"
import { FolderSystem } from "../../../src/common/api/common/mail/FolderSystem"
import { isSameId } from "../../../src/common/api/common/utils/EntityUtils"
import { any } from "@tensorflow/tfjs-core"
const { anything } = matchers
o.spec("SpamClassificationHandlerTest", function () {
let mailFacade = object<MailFacade>()
let body: Body
let mail: Mail
let spamClassifier: SpamClassifier
let spamHandler: SpamClassificationHandler
let folderSystem: FolderSystem
let mailDetails: MailDetails
const inboxFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "inbox"], folderType: MailSetKind.INBOX })
const trashFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "trash"], folderType: MailSetKind.TRASH })
const spamFolder = createTestEntity(MailFolderTypeRef, { _id: ["listId", "spam"], folderType: MailSetKind.SPAM })
o.beforeEach(function () {
spamClassifier = object<SpamClassifier>()
body = createTestEntity(BodyTypeRef, { text: "Body Text" })
mailDetails = createTestEntity(MailDetailsTypeRef, { _id: "mailDetail", body })
mail = createTestEntity(MailTypeRef, {
_id: ["listId", "elementId"],
sets: [spamFolder._id],
subject: "subject",
_ownerGroup: "owner",
mailDetails: ["detailsList", mailDetails._id],
unread: true,
processingState: ProcessingState.INBOX_RULE_NOT_PROCESSED,
clientSpamClassifierResult: createTestEntity(ClientSpamClassifierResultTypeRef, { spamDecision: SpamDecision.NONE }),
})
folderSystem = object<FolderSystem>()
when(mailFacade.moveMails(anything(), anything(), anything(), ClientClassifierType.CLIENT_CLASSIFICATION)).thenResolve([])
when(folderSystem.getSystemFolderByType(MailSetKind.SPAM)).thenReturn(spamFolder)
when(folderSystem.getSystemFolderByType(MailSetKind.INBOX)).thenReturn(inboxFolder)
when(folderSystem.getSystemFolderByType(MailSetKind.TRASH)).thenReturn(trashFolder)
when(folderSystem.getFolderByMail(anything())).thenDo((mail: Mail) => {
assert(mail.sets.length === 1, "Expected exactly one mail set")
const mailFolderId = assertNotNull(mail.sets[0])
if (isSameId(mailFolderId, trashFolder._id)) return trashFolder
else if (isSameId(mailFolderId, spamFolder._id)) return spamFolder
else if (isSameId(mailFolderId, inboxFolder._id)) return inboxFolder
else throw new Error("Unknown mail Folder")
})
when(
mailFacade.loadMailDetailsBlob(
matchers.argThat((requestedMails: Array<Mail>) => {
assert(requestedMails.length === 1, "exactly one mail is requested at a time")
return isSameId(requestedMails[0]._id, mail._id)
}),
),
anything(),
).thenDo(async () => [{ mail, mailDetails }])
spamHandler = new SpamClassificationHandler(mailFacade, spamClassifier)
})
o("predictSpamForNewMail does move mail from inbox to spam folder if mail is spam", async function () {
mail.sets = [inboxFolder._id]
when(spamClassifier.predict(anything())).thenResolve(true)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData?.mails).deepEquals([mail._id])
o(spamHandler.classifierResultServiceMailIds).deepEquals([])
o(finalResult).deepEquals(spamFolder)
})
o("predictSpamForNewMail does NOT move mail from inbox to spam folder if mail is ham", async function () {
mail.sets = [inboxFolder._id]
when(spamClassifier.predict(anything())).thenResolve(false)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([mail._id])
o(finalResult).deepEquals(inboxFolder)
})
o("predictSpamForNewMail does NOT move mail from spam to inbox folder if mail is spam", async function () {
mail.sets = [spamFolder._id]
when(spamClassifier.predict(anything())).thenResolve(true)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([mail._id])
o(finalResult).deepEquals(spamFolder)
})
o("predictSpamForNewMail moves mail from spam to inbox folder if mail is ham", async function () {
mail.sets = [spamFolder._id]
when(spamClassifier.predict(anything())).thenResolve(false)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem)
o(spamHandler.hamMoveMailData?.mails).deepEquals([mail._id])
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([])
o(finalResult).deepEquals(inboxFolder)
})
o("predictSpamForNewMail does NOT move mail from spam to spam folder if mail is spam", async function () {
mail.sets = [spamFolder._id]
when(spamClassifier.predict(anything())).thenResolve(true)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, spamFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([mail._id])
o(finalResult).deepEquals(spamFolder)
})
o(
"predictSpamForNewMail does NOT send classifierResultService request if processingState is INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE",
async function () {
mail.sets = [inboxFolder._id]
mail.processingState = ProcessingState.INBOX_RULE_PROCESSED_AND_SPAM_PREDICTION_MADE
when(spamClassifier.predict(anything())).thenResolve(false)
const finalResult = await spamHandler.predictSpamForNewMail(mail, mailDetails, inboxFolder, folderSystem)
o(spamHandler.hamMoveMailData).deepEquals(null)
o(spamHandler.spamMoveMailData).deepEquals(null)
o(spamHandler.classifierResultServiceMailIds).deepEquals([])
o(finalResult).deepEquals(inboxFolder)
},
)
o("update spam classification data on every mail update", async function () {
when(spamClassifier.getSpamClassification(anything())).thenResolve({ isSpam: false, isSpamConfidence: 0 })
mail.clientSpamClassifierResult = createTestEntity(ClientSpamClassifierResultTypeRef, {
spamDecision: SpamDecision.BLACKLIST,
confidence: "1",
})
await spamHandler.updateSpamClassificationData(mail)
verify(spamClassifier.updateSpamClassification(["listId", "elementId"], true, 1), { times: 1 })
})
})

View file

@ -209,7 +209,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 0, times: 0,
}) })
}) })
@ -227,7 +227,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 0, times: 0,
}) })
}) })
@ -246,7 +246,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 0, times: 0,
}) })
o.check(model.loadingStatus).equals(ListLoadingState.Idle) o.check(model.loadingStatus).equals(ListLoadingState.Idle)
@ -266,7 +266,6 @@ o.spec("ConversationListModel", () => {
mailboxDetail, mailboxDetail,
matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))), matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))),
true, true,
false,
), ),
).thenResolve({}) ).thenResolve({})
@ -282,7 +281,7 @@ o.spec("ConversationListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 1, times: 1,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 100, times: 100,
}) })
}) })

View file

@ -204,7 +204,7 @@ o.spec("MailListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 0, times: 0,
}) })
}) })
@ -222,7 +222,7 @@ o.spec("MailListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 0, times: 0,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 0, times: 0,
}) })
}) })
@ -236,7 +236,6 @@ o.spec("MailListModel", () => {
mailboxDetail, mailboxDetail,
matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))), matchers.argThat((mail: Mail) => isSameId(mail._id, makeMailId(25))),
true, true,
false,
), ),
).thenResolve({}) ).thenResolve({})
@ -252,7 +251,7 @@ o.spec("MailListModel", () => {
verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), { verify(mailModel.getMailboxDetailsForMailFolder(matchers.anything()), {
times: 1, times: 1,
}) })
verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true, false), { verify(inboxRuleHandler.findAndApplyMatchingRule(mailboxDetail, matchers.anything(), true), {
times: 100, times: 100,
}) })
}) })

View file

@ -41,9 +41,8 @@ o.spec("MailIndexAndSpamClassificationPostLoginAction", () => {
} as User) } as User)
when(customerFacadeMock.getUser()).thenResolve(user) when(customerFacadeMock.getUser()).thenResolve(user)
await postLoginAction.onPartialLoginSuccess(loggedInEvent) const { asyncAction } = await postLoginAction.onPartialLoginSuccess(loggedInEvent)
// since the resizeMailIndex.then() is not awaited, we resolve all pending promises manually await asyncAction
await new Promise((resolve) => setImmediate(resolve))
verify(spamClassifierMock.initialize("firstMailGroup"), { times: 1 }) verify(spamClassifierMock.initialize("firstMailGroup"), { times: 1 })
verify(spamClassifierMock.initialize("secondMailGroup"), { times: 1 }) verify(spamClassifierMock.initialize("secondMailGroup"), { times: 1 })

View file

@ -1814,7 +1814,13 @@ mod tests {
], ],
)]), )]),
"1465"=> JsonElement::Array(vec![]), "1465"=> JsonElement::Array(vec![]),
"1677"=> JsonElement::Null "1677"=> JsonElement::Null,
"1728"=> JsonElement::String(
"1".to_string(),
),
"1729"=> JsonElement::Array(
vec![],
),
} }
} }

View file

@ -41,5 +41,7 @@
"115": [], "115": [],
"108": "2", "108": "2",
"1465": [], "1465": [],
"1677": null "1677": null,
"1728": "1",
"1729": []
} }

View file

@ -41,5 +41,7 @@
"426": "AR8zeFN4c98e8Ds8AkusyHbPK0iPTHJsnwisT/nzYPQhyVEMEV9SCk4s20/s5YKWdeU960ddEtcAcCpGRBVSS9Y=", "426": "AR8zeFN4c98e8Ds8AkusyHbPK0iPTHJsnwisT/nzYPQhyVEMEV9SCk4s20/s5YKWdeU960ddEtcAcCpGRBVSS9Y=",
"1021": "0", "1021": "0",
"896": "1723113273034", "896": "1723113273034",
"1677": null "1677": null,
"1728": "1",
"1729": []
} }

View file

@ -41,5 +41,7 @@
"115": [], "115": [],
"108": "2", "108": "2",
"1465": [], "1465": [],
"1677": null "1677": null,
"1728": "1",
"1729": []
} }

View file

@ -41,5 +41,7 @@
} }
], ],
"466": "", "466": "",
"1677": null "1677": null,
"1728": "1",
"1729": []
} }